diff --git a/api/commands/vector.py b/api/commands/vector.py index bef18bf73b..cb7eb7c452 100644 --- a/api/commands/vector.py +++ b/api/commands/vector.py @@ -10,7 +10,7 @@ from configs import dify_config from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_type import VectorType from core.rag.index_processor.constant.built_in_field import BuiltInField -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.models.document import ChildDocument, Document from extensions.ext_database import db from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment @@ -86,7 +86,7 @@ def migrate_annotation_vector_database(): dataset = Dataset( id=app.id, tenant_id=app.tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, collection_binding_id=dataset_collection_binding.id, @@ -178,7 +178,9 @@ def migrate_knowledge_vector_database(): while True: try: stmt = ( - select(Dataset).where(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc()) + select(Dataset) + .where(Dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY) + .order_by(Dataset.created_at.desc()) ) datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index fb98932269..27c772fbe0 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -3,7 +3,7 @@ from typing import Any, cast from flask import request from flask_restx import Resource, fields, marshal, marshal_with from pydantic import BaseModel, Field, field_validator -from sqlalchemy import select +from sqlalchemy import func, select from werkzeug.exceptions import Forbidden, NotFound import services @@ -29,6 +29,7 @@ from core.provider_manager import ProviderManager from core.rag.datasource.vdb.vector_type import VectorType from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db @@ -355,7 +356,7 @@ class DatasetListApi(Resource): for item in data: # convert embedding_model_provider to plugin standard format - if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]: + if item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY and item["embedding_model_provider"]: item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"])) item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" if item_model in model_names: @@ -436,7 +437,7 @@ class DatasetApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields)) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: if dataset.embedding_model_provider: provider_id = ModelProviderID(dataset.embedding_model_provider) data["embedding_model_provider"] = str(provider_id) @@ -454,7 +455,7 @@ class DatasetApi(Resource): for embedding_model in embedding_models: model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") - if data["indexing_technique"] == "high_quality": + if data["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY: item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}" if item_model in model_names: data["embedding_available"] = True @@ -485,7 +486,7 @@ class DatasetApi(Resource): current_user, current_tenant_id = current_account_with_tenant() # check embedding model setting if ( - payload.indexing_technique == "high_quality" + payload.indexing_technique == IndexTechniqueType.HIGH_QUALITY and payload.embedding_model_provider is not None and payload.embedding_model is not None ): @@ -738,20 +739,23 @@ class DatasetIndexingStatusApi(Resource): documents_status = [] for document in documents: completed_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != SegmentStatus.RE_SEGMENT, + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) total_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) # Create a dictionary with document attributes and additional fields document_dict = { @@ -802,9 +806,12 @@ class DatasetApiKeyApi(Resource): _, current_tenant_id = current_account_with_tenant() current_key_count = ( - db.session.query(ApiToken) - .where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id) - .count() + db.session.scalar( + select(func.count(ApiToken.id)).where( + ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id + ) + ) + or 0 ) if current_key_count >= self.max_keys: @@ -839,14 +846,14 @@ class DatasetApiDeleteApi(Resource): def delete(self, api_key_id): _, current_tenant_id = current_account_with_tenant() api_key_id = str(api_key_id) - key = ( - db.session.query(ApiToken) + key = db.session.scalar( + select(ApiToken) .where( ApiToken.tenant_id == current_tenant_id, ApiToken.type == self.resource_type, ApiToken.id == api_key_id, ) - .first() + .limit(1) ) if key is None: @@ -857,7 +864,7 @@ class DatasetApiDeleteApi(Resource): assert key is not None # nosec - for type checker only ApiTokenCache.delete(key.token, key.type) - db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() + db.session.delete(key) db.session.commit() return {"result": "success"}, 204 diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 074694e7ea..897724182f 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -27,6 +27,7 @@ from core.model_manager import ModelManager from core.plugin.impl.exc import PluginDaemonClientSideError from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo +from core.rag.index_processor.constant.index_type import IndexTechniqueType from dify_graph.model_runtime.entities.model_entities import ModelType from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError from extensions.ext_database import db @@ -449,7 +450,7 @@ class DatasetInitApi(Resource): raise Forbidden() knowledge_config = KnowledgeConfig.model_validate(console_ns.payload or {}) - if knowledge_config.indexing_technique == "high_quality": + if knowledge_config.indexing_technique == IndexTechniqueType.HIGH_QUALITY: if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None: raise ValueError("embedding model and embedding model provider are required for high quality indexing.") try: @@ -463,7 +464,7 @@ class DatasetInitApi(Resource): is_multimodal = DatasetService.check_is_multimodal_model( current_tenant_id, knowledge_config.embedding_model_provider, knowledge_config.embedding_model ) - knowledge_config.is_multimodal = is_multimodal + knowledge_config.is_multimodal = is_multimodal # pyrefly: ignore[bad-assignment] except InvokeAuthorizationError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider." @@ -1337,7 +1338,7 @@ class DocumentGenerateSummaryApi(Resource): raise BadRequest("document_list cannot be empty.") # Check if dataset configuration supports summary generation - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: raise ValueError( f"Summary generation is only available for 'high_quality' indexing technique. " f"Current indexing technique: {dataset.indexing_technique}" diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index fa9bc7f159..7333fcaa07 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -26,6 +26,7 @@ from controllers.console.wraps import ( ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager +from core.rag.index_processor.constant.index_type import IndexTechniqueType from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from extensions.ext_redis import redis_client @@ -279,7 +280,7 @@ class DatasetDocumentSegmentApi(Resource): DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting try: model_manager = ModelManager() @@ -333,7 +334,7 @@ class DatasetDocumentSegmentAddApi(Resource): if not current_user.is_dataset_editor: raise Forbidden() # check embedding model setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: model_manager = ModelManager() model_manager.get_model_instance( @@ -383,7 +384,7 @@ class DatasetDocumentSegmentUpdateApi(Resource): document = DocumentService.get_document(dataset_id, document_id) if not document: raise NotFound("Document not found.") - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting try: model_manager = ModelManager() @@ -401,10 +402,10 @@ class DatasetDocumentSegmentUpdateApi(Resource): raise ProviderNotInitializeError(ex.description) # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") @@ -447,10 +448,10 @@ class DatasetDocumentSegmentUpdateApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") @@ -494,7 +495,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource): payload = BatchImportPayload.model_validate(console_ns.payload or {}) upload_file_id = payload.upload_file_id - upload_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first() + upload_file = db.session.scalar(select(UploadFile).where(UploadFile.id == upload_file_id).limit(1)) if not upload_file: raise NotFound("UploadFile not found.") @@ -559,17 +560,17 @@ class ChildChunkAddApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") if not current_user.is_dataset_editor: raise Forbidden() # check embedding model setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: model_manager = ModelManager() model_manager.get_model_instance( @@ -616,10 +617,10 @@ class ChildChunkAddApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") @@ -666,10 +667,10 @@ class ChildChunkAddApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") @@ -714,24 +715,24 @@ class ChildChunkUpdateApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") # check child chunk child_chunk_id = str(child_chunk_id) - child_chunk = ( - db.session.query(ChildChunk) + child_chunk = db.session.scalar( + select(ChildChunk) .where( ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_tenant_id, ChildChunk.segment_id == segment.id, ChildChunk.document_id == document_id, ) - .first() + .limit(1) ) if not child_chunk: raise NotFound("Child chunk not found.") @@ -771,24 +772,24 @@ class ChildChunkUpdateApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") # check child chunk child_chunk_id = str(child_chunk_id) - child_chunk = ( - db.session.query(ChildChunk) + child_chunk = db.session.scalar( + select(ChildChunk) .where( ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_tenant_id, ChildChunk.segment_id == segment.id, ChildChunk.document_id == document_id, ) - .first() + .limit(1) ) if not child_chunk: raise NotFound("Child chunk not found.") diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 89be847cd3..25b6436a71 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -15,6 +15,7 @@ from controllers.service_api.wraps import ( cloud_edition_billing_rate_limit_check, ) from core.provider_manager import ProviderManager +from core.rag.index_processor.constant.index_type import IndexTechniqueType from dify_graph.model_runtime.entities.model_entities import ModelType from fields.dataset_fields import dataset_detail_fields from fields.tag_fields import DataSetTag @@ -153,9 +154,14 @@ class DatasetListApi(DatasetApiResource): data = marshal(datasets, dataset_detail_fields) for item in data: - if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]: # type: ignore - item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"])) # type: ignore - item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" # type: ignore + if ( + item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY # pyrefly: ignore[bad-index] + and item["embedding_model_provider"] # pyrefly: ignore[bad-index] + ): + item["embedding_model_provider"] = str( # pyrefly: ignore[unsupported-operation] + ModelProviderID(item["embedding_model_provider"]) # pyrefly: ignore[bad-index] + ) + item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" # pyrefly: ignore[bad-index] if item_model in model_names: item["embedding_available"] = True # type: ignore else: @@ -265,7 +271,7 @@ class DatasetApi(DatasetApiResource): for embedding_model in embedding_models: model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") - if data.get("indexing_technique") == "high_quality": + if data.get("indexing_technique") == IndexTechniqueType.HIGH_QUALITY: item_model = f"{data.get('embedding_model')}:{data.get('embedding_model_provider')}" if item_model in model_names: data["embedding_available"] = True @@ -315,7 +321,7 @@ class DatasetApi(DatasetApiResource): # check embedding model setting embedding_model_provider = payload.embedding_model_provider embedding_model = payload.embedding_model - if payload.indexing_technique == "high_quality" or embedding_model_provider: + if payload.indexing_technique == IndexTechniqueType.HIGH_QUALITY or embedding_model_provider: if embedding_model_provider and embedding_model: DatasetService.check_embedding_model_setting( dataset.tenant_id, embedding_model_provider, embedding_model diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 2e3b7fd85e..595b01a9f2 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -17,6 +17,7 @@ from controllers.service_api.wraps import ( ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager +from core.rag.index_processor.constant.index_type import IndexTechniqueType from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from fields.segment_fields import child_chunk_fields, segment_fields @@ -103,7 +104,7 @@ class SegmentApi(DatasetApiResource): if not document.enabled: raise NotFound("Document is disabled.") # check embedding model setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: model_manager = ModelManager() model_manager.get_model_instance( @@ -157,7 +158,7 @@ class SegmentApi(DatasetApiResource): if not document: raise NotFound("Document not found.") # check embedding model setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: model_manager = ModelManager() model_manager.get_model_instance( @@ -262,7 +263,7 @@ class DatasetSegmentApi(DatasetApiResource): document = DocumentService.get_document(dataset_id, document_id) if not document: raise NotFound("Document not found.") - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting try: model_manager = ModelManager() @@ -358,7 +359,7 @@ class ChildChunkApi(DatasetApiResource): raise NotFound("Segment not found.") # check embedding model setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: model_manager = ModelManager() model_manager.get_model_instance( diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index 87d4772815..0bd904811a 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -4,6 +4,7 @@ from sqlalchemy import select from core.app.entities.app_invoke_entities import InvokeFrom from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db from models.dataset import Dataset from models.enums import CollectionBindingType, ConversationFromSource @@ -50,7 +51,7 @@ class AnnotationReplyFeature: dataset = Dataset( id=app_record.id, tenant_id=app_record.tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=embedding_provider_name, embedding_model=embedding_model_name, collection_binding_id=dataset_collection_binding.id, diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 52776ee626..06bc366081 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -21,7 +21,7 @@ from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import ChildDocument, Document @@ -271,7 +271,7 @@ class IndexingRunner: doc_form: str | None = None, doc_language: str = "English", dataset_id: str | None = None, - indexing_technique: str = "economy", + indexing_technique: str = IndexTechniqueType.ECONOMY, ) -> IndexingEstimate: """ Estimate the indexing for the document. @@ -289,7 +289,7 @@ class IndexingRunner: dataset = db.session.query(Dataset).filter_by(id=dataset_id).first() if not dataset: raise ValueError("Dataset not found.") - if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality": + if IndexTechniqueType.HIGH_QUALITY in {dataset.indexing_technique, indexing_technique}: if dataset.embedding_model_provider: embedding_model_instance = self.model_manager.get_model_instance( tenant_id=tenant_id, @@ -303,7 +303,7 @@ class IndexingRunner: model_type=ModelType.TEXT_EMBEDDING, ) else: - if indexing_technique == "high_quality": + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: embedding_model_instance = self.model_manager.get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING, @@ -573,7 +573,7 @@ class IndexingRunner: """ embedding_model_instance = None - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: embedding_model_instance = self.model_manager.get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, @@ -587,7 +587,7 @@ class IndexingRunner: create_keyword_thread = None if ( dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX - and dataset.indexing_technique == "economy" + and dataset.indexing_technique == IndexTechniqueType.ECONOMY ): # create keyword index create_keyword_thread = threading.Thread( @@ -597,7 +597,7 @@ class IndexingRunner: create_keyword_thread.start() max_workers = 10 - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [] @@ -628,7 +628,7 @@ class IndexingRunner: tokens += future.result() if ( dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX - and dataset.indexing_technique == "economy" + and dataset.indexing_technique == IndexTechniqueType.ECONOMY and create_keyword_thread is not None ): create_keyword_thread.join() @@ -654,7 +654,7 @@ class IndexingRunner: raise ValueError("no dataset found") keyword = Keyword(dataset) keyword.create(documents) - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: document_ids = [document.metadata["doc_id"] for document in documents] db.session.query(DocumentSegment).where( DocumentSegment.document_id == document_id, @@ -764,7 +764,7 @@ class IndexingRunner: ) -> list[Document]: # get embedding model instance embedding_model_instance = None - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: if dataset.embedding_model_provider: embedding_model_instance = self.model_manager.get_model_instance( tenant_id=dataset.tenant_id, diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 16a5588024..cd27113245 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -6,6 +6,7 @@ from typing import Any from sqlalchemy import func, select from core.model_manager import ModelManager +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import AttachmentDocument, Document from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db @@ -71,7 +72,7 @@ class DatasetDocumentStore: if max_position is None: max_position = 0 embedding_model = None - if self._dataset.indexing_technique == "high_quality": + if self._dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=self._dataset.tenant_id, diff --git a/api/core/rag/index_processor/index_processor.py b/api/core/rag/index_processor/index_processor.py index d9145023ac..a6d1db214b 100644 --- a/api/core/rag/index_processor/index_processor.py +++ b/api/core/rag/index_processor/index_processor.py @@ -9,6 +9,7 @@ from flask import current_app from sqlalchemy import delete, func, select from core.db.session_factory import session_factory +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError from core.workflow.nodes.knowledge_index.protocols import Preview, PreviewItem, QaPreview @@ -159,7 +160,7 @@ class IndexProcessor: tenant_id = dataset.tenant_id preview_output = self.format_preview(chunk_structure, chunks) - if indexing_technique != "high_quality": + if indexing_technique != IndexTechniqueType.HIGH_QUALITY: return preview_output if not summary_index_setting or not summary_index_setting.get("enable"): diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 80163b1707..726cc062f6 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -22,7 +22,7 @@ from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.constant.doc_type import DocType -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -117,7 +117,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): with_keywords: bool = True, **kwargs, ) -> None: - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) vector.create(documents) if multimodal_documents and dataset.is_multimodal: @@ -155,7 +155,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): # Delete all summaries for the dataset SummaryIndexService.delete_summaries_for_segments(dataset, None) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) if node_ids: vector.delete_by_ids(node_ids) @@ -253,12 +253,12 @@ class ParagraphIndexProcessor(BaseIndexProcessor): doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id) # add document segments doc_store.add_documents(docs=documents, save_child=False) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) vector.create(documents) if all_multimodal_documents and dataset.is_multimodal: vector.create_multimodal(all_multimodal_documents) - elif dataset.indexing_technique == "economy": + elif dataset.indexing_technique == IndexTechniqueType.ECONOMY: keyword = Keyword(dataset) keyword.add_texts(documents) diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index df0761ca73..70504e6e50 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -18,7 +18,7 @@ from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.constant.doc_type import DocType -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict from core.rag.models.document import AttachmentDocument, ChildDocument, Document, ParentChildStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -128,7 +128,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): with_keywords: bool = True, **kwargs, ) -> None: - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) for document in documents: child_documents = document.children @@ -166,7 +166,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): # Delete all summaries for the dataset SummaryIndexService.delete_summaries_for_segments(dataset, None) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: delete_child_chunks = kwargs.get("delete_child_chunks") or False precomputed_child_node_ids = kwargs.get("precomputed_child_node_ids") vector = Vector(dataset) @@ -332,7 +332,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id) # add document segments doc_store.add_documents(docs=documents, save_child=True) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: all_child_documents = [] all_multimodal_documents = [] for doc in documents: diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 62f88b7760..6874603a83 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -21,7 +21,7 @@ from core.rag.datasource.vdb.vector_factory import Vector from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict from core.rag.models.document import AttachmentDocument, Document, QAStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -141,7 +141,7 @@ class QAIndexProcessor(BaseIndexProcessor): with_keywords: bool = True, **kwargs, ) -> None: - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) vector.create(documents) if multimodal_documents and dataset.is_multimodal: @@ -224,7 +224,7 @@ class QAIndexProcessor(BaseIndexProcessor): # save node to document segment doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id) doc_store.add_documents(docs=documents, save_child=False) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) vector.create(documents) else: diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 78a97f79a5..52061fd93d 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -675,7 +675,7 @@ class DatasetRetrieval: # get top k top_k = retrieval_model_config["top_k"] # get retrieval method - if selected_dataset.indexing_technique == "economy": + if selected_dataset.indexing_technique == IndexTechniqueType.ECONOMY: retrieval_method = RetrievalMethod.KEYWORD_SEARCH else: retrieval_method = retrieval_model_config["search_method"] @@ -752,7 +752,7 @@ class DatasetRetrieval: "The configured knowledge base list have different indexing technique, please set reranking model." ) index_type = available_datasets[0].indexing_technique - if index_type == "high_quality": + if index_type == IndexTechniqueType.HIGH_QUALITY: embedding_model_check = all( item.embedding_model == available_datasets[0].embedding_model for item in available_datasets ) @@ -1068,7 +1068,7 @@ class DatasetRetrieval: else default_retrieval_model ) - if dataset.indexing_technique == "economy": + if dataset.indexing_technique == IndexTechniqueType.ECONOMY: # use keyword table query documents = RetrievalService.retrieve( retrieval_method=RetrievalMethod.KEYWORD_SEARCH, diff --git a/api/core/rag/summary_index/summary_index.py b/api/core/rag/summary_index/summary_index.py index 31d21dbeee..6f120bd471 100644 --- a/api/core/rag/summary_index/summary_index.py +++ b/api/core/rag/summary_index/summary_index.py @@ -2,6 +2,7 @@ import concurrent.futures import logging from core.db.session_factory import session_factory +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary from services.summary_index_service import SummaryIndexService @@ -21,7 +22,7 @@ class SummaryIndex: if is_preview: with session_factory.create_session() as session: dataset = session.query(Dataset).filter_by(id=dataset_id).first() - if not dataset or dataset.indexing_technique != "high_quality": + if not dataset or dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: return if summary_index_setting is None: diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py index c2b520fa99..75b923fd8b 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -8,6 +8,7 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa from core.model_manager import ModelManager from core.rag.datasource.retrieval_service import RetrievalService from core.rag.entities.citation_metadata import RetrievalSourceMetadata +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document as RagDocument from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -169,7 +170,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): # get retrieval model , if the model is not setting , using default retrieval_model = dataset.retrieval_model or default_retrieval_model - if dataset.indexing_technique == "economy": + if dataset.indexing_technique == IndexTechniqueType.ECONOMY: # use keyword table query documents = RetrievalService.retrieve( retrieval_method=RetrievalMethod.KEYWORD_SEARCH, diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index 429b7e6622..f3d390ed59 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -8,6 +8,7 @@ from core.rag.data_post_processor.data_post_processor import RerankingModelDict, from core.rag.datasource.retrieval_service import RetrievalService from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities.context_entities import DocumentContext +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document as RetrievalDocument from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -140,7 +141,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): # get retrieval model , if the model is not setting , using default retrieval_model = dataset.retrieval_model or default_retrieval_model retrieval_resource_list: list[RetrievalSourceMetadata] = [] - if dataset.indexing_technique == "economy": + if dataset.indexing_technique == IndexTechniqueType.ECONOMY: # use keyword table query documents = RetrievalService.retrieve( retrieval_method=RetrievalMethod.KEYWORD_SEARCH, @@ -173,7 +174,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): for hit_callback in self.hit_callbacks: hit_callback.on_tool_end(documents) document_score_list = {} - if dataset.indexing_technique != "economy": + if dataset.indexing_technique != IndexTechniqueType.ECONOMY: for item in documents: if item.metadata is not None and item.metadata.get("score"): document_score_list[item.metadata["doc_id"]] = item.metadata["score"] diff --git a/api/models/dataset.py b/api/models/dataset.py index b4fb03a7f4..e323ccfd7f 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -20,7 +20,7 @@ from sqlalchemy.orm import Mapped, Session, mapped_column from configs import dify_config from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.constant.query_type import QueryType from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.signature import sign_upload_file @@ -137,7 +137,7 @@ class Dataset(Base): default=DatasetPermissionEnum.ONLY_ME, ) data_source_type = mapped_column(EnumText(DataSourceType, length=255)) - indexing_technique: Mapped[str | None] = mapped_column(String(255)) + indexing_technique: Mapped[IndexTechniqueType | None] = mapped_column(EnumText(IndexTechniqueType, length=255)) index_struct = mapped_column(LongText, nullable=True) created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/model.py b/api/models/model.py index dff5b71e7e..68ff37bcaa 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -940,7 +940,9 @@ class AccountTrialAppRecord(Base): class ExporleBanner(TypeBase): __tablename__ = "exporle_banners" __table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),) - id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv4_string, init=False) + id: Mapped[str] = mapped_column( + StringUUID, insert_default=gen_uuidv4_string, default_factory=gen_uuidv4_string, init=False + ) content: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False) link: Mapped[str] = mapped_column(String(255), nullable=False) sort: Mapped[int] = mapped_column(sa.Integer, nullable=False) @@ -1849,7 +1851,9 @@ class AppAnnotationHitHistory(TypeBase): sa.Index("app_annotation_hit_histories_message_idx", "message_id"), ) - id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False) + id: Mapped[str] = mapped_column( + StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False + ) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) annotation_id: Mapped[str] = mapped_column(StringUUID, nullable=False) source: Mapped[str] = mapped_column(LongText, nullable=False) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 65e112f1e9..969ca68545 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -21,7 +21,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.helper.name_generator import generate_incremental_name from core.model_manager import ModelManager from core.rag.index_processor.constant.built_in_field import BuiltInField -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod from dify_graph.file import helpers as file_helpers from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType @@ -228,7 +228,7 @@ class DatasetService: if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first(): raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.") embedding_model = None - if indexing_technique == "high_quality": + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: model_manager = ModelManager() if embedding_model_provider and embedding_model_name: # check if embedding model setting is valid @@ -254,7 +254,10 @@ class DatasetService: retrieval_model.reranking_model.reranking_provider_name, retrieval_model.reranking_model.reranking_model_name, ) - dataset = Dataset(name=name, indexing_technique=indexing_technique) + dataset = Dataset( + name=name, + indexing_technique=IndexTechniqueType(indexing_technique) if indexing_technique else None, + ) # dataset = Dataset(name=name, provider=provider, config=config) dataset.description = description dataset.created_by = account.id @@ -349,7 +352,7 @@ class DatasetService: @staticmethod def check_dataset_model_setting(dataset): - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: model_manager = ModelManager() model_manager.get_model_instance( @@ -717,13 +720,13 @@ class DatasetService: if "indexing_technique" not in data: return None if dataset.indexing_technique != data["indexing_technique"]: - if data["indexing_technique"] == "economy": + if data["indexing_technique"] == IndexTechniqueType.ECONOMY: # Remove embedding model configuration for economy mode filtered_data["embedding_model"] = None filtered_data["embedding_model_provider"] = None filtered_data["collection_binding_id"] = None return "remove" - elif data["indexing_technique"] == "high_quality": + elif data["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY: # Configure embedding model for high quality mode DatasetService._configure_embedding_model_for_high_quality(data, filtered_data) return "add" @@ -953,8 +956,8 @@ class DatasetService: dataset = session.merge(dataset) if not has_published: dataset.chunk_structure = knowledge_configuration.chunk_structure - dataset.indexing_technique = knowledge_configuration.indexing_technique - if knowledge_configuration.indexing_technique == "high_quality": + dataset.indexing_technique = IndexTechniqueType(knowledge_configuration.indexing_technique) + if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY: model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, # ignore type error @@ -976,7 +979,7 @@ class DatasetService: embedding_model_name, ) dataset.collection_binding_id = dataset_collection_binding.id - elif knowledge_configuration.indexing_technique == "economy": + elif knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY: dataset.keyword_number = knowledge_configuration.keyword_number else: raise ValueError("Invalid index method") @@ -991,9 +994,9 @@ class DatasetService: action = None if dataset.indexing_technique != knowledge_configuration.indexing_technique: # if update indexing_technique - if knowledge_configuration.indexing_technique == "economy": + if knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY: raise ValueError("Knowledge base indexing technique is not allowed to be updated to economy.") - elif knowledge_configuration.indexing_technique == "high_quality": + elif knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY: action = "add" # get embedding model setting try: @@ -1018,7 +1021,7 @@ class DatasetService: ) dataset.is_multimodal = is_multimodal dataset.collection_binding_id = dataset_collection_binding.id - dataset.indexing_technique = knowledge_configuration.indexing_technique + dataset.indexing_technique = IndexTechniqueType(knowledge_configuration.indexing_technique) except LLMBadRequestError: raise ValueError( "No Embedding Model available. Please configure a valid provider " @@ -1029,7 +1032,7 @@ class DatasetService: else: # add default plugin id to both setting sets, to make sure the plugin model provider is consistent # Skip embedding model checks if not provided in the update request - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: skip_embedding_update = False try: # Handle existing model provider @@ -1089,7 +1092,7 @@ class DatasetService: ) except ProviderTokenNotInitError as ex: raise ValueError(ex.description) - elif dataset.indexing_technique == "economy": + elif dataset.indexing_technique == IndexTechniqueType.ECONOMY: if dataset.keyword_number != knowledge_configuration.keyword_number: dataset.keyword_number = knowledge_configuration.keyword_number dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() @@ -1907,8 +1910,8 @@ class DocumentService: if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: raise ValueError("Indexing technique is invalid") - dataset.indexing_technique = knowledge_config.indexing_technique - if knowledge_config.indexing_technique == "high_quality": + dataset.indexing_technique = IndexTechniqueType(knowledge_config.indexing_technique) + if knowledge_config.indexing_technique == IndexTechniqueType.HIGH_QUALITY: model_manager = ModelManager() if knowledge_config.embedding_model and knowledge_config.embedding_model_provider: dataset_embedding_model = knowledge_config.embedding_model @@ -2689,7 +2692,7 @@ class DocumentService: dataset_collection_binding_id = None retrieval_model = None - if knowledge_config.indexing_technique == "high_quality": + if knowledge_config.indexing_technique == IndexTechniqueType.HIGH_QUALITY: assert knowledge_config.embedding_model_provider assert knowledge_config.embedding_model dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( @@ -2712,7 +2715,7 @@ class DocumentService: tenant_id=tenant_id, name="", data_source_type=knowledge_config.data_source.info_list.data_source_type, - indexing_technique=knowledge_config.indexing_technique, + indexing_technique=IndexTechniqueType(knowledge_config.indexing_technique), created_by=account.id, embedding_model=knowledge_config.embedding_model, embedding_model_provider=knowledge_config.embedding_model_provider, @@ -3125,7 +3128,7 @@ class SegmentService: doc_id = str(uuid.uuid4()) segment_hash = helper.generate_text_hash(content) tokens = 0 - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, @@ -3208,7 +3211,7 @@ class SegmentService: try: with redis_client.lock(lock_name, timeout=600): embedding_model = None - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, @@ -3230,7 +3233,7 @@ class SegmentService: doc_id = str(uuid.uuid4()) segment_hash = helper.generate_text_hash(content) tokens = 0 - if dataset.indexing_technique == "high_quality" and embedding_model: + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY and embedding_model: # calc embedding use tokens if document.doc_form == IndexStructureType.QA_INDEX: tokens = embedding_model.get_text_embedding_num_tokens( @@ -3345,7 +3348,7 @@ class SegmentService: if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: # regenerate child chunks # get embedding model instance - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting model_manager = ModelManager() @@ -3382,7 +3385,7 @@ class SegmentService: # When user manually provides summary, allow saving even if summary_index_setting doesn't exist # summary_index_setting is only needed for LLM generation, not for manual summary vectorization # Vectorization uses dataset.embedding_model, which doesn't require summary_index_setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # Query existing summary from database from models.dataset import DocumentSegmentSummary @@ -3409,7 +3412,7 @@ class SegmentService: else: segment_hash = helper.generate_text_hash(content) tokens = 0 - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, @@ -3449,7 +3452,7 @@ class SegmentService: db.session.commit() if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: # get embedding model instance - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting model_manager = ModelManager() @@ -3481,7 +3484,7 @@ class SegmentService: # update segment vector index VectorService.update_segment_vector(args.keywords, segment, dataset) # Handle summary index when content changed - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: from models.dataset import DocumentSegmentSummary existing_summary = ( diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index deb59da8d3..fd66d55c1a 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -22,6 +22,7 @@ from sqlalchemy.orm import Session from core.helper import ssrf_proxy from core.helper.name_generator import generate_incremental_name from core.plugin.entities.plugin import PluginDependency +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.workflow.nodes.datasource.entities import DatasourceNodeData from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData @@ -311,13 +312,13 @@ class RagPipelineDslService: "icon_background": icon_background, "icon_url": icon_url, }, - indexing_technique=knowledge_configuration.indexing_technique, + indexing_technique=IndexTechniqueType(knowledge_configuration.indexing_technique), created_by=account.id, retrieval_model=knowledge_configuration.retrieval_model.model_dump(), runtime_mode=DatasetRuntimeMode.RAG_PIPELINE, chunk_structure=knowledge_configuration.chunk_structure, ) - if knowledge_configuration.indexing_technique == "high_quality": + if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY: dataset_collection_binding = ( self._session.query(DatasetCollectionBinding) .where( @@ -343,7 +344,7 @@ class RagPipelineDslService: dataset.collection_binding_id = dataset_collection_binding_id dataset.embedding_model = knowledge_configuration.embedding_model dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider - elif knowledge_configuration.indexing_technique == "economy": + elif knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY: dataset.keyword_number = knowledge_configuration.keyword_number # Update summary_index_setting if provided if knowledge_configuration.summary_index_setting is not None: @@ -443,18 +444,18 @@ class RagPipelineDslService: "icon_background": icon_background, "icon_url": icon_url, }, - indexing_technique=knowledge_configuration.indexing_technique, + indexing_technique=IndexTechniqueType(knowledge_configuration.indexing_technique), created_by=account.id, retrieval_model=knowledge_configuration.retrieval_model.model_dump(), runtime_mode=DatasetRuntimeMode.RAG_PIPELINE, chunk_structure=knowledge_configuration.chunk_structure, ) else: - dataset.indexing_technique = knowledge_configuration.indexing_technique + dataset.indexing_technique = IndexTechniqueType(knowledge_configuration.indexing_technique) dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE dataset.chunk_structure = knowledge_configuration.chunk_structure - if knowledge_configuration.indexing_technique == "high_quality": + if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY: dataset_collection_binding = ( self._session.query(DatasetCollectionBinding) .where( @@ -480,7 +481,7 @@ class RagPipelineDslService: dataset.collection_binding_id = dataset_collection_binding_id dataset.embedding_model = knowledge_configuration.embedding_model dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider - elif knowledge_configuration.indexing_technique == "economy": + elif knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY: dataset.keyword_number = knowledge_configuration.keyword_number # Update summary_index_setting if provided if knowledge_configuration.summary_index_setting is not None: @@ -772,7 +773,7 @@ class RagPipelineDslService: ) case _ if typ == KNOWLEDGE_INDEX_NODE_TYPE: knowledge_index_entity = KnowledgeConfiguration.model_validate(node["data"]) - if knowledge_index_entity.indexing_technique == "high_quality": + if knowledge_index_entity.indexing_technique == IndexTechniqueType.HIGH_QUALITY: if knowledge_index_entity.embedding_model_provider: dependencies.append( DependenciesAnalysisService.analyze_model_provider_dependency( diff --git a/api/services/rag_pipeline/rag_pipeline_transform_service.py b/api/services/rag_pipeline/rag_pipeline_transform_service.py index 7dcfecdd1d..215a8c8528 100644 --- a/api/services/rag_pipeline/rag_pipeline_transform_service.py +++ b/api/services/rag_pipeline/rag_pipeline_transform_service.py @@ -9,7 +9,7 @@ from flask_login import current_user from constants import DOCUMENT_EXTENSIONS from core.plugin.impl.plugin import PluginInstaller -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from factories import variable_factory @@ -105,29 +105,29 @@ class RagPipelineTransformService: if doc_form == IndexStructureType.PARAGRAPH_INDEX: match datasource_type: case DataSourceType.UPLOAD_FILE: - if indexing_technique == "high_quality": + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: # get graph from transform.file-general-high-quality.yml with open(f"{Path(__file__).parent}/transform/file-general-high-quality.yml") as f: pipeline_yaml = yaml.safe_load(f) - if indexing_technique == "economy": + if indexing_technique == IndexTechniqueType.ECONOMY: # get graph from transform.file-general-economy.yml with open(f"{Path(__file__).parent}/transform/file-general-economy.yml") as f: pipeline_yaml = yaml.safe_load(f) case DataSourceType.NOTION_IMPORT: - if indexing_technique == "high_quality": + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: # get graph from transform.notion-general-high-quality.yml with open(f"{Path(__file__).parent}/transform/notion-general-high-quality.yml") as f: pipeline_yaml = yaml.safe_load(f) - if indexing_technique == "economy": + if indexing_technique == IndexTechniqueType.ECONOMY: # get graph from transform.notion-general-economy.yml with open(f"{Path(__file__).parent}/transform/notion-general-economy.yml") as f: pipeline_yaml = yaml.safe_load(f) case DataSourceType.WEBSITE_CRAWL: - if indexing_technique == "high_quality": + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: # get graph from transform.website-crawl-general-high-quality.yml with open(f"{Path(__file__).parent}/transform/website-crawl-general-high-quality.yml") as f: pipeline_yaml = yaml.safe_load(f) - if indexing_technique == "economy": + if indexing_technique == IndexTechniqueType.ECONOMY: # get graph from transform.website-crawl-general-economy.yml with open(f"{Path(__file__).parent}/transform/website-crawl-general-economy.yml") as f: pipeline_yaml = yaml.safe_load(f) @@ -170,11 +170,11 @@ class RagPipelineTransformService: ): knowledge_configuration_dict = node.get("data", {}) - if indexing_technique == "high_quality": + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: knowledge_configuration.embedding_model = dataset.embedding_model knowledge_configuration.embedding_model_provider = dataset.embedding_model_provider if retrieval_model: - if indexing_technique == "economy": + if indexing_technique == IndexTechniqueType.ECONOMY: retrieval_model.search_method = RetrievalMethod.KEYWORD_SEARCH knowledge_configuration.retrieval_model = retrieval_model else: diff --git a/api/services/summary_index_service.py b/api/services/summary_index_service.py index 943dfc972b..ed7a33feae 100644 --- a/api/services/summary_index_service.py +++ b/api/services/summary_index_service.py @@ -12,6 +12,7 @@ from core.db.session_factory import session_factory from core.model_manager import ModelManager from core.rag.datasource.vdb.vector_factory import Vector from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.models.document import Document from dify_graph.model_runtime.entities.llm_entities import LLMUsage @@ -140,7 +141,7 @@ class SummaryIndexService: session: Optional SQLAlchemy session. If provided, uses this session instead of creating a new one. If not provided, creates a new session and commits automatically. """ - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: logger.warning( "Summary vectorization skipped for dataset %s: indexing_technique is not high_quality", dataset.id, @@ -724,7 +725,7 @@ class SummaryIndexService: List of created DocumentSegmentSummary instances """ # Only generate summary index for high_quality indexing technique - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: logger.info( "Skipping summary generation for dataset %s: indexing_technique is %s, not 'high_quality'", dataset.id, @@ -851,7 +852,7 @@ class SummaryIndexService: ) # Remove from vector database (but keep records) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id] if summary_node_ids: try: @@ -889,7 +890,7 @@ class SummaryIndexService: segment_ids: List of segment IDs to enable summaries for. If None, enable all. """ # Only enable summary index for high_quality indexing technique - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: return with session_factory.create_session() as session: @@ -981,7 +982,7 @@ class SummaryIndexService: return # Delete from vector database - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id] if summary_node_ids: vector = Vector(dataset) @@ -1012,7 +1013,7 @@ class SummaryIndexService: Updated DocumentSegmentSummary instance, or None if indexing technique is not high_quality """ # Only update summary index for high_quality indexing technique - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: return None # When user manually provides summary, allow saving even if summary_index_setting doesn't exist diff --git a/api/services/vector_service.py b/api/services/vector_service.py index b66fdd7a20..bb94a03ba3 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -4,7 +4,7 @@ from core.model_manager import ModelInstance, ModelManager from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector from core.rag.index_processor.constant.doc_type import DocType -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import AttachmentDocument, Document @@ -45,7 +45,7 @@ class VectorService: if not processing_rule: raise ValueError("No processing rule found.") # get embedding model instance - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting model_manager = ModelManager() @@ -112,7 +112,7 @@ class VectorService: "dataset_id": segment.dataset_id, }, ) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # update vector index vector = Vector(dataset=dataset) vector.delete_by_ids([segment.index_node_id]) @@ -197,7 +197,7 @@ class VectorService: "dataset_id": child_segment.dataset_id, }, ) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # save vector index vector = Vector(dataset=dataset) vector.add_texts([child_document], duplicate_check=True) @@ -237,7 +237,7 @@ class VectorService: delete_node_ids.append(update_child_chunk.index_node_id) for delete_child_chunk in delete_child_chunks: delete_node_ids.append(delete_child_chunk.index_node_id) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # update vector index vector = Vector(dataset=dataset) if delete_node_ids: @@ -252,7 +252,7 @@ class VectorService: @classmethod def update_multimodel_vector(cls, segment: DocumentSegment, attachment_ids: list[str], dataset: Dataset): - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: return attachments = segment.attachments diff --git a/api/tasks/annotation/add_annotation_to_index_task.py b/api/tasks/annotation/add_annotation_to_index_task.py index a9a8b892c2..dafa36cc34 100644 --- a/api/tasks/annotation/add_annotation_to_index_task.py +++ b/api/tasks/annotation/add_annotation_to_index_task.py @@ -5,6 +5,7 @@ import click from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService @@ -36,7 +37,7 @@ def add_annotation_to_index_task( dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, collection_binding_id=dataset_collection_binding.id, diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py index fc6bf03454..c734e1321b 100644 --- a/api/tasks/annotation/batch_import_annotations_task.py +++ b/api/tasks/annotation/batch_import_annotations_task.py @@ -7,6 +7,7 @@ from werkzeug.exceptions import NotFound from core.db.session_factory import session_factory from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset @@ -67,7 +68,7 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, collection_binding_id=dataset_collection_binding.id, diff --git a/api/tasks/annotation/delete_annotation_index_task.py b/api/tasks/annotation/delete_annotation_index_task.py index 432732af95..c9aa8fadb7 100644 --- a/api/tasks/annotation/delete_annotation_index_task.py +++ b/api/tasks/annotation/delete_annotation_index_task.py @@ -5,6 +5,7 @@ import click from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService @@ -26,7 +27,7 @@ def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, collection_binding_id=dataset_collection_binding.id, ) diff --git a/api/tasks/annotation/disable_annotation_reply_task.py b/api/tasks/annotation/disable_annotation_reply_task.py index 7b5cd46b00..41cf7ccbf6 100644 --- a/api/tasks/annotation/disable_annotation_reply_task.py +++ b/api/tasks/annotation/disable_annotation_reply_task.py @@ -7,6 +7,7 @@ from sqlalchemy import exists, select from core.db.session_factory import session_factory from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_redis import redis_client from models.dataset import Dataset from models.model import App, AppAnnotationSetting, MessageAnnotation @@ -44,7 +45,7 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str): dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, collection_binding_id=app_annotation_setting.collection_binding_id, ) diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index 1fe43c3d62..2c07fe0f31 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -7,6 +7,7 @@ from sqlalchemy import select from core.db.session_factory import session_factory from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now @@ -64,7 +65,7 @@ def enable_annotation_reply_task( old_dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=old_dataset_collection_binding.provider_name, embedding_model=old_dataset_collection_binding.model_name, collection_binding_id=old_dataset_collection_binding.id, @@ -93,7 +94,7 @@ def enable_annotation_reply_task( dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=embedding_provider_name, embedding_model=embedding_model_name, collection_binding_id=dataset_collection_binding.id, diff --git a/api/tasks/annotation/update_annotation_to_index_task.py b/api/tasks/annotation/update_annotation_to_index_task.py index 6ff34c0e74..f41da1d373 100644 --- a/api/tasks/annotation/update_annotation_to_index_task.py +++ b/api/tasks/annotation/update_annotation_to_index_task.py @@ -5,6 +5,7 @@ import click from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService @@ -37,7 +38,7 @@ def update_annotation_to_index_task( dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, collection_binding_id=dataset_collection_binding.id, diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 7f810129ef..dd58378e0e 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -11,7 +11,7 @@ from sqlalchemy import func from core.db.session_factory import session_factory from core.model_manager import ModelManager -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_redis import redis_client from extensions.ext_storage import storage @@ -120,7 +120,7 @@ def batch_create_segment_to_index_task( document_segments = [] embedding_model = None - if dataset_config["indexing_technique"] == "high_quality": + if dataset_config["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY: model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=dataset_config["tenant_id"], diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index b5794e33e2..23a80fa106 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -10,7 +10,7 @@ from configs import dify_config from core.db.session_factory import session_factory from core.entities.document_task import DocumentTask from core.indexing_runner import DocumentIsPausedError, IndexingRunner -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan from libs.datetime_utils import naive_utc_now @@ -127,7 +127,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): logger.warning("Dataset %s not found after indexing", dataset_id) return - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: summary_index_setting = dataset.summary_index_setting if summary_index_setting and summary_index_setting.get("enable"): # expire all session to get latest document's indexing status diff --git a/api/tasks/generate_summary_index_task.py b/api/tasks/generate_summary_index_task.py index 6493833edc..e3d82d2851 100644 --- a/api/tasks/generate_summary_index_task.py +++ b/api/tasks/generate_summary_index_task.py @@ -7,6 +7,7 @@ import click from celery import shared_task from core.db.session_factory import session_factory +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models.dataset import Dataset, DocumentSegment from models.dataset import Document as DatasetDocument from services.summary_index_service import SummaryIndexService @@ -59,7 +60,7 @@ def generate_summary_index_task(dataset_id: str, document_id: str, segment_ids: return # Only generate summary index for high_quality indexing technique - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: logger.info( click.style( f"Skipping summary generation for dataset {dataset_id}: " diff --git a/api/tasks/regenerate_summary_index_task.py b/api/tasks/regenerate_summary_index_task.py index ac5d23408a..6f490ab7ea 100644 --- a/api/tasks/regenerate_summary_index_task.py +++ b/api/tasks/regenerate_summary_index_task.py @@ -9,7 +9,7 @@ from celery import shared_task from sqlalchemy import or_, select from core.db.session_factory import session_factory -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary from models.dataset import Document as DatasetDocument from services.summary_index_service import SummaryIndexService @@ -53,7 +53,7 @@ def regenerate_summary_index_task( return # Only regenerate summary index for high_quality indexing technique - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: logger.info( click.style( f"Skipping summary regeneration for dataset {dataset_id}: " diff --git a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py index ea8d04502a..00d7496a40 100644 --- a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py +++ b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py @@ -4,7 +4,7 @@ from unittest.mock import patch import pytest from faker import Faker -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest from models.dataset import Dataset, Document @@ -39,7 +39,7 @@ class TestGetAvailableDatasetsIntegration: provider="dify", data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) db_session_with_containers.add(dataset) db_session_with_containers.flush() @@ -460,7 +460,7 @@ class TestKnowledgeRetrievalIntegration: provider="dify", data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py b/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py index 6b35f867d7..02c3d1a80e 100644 --- a/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py +++ b/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py @@ -13,6 +13,7 @@ import pytest from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import AppDatasetJoin, Dataset, DatasetPermissionEnum from models.enums import DataSourceType @@ -74,7 +75,7 @@ class DatasetUpdateDeleteTestDataFactory: name=name, description="Test description", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=created_by, permission=permission, provider="vendor", diff --git a/api/tests/test_containers_integration_tests/services/test_app_service.py b/api/tests/test_containers_integration_tests/services/test_app_service.py index 9ca8729b77..a83af30fb9 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_service.py @@ -1245,3 +1245,51 @@ class TestAppService: assert paginated_apps is not None assert paginated_apps.total == 1 assert all("50%" in app.name for app in paginated_apps.items) + + def test_get_app_code_by_id_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test get_app_code_by_id raises ValueError when site is missing.""" + from uuid import uuid4 + + from services.app_service import AppService + + with pytest.raises(ValueError, match="not found"): + AppService.get_app_code_by_id(str(uuid4())) + + def test_get_app_id_by_code_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test get_app_id_by_code raises ValueError when code does not exist.""" + from services.app_service import AppService + + with pytest.raises(ValueError, match="not found"): + AppService.get_app_id_by_code("nonexistent-code") + + def test_get_app_meta_returns_empty_when_workflow_missing( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test get_app_meta returns empty tool_icons when workflow is None.""" + from types import SimpleNamespace + + from services.app_service import AppService + + app_service = AppService() + workflow_app = SimpleNamespace(mode="workflow", workflow=None) + + meta = app_service.get_app_meta(workflow_app) + assert meta == {"tool_icons": {}} + + def test_get_app_meta_returns_empty_when_model_config_missing( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test get_app_meta returns empty tool_icons when app_model_config is None.""" + from types import SimpleNamespace + + from services.app_service import AppService + + app_service = AppService() + chat_app = SimpleNamespace(mode="chat", app_model_config=None) + + meta = app_service.get_app_meta(chat_app) + assert meta == {"tool_icons": {}} diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py index 55bfb64e18..71c8874f79 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py @@ -9,6 +9,7 @@ from uuid import uuid4 import pytest +from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import ( @@ -69,7 +70,7 @@ class DatasetPermissionTestDataFactory: name=name, description="desc", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=created_by, permission=permission, provider="vendor", diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_service.py index c4d20bc02c..0702680f5c 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service.py @@ -11,7 +11,7 @@ from uuid import uuid4 import pytest from sqlalchemy.orm import Session -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod from dify_graph.model_runtime.entities.model_entities import ModelType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -63,7 +63,7 @@ class DatasetServiceIntegrationDataFactory: name: str = "Test Dataset", description: str | None = "Test description", provider: str = "vendor", - indexing_technique: str | None = "high_quality", + indexing_technique: str | None = IndexTechniqueType.HIGH_QUALITY, permission: str = DatasetPermissionEnum.ONLY_ME, retrieval_model: dict | None = None, embedding_model_provider: str | None = None, @@ -157,13 +157,13 @@ class TestDatasetServiceCreateDataset: tenant_id=tenant.id, name="Economy Dataset", description=None, - indexing_technique="economy", + indexing_technique=IndexTechniqueType.ECONOMY, account=account, ) # Assert db_session_with_containers.refresh(result) - assert result.indexing_technique == "economy" + assert result.indexing_technique == IndexTechniqueType.ECONOMY assert result.embedding_model_provider is None assert result.embedding_model is None @@ -181,13 +181,13 @@ class TestDatasetServiceCreateDataset: tenant_id=tenant.id, name="High Quality Dataset", description=None, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, account=account, ) # Assert db_session_with_containers.refresh(result) - assert result.indexing_technique == "high_quality" + assert result.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert result.embedding_model_provider == embedding_model.provider assert result.embedding_model == embedding_model.model_name mock_model_manager.return_value.get_default_model_instance.assert_called_once_with( @@ -273,7 +273,7 @@ class TestDatasetServiceCreateDataset: tenant_id=tenant.id, name="Dataset With Reranking", description=None, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, account=account, retrieval_model=retrieval_model, ) @@ -306,7 +306,7 @@ class TestDatasetServiceCreateDataset: tenant_id=tenant.id, name="Custom Embedding Dataset", description=None, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, account=account, embedding_model_provider=embedding_provider, embedding_model_name=embedding_model_name, @@ -314,7 +314,7 @@ class TestDatasetServiceCreateDataset: # Assert db_session_with_containers.refresh(result) - assert result.indexing_technique == "high_quality" + assert result.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert result.embedding_model_provider == embedding_provider assert result.embedding_model == embedding_model_name mock_check_embedding.assert_called_once_with(tenant.id, embedding_provider, embedding_model_name) @@ -589,7 +589,7 @@ class TestDatasetServiceUpdateAndDeleteDataset: db_session_with_containers, tenant_id=tenant.id, created_by=account.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, chunk_structure="text_model", ) DatasetServiceIntegrationDataFactory.create_document( @@ -685,14 +685,14 @@ class TestDatasetServiceRetrievalConfiguration: db_session_with_containers, tenant_id=tenant.id, created_by=account.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, retrieval_model={"search_method": "semantic_search", "top_k": 2, "score_threshold": 0.0}, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=str(uuid4()), ) update_data = { - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "retrieval_model": { "search_method": "full_text_search", "top_k": 10, diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py index 807d18322c..3cac964d89 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py @@ -3,7 +3,7 @@ from unittest.mock import patch from uuid import uuid4 -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document from models.enums import DataSourceType, DocumentCreatedFrom @@ -109,7 +109,7 @@ class TestDatasetServiceDeleteDataset: db_session_with_containers, tenant_id=tenant.id, created_by=owner.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, chunk_structure=None, index_struct='{"type": "paragraph"}', collection_binding_id=str(uuid4()), @@ -208,7 +208,7 @@ class TestDatasetServiceDeleteDataset: db_session_with_containers, tenant_id=tenant.id, created_by=owner.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, chunk_structure=None, index_struct='{"type": "paragraph"}', collection_binding_id=str(uuid4()), diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py index c4b3a57bb2..87239b2cb3 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py @@ -12,6 +12,7 @@ from uuid import uuid4 from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom @@ -64,7 +65,7 @@ class SegmentServiceTestDataFactory: name=f"Test Dataset {uuid4()}", description="Test description", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=created_by, permission=DatasetPermissionEnum.ONLY_ME, provider="vendor", diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py index 3021d8984d..2f90d16176 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py @@ -15,6 +15,7 @@ from uuid import uuid4 from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import ( AppDatasetJoin, @@ -102,7 +103,7 @@ class DatasetRetrievalTestDataFactory: name=name, description="desc", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=created_by, permission=permission, provider="vendor", diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py index fd81948247..2899d5b8a5 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py @@ -4,6 +4,7 @@ from uuid import uuid4 import pytest from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexTechniqueType from dify_graph.model_runtime.entities.model_entities import ModelType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, ExternalKnowledgeBindings @@ -53,7 +54,7 @@ class DatasetUpdateTestDataFactory: provider: str = "vendor", name: str = "old_name", description: str = "old_description", - indexing_technique: str = "high_quality", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, retrieval_model: str = "old_model", permission: str = "only_me", embedding_model_provider: str | None = None, @@ -241,7 +242,7 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=existing_binding_id, @@ -250,7 +251,7 @@ class TestDatasetServiceUpdateDataset: update_data = { "name": "new_name", "description": "new_description", - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "retrieval_model": "new_model", "embedding_model_provider": "openai", "embedding_model": "text-embedding-ada-002", @@ -261,7 +262,7 @@ class TestDatasetServiceUpdateDataset: assert dataset.name == "new_name" assert dataset.description == "new_description" - assert dataset.indexing_technique == "high_quality" + assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert dataset.retrieval_model == "new_model" assert dataset.embedding_model_provider == "openai" assert dataset.embedding_model == "text-embedding-ada-002" @@ -276,7 +277,7 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=existing_binding_id, @@ -285,7 +286,7 @@ class TestDatasetServiceUpdateDataset: update_data = { "name": "new_name", "description": None, - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "retrieval_model": "new_model", "embedding_model_provider": None, "embedding_model": None, @@ -312,14 +313,14 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=existing_binding_id, ) update_data = { - "indexing_technique": "economy", + "indexing_technique": IndexTechniqueType.ECONOMY, "retrieval_model": "new_model", } @@ -328,7 +329,7 @@ class TestDatasetServiceUpdateDataset: mock_task.delay.assert_called_once_with(dataset.id, "remove") db_session_with_containers.refresh(dataset) - assert dataset.indexing_technique == "economy" + assert dataset.indexing_technique == IndexTechniqueType.ECONOMY assert dataset.embedding_model is None assert dataset.embedding_model_provider is None assert dataset.collection_binding_id is None @@ -343,7 +344,7 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="economy", + indexing_technique=IndexTechniqueType.ECONOMY, ) embedding_model = Mock() @@ -354,7 +355,7 @@ class TestDatasetServiceUpdateDataset: binding.id = str(uuid4()) update_data = { - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "embedding_model_provider": "openai", "embedding_model": "text-embedding-ada-002", "retrieval_model": "new_model", @@ -383,7 +384,7 @@ class TestDatasetServiceUpdateDataset: mock_task.delay.assert_called_once_with(dataset.id, "add") db_session_with_containers.refresh(dataset) - assert dataset.indexing_technique == "high_quality" + assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert dataset.embedding_model == "text-embedding-ada-002" assert dataset.embedding_model_provider == "openai" assert dataset.collection_binding_id == binding.id @@ -403,7 +404,7 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=existing_binding_id, @@ -411,7 +412,7 @@ class TestDatasetServiceUpdateDataset: update_data = { "name": "new_name", - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "retrieval_model": "new_model", } @@ -419,7 +420,7 @@ class TestDatasetServiceUpdateDataset: db_session_with_containers.refresh(dataset) assert dataset.name == "new_name" - assert dataset.indexing_technique == "high_quality" + assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert dataset.embedding_model_provider == "openai" assert dataset.embedding_model == "text-embedding-ada-002" assert dataset.collection_binding_id == existing_binding_id @@ -435,7 +436,7 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=existing_binding_id, @@ -449,7 +450,7 @@ class TestDatasetServiceUpdateDataset: binding.id = str(uuid4()) update_data = { - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "embedding_model_provider": "openai", "embedding_model": "text-embedding-3-small", "retrieval_model": "new_model", @@ -531,11 +532,11 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="economy", + indexing_technique=IndexTechniqueType.ECONOMY, ) update_data = { - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "embedding_model_provider": "invalid_provider", "embedding_model": "invalid_model", "retrieval_model": "new_model", diff --git a/api/tests/test_containers_integration_tests/services/test_tag_service.py b/api/tests/test_containers_integration_tests/services/test_tag_service.py index 1a72e3b6c2..f504f35589 100644 --- a/api/tests/test_containers_integration_tests/services/test_tag_service.py +++ b/api/tests/test_containers_integration_tests/services/test_tag_service.py @@ -7,6 +7,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset from models.enums import DataSourceType, TagType @@ -102,7 +103,7 @@ class TestTagService: provider="vendor", permission="only_me", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, tenant_id=tenant_id, created_by=mock_external_service_dependencies["current_user"].id, ) diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py index 84ce6364df..880143013e 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import json import uuid from datetime import UTC, datetime, timedelta +from types import SimpleNamespace from unittest.mock import patch import pytest @@ -8,14 +11,14 @@ from faker import Faker from sqlalchemy.orm import Session from dify_graph.entities.workflow_execution import WorkflowExecutionStatus -from models import EndUser, Workflow, WorkflowAppLog, WorkflowRun -from models.enums import CreatorUserRole +from models import EndUser, Workflow, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun +from models.enums import AppTriggerType, CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowAppLogCreatedFrom from services.account_service import AccountService, TenantService # Delay import of AppService to avoid circular dependency # from services.app_service import AppService -from services.workflow_app_service import WorkflowAppService +from services.workflow_app_service import LogView, WorkflowAppService from tests.test_containers_integration_tests.helpers import generate_valid_password @@ -1525,3 +1528,168 @@ class TestWorkflowAppService: # Should not find tenant2's data when searching from tenant1's context assert result_cross_tenant["total"] == 0 + + def test_get_paginate_workflow_app_logs_raises_when_account_filter_email_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + service = WorkflowAppService() + + with pytest.raises(ValueError, match="Account not found: nonexistent@example.com"): + service.get_paginate_workflow_app_logs( + session=db_session_with_containers, + app_model=app, + created_by_account="nonexistent@example.com", + ) + + def test_get_paginate_workflow_app_logs_filters_by_account( + self, db_session_with_containers, mock_external_service_dependencies + ): + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + service = WorkflowAppService() + workflow, workflow_run, _log = self._create_test_workflow_data(db_session_with_containers, app, account) + + result = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, + app_model=app, + created_by_account=account.email, + ) + + assert result["total"] >= 0 + assert isinstance(result["data"], list) + + def test_get_paginate_workflow_archive_logs(self, db_session_with_containers, mock_external_service_dependencies): + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + service = WorkflowAppService() + + end_user = EndUser( + tenant_id=app.tenant_id, + app_id=app.id, + type="browser", + is_anonymous=False, + session_id="session-1", + ) + db_session_with_containers.add(end_user) + db_session_with_containers.commit() + + now = datetime.now(UTC) + archive_defaults = { + "workflow_id": str(uuid.uuid4()), + "run_version": "1.0.0", + "run_status": WorkflowExecutionStatus.SUCCEEDED, + "run_triggered_from": WorkflowRunTriggeredFrom.APP_RUN, + "run_error": None, + "run_elapsed_time": 1.0, + "run_total_tokens": 0, + "run_total_steps": 0, + "run_created_at": now, + "run_finished_at": now, + "run_exceptions_count": 0, + "trigger_metadata": '{"type":"trigger-webhook"}', + "log_created_at": now, + "log_created_from": WorkflowAppLogCreatedFrom.SERVICE_API, + } + archive_account = WorkflowArchiveLog( + tenant_id=app.tenant_id, + app_id=app.id, + workflow_run_id=str(uuid.uuid4()), + log_id=str(uuid.uuid4()), + created_by=account.id, + created_by_role=CreatorUserRole.ACCOUNT, + **archive_defaults, + ) + archive_end_user = WorkflowArchiveLog( + tenant_id=app.tenant_id, + app_id=app.id, + workflow_run_id=str(uuid.uuid4()), + log_id=str(uuid.uuid4()), + created_by=end_user.id, + created_by_role=CreatorUserRole.END_USER, + **archive_defaults, + ) + db_session_with_containers.add_all([archive_account, archive_end_user]) + db_session_with_containers.commit() + + result = service.get_paginate_workflow_archive_logs( + session=db_session_with_containers, + app_model=app, + page=1, + limit=20, + ) + + assert result["total"] == 2 + assert len(result["data"]) == 2 + account_item = next(d for d in result["data"] if d["created_by_account"] is not None) + end_user_item = next(d for d in result["data"] if d["created_by_end_user"] is not None) + assert account_item["created_by_account"].id == account.id + assert end_user_item["created_by_end_user"].id == end_user.id + + +class TestLogView: + def test_details_and_proxy_attributes(self): + log = SimpleNamespace(id="log-1", status="succeeded") + view = LogView(log=log, details={"trigger_metadata": {"type": "plugin"}}) + + assert view.details == {"trigger_metadata": {"type": "plugin"}} + assert view.status == "succeeded" + + +class TestHandleTriggerMetadata: + def test_returns_empty_dict_when_metadata_missing(self): + service = WorkflowAppService() + assert service.handle_trigger_metadata("tenant-1", None) == {} + + def test_enriches_plugin_icons(self): + service = WorkflowAppService() + meta = { + "type": AppTriggerType.TRIGGER_PLUGIN.value, + "icon_filename": "light.png", + "icon_dark_filename": "dark.png", + } + with patch( + "services.workflow_app_service.PluginService.get_plugin_icon_url", + side_effect=["https://cdn/light.png", "https://cdn/dark.png"], + ) as mock_icon: + result = service.handle_trigger_metadata("tenant-1", json.dumps(meta)) + + assert result["icon"] == "https://cdn/light.png" + assert result["icon_dark"] == "https://cdn/dark.png" + assert mock_icon.call_count == 2 + + def test_non_plugin_metadata_without_icon_lookup(self): + service = WorkflowAppService() + meta = {"type": AppTriggerType.TRIGGER_WEBHOOK.value} + with patch("services.workflow_app_service.PluginService.get_plugin_icon_url") as mock_icon: + result = service.handle_trigger_metadata("tenant-1", json.dumps(meta)) + + assert result["type"] == AppTriggerType.TRIGGER_WEBHOOK.value + mock_icon.assert_not_called() + + +class TestSafeJsonLoads: + @pytest.mark.parametrize( + ("value", "expected"), + [ + (None, None), + ("", None), + ('{"k":"v"}', {"k": "v"}), + ("not-json", None), + ({"raw": True}, {"raw": True}), + ], + ) + def test_handles_various_inputs(self, value, expected): + assert WorkflowAppService._safe_json_loads(value) == expected + + +class TestSafeParseUuid: + def test_returns_none_for_short_or_invalid_values(self): + service = WorkflowAppService() + assert service._safe_parse_uuid("short") is None + assert service._safe_parse_uuid("x" * 40) is None + + def test_returns_uuid_for_valid_string(self): + service = WorkflowAppService() + raw = str(uuid.uuid4()) + result = service._safe_parse_uuid(raw) + assert result is not None + assert str(result) == raw diff --git a/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py b/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py index 7ab059bb75..2dc50cc720 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py @@ -1,12 +1,24 @@ +from __future__ import annotations + from unittest.mock import Mock, patch import pytest from faker import Faker from sqlalchemy.orm import Session -from core.tools.entities.api_entities import ToolProviderApiEntity +from core.tools.__base.tool import Tool +from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolProviderType +from core.tools.entities.tool_entities import ( + ApiProviderSchemaType, + ToolDescription, + ToolEntity, + ToolIdentity, + ToolParameter, + ToolProviderEntity, + ToolProviderIdentity, + ToolProviderType, +) from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider from services.plugin.plugin_service import PluginService from services.tools.tools_transform_service import ToolTransformService @@ -786,3 +798,192 @@ class TestToolTransformService: assert result is not None assert result == mock_controller mock_from_db.assert_called_once_with(provider) + + +def _mock_tool(*, base_params, runtime_params): + """Helper to build a Mock tool with real entity objects. + + Tool is abstract and requires runtime behaviour (fork_tool_runtime, + get_runtime_parameters), so it stays as a Mock. Everything else uses + real Pydantic instances. + """ + entity = ToolEntity( + identity=ToolIdentity( + author="test_author", + name="test_tool", + label=I18nObject(en_US="Test Tool"), + provider="test_provider", + ), + parameters=base_params or [], + description=ToolDescription( + human=I18nObject(en_US="Test description"), + llm="Test description for LLM", + ), + output_schema={}, + ) + mock_tool = Mock(spec=Tool) + mock_tool.entity = entity + mock_tool.get_runtime_parameters.return_value = runtime_params + mock_tool.fork_tool_runtime.return_value = mock_tool + return mock_tool + + +def _param(name, *, form=ToolParameter.ToolParameterForm.FORM, label=None): + return ToolParameter( + name=name, + label=I18nObject(en_US=label or name), + human_description=I18nObject(en_US=name), + type=ToolParameter.ToolParameterType.STRING, + form=form, + ) + + +class TestConvertToolEntityToApiEntity: + """Tests for ToolTransformService.convert_tool_entity_to_api_entity.""" + + def test_parameter_override(self): + base = [_param("param1", label="Base 1"), _param("param2", label="Base 2")] + runtime = [_param("param1", label="Runtime 1")] + tool = _mock_tool(base_params=base, runtime_params=runtime) + + result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None) + + assert isinstance(result, ToolApiEntity) + assert len(result.parameters) == 2 + assert next(p for p in result.parameters if p.name == "param1").label.en_US == "Runtime 1" + assert next(p for p in result.parameters if p.name == "param2").label.en_US == "Base 2" + + def test_additional_runtime_parameters(self): + base = [_param("param1", label="Base 1")] + runtime = [_param("param1", label="Runtime 1"), _param("runtime_only", label="Runtime Only")] + tool = _mock_tool(base_params=base, runtime_params=runtime) + + result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None) + + assert len(result.parameters) == 2 + names = [p.name for p in result.parameters] + assert "param1" in names + assert "runtime_only" in names + + def test_non_form_runtime_parameters_excluded(self): + base = [_param("param1")] + runtime = [ + _param("param1", label="Runtime 1"), + _param("llm_param", form=ToolParameter.ToolParameterForm.LLM), + ] + tool = _mock_tool(base_params=base, runtime_params=runtime) + + result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None) + + assert len(result.parameters) == 1 + assert result.parameters[0].name == "param1" + + def test_empty_parameters(self): + tool = _mock_tool(base_params=[], runtime_params=[]) + + result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None) + + assert isinstance(result, ToolApiEntity) + assert len(result.parameters) == 0 + + def test_none_parameters(self): + tool = _mock_tool(base_params=None, runtime_params=[]) + + result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None) + + assert isinstance(result, ToolApiEntity) + assert len(result.parameters) == 0 + + def test_parameter_order_preserved(self): + base = [_param("p1", label="B1"), _param("p2", label="B2"), _param("p3", label="B3")] + runtime = [_param("p2", label="R2"), _param("p4", label="R4")] + tool = _mock_tool(base_params=base, runtime_params=runtime) + + result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None) + + assert [p.name for p in result.parameters] == ["p1", "p2", "p3", "p4"] + assert result.parameters[1].label.en_US == "R2" + + +class TestWorkflowProviderToUserProvider: + """Tests for ToolTransformService.workflow_provider_to_user_provider.""" + + @staticmethod + def _make_controller(provider_id="provider_123", **identity_overrides): + from core.tools.workflow_as_tool.provider import WorkflowToolProviderController + + defaults = { + "author": "test_author", + "name": "test_workflow_tool", + "description": I18nObject(en_US="Test description"), + "icon": '{"type": "emoji", "content": "🔧"}', + "icon_dark": None, + "label": I18nObject(en_US="Test Workflow Tool"), + } + defaults.update(identity_overrides) + identity = ToolProviderIdentity(**defaults) + entity = ToolProviderEntity(identity=identity) + return WorkflowToolProviderController(entity=entity, provider_id=provider_id) + + def test_with_workflow_app_id(self): + ctrl = self._make_controller() + + result = ToolTransformService.workflow_provider_to_user_provider( + provider_controller=ctrl, + labels=["l1", "l2"], + workflow_app_id="app_123", + ) + + assert isinstance(result, ToolProviderApiEntity) + assert result.id == "provider_123" + assert result.type == ToolProviderType.WORKFLOW + assert result.workflow_app_id == "app_123" + assert result.labels == ["l1", "l2"] + assert result.is_team_authorization is True + + def test_without_workflow_app_id(self): + ctrl = self._make_controller() + + result = ToolTransformService.workflow_provider_to_user_provider( + provider_controller=ctrl, + labels=["l1"], + ) + + assert result.workflow_app_id is None + + def test_workflow_app_id_none_explicit(self): + ctrl = self._make_controller() + + result = ToolTransformService.workflow_provider_to_user_provider( + provider_controller=ctrl, + labels=None, + workflow_app_id=None, + ) + + assert result.workflow_app_id is None + assert result.labels == [] + + def test_preserves_other_fields(self): + ctrl = self._make_controller( + "provider_456", + author="another_author", + name="another_workflow_tool", + description=I18nObject(en_US="Another desc", zh_Hans="Another desc"), + icon='{"type": "emoji", "content": "⚙️"}', + icon_dark='{"type": "emoji", "content": "🔧"}', + label=I18nObject(en_US="Another Tool", zh_Hans="Another Tool"), + ) + + result = ToolTransformService.workflow_provider_to_user_provider( + provider_controller=ctrl, + labels=["automation"], + workflow_app_id="app_456", + ) + + assert result.id == "provider_456" + assert result.author == "another_author" + assert result.name == "another_workflow_tool" + assert result.type == ToolProviderType.WORKFLOW + assert result.workflow_app_id == "app_456" + assert result.is_team_authorization is True + assert result.allow_delete is True diff --git a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py index 94173c34bf..4b04c1accb 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py @@ -4,7 +4,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetAutoDisableLog, Document, DocumentSegment @@ -81,7 +81,7 @@ class TestAddDocumentToIndexTask: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py index 5ebf141828..d2e343ef52 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py @@ -19,7 +19,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.storage.storage_type import StorageType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -142,7 +142,7 @@ class TestBatchCreateSegmentToIndexTask: name=fake.company(), description=fake.text(), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model="text-embedding-ada-002", embedding_model_provider="openai", created_by=account.id, diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py index 9449fee0af..1dd37fbc92 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py @@ -18,7 +18,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.storage.storage_type import StorageType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import ( @@ -154,7 +154,7 @@ class TestCleanDatasetTask: tenant_id=tenant.id, name="test_dataset", description="Test dataset for cleanup testing", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph"}', collection_binding_id=str(uuid.uuid4()), created_by=account.id, @@ -870,7 +870,7 @@ class TestCleanDatasetTask: tenant_id=tenant.id, name=long_name, description=long_description, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph", "max_length": 10000}', collection_binding_id=str(uuid.uuid4()), created_by=account.id, diff --git a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py index 979435282b..9f8e37fc9e 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py @@ -12,7 +12,7 @@ from uuid import uuid4 import pytest from faker import Faker -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -121,7 +121,7 @@ class TestCreateSegmentToIndexTask: description=fake.text(max_nb_chars=100), tenant_id=tenant_id, data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", created_by=account_id, diff --git a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py index 67f9dc7011..13ea94348a 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py @@ -8,6 +8,7 @@ import pytest from faker import Faker from core.indexing_runner import DocumentIsPausedError +from core.rag.index_processor.constant.index_type import IndexTechniqueType from enums.cloud_plan import CloudPlan from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document @@ -141,7 +142,7 @@ class TestDatasetIndexingTaskIntegration: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py index 6fc2a53f9c..8a69707b38 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py @@ -12,7 +12,7 @@ from unittest.mock import MagicMock, patch from faker import Faker -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models import Account, Dataset, Document, DocumentSegment, Tenant from models.enums import DataSourceType, DocumentCreatedFrom, DocumentDocType, IndexingStatus, SegmentStatus from tasks.delete_segment_from_index_task import delete_segment_from_index_task @@ -108,7 +108,7 @@ class TestDeleteSegmentFromIndexTask: dataset.provider = "vendor" dataset.permission = "only_me" dataset.data_source_type = DataSourceType.UPLOAD_FILE - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.index_struct = '{"type": "paragraph"}' dataset.created_by = account.id dataset.created_at = fake.date_time_this_year() diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py index d21f1daf23..5bdf7d1389 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py @@ -15,7 +15,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -100,7 +100,7 @@ class TestDisableSegmentFromIndexTask: name=fake.sentence(nb_words=3), description=fake.text(max_nb_chars=200), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py index fbcb7b5264..3e9a0c8f7f 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py @@ -11,7 +11,7 @@ from unittest.mock import MagicMock, patch from faker import Faker from sqlalchemy.orm import Session -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models import Account, Dataset, DocumentSegment from models import Document as DatasetDocument from models.dataset import DatasetProcessRule @@ -103,7 +103,7 @@ class TestDisableSegmentsFromIndexTask: provider="vendor", permission="only_me", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, updated_by=account.id, embedding_model="text-embedding-ada-002", diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py index 10d97919fb..d4021143ef 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py @@ -14,7 +14,7 @@ from uuid import uuid4 import pytest from core.indexing_runner import DocumentIsPausedError, IndexingRunner -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus @@ -57,7 +57,7 @@ class DocumentIndexingSyncTaskTestDataFactory: name=f"dataset-{uuid4()}", description="sync test dataset", data_source_type=DataSourceType.NOTION_IMPORT, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=created_by, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py index 9421b07285..cf1a8666f3 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py @@ -5,6 +5,7 @@ import pytest from faker import Faker from core.entities.document_task import DocumentTask +from core.rag.index_processor.constant.index_type import IndexTechniqueType from enums.cloud_plan import CloudPlan from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document @@ -99,7 +100,7 @@ class TestDocumentIndexingTasks: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -181,7 +182,7 @@ class TestDocumentIndexingTasks: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py index c650d56091..d94abf2b40 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus @@ -64,7 +64,7 @@ class TestDocumentIndexingUpdateTask: name=fake.company(), description=fake.text(max_nb_chars=64), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py index 76b6a8ae73..6a8e186958 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py @@ -4,7 +4,7 @@ import pytest from faker import Faker from core.indexing_runner import DocumentIsPausedError -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from enums.cloud_plan import CloudPlan from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -110,7 +110,7 @@ class TestDuplicateDocumentIndexingTasks: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -245,7 +245,7 @@ class TestDuplicateDocumentIndexingTasks: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py index 54b50016a8..e2f35067e3 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py @@ -4,7 +4,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -81,7 +81,7 @@ class TestEnableSegmentsToIndexTask: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py index 68a7b30b9e..ff565f19fd 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py @@ -1476,8 +1476,8 @@ class TestDatasetIndexingStatusApi: return_value=MagicMock(all=lambda: [document]), ), patch( - "controllers.console.datasets.datasets.db.session.query", - return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 3)), + "controllers.console.datasets.datasets.db.session.scalar", + return_value=3, ), ): response, status = method(api, "dataset-1") @@ -1526,13 +1526,6 @@ class TestDatasetIndexingStatusApi: document.error = None document.stopped_at = None - # First count = completed segments, second = total segments - query_mock = MagicMock() - query_mock.where.side_effect = [ - MagicMock(count=lambda: 2), - MagicMock(count=lambda: 5), - ] - with ( app.test_request_context("/"), patch( @@ -1544,8 +1537,8 @@ class TestDatasetIndexingStatusApi: return_value=MagicMock(all=lambda: [document]), ), patch( - "controllers.console.datasets.datasets.db.session.query", - return_value=query_mock, + "controllers.console.datasets.datasets.db.session.scalar", + side_effect=[2, 5], ), ): response, status = method(api, "dataset-1") @@ -1591,8 +1584,8 @@ class TestDatasetApiKeyApi: return_value=(MagicMock(), "tenant-1"), ), patch( - "controllers.console.datasets.datasets.db.session.query", - return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 3)), + "controllers.console.datasets.datasets.db.session.scalar", + return_value=3, ), patch( "controllers.console.datasets.datasets.ApiToken.generate_api_key", @@ -1625,8 +1618,8 @@ class TestDatasetApiKeyApi: return_value=(MagicMock(), "tenant-1"), ), patch( - "controllers.console.datasets.datasets.db.session.query", - return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 10)), + "controllers.console.datasets.datasets.db.session.scalar", + return_value=10, ), ): with pytest.raises(BadRequest) as exc_info: @@ -1653,8 +1646,8 @@ class TestDatasetApiDeleteApi: return_value=(MagicMock(), "tenant-1"), ), patch( - "controllers.console.datasets.datasets.db.session.query", - return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(first=lambda: mock_key)), + "controllers.console.datasets.datasets.db.session.scalar", + return_value=mock_key, ), patch( "controllers.console.datasets.datasets.db.session.commit", @@ -1681,8 +1674,8 @@ class TestDatasetApiDeleteApi: return_value=(MagicMock(), "tenant-1"), ), patch( - "controllers.console.datasets.datasets.db.session.query", - return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(first=lambda: None)), + "controllers.console.datasets.datasets.db.session.scalar", + return_value=None, ), ): with pytest.raises(NotFound): diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py index 1482499c41..306a772fd1 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py @@ -526,8 +526,8 @@ class TestDatasetDocumentSegmentUpdateApi: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=segment, ), patch( "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", @@ -621,8 +621,8 @@ class TestDatasetDocumentSegmentBatchImportApi: return_value=MagicMock(), ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=upload_file, ), patch( "controllers.console.datasets.datasets_segments.redis_client.setnx", @@ -706,8 +706,8 @@ class TestDatasetDocumentSegmentBatchImportApi: return_value=MagicMock(), ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: None)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=None, ), ): with pytest.raises(NotFound): @@ -738,8 +738,8 @@ class TestDatasetDocumentSegmentBatchImportApi: return_value=MagicMock(), ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=upload_file, ), ): with pytest.raises(ValueError): @@ -770,8 +770,8 @@ class TestDatasetDocumentSegmentBatchImportApi: return_value=MagicMock(), ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=upload_file, ), patch( "controllers.console.datasets.datasets_segments.redis_client.setnx", @@ -831,8 +831,8 @@ class TestChildChunkAddApi: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=segment, ), patch( "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", @@ -880,8 +880,8 @@ class TestChildChunkAddApi: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=segment, ), patch( "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", @@ -924,11 +924,8 @@ class TestChildChunkUpdateApi: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - side_effect=[ - MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), - MagicMock(where=lambda *a, **k: MagicMock(first=lambda: child_chunk)), - ], + "controllers.console.datasets.datasets_segments.db.session.scalar", + side_effect=[segment, child_chunk], ), patch( "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", @@ -970,11 +967,8 @@ class TestChildChunkUpdateApi: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - side_effect=[ - MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), - MagicMock(where=lambda *a, **k: MagicMock(first=lambda: child_chunk)), - ], + "controllers.console.datasets.datasets_segments.db.session.scalar", + side_effect=[segment, child_chunk], ), patch( "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", @@ -1180,8 +1174,8 @@ class TestSegmentOperationCases: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=upload_file, ), ): with pytest.raises(NotFound): @@ -1215,8 +1209,8 @@ class TestSegmentOperationCases: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=upload_file, ), patch( "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py index e6cc582398..2c234edd9a 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py @@ -4,6 +4,7 @@ from unittest.mock import Mock, patch import pytest from core.entities.knowledge_entities import PreviewDetail +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor from core.rag.models.document import AttachmentDocument, Document from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage @@ -21,7 +22,7 @@ class TestParagraphIndexProcessor: dataset = Mock() dataset.id = "dataset-1" dataset.tenant_id = "tenant-1" - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.is_multimodal = True return dataset @@ -167,7 +168,7 @@ class TestParagraphIndexProcessor: def test_load_uses_keyword_add_texts_with_keywords_when_economy( self, processor: ParagraphIndexProcessor, dataset: Mock ) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY docs = [Document(page_content="chunk", metadata={})] with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls: @@ -178,7 +179,7 @@ class TestParagraphIndexProcessor: def test_load_uses_keyword_add_texts_without_keywords_when_economy( self, processor: ParagraphIndexProcessor, dataset: Mock ) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY docs = [Document(page_content="chunk", metadata={})] with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls: @@ -208,7 +209,7 @@ class TestParagraphIndexProcessor: def test_clean_economy_deletes_summaries_and_keywords( self, processor: ParagraphIndexProcessor, dataset: Mock ) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY with ( patch( @@ -222,7 +223,7 @@ class TestParagraphIndexProcessor: mock_keyword_cls.return_value.delete.assert_called_once() def test_clean_deletes_keywords_by_ids(self, processor: ParagraphIndexProcessor, dataset: Mock) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls: processor.clean(dataset, ["node-2"], with_keywords=True) @@ -267,7 +268,7 @@ class TestParagraphIndexProcessor: def test_index_list_chunks_economy( self, processor: ParagraphIndexProcessor, dataset: Mock, dataset_document: Mock ) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY with ( patch( "core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash", diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py index 5c78cae7c1..b1ed735ee7 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py @@ -4,6 +4,7 @@ from unittest.mock import MagicMock, Mock, patch import pytest from core.entities.knowledge_entities import PreviewDetail +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor from core.rag.models.document import AttachmentDocument, ChildDocument, Document from services.entities.knowledge_entities.knowledge_entities import ParentMode @@ -19,7 +20,7 @@ class TestParentChildIndexProcessor: dataset = Mock() dataset.id = "dataset-1" dataset.tenant_id = "tenant-1" - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.is_multimodal = True return dataset diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py index 99323eeec9..98c47bec8f 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py @@ -6,6 +6,7 @@ import pytest from werkzeug.datastructures import FileStorage from core.entities.knowledge_entities import PreviewDetail +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor from core.rag.models.document import AttachmentDocument, Document @@ -33,7 +34,7 @@ class TestQAIndexProcessor: dataset = Mock() dataset.id = "dataset-1" dataset.tenant_id = "tenant-1" - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.is_multimodal = True return dataset @@ -207,7 +208,7 @@ class TestQAIndexProcessor: vector.create_multimodal.assert_called_once_with(multimodal_docs) def test_load_skips_vector_for_non_high_quality(self, processor: QAIndexProcessor, dataset: Mock) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY docs = [Document(page_content="Q1", metadata={"answer": "A1"})] with patch("core.rag.index_processor.processor.qa_index_processor.Vector") as mock_vector_cls: @@ -298,7 +299,7 @@ class TestQAIndexProcessor: def test_index_requires_high_quality( self, processor: QAIndexProcessor, dataset: Mock, dataset_document: Mock ) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY qa_chunks = SimpleNamespace(qa_chunks=[SimpleNamespace(question="Q1", answer="A1")]) with ( diff --git a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py index b011ade884..b54a74b69c 100644 --- a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py +++ b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py @@ -61,7 +61,7 @@ from core.indexing_runner import ( DocumentIsPausedError, IndexingRunner, ) -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.models.document import ChildDocument, Document from dify_graph.model_runtime.entities.model_entities import ModelType from libs.datetime_utils import naive_utc_now @@ -76,7 +76,7 @@ from models.dataset import Document as DatasetDocument def create_mock_dataset( dataset_id: str | None = None, tenant_id: str | None = None, - indexing_technique: str = "high_quality", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, embedding_provider: str = "openai", embedding_model: str = "text-embedding-ada-002", ) -> Mock: @@ -458,7 +458,7 @@ class TestIndexingRunnerTransform: dataset = Mock(spec=Dataset) dataset.id = str(uuid.uuid4()) dataset.tenant_id = str(uuid.uuid4()) - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.embedding_model_provider = "openai" dataset.embedding_model = "text-embedding-ada-002" return dataset @@ -521,7 +521,7 @@ class TestIndexingRunnerTransform: """Test transformation with economy indexing (no embeddings).""" # Arrange runner = IndexingRunner() - sample_dataset.indexing_technique = "economy" + sample_dataset.indexing_technique = IndexTechniqueType.ECONOMY mock_processor = MagicMock() transformed_docs = [ @@ -605,7 +605,7 @@ class TestIndexingRunnerLoad: dataset = Mock(spec=Dataset) dataset.id = str(uuid.uuid4()) dataset.tenant_id = str(uuid.uuid4()) - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.embedding_model_provider = "openai" dataset.embedding_model = "text-embedding-ada-002" return dataset @@ -674,7 +674,7 @@ class TestIndexingRunnerLoad: """Test loading with economy indexing (keyword only).""" # Arrange runner = IndexingRunner() - sample_dataset.indexing_technique = "economy" + sample_dataset.indexing_technique = IndexTechniqueType.ECONOMY mock_processor = MagicMock() @@ -701,7 +701,7 @@ class TestIndexingRunnerLoad: # Arrange runner = IndexingRunner() sample_dataset_document.doc_form = IndexStructureType.PARENT_CHILD_INDEX - sample_dataset.indexing_technique = "high_quality" + sample_dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY # Add child documents for doc in sample_documents: @@ -795,7 +795,7 @@ class TestIndexingRunnerRun: mock_dataset = Mock(spec=Dataset) mock_dataset.id = doc.dataset_id mock_dataset.tenant_id = doc.tenant_id - mock_dataset.indexing_technique = "economy" + mock_dataset.indexing_technique = IndexTechniqueType.ECONOMY mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = mock_dataset mock_process_rule = Mock(spec=DatasetProcessRule) @@ -949,7 +949,7 @@ class TestIndexingRunnerRun: mock_dependencies["db"].session.get.side_effect = get_side_effect mock_dataset = Mock(spec=Dataset) - mock_dataset.indexing_technique = "economy" + mock_dataset.indexing_technique = IndexTechniqueType.ECONOMY mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = mock_dataset mock_process_rule = Mock(spec=DatasetProcessRule) diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py index 33f7ace5ab..feb560bbc3 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py @@ -5,6 +5,7 @@ from unittest.mock import Mock import pytest from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.workflow.nodes.knowledge_index.entities import KnowledgeIndexNodeData from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError from core.workflow.nodes.knowledge_index.knowledge_index_node import KnowledgeIndexNode @@ -78,7 +79,7 @@ def sample_node_data(): type="knowledge-index", chunk_structure="general_structure", index_chunk_variable_selector=["start", "chunks"], - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, summary_index_setting=None, ) diff --git a/api/tests/unit_tests/models/test_dataset_models.py b/api/tests/unit_tests/models/test_dataset_models.py index 98dd07907a..6c8a91129b 100644 --- a/api/tests/unit_tests/models/test_dataset_models.py +++ b/api/tests/unit_tests/models/test_dataset_models.py @@ -15,6 +15,7 @@ from datetime import UTC, datetime from unittest.mock import patch from uuid import uuid4 +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models.dataset import ( AppDatasetJoin, ChildChunk, @@ -67,14 +68,14 @@ class TestDatasetModelValidation: data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), description="Test description", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model="text-embedding-ada-002", embedding_model_provider="openai", ) # Assert assert dataset.description == "Test description" - assert dataset.indexing_technique == "high_quality" + assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert dataset.embedding_model == "text-embedding-ada-002" assert dataset.embedding_model_provider == "openai" @@ -86,21 +87,21 @@ class TestDatasetModelValidation: name="High Quality Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) dataset_economy = Dataset( tenant_id=str(uuid4()), name="Economy Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), - indexing_technique="economy", + indexing_technique=IndexTechniqueType.ECONOMY, ) # Assert - assert dataset_high_quality.indexing_technique == "high_quality" - assert dataset_economy.indexing_technique == "economy" - assert "high_quality" in Dataset.INDEXING_TECHNIQUE_LIST - assert "economy" in Dataset.INDEXING_TECHNIQUE_LIST + assert dataset_high_quality.indexing_technique == IndexTechniqueType.HIGH_QUALITY + assert dataset_economy.indexing_technique == IndexTechniqueType.ECONOMY + assert IndexTechniqueType.HIGH_QUALITY in Dataset.INDEXING_TECHNIQUE_LIST + assert IndexTechniqueType.ECONOMY in Dataset.INDEXING_TECHNIQUE_LIST def test_dataset_provider_validation(self): """Test dataset provider values.""" @@ -983,7 +984,7 @@ class TestModelIntegration: name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) dataset.id = dataset_id @@ -1019,7 +1020,7 @@ class TestModelIntegration: assert document.dataset_id == dataset_id assert segment.dataset_id == dataset_id assert segment.document_id == document_id - assert dataset.indexing_technique == "high_quality" + assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert document.word_count == 100 assert segment.status == SegmentStatus.COMPLETED diff --git a/api/tests/unit_tests/services/dataset_service_update_delete.py b/api/tests/unit_tests/services/dataset_service_update_delete.py index c805dd98e2..424ac18870 100644 --- a/api/tests/unit_tests/services/dataset_service_update_delete.py +++ b/api/tests/unit_tests/services/dataset_service_update_delete.py @@ -97,6 +97,7 @@ from unittest.mock import Mock, create_autospec, patch import pytest from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models import Account, TenantAccountRole from models.dataset import ( AppDatasetJoin, @@ -149,7 +150,7 @@ class DatasetUpdateDeleteTestDataFactory: name: str = "Test Dataset", description: str = "Test description", tenant_id: str = "tenant-123", - indexing_technique: str = "high_quality", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, embedding_model_provider: str | None = "openai", embedding_model: str | None = "text-embedding-ada-002", collection_binding_id: str | None = "binding-123", @@ -237,7 +238,7 @@ class DatasetUpdateDeleteTestDataFactory: @staticmethod def create_knowledge_configuration_mock( chunk_structure: str = "tree", - indexing_technique: str = "high_quality", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, embedding_model_provider: str = "openai", embedding_model: str = "text-embedding-ada-002", keyword_number: int = 10, @@ -630,12 +631,12 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings: dataset_id="dataset-123", runtime_mode="rag_pipeline", chunk_structure="tree", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) knowledge_config = DatasetUpdateDeleteTestDataFactory.create_knowledge_configuration_mock( chunk_structure="list", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", ) @@ -671,7 +672,7 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings: # Assert assert dataset.chunk_structure == "list" - assert dataset.indexing_technique == "high_quality" + assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert dataset.embedding_model == "text-embedding-ada-002" assert dataset.embedding_model_provider == "openai" assert dataset.collection_binding_id == "binding-123" @@ -698,12 +699,12 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings: dataset_id="dataset-123", runtime_mode="rag_pipeline", chunk_structure="tree", # Existing structure - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) knowledge_config = DatasetUpdateDeleteTestDataFactory.create_knowledge_configuration_mock( chunk_structure="list", # Different structure - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) mock_session.merge.return_value = dataset @@ -735,11 +736,11 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings: dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock( dataset_id="dataset-123", runtime_mode="rag_pipeline", - indexing_technique="high_quality", # Current technique + indexing_technique=IndexTechniqueType.HIGH_QUALITY, # Current technique ) knowledge_config = DatasetUpdateDeleteTestDataFactory.create_knowledge_configuration_mock( - indexing_technique="economy", # Trying to change to economy + indexing_technique=IndexTechniqueType.ECONOMY, # Trying to change to economy ) mock_session.merge.return_value = dataset diff --git a/api/tests/unit_tests/services/document_service_validation.py b/api/tests/unit_tests/services/document_service_validation.py index 1f68ff6b3d..49fdc5cc9b 100644 --- a/api/tests/unit_tests/services/document_service_validation.py +++ b/api/tests/unit_tests/services/document_service_validation.py @@ -111,7 +111,7 @@ from unittest.mock import Mock, patch import pytest from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from dify_graph.model_runtime.entities.model_entities import ModelType from models.dataset import Dataset, DatasetProcessRule, Document from services.dataset_service import DatasetService, DocumentService @@ -154,7 +154,7 @@ class DocumentValidationTestDataFactory: dataset_id: str = "dataset-123", tenant_id: str = "tenant-123", doc_form: str | None = None, - indexing_technique: str = "high_quality", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, embedding_model_provider: str = "openai", embedding_model: str = "text-embedding-ada-002", **kwargs, @@ -190,7 +190,7 @@ class DocumentValidationTestDataFactory: data_source: DataSource | None = None, process_rule: ProcessRule | None = None, doc_form: str = IndexStructureType.PARAGRAPH_INDEX, - indexing_technique: str = "high_quality", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, **kwargs, ) -> Mock: """ @@ -448,7 +448,7 @@ class TestDatasetServiceCheckDatasetModelSetting: """ # Arrange dataset = DocumentValidationTestDataFactory.create_dataset_mock( - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", ) @@ -481,7 +481,7 @@ class TestDatasetServiceCheckDatasetModelSetting: - No errors are raised """ # Arrange - dataset = DocumentValidationTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = DocumentValidationTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) # Act (should not raise) DatasetService.check_dataset_model_setting(dataset) @@ -503,7 +503,7 @@ class TestDatasetServiceCheckDatasetModelSetting: """ # Arrange dataset = DocumentValidationTestDataFactory.create_dataset_mock( - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="invalid-model", ) @@ -533,7 +533,7 @@ class TestDatasetServiceCheckDatasetModelSetting: """ # Arrange dataset = DocumentValidationTestDataFactory.create_dataset_mock( - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", ) diff --git a/api/tests/unit_tests/services/segment_service.py b/api/tests/unit_tests/services/segment_service.py index 5e625fa0cd..14af7f7119 100644 --- a/api/tests/unit_tests/services/segment_service.py +++ b/api/tests/unit_tests/services/segment_service.py @@ -2,7 +2,7 @@ from unittest.mock import MagicMock, Mock, patch import pytest -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models.account import Account from models.dataset import ChildChunk, Dataset, Document, DocumentSegment from models.enums import SegmentType @@ -111,7 +111,7 @@ class SegmentTestDataFactory: def create_dataset_mock( dataset_id: str = "dataset-123", tenant_id: str = "tenant-123", - indexing_technique: str = "high_quality", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, embedding_model: str = "text-embedding-ada-002", embedding_model_provider: str = "openai", **kwargs, @@ -163,7 +163,7 @@ class TestSegmentServiceCreateSegment: """Test successful creation of a segment.""" # Arrange document = SegmentTestDataFactory.create_document_mock(word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) args = {"content": "New segment content", "keywords": ["test", "segment"]} mock_query = MagicMock() @@ -212,7 +212,7 @@ class TestSegmentServiceCreateSegment: """Test creation of segment with QA model (requires answer).""" # Arrange document = SegmentTestDataFactory.create_document_mock(doc_form=IndexStructureType.QA_INDEX, word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) args = {"content": "What is AI?", "answer": "AI is Artificial Intelligence", "keywords": ["ai"]} mock_query = MagicMock() @@ -247,7 +247,7 @@ class TestSegmentServiceCreateSegment: """Test creation of segment with high quality indexing technique.""" # Arrange document = SegmentTestDataFactory.create_document_mock(word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="high_quality") + dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) args = {"content": "New segment content", "keywords": ["test"]} mock_query = MagicMock() @@ -289,7 +289,7 @@ class TestSegmentServiceCreateSegment: """Test segment creation when vector indexing fails.""" # Arrange document = SegmentTestDataFactory.create_document_mock(word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) args = {"content": "New segment content", "keywords": ["test"]} mock_query = MagicMock() @@ -342,7 +342,7 @@ class TestSegmentServiceUpdateSegment: # Arrange segment = SegmentTestDataFactory.create_segment_mock(enabled=True, word_count=10) document = SegmentTestDataFactory.create_document_mock(word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) args = SegmentUpdateArgs(content="Updated content", keywords=["updated"]) mock_db_session.query.return_value.where.return_value.first.return_value = segment @@ -431,7 +431,7 @@ class TestSegmentServiceUpdateSegment: # Arrange segment = SegmentTestDataFactory.create_segment_mock(enabled=True, word_count=10) document = SegmentTestDataFactory.create_document_mock(doc_form=IndexStructureType.QA_INDEX, word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) args = SegmentUpdateArgs(content="Updated question", answer="Updated answer", keywords=["qa"]) mock_db_session.query.return_value.where.return_value.first.return_value = segment diff --git a/api/tests/unit_tests/services/test_advanced_prompt_template_service.py b/api/tests/unit_tests/services/test_advanced_prompt_template_service.py deleted file mode 100644 index a6bc79e82b..0000000000 --- a/api/tests/unit_tests/services/test_advanced_prompt_template_service.py +++ /dev/null @@ -1,214 +0,0 @@ -""" -Unit tests for services.advanced_prompt_template_service -""" - -import copy - -from core.prompt.prompt_templates.advanced_prompt_templates import ( - BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, - BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, - BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, - BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG, - BAICHUAN_CONTEXT, - CHAT_APP_CHAT_PROMPT_CONFIG, - CHAT_APP_COMPLETION_PROMPT_CONFIG, - COMPLETION_APP_CHAT_PROMPT_CONFIG, - COMPLETION_APP_COMPLETION_PROMPT_CONFIG, - CONTEXT, -) -from models.model import AppMode -from services.advanced_prompt_template_service import AdvancedPromptTemplateService - - -class TestAdvancedPromptTemplateService: - """Test suite for AdvancedPromptTemplateService.""" - - def test_get_prompt_should_use_baichuan_prompt_when_model_name_contains_baichuan(self) -> None: - """Test baichuan model names use baichuan context prompt.""" - # Arrange - args = { - "app_mode": AppMode.CHAT, - "model_mode": "chat", - "model_name": "Baichuan2-13B", - "has_context": "true", - } - - # Act - result = AdvancedPromptTemplateService.get_prompt(args) - - # Assert - assert result["chat_prompt_config"]["prompt"][0]["text"].startswith(BAICHUAN_CONTEXT) - - def test_get_prompt_should_use_common_prompt_when_model_name_not_baichuan(self) -> None: - """Test non-baichuan model names use common prompt.""" - # Arrange - args = { - "app_mode": AppMode.CHAT, - "model_mode": "completion", - "model_name": "gpt-4", - "has_context": "false", - } - original_config = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_prompt(args) - - # Assert - assert result == original_config - assert original_config == CHAT_APP_COMPLETION_PROMPT_CONFIG - - def test_get_common_prompt_should_return_empty_dict_when_app_mode_invalid(self) -> None: - """Test invalid app mode returns empty dict.""" - # Arrange - app_mode = "invalid" - model_mode = "chat" - - # Act - result = AdvancedPromptTemplateService.get_common_prompt(app_mode, model_mode, "true") - - # Assert - assert result == {} - - def test_get_common_prompt_should_prepend_context_for_completion_prompt(self) -> None: - """Test context is prepended for completion prompt when has_context is true.""" - # Arrange - original_config = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "completion", "true") - - # Assert - assert result["completion_prompt_config"]["prompt"]["text"].startswith(CONTEXT) - assert original_config == CHAT_APP_COMPLETION_PROMPT_CONFIG - - def test_get_common_prompt_should_prepend_context_for_chat_prompt(self) -> None: - """Test context is prepended for chat prompt when has_context is true.""" - # Arrange - original_config = copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "chat", "true") - - # Assert - assert result["chat_prompt_config"]["prompt"][0]["text"].startswith(CONTEXT) - assert original_config == COMPLETION_APP_CHAT_PROMPT_CONFIG - - def test_get_common_prompt_should_return_chat_prompt_without_context_when_has_context_false(self) -> None: - """Test chat prompt remains unchanged when has_context is false.""" - # Arrange - original_config = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "chat", "false") - - # Assert - assert result == original_config - assert original_config == CHAT_APP_CHAT_PROMPT_CONFIG - - def test_get_common_prompt_should_return_completion_prompt_for_completion_app_mode(self) -> None: - """Test completion app mode with completion model returns completion prompt.""" - # Arrange - original_config = copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "completion", "false") - - # Assert - assert result == original_config - assert original_config == COMPLETION_APP_COMPLETION_PROMPT_CONFIG - - def test_get_common_prompt_should_return_empty_dict_when_model_mode_invalid(self) -> None: - """Test invalid model mode returns empty dict.""" - # Arrange - app_mode = AppMode.CHAT - model_mode = "invalid" - - # Act - result = AdvancedPromptTemplateService.get_common_prompt(app_mode, model_mode, "false") - - # Assert - assert result == {} - - def test_get_completion_prompt_should_not_prepend_context_when_has_context_false(self) -> None: - """Test helper keeps completion prompt unchanged when context is disabled.""" - # Arrange - prompt_template = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG) - original_text = prompt_template["completion_prompt_config"]["prompt"]["text"] - - # Act - result = AdvancedPromptTemplateService.get_completion_prompt(prompt_template, "false", CONTEXT) - - # Assert - assert result["completion_prompt_config"]["prompt"]["text"] == original_text - - def test_get_chat_prompt_should_not_prepend_context_when_has_context_false(self) -> None: - """Test helper keeps chat prompt unchanged when context is disabled.""" - # Arrange - prompt_template = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG) - original_text = prompt_template["chat_prompt_config"]["prompt"][0]["text"] - - # Act - result = AdvancedPromptTemplateService.get_chat_prompt(prompt_template, "false", CONTEXT) - - # Assert - assert result["chat_prompt_config"]["prompt"][0]["text"] == original_text - - def test_get_baichuan_prompt_should_return_chat_completion_config_when_chat_completion(self) -> None: - """Test baichuan chat/completion returns the expected config.""" - # Arrange - original_config = copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "completion", "false") - - # Assert - assert result == original_config - assert original_config == BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG - - def test_get_baichuan_prompt_should_return_completion_chat_config_when_completion_chat(self) -> None: - """Test baichuan completion/chat returns the expected config.""" - # Arrange - original_config = copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "chat", "false") - - # Assert - assert result == original_config - assert original_config == BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG - - def test_get_baichuan_prompt_should_return_completion_completion_config_when_enabled_context(self) -> None: - """Test baichuan completion/completion prepends baichuan context when enabled.""" - # Arrange - original_config = copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "completion", "true") - - # Assert - assert result["completion_prompt_config"]["prompt"]["text"].startswith(BAICHUAN_CONTEXT) - assert original_config == BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG - - def test_get_baichuan_prompt_should_return_chat_chat_config_when_enabled_context(self) -> None: - """Test baichuan chat/chat prepends baichuan context when enabled.""" - # Arrange - original_config = copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "chat", "true") - - # Assert - assert result["chat_prompt_config"]["prompt"][0]["text"].startswith(BAICHUAN_CONTEXT) - assert original_config == BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG - - def test_get_baichuan_prompt_should_return_empty_dict_when_invalid_inputs(self) -> None: - """Test invalid baichuan mode combinations return empty dict.""" - # Arrange - app_mode = "invalid" - model_mode = "invalid" - - # Act - result = AdvancedPromptTemplateService.get_baichuan_prompt(app_mode, model_mode, "true") - - # Assert - assert result == {} diff --git a/api/tests/unit_tests/services/test_app_service.py b/api/tests/unit_tests/services/test_app_service.py deleted file mode 100644 index 95fc28b1e7..0000000000 --- a/api/tests/unit_tests/services/test_app_service.py +++ /dev/null @@ -1,683 +0,0 @@ -"""Unit tests for services.app_service.""" - -import json -from types import SimpleNamespace -from typing import cast -from unittest.mock import MagicMock, patch - -import pytest - -from core.errors.error import ProviderTokenNotInitError -from models import Account, Tenant -from models.model import App, AppMode, IconType -from services.app_service import AppService - - -@pytest.fixture -def service() -> AppService: - """Provide AppService instance.""" - return AppService() - - -@pytest.fixture -def account() -> Account: - """Create account object for create_app tests.""" - tenant = Tenant(name="Tenant") - tenant.id = "tenant-1" - result = Account(name="Account User", email="account@example.com") - result.id = "acc-1" - result._current_tenant = tenant - return result - - -@pytest.fixture -def default_args() -> dict: - """Create default create_app args.""" - return { - "name": "Test App", - "mode": AppMode.CHAT.value, - "icon": "🤖", - "icon_background": "#FFFFFF", - } - - -@pytest.fixture -def app_template() -> dict: - """Create basic app template for create_app tests.""" - return { - AppMode.CHAT: { - "app": {}, - "model_config": { - "model": { - "provider": "provider-a", - "name": "model-a", - "mode": "chat", - "completion_params": {}, - } - }, - } - } - - -def _make_current_user() -> Account: - user = Account(name="Tester", email="tester@example.com") - user.id = "user-1" - tenant = Tenant(name="Tenant") - tenant.id = "tenant-1" - user._current_tenant = tenant - return user - - -class TestAppServicePagination: - """Test suite for get_paginate_apps.""" - - def test_get_paginate_apps_should_return_none_when_tag_filter_empty(self, service: AppService) -> None: - """Test pagination returns None when tag filter has no targets.""" - # Arrange - args = {"mode": "chat", "page": 1, "limit": 20, "tag_ids": ["tag-1"]} - - with patch("services.app_service.TagService.get_target_ids_by_tag_ids", return_value=[]): - # Act - result = service.get_paginate_apps("user-1", "tenant-1", args) - - # Assert - assert result is None - - def test_get_paginate_apps_should_delegate_to_db_paginate(self, service: AppService) -> None: - """Test pagination delegates to db.paginate when filters are valid.""" - # Arrange - args = { - "mode": "workflow", - "is_created_by_me": True, - "name": "My_App%", - "tag_ids": ["tag-1"], - "page": 2, - "limit": 10, - } - expected_pagination = MagicMock() - - with ( - patch("services.app_service.TagService.get_target_ids_by_tag_ids", return_value=["app-1"]), - patch("libs.helper.escape_like_pattern", return_value="escaped"), - patch("services.app_service.db") as mock_db, - ): - mock_db.paginate.return_value = expected_pagination - - # Act - result = service.get_paginate_apps("user-1", "tenant-1", args) - - # Assert - assert result is expected_pagination - mock_db.paginate.assert_called_once() - - -class TestAppServiceCreate: - """Test suite for create_app.""" - - def test_create_app_should_create_with_matching_default_model( - self, - service: AppService, - account: Account, - default_args: dict, - app_template: dict, - ) -> None: - """Test create_app uses matching default model and persists app config.""" - # Arrange - app_instance = SimpleNamespace(id="app-1", tenant_id="tenant-1") - app_model_config = SimpleNamespace(id="cfg-1") - model_instance = SimpleNamespace( - model_name="model-a", - provider="provider-a", - model_type_instance=MagicMock(), - credentials={"k": "v"}, - ) - - with ( - patch("services.app_service.default_app_templates", app_template), - patch("services.app_service.App", return_value=app_instance), - patch("services.app_service.AppModelConfig", return_value=app_model_config), - patch("services.app_service.ModelManager") as mock_model_manager, - patch("services.app_service.db") as mock_db, - patch("services.app_service.app_was_created") as mock_event, - patch("services.app_service.FeatureService.get_system_features") as mock_features, - patch("services.app_service.BillingService") as mock_billing, - patch("services.app_service.dify_config") as mock_config, - ): - manager = mock_model_manager.return_value - manager.get_default_model_instance.return_value = model_instance - mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) - mock_config.BILLING_ENABLED = True - - # Act - result = service.create_app("tenant-1", default_args, account) - - # Assert - assert result is app_instance - assert app_instance.app_model_config_id == "cfg-1" - mock_db.session.add.assert_any_call(app_instance) - mock_db.session.add.assert_any_call(app_model_config) - assert mock_db.session.flush.call_count == 2 - mock_db.session.commit.assert_called_once() - mock_event.send.assert_called_once_with(app_instance, account=account) - mock_billing.clean_billing_info_cache.assert_called_once_with("tenant-1") - - def test_create_app_should_raise_when_model_schema_missing( - self, - service: AppService, - account: Account, - default_args: dict, - app_template: dict, - ) -> None: - """Test create_app raises ValueError when non-matching model has no schema.""" - # Arrange - app_instance = SimpleNamespace(id="app-1") - model_instance = SimpleNamespace( - model_name="model-b", - provider="provider-b", - model_type_instance=MagicMock(), - credentials={"k": "v"}, - ) - model_instance.model_type_instance.get_model_schema.return_value = None - - with ( - patch("services.app_service.default_app_templates", app_template), - patch("services.app_service.App", return_value=app_instance), - patch("services.app_service.ModelManager") as mock_model_manager, - patch("services.app_service.db") as mock_db, - ): - manager = mock_model_manager.return_value - manager.get_default_model_instance.return_value = model_instance - - # Act & Assert - with pytest.raises(ValueError, match="model schema not found"): - service.create_app("tenant-1", default_args, account) - mock_db.session.commit.assert_not_called() - - def test_create_app_should_fallback_to_default_provider_when_model_missing( - self, - service: AppService, - account: Account, - default_args: dict, - app_template: dict, - ) -> None: - """Test create_app falls back to provider/model name when no default model instance is available.""" - # Arrange - app_instance = SimpleNamespace(id="app-1", tenant_id="tenant-1") - app_model_config = SimpleNamespace(id="cfg-1") - - with ( - patch("services.app_service.default_app_templates", app_template), - patch("services.app_service.App", return_value=app_instance), - patch("services.app_service.AppModelConfig", return_value=app_model_config), - patch("services.app_service.ModelManager") as mock_model_manager, - patch("services.app_service.db") as mock_db, - patch("services.app_service.app_was_created") as mock_event, - patch("services.app_service.FeatureService.get_system_features") as mock_features, - patch("services.app_service.EnterpriseService") as mock_enterprise, - patch("services.app_service.dify_config") as mock_config, - ): - manager = mock_model_manager.return_value - manager.get_default_model_instance.side_effect = ProviderTokenNotInitError("not ready") - manager.get_default_provider_model_name.return_value = ("fallback-provider", "fallback-model") - mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True)) - mock_config.BILLING_ENABLED = False - - # Act - result = service.create_app("tenant-1", default_args, account) - - # Assert - assert result is app_instance - mock_event.send.assert_called_once_with(app_instance, account=account) - mock_db.session.commit.assert_called_once() - mock_enterprise.WebAppAuth.update_app_access_mode.assert_called_once_with("app-1", "private") - - def test_create_app_should_log_and_fallback_on_unexpected_model_error( - self, - service: AppService, - account: Account, - default_args: dict, - app_template: dict, - ) -> None: - """Test unexpected model manager errors are logged and fallback provider is used.""" - # Arrange - app_instance = SimpleNamespace(id="app-1", tenant_id="tenant-1") - app_model_config = SimpleNamespace(id="cfg-1") - - with ( - patch("services.app_service.default_app_templates", app_template), - patch("services.app_service.App", return_value=app_instance), - patch("services.app_service.AppModelConfig", return_value=app_model_config), - patch("services.app_service.ModelManager") as mock_model_manager, - patch("services.app_service.db"), - patch("services.app_service.app_was_created"), - patch( - "services.app_service.FeatureService.get_system_features", - return_value=SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)), - ), - patch("services.app_service.dify_config", new=SimpleNamespace(BILLING_ENABLED=False)), - patch("services.app_service.logger") as mock_logger, - ): - manager = mock_model_manager.return_value - manager.get_default_model_instance.side_effect = RuntimeError("boom") - manager.get_default_provider_model_name.return_value = ("fallback-provider", "fallback-model") - - # Act - result = service.create_app("tenant-1", default_args, account) - - # Assert - assert result is app_instance - mock_logger.exception.assert_called_once() - - -class TestAppServiceGetAndUpdate: - """Test suite for app retrieval and update methods.""" - - def test_get_app_should_return_original_when_not_agent_app(self, service: AppService) -> None: - """Test get_app returns original app for non-agent modes.""" - # Arrange - app = MagicMock() - app.mode = AppMode.CHAT - app.is_agent = False - - with patch("services.app_service.current_user", _make_current_user()): - # Act - result = service.get_app(app) - - # Assert - assert result is app - - def test_get_app_should_return_original_when_model_config_missing(self, service: AppService) -> None: - """Test get_app returns app when agent mode has no model config.""" - # Arrange - app = MagicMock() - app.id = "app-1" - app.mode = AppMode.AGENT_CHAT - app.is_agent = False - app.app_model_config = None - - with patch("services.app_service.current_user", _make_current_user()): - # Act - result = service.get_app(app) - - # Assert - assert result is app - - def test_get_app_should_mask_tool_parameters_for_agent_tools(self, service: AppService) -> None: - """Test get_app decrypts and masks secret tool parameters.""" - # Arrange - tool = { - "provider_type": "builtin", - "provider_id": "provider-1", - "tool_name": "tool-a", - "tool_parameters": {"secret": "encrypted"}, - "extra": True, - } - model_config = MagicMock() - model_config.agent_mode_dict = {"tools": [tool, {"skip": True}]} - - app = MagicMock() - app.id = "app-1" - app.mode = AppMode.AGENT_CHAT - app.is_agent = False - app.app_model_config = model_config - - manager = MagicMock() - manager.decrypt_tool_parameters.return_value = {"secret": "decrypted"} - manager.mask_tool_parameters.return_value = {"secret": "***"} - - with ( - patch("services.app_service.current_user", _make_current_user()), - patch("services.app_service.ToolManager.get_agent_tool_runtime", return_value=MagicMock()), - patch("services.app_service.ToolParameterConfigurationManager", return_value=manager), - ): - # Act - result = service.get_app(app) - - # Assert - assert result.app_model_config is model_config - assert tool["tool_parameters"] == {"secret": "***"} - assert json.loads(model_config.agent_mode)["tools"][0]["tool_parameters"] == {"secret": "***"} - - def test_get_app_should_continue_when_tool_parameter_masking_fails(self, service: AppService) -> None: - """Test get_app logs and continues when masking fails.""" - # Arrange - tool = { - "provider_type": "builtin", - "provider_id": "provider-1", - "tool_name": "tool-a", - "tool_parameters": {"secret": "encrypted"}, - "extra": True, - } - model_config = MagicMock() - model_config.agent_mode_dict = {"tools": [tool]} - - app = MagicMock() - app.id = "app-1" - app.mode = AppMode.AGENT_CHAT - app.is_agent = False - app.app_model_config = model_config - - with ( - patch("services.app_service.current_user", _make_current_user()), - patch("services.app_service.ToolManager.get_agent_tool_runtime", side_effect=RuntimeError("mask-failed")), - patch("services.app_service.logger") as mock_logger, - ): - # Act - result = service.get_app(app) - - # Assert - assert result.app_model_config is model_config - mock_logger.exception.assert_called_once() - - def test_update_methods_should_mutate_app_and_commit(self, service: AppService) -> None: - """Test update methods set fields and commit changes.""" - # Arrange - app = cast( - App, - SimpleNamespace( - name="old", - description="old", - icon_type="emoji", - icon="a", - icon_background="#111", - enable_site=True, - enable_api=True, - ), - ) - args = { - "name": "new", - "description": "new-desc", - "icon_type": "image", - "icon": "new-icon", - "icon_background": "#222", - "use_icon_as_answer_icon": True, - "max_active_requests": 5, - } - user = SimpleNamespace(id="user-1") - - with ( - patch("services.app_service.current_user", user), - patch("services.app_service.db") as mock_db, - patch("services.app_service.naive_utc_now", return_value="now"), - ): - # Act - updated = service.update_app(app, args) - renamed = service.update_app_name(app, "rename") - iconed = service.update_app_icon(app, "icon-2", "#333") - site_same = service.update_app_site_status(app, app.enable_site) - api_same = service.update_app_api_status(app, app.enable_api) - site_changed = service.update_app_site_status(app, False) - api_changed = service.update_app_api_status(app, False) - - # Assert - assert updated is app - assert updated.icon_type == IconType.IMAGE - assert renamed is app - assert iconed is app - assert site_same is app - assert api_same is app - assert site_changed is app - assert api_changed is app - assert mock_db.session.commit.call_count >= 5 - - def test_update_app_should_preserve_icon_type_when_not_provided(self, service: AppService) -> None: - """Test update_app keeps the existing icon_type when the payload omits it.""" - # Arrange - app = cast( - App, - SimpleNamespace( - name="old", - description="old", - icon_type=IconType.EMOJI, - icon="a", - icon_background="#111", - use_icon_as_answer_icon=False, - max_active_requests=1, - ), - ) - args = { - "name": "new", - "description": "new-desc", - "icon_type": None, - "icon": "new-icon", - "icon_background": "#222", - "use_icon_as_answer_icon": True, - "max_active_requests": 5, - } - user = SimpleNamespace(id="user-1") - - with ( - patch("services.app_service.current_user", user), - patch("services.app_service.db") as mock_db, - patch("services.app_service.naive_utc_now", return_value="now"), - ): - # Act - updated = service.update_app(app, args) - - # Assert - assert updated is app - assert updated.icon_type == IconType.EMOJI - mock_db.session.commit.assert_called_once() - - def test_update_app_should_reject_empty_icon_type(self, service: AppService) -> None: - """Test update_app rejects an explicit empty icon_type.""" - app = cast( - App, - SimpleNamespace( - name="old", - description="old", - icon_type=IconType.EMOJI, - icon="a", - icon_background="#111", - use_icon_as_answer_icon=False, - max_active_requests=1, - ), - ) - args = { - "name": "new", - "description": "new-desc", - "icon_type": "", - "icon": "new-icon", - "icon_background": "#222", - "use_icon_as_answer_icon": True, - "max_active_requests": 5, - } - user = SimpleNamespace(id="user-1") - - with ( - patch("services.app_service.current_user", user), - patch("services.app_service.db") as mock_db, - ): - with pytest.raises(ValueError): - service.update_app(app, args) - - mock_db.session.commit.assert_not_called() - - -class TestAppServiceDeleteAndMeta: - """Test suite for delete and metadata methods.""" - - def test_delete_app_should_cleanup_and_enqueue_task(self, service: AppService) -> None: - """Test delete_app removes app, runs cleanup, and triggers async deletion task.""" - # Arrange - app = cast(App, SimpleNamespace(id="app-1", tenant_id="tenant-1")) - - with ( - patch("services.app_service.db") as mock_db, - patch( - "services.app_service.FeatureService.get_system_features", - return_value=SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True)), - ), - patch("services.app_service.EnterpriseService") as mock_enterprise, - patch( - "services.app_service.dify_config", - new=SimpleNamespace(BILLING_ENABLED=True, CONSOLE_API_URL="https://console.example"), - ), - patch("services.app_service.BillingService") as mock_billing, - patch("services.app_service.remove_app_and_related_data_task") as mock_task, - ): - # Act - service.delete_app(app) - - # Assert - mock_db.session.delete.assert_called_once_with(app) - mock_db.session.commit.assert_called_once() - mock_enterprise.WebAppAuth.cleanup_webapp.assert_called_once_with("app-1") - mock_billing.clean_billing_info_cache.assert_called_once_with("tenant-1") - mock_task.delay.assert_called_once_with(tenant_id="tenant-1", app_id="app-1") - - def test_get_app_meta_should_handle_workflow_and_tool_provider_icons(self, service: AppService) -> None: - """Test get_app_meta extracts builtin and API tool icons from workflow graph.""" - # Arrange - workflow = SimpleNamespace( - graph_dict={ - "nodes": [ - { - "data": { - "type": "tool", - "provider_type": "builtin", - "provider_id": "builtin-provider", - "tool_name": "tool_builtin", - } - }, - { - "data": { - "type": "tool", - "provider_type": "api", - "provider_id": "api-provider-id", - "tool_name": "tool_api", - } - }, - ] - } - ) - app = cast( - App, - SimpleNamespace( - mode=AppMode.WORKFLOW.value, - workflow=workflow, - app_model_config=None, - tenant_id="tenant-1", - icon_type="emoji", - icon_background="#fff", - ), - ) - - provider = SimpleNamespace(icon=json.dumps({"background": "#000", "content": "A"})) - - with ( - patch("services.app_service.dify_config", new=SimpleNamespace(CONSOLE_API_URL="https://console.example")), - patch("services.app_service.db") as mock_db, - ): - query = MagicMock() - query.where.return_value = query - query.first.return_value = provider - mock_db.session.query.return_value = query - - # Act - meta = service.get_app_meta(app) - - # Assert - assert meta["tool_icons"]["tool_builtin"].endswith("/builtin-provider/icon") - assert meta["tool_icons"]["tool_api"] == {"background": "#000", "content": "A"} - - def test_get_app_meta_should_use_default_api_icon_on_lookup_error(self, service: AppService) -> None: - """Test get_app_meta falls back to default icon when API provider lookup fails.""" - # Arrange - app_model_config = SimpleNamespace( - agent_mode_dict={ - "tools": [{"provider_type": "api", "provider_id": "x", "tool_name": "t", "tool_parameters": {}}] - } - ) - app = cast(App, SimpleNamespace(mode=AppMode.CHAT.value, app_model_config=app_model_config, workflow=None)) - - with ( - patch("services.app_service.dify_config", new=SimpleNamespace(CONSOLE_API_URL="https://console.example")), - patch("services.app_service.db") as mock_db, - ): - query = MagicMock() - query.where.return_value = query - query.first.return_value = None - mock_db.session.query.return_value = query - - # Act - meta = service.get_app_meta(app) - - # Assert - assert meta["tool_icons"]["t"] == {"background": "#252525", "content": "\ud83d\ude01"} - - def test_get_app_meta_should_return_empty_when_required_data_missing(self, service: AppService) -> None: - """Test get_app_meta returns empty metadata when workflow/model config is absent.""" - # Arrange - workflow_app = cast(App, SimpleNamespace(mode=AppMode.WORKFLOW.value, workflow=None)) - chat_app = cast(App, SimpleNamespace(mode=AppMode.CHAT.value, app_model_config=None)) - - # Act - workflow_meta = service.get_app_meta(workflow_app) - chat_meta = service.get_app_meta(chat_app) - - # Assert - assert workflow_meta == {"tool_icons": {}} - assert chat_meta == {"tool_icons": {}} - - -class TestAppServiceCodeLookup: - """Test suite for app code lookup methods.""" - - def test_get_app_code_by_id_should_raise_when_site_missing(self) -> None: - """Test get_app_code_by_id raises when site is missing.""" - # Arrange - with patch("services.app_service.db") as mock_db: - query = MagicMock() - query.where.return_value = query - query.first.return_value = None - mock_db.session.query.return_value = query - - # Act & Assert - with pytest.raises(ValueError, match="not found"): - AppService.get_app_code_by_id("app-1") - - def test_get_app_code_by_id_should_return_code(self) -> None: - """Test get_app_code_by_id returns site code.""" - # Arrange - site = SimpleNamespace(code="code-1") - with patch("services.app_service.db") as mock_db: - query = MagicMock() - query.where.return_value = query - query.first.return_value = site - mock_db.session.query.return_value = query - - # Act - result = AppService.get_app_code_by_id("app-1") - - # Assert - assert result == "code-1" - - def test_get_app_id_by_code_should_raise_when_site_missing(self) -> None: - """Test get_app_id_by_code raises when code does not exist.""" - # Arrange - with patch("services.app_service.db") as mock_db: - query = MagicMock() - query.where.return_value = query - query.first.return_value = None - mock_db.session.query.return_value = query - - # Act & Assert - with pytest.raises(ValueError, match="not found"): - AppService.get_app_id_by_code("missing") - - def test_get_app_id_by_code_should_return_app_id(self) -> None: - """Test get_app_id_by_code returns linked app id.""" - # Arrange - site = SimpleNamespace(app_id="app-1") - with patch("services.app_service.db") as mock_db: - query = MagicMock() - query.where.return_value = query - query.first.return_value = site - mock_db.session.query.return_value = query - - # Act - result = AppService.get_app_id_by_code("code-1") - - # Assert - assert result == "app-1" diff --git a/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py b/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py index d2287e8982..9a513c3fe6 100644 --- a/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py +++ b/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py @@ -4,7 +4,7 @@ from unittest.mock import Mock, create_autospec import pytest from redis.exceptions import LockNotOwnedError -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models.account import Account from models.dataset import Dataset, Document from services.dataset_service import DocumentService, SegmentService @@ -71,7 +71,7 @@ def test_save_document_with_dataset_id_ignores_lock_not_owned( dataset.id = "ds-1" dataset.tenant_id = fake_current_user.current_tenant_id dataset.data_source_type = "upload_file" - dataset.indexing_technique = "high_quality" # so we skip re-initialization branch + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY # so we skip re-initialization branch # Minimal knowledge_config stub that satisfies pre-lock code info_list = types.SimpleNamespace(data_source_type="upload_file") @@ -80,7 +80,7 @@ def test_save_document_with_dataset_id_ignores_lock_not_owned( doc_form=IndexStructureType.QA_INDEX, original_document_id=None, # go into "new document" branch data_source=data_source, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model=None, embedding_model_provider=None, retrieval_model=None, @@ -126,7 +126,7 @@ def test_add_segment_ignores_lock_not_owned( dataset = create_autospec(Dataset, instance=True) dataset.id = "ds-1" dataset.tenant_id = fake_current_user.current_tenant_id - dataset.indexing_technique = "economy" # skip embedding/token calculation branch + dataset.indexing_technique = IndexTechniqueType.ECONOMY # skip embedding/token calculation branch document = create_autospec(Document, instance=True) document.id = "doc-1" @@ -169,7 +169,7 @@ def test_multi_create_segment_ignores_lock_not_owned( dataset = create_autospec(Dataset, instance=True) dataset.id = "ds-1" dataset.tenant_id = fake_current_user.current_tenant_id - dataset.indexing_technique = "economy" # again, skip high_quality path + dataset.indexing_technique = IndexTechniqueType.ECONOMY # again, skip high_quality path document = create_autospec(Document, instance=True) document.id = "doc-1" diff --git a/api/tests/unit_tests/services/test_summary_index_service.py b/api/tests/unit_tests/services/test_summary_index_service.py index c4285c73a0..ef53df9350 100644 --- a/api/tests/unit_tests/services/test_summary_index_service.py +++ b/api/tests/unit_tests/services/test_summary_index_service.py @@ -11,7 +11,7 @@ from unittest.mock import MagicMock import pytest import services.summary_index_service as summary_module -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models.enums import SegmentStatus, SummaryStatus from services.summary_index_service import SummaryIndexService @@ -27,7 +27,7 @@ class _SessionContext: return None -def _dataset(*, indexing_technique: str = "high_quality") -> MagicMock: +def _dataset(*, indexing_technique: str = IndexTechniqueType.HIGH_QUALITY) -> MagicMock: dataset = MagicMock(name="dataset") dataset.id = "dataset-1" dataset.tenant_id = "tenant-1" @@ -169,7 +169,8 @@ def test_create_summary_record_creates_new(monkeypatch: pytest.MonkeyPatch) -> N def test_vectorize_summary_skips_non_high_quality(monkeypatch: pytest.MonkeyPatch) -> None: vector_cls = MagicMock() monkeypatch.setattr(summary_module, "Vector", vector_cls) - SummaryIndexService.vectorize_summary(_summary_record(), _segment(), _dataset(indexing_technique="economy")) + dataset = _dataset(indexing_technique=IndexTechniqueType.ECONOMY) + SummaryIndexService.vectorize_summary(_summary_record(), _segment(), dataset) vector_cls.assert_not_called() @@ -621,7 +622,7 @@ def test_generate_and_vectorize_summary_creates_missing_record_and_logs_usage(mo def test_generate_summaries_for_document_skip_conditions(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _dataset(indexing_technique="economy") + dataset = _dataset(indexing_technique=IndexTechniqueType.ECONOMY) document = MagicMock(spec=summary_module.DatasetDocument) document.id = "doc-1" document.doc_form = IndexStructureType.PARAGRAPH_INDEX @@ -778,7 +779,7 @@ def test_disable_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.Mo def test_enable_summaries_for_segments_skips_non_high_quality() -> None: - SummaryIndexService.enable_summaries_for_segments(_dataset(indexing_technique="economy")) + SummaryIndexService.enable_summaries_for_segments(_dataset(indexing_technique=IndexTechniqueType.ECONOMY)) def test_enable_summaries_for_segments_revectorizes_and_enables(monkeypatch: pytest.MonkeyPatch) -> None: @@ -932,9 +933,8 @@ def test_delete_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.Mon def test_update_summary_for_segment_skip_conditions() -> None: - assert ( - SummaryIndexService.update_summary_for_segment(_segment(), _dataset(indexing_technique="economy"), "x") is None - ) + economy_dataset = _dataset(indexing_technique=IndexTechniqueType.ECONOMY) + assert SummaryIndexService.update_summary_for_segment(_segment(), economy_dataset, "x") is None seg = _segment(has_document=True) seg.document.doc_form = IndexStructureType.QA_INDEX assert SummaryIndexService.update_summary_for_segment(seg, _dataset(), "x") is None diff --git a/api/tests/unit_tests/services/test_vector_service.py b/api/tests/unit_tests/services/test_vector_service.py index d3a98dd4bb..16d3011810 100644 --- a/api/tests/unit_tests/services/test_vector_service.py +++ b/api/tests/unit_tests/services/test_vector_service.py @@ -9,7 +9,7 @@ from unittest.mock import MagicMock import pytest import services.vector_service as vector_service_module -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from services.vector_service import VectorService @@ -32,7 +32,7 @@ class _ParentDocStub: def _make_dataset( *, - indexing_technique: str = "high_quality", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, doc_form: str = IndexStructureType.PARAGRAPH_INDEX, tenant_id: str = "tenant-1", dataset_id: str = "dataset-1", @@ -192,7 +192,7 @@ def test_create_segments_vector_parent_child_calls_generate_child_chunks_with_ex dataset = _make_dataset( doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX, embedding_model_provider="openai", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) segment = _make_segment() @@ -241,7 +241,7 @@ def test_create_segments_vector_parent_child_uses_default_embedding_model_when_p dataset = _make_dataset( doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX, embedding_model_provider=None, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) segment = _make_segment() @@ -329,7 +329,7 @@ def test_create_segments_vector_parent_child_missing_processing_rule_raises(monk def test_create_segments_vector_parent_child_non_high_quality_raises(monkeypatch: pytest.MonkeyPatch) -> None: dataset = _make_dataset( doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX, - indexing_technique="economy", + indexing_technique=IndexTechniqueType.ECONOMY, ) segment = _make_segment() dataset_document = MagicMock() @@ -348,7 +348,7 @@ def test_create_segments_vector_parent_child_non_high_quality_raises(monkeypatch def test_update_segment_vector_high_quality_uses_vector(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="high_quality") + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY) segment = _make_segment() vector_instance = MagicMock() @@ -364,7 +364,7 @@ def test_update_segment_vector_high_quality_uses_vector(monkeypatch: pytest.Monk def test_update_segment_vector_economy_uses_keyword_with_keywords_list(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="economy") + dataset = _make_dataset(indexing_technique=IndexTechniqueType.ECONOMY) segment = _make_segment() keyword_instance = MagicMock() @@ -380,7 +380,7 @@ def test_update_segment_vector_economy_uses_keyword_with_keywords_list(monkeypat def test_update_segment_vector_economy_uses_keyword_without_keywords_list(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="economy") + dataset = _make_dataset(indexing_technique=IndexTechniqueType.ECONOMY) segment = _make_segment() keyword_instance = MagicMock() @@ -473,7 +473,7 @@ def test_generate_child_chunks_commits_even_when_no_children(monkeypatch: pytest def test_create_child_chunk_vector_high_quality_adds_texts(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="high_quality") + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY) child_chunk = MagicMock() child_chunk.content = "child" child_chunk.index_node_id = "id" @@ -489,7 +489,7 @@ def test_create_child_chunk_vector_high_quality_adds_texts(monkeypatch: pytest.M def test_create_child_chunk_vector_economy_noop(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="economy") + dataset = _make_dataset(indexing_technique=IndexTechniqueType.ECONOMY) vector_cls = MagicMock() monkeypatch.setattr(vector_service_module, "Vector", vector_cls) @@ -505,7 +505,7 @@ def test_create_child_chunk_vector_economy_noop(monkeypatch: pytest.MonkeyPatch) def test_update_child_chunk_vector_high_quality_updates_vector(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="high_quality") + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY) new_chunk = MagicMock() new_chunk.content = "n" @@ -536,7 +536,7 @@ def test_update_child_chunk_vector_high_quality_updates_vector(monkeypatch: pyte def test_update_child_chunk_vector_economy_noop(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="economy") + dataset = _make_dataset(indexing_technique=IndexTechniqueType.ECONOMY) vector_cls = MagicMock() monkeypatch.setattr(vector_service_module, "Vector", vector_cls) VectorService.update_child_chunk_vector([], [], [], dataset) @@ -561,7 +561,7 @@ def test_delete_child_chunk_vector_deletes_by_id(monkeypatch: pytest.MonkeyPatch def test_update_multimodel_vector_returns_when_not_high_quality(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="economy", is_multimodal=True) + dataset = _make_dataset(indexing_technique=IndexTechniqueType.ECONOMY, is_multimodal=True) segment = _make_segment(tenant_id="t", attachments=[{"id": "a"}]) vector_cls = MagicMock() @@ -575,7 +575,7 @@ def test_update_multimodel_vector_returns_when_not_high_quality(monkeypatch: pyt def test_update_multimodel_vector_returns_when_no_actual_change(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True) segment = _make_segment(tenant_id="t", attachments=[{"id": "a"}, {"id": "b"}]) vector_cls = MagicMock() @@ -591,7 +591,7 @@ def test_update_multimodel_vector_returns_when_no_actual_change(monkeypatch: pyt def test_update_multimodel_vector_deletes_bindings_and_commits_on_empty_new_ids( monkeypatch: pytest.MonkeyPatch, ) -> None: - dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True) segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}, {"id": "old-2"}]) vector_instance = MagicMock(name="vector_instance") @@ -612,7 +612,7 @@ def test_update_multimodel_vector_deletes_bindings_and_commits_on_empty_new_ids( def test_update_multimodel_vector_commits_when_no_upload_files_found(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True) segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}]) vector_instance = MagicMock() @@ -630,7 +630,7 @@ def test_update_multimodel_vector_commits_when_no_upload_files_found(monkeypatch def test_update_multimodel_vector_adds_bindings_and_vectors_and_skips_missing_upload_files( monkeypatch: pytest.MonkeyPatch, ) -> None: - dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True) segment = _make_segment(segment_id="seg-1", tenant_id="tenant-1", attachments=[{"id": "old-1"}]) vector_instance = MagicMock() @@ -663,7 +663,7 @@ def test_update_multimodel_vector_adds_bindings_and_vectors_and_skips_missing_up def test_update_multimodel_vector_updates_bindings_without_multimodal_vector_ops( monkeypatch: pytest.MonkeyPatch, ) -> None: - dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=False) + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=False) segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}]) vector_instance = MagicMock() @@ -683,7 +683,7 @@ def test_update_multimodel_vector_updates_bindings_without_multimodal_vector_ops def test_update_multimodel_vector_rolls_back_and_reraises_on_error(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True) segment = _make_segment(segment_id="seg-1", tenant_id="tenant-1", attachments=[{"id": "old-1"}]) vector_instance = MagicMock() diff --git a/api/tests/unit_tests/services/test_webapp_auth_service.py b/api/tests/unit_tests/services/test_webapp_auth_service.py deleted file mode 100644 index 262c1f1524..0000000000 --- a/api/tests/unit_tests/services/test_webapp_auth_service.py +++ /dev/null @@ -1,379 +0,0 @@ -from __future__ import annotations - -from datetime import UTC, datetime -from types import SimpleNamespace -from typing import Any, cast -from unittest.mock import MagicMock - -import pytest -from pytest_mock import MockerFixture -from werkzeug.exceptions import NotFound, Unauthorized - -from models import Account, AccountStatus -from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError -from services.webapp_auth_service import WebAppAuthService, WebAppAuthType - -ACCOUNT_LOOKUP_PATH = "services.webapp_auth_service.AccountService.get_account_by_email_with_case_fallback" -TOKEN_GENERATE_PATH = "services.webapp_auth_service.TokenManager.generate_token" -TOKEN_GET_DATA_PATH = "services.webapp_auth_service.TokenManager.get_token_data" - - -def _account(**kwargs: Any) -> Account: - return cast(Account, SimpleNamespace(**kwargs)) - - -@pytest.fixture -def mock_db(mocker: MockerFixture) -> MagicMock: - # Arrange - mocked_db = mocker.patch("services.webapp_auth_service.db") - mocked_db.session = MagicMock() - return mocked_db - - -def test_authenticate_should_raise_account_not_found_when_email_does_not_exist(mocker: MockerFixture) -> None: - # Arrange - mocker.patch(ACCOUNT_LOOKUP_PATH, return_value=None) - - # Act + Assert - with pytest.raises(AccountNotFoundError): - WebAppAuthService.authenticate("user@example.com", "pwd") - - -def test_authenticate_should_raise_account_login_error_when_account_is_banned(mocker: MockerFixture) -> None: - # Arrange - account = SimpleNamespace(status=AccountStatus.BANNED, password="hash", password_salt="salt") - mocker.patch( - ACCOUNT_LOOKUP_PATH, - return_value=account, - ) - - # Act + Assert - with pytest.raises(AccountLoginError, match="Account is banned"): - WebAppAuthService.authenticate("user@example.com", "pwd") - - -@pytest.mark.parametrize("password_value", [None, "hash"]) -def test_authenticate_should_raise_password_error_when_password_is_invalid( - password_value: str | None, - mocker: MockerFixture, -) -> None: - # Arrange - account = SimpleNamespace(status=AccountStatus.ACTIVE, password=password_value, password_salt="salt") - mocker.patch( - ACCOUNT_LOOKUP_PATH, - return_value=account, - ) - mocker.patch("services.webapp_auth_service.compare_password", return_value=False) - - # Act + Assert - with pytest.raises(AccountPasswordError, match="Invalid email or password"): - WebAppAuthService.authenticate("user@example.com", "pwd") - - -def test_authenticate_should_return_account_when_credentials_are_valid(mocker: MockerFixture) -> None: - # Arrange - account = SimpleNamespace(status=AccountStatus.ACTIVE, password="hash", password_salt="salt") - mocker.patch( - ACCOUNT_LOOKUP_PATH, - return_value=account, - ) - mocker.patch("services.webapp_auth_service.compare_password", return_value=True) - - # Act - result = WebAppAuthService.authenticate("user@example.com", "pwd") - - # Assert - assert result is account - - -def test_login_should_return_token_from_internal_token_builder(mocker: MockerFixture) -> None: - # Arrange - account = _account(id="a1", email="u@example.com") - mock_get_token = mocker.patch.object(WebAppAuthService, "_get_account_jwt_token", return_value="jwt-token") - - # Act - result = WebAppAuthService.login(account) - - # Assert - assert result == "jwt-token" - mock_get_token.assert_called_once_with(account=account) - - -def test_get_user_through_email_should_return_none_when_account_not_found(mocker: MockerFixture) -> None: - # Arrange - mocker.patch(ACCOUNT_LOOKUP_PATH, return_value=None) - - # Act - result = WebAppAuthService.get_user_through_email("missing@example.com") - - # Assert - assert result is None - - -def test_get_user_through_email_should_raise_unauthorized_when_account_banned(mocker: MockerFixture) -> None: - # Arrange - account = SimpleNamespace(status=AccountStatus.BANNED) - mocker.patch( - ACCOUNT_LOOKUP_PATH, - return_value=account, - ) - - # Act + Assert - with pytest.raises(Unauthorized, match="Account is banned"): - WebAppAuthService.get_user_through_email("user@example.com") - - -def test_get_user_through_email_should_return_account_when_active(mocker: MockerFixture) -> None: - # Arrange - account = SimpleNamespace(status=AccountStatus.ACTIVE) - mocker.patch( - ACCOUNT_LOOKUP_PATH, - return_value=account, - ) - - # Act - result = WebAppAuthService.get_user_through_email("user@example.com") - - # Assert - assert result is account - - -def test_send_email_code_login_email_should_raise_error_when_email_not_provided() -> None: - # Arrange - # Act + Assert - with pytest.raises(ValueError, match="Email must be provided"): - WebAppAuthService.send_email_code_login_email(account=None, email=None) - - -def test_send_email_code_login_email_should_generate_token_and_send_mail_for_account( - mocker: MockerFixture, -) -> None: - # Arrange - account = _account(email="user@example.com") - mocker.patch("services.webapp_auth_service.secrets.randbelow", side_effect=[1, 2, 3, 4, 5, 6]) - mock_generate_token = mocker.patch(TOKEN_GENERATE_PATH, return_value="token-1") - mock_delay = mocker.patch("services.webapp_auth_service.send_email_code_login_mail_task.delay") - - # Act - result = WebAppAuthService.send_email_code_login_email(account=account, language="en-US") - - # Assert - assert result == "token-1" - mock_generate_token.assert_called_once() - assert mock_generate_token.call_args.kwargs["additional_data"] == {"code": "123456"} - mock_delay.assert_called_once_with(language="en-US", to="user@example.com", code="123456") - - -def test_send_email_code_login_email_should_send_mail_for_email_without_account( - mocker: MockerFixture, -) -> None: - # Arrange - mocker.patch("services.webapp_auth_service.secrets.randbelow", side_effect=[0, 0, 0, 0, 0, 0]) - mocker.patch(TOKEN_GENERATE_PATH, return_value="token-2") - mock_delay = mocker.patch("services.webapp_auth_service.send_email_code_login_mail_task.delay") - - # Act - result = WebAppAuthService.send_email_code_login_email(account=None, email="alt@example.com", language="zh-Hans") - - # Assert - assert result == "token-2" - mock_delay.assert_called_once_with(language="zh-Hans", to="alt@example.com", code="000000") - - -def test_get_email_code_login_data_should_delegate_to_token_manager(mocker: MockerFixture) -> None: - # Arrange - mock_get_data = mocker.patch(TOKEN_GET_DATA_PATH, return_value={"code": "123"}) - - # Act - result = WebAppAuthService.get_email_code_login_data("token-abc") - - # Assert - assert result == {"code": "123"} - mock_get_data.assert_called_once_with("token-abc", "email_code_login") - - -def test_revoke_email_code_login_token_should_delegate_to_token_manager(mocker: MockerFixture) -> None: - # Arrange - mock_revoke = mocker.patch("services.webapp_auth_service.TokenManager.revoke_token") - - # Act - WebAppAuthService.revoke_email_code_login_token("token-xyz") - - # Assert - mock_revoke.assert_called_once_with("token-xyz", "email_code_login") - - -def test_create_end_user_should_raise_not_found_when_site_does_not_exist(mock_db: MagicMock) -> None: - # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = None - - # Act + Assert - with pytest.raises(NotFound, match="Site not found"): - WebAppAuthService.create_end_user("app-code", "user@example.com") - - -def test_create_end_user_should_raise_not_found_when_app_does_not_exist(mock_db: MagicMock) -> None: - # Arrange - site = SimpleNamespace(app_id="app-1") - app_query = MagicMock() - app_query.where.return_value.first.return_value = None - mock_db.session.query.return_value.where.return_value.first.side_effect = [site, None] - - # Act + Assert - with pytest.raises(NotFound, match="App not found"): - WebAppAuthService.create_end_user("app-code", "user@example.com") - - -def test_create_end_user_should_create_and_commit_end_user_when_data_is_valid(mock_db: MagicMock) -> None: - # Arrange - site = SimpleNamespace(app_id="app-1") - app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1") - mock_db.session.query.return_value.where.return_value.first.side_effect = [site, app_model] - - # Act - result = WebAppAuthService.create_end_user("app-code", "user@example.com") - - # Assert - assert result.tenant_id == "tenant-1" - assert result.app_id == "app-1" - assert result.session_id == "user@example.com" - mock_db.session.add.assert_called_once() - mock_db.session.commit.assert_called_once() - - -def test_get_account_jwt_token_should_build_payload_and_issue_token(mocker: MockerFixture) -> None: - # Arrange - account = _account(id="a1", email="user@example.com") - mocker.patch("services.webapp_auth_service.dify_config.ACCESS_TOKEN_EXPIRE_MINUTES", 60) - mock_issue = mocker.patch("services.webapp_auth_service.PassportService.issue", return_value="jwt-1") - - # Act - token = WebAppAuthService._get_account_jwt_token(account) - - # Assert - assert token == "jwt-1" - payload = mock_issue.call_args.args[0] - assert payload["user_id"] == "a1" - assert payload["session_id"] == "user@example.com" - assert payload["token_source"] == "webapp_login_token" - assert payload["auth_type"] == "internal" - assert payload["exp"] > int(datetime.now(UTC).timestamp()) - - -@pytest.mark.parametrize( - ("access_mode", "expected"), - [ - ("private", True), - ("private_all", True), - ("public", False), - ], -) -def test_is_app_require_permission_check_should_use_access_mode_when_provided( - access_mode: str, - expected: bool, -) -> None: - # Arrange - # Act - result = WebAppAuthService.is_app_require_permission_check(access_mode=access_mode) - - # Assert - assert result is expected - - -def test_is_app_require_permission_check_should_raise_when_no_identifier_provided() -> None: - # Arrange - # Act + Assert - with pytest.raises(ValueError, match="Either app_code or app_id must be provided"): - WebAppAuthService.is_app_require_permission_check() - - -def test_is_app_require_permission_check_should_raise_when_app_id_cannot_be_determined(mocker: MockerFixture) -> None: - # Arrange - mocker.patch("services.webapp_auth_service.AppService.get_app_id_by_code", return_value=None) - - # Act + Assert - with pytest.raises(ValueError, match="App ID could not be determined"): - WebAppAuthService.is_app_require_permission_check(app_code="app-code") - - -def test_is_app_require_permission_check_should_return_true_when_enterprise_mode_requires_it( - mocker: MockerFixture, -) -> None: - # Arrange - mocker.patch("services.webapp_auth_service.AppService.get_app_id_by_code", return_value="app-1") - mocker.patch( - "services.webapp_auth_service.EnterpriseService.WebAppAuth.get_app_access_mode_by_id", - return_value=SimpleNamespace(access_mode="private"), - ) - - # Act - result = WebAppAuthService.is_app_require_permission_check(app_code="app-code") - - # Assert - assert result is True - - -def test_is_app_require_permission_check_should_return_false_when_enterprise_settings_do_not_require_it( - mocker: MockerFixture, -) -> None: - # Arrange - mocker.patch( - "services.webapp_auth_service.EnterpriseService.WebAppAuth.get_app_access_mode_by_id", - return_value=SimpleNamespace(access_mode="public"), - ) - - # Act - result = WebAppAuthService.is_app_require_permission_check(app_id="app-1") - - # Assert - assert result is False - - -@pytest.mark.parametrize( - ("access_mode", "expected"), - [ - ("public", WebAppAuthType.PUBLIC), - ("private", WebAppAuthType.INTERNAL), - ("private_all", WebAppAuthType.INTERNAL), - ("sso_verified", WebAppAuthType.EXTERNAL), - ], -) -def test_get_app_auth_type_should_map_access_modes_correctly( - access_mode: str, - expected: WebAppAuthType, -) -> None: - # Arrange - # Act - result = WebAppAuthService.get_app_auth_type(access_mode=access_mode) - - # Assert - assert result == expected - - -def test_get_app_auth_type_should_resolve_from_app_code(mocker: MockerFixture) -> None: - # Arrange - mocker.patch("services.webapp_auth_service.AppService.get_app_id_by_code", return_value="app-1") - mocker.patch( - "services.webapp_auth_service.EnterpriseService.WebAppAuth.get_app_access_mode_by_id", - return_value=SimpleNamespace(access_mode="private_all"), - ) - - # Act - result = WebAppAuthService.get_app_auth_type(app_code="app-code") - - # Assert - assert result == WebAppAuthType.INTERNAL - - -def test_get_app_auth_type_should_raise_when_no_input_provided() -> None: - # Arrange - # Act + Assert - with pytest.raises(ValueError, match="Either app_code or access_mode must be provided"): - WebAppAuthService.get_app_auth_type() - - -def test_get_app_auth_type_should_raise_when_cannot_determine_type_from_invalid_mode() -> None: - # Arrange - # Act + Assert - with pytest.raises(ValueError, match="Could not determine app authentication type"): - WebAppAuthService.get_app_auth_type(access_mode="unknown") diff --git a/api/tests/unit_tests/services/test_workflow_app_service.py b/api/tests/unit_tests/services/test_workflow_app_service.py deleted file mode 100644 index fa76521f2d..0000000000 --- a/api/tests/unit_tests/services/test_workflow_app_service.py +++ /dev/null @@ -1,300 +0,0 @@ -from __future__ import annotations - -import json -import uuid -from types import SimpleNamespace -from typing import Any, cast -from unittest.mock import MagicMock - -import pytest -from pytest_mock import MockerFixture - -from dify_graph.enums import WorkflowExecutionStatus -from models import App, WorkflowAppLog -from models.enums import AppTriggerType, CreatorUserRole -from services.workflow_app_service import LogView, WorkflowAppService - - -@pytest.fixture -def service() -> WorkflowAppService: - # Arrange - return WorkflowAppService() - - -@pytest.fixture -def app_model() -> App: - # Arrange - return cast(App, SimpleNamespace(id="app-1", tenant_id="tenant-1")) - - -def _workflow_app_log(**kwargs: Any) -> WorkflowAppLog: - return cast(WorkflowAppLog, SimpleNamespace(**kwargs)) - - -def test_log_view_details_should_return_wrapped_details_and_proxy_attributes() -> None: - # Arrange - log = _workflow_app_log(id="log-1", status="succeeded") - view = LogView(log=log, details={"trigger_metadata": {"type": "plugin"}}) - - # Act - details = view.details - proxied_status = view.status - - # Assert - assert details == {"trigger_metadata": {"type": "plugin"}} - assert proxied_status == "succeeded" - - -def test_get_paginate_workflow_app_logs_should_return_paginated_summary_when_detail_false( - service: WorkflowAppService, - app_model: App, -) -> None: - # Arrange - session = MagicMock() - log_1 = SimpleNamespace(id="log-1") - log_2 = SimpleNamespace(id="log-2") - session.scalar.return_value = 3 - session.scalars.return_value.all.return_value = [log_1, log_2] - - # Act - result = service.get_paginate_workflow_app_logs( - session=session, - app_model=app_model, - page=1, - limit=2, - detail=False, - ) - - # Assert - assert result["page"] == 1 - assert result["limit"] == 2 - assert result["total"] == 3 - assert result["has_more"] is True - assert len(result["data"]) == 2 - assert isinstance(result["data"][0], LogView) - assert result["data"][0].details is None - - -def test_get_paginate_workflow_app_logs_should_return_detailed_rows_when_detail_true( - service: WorkflowAppService, - app_model: App, - mocker: MockerFixture, -) -> None: - # Arrange - session = MagicMock() - session.scalar.side_effect = [1] - log_1 = SimpleNamespace(id="log-1") - session.execute.return_value.all.return_value = [(log_1, '{"type":"trigger_plugin"}')] - mock_handle = mocker.patch.object( - service, - "handle_trigger_metadata", - return_value={"type": "trigger_plugin", "icon": "url"}, - ) - - # Act - result = service.get_paginate_workflow_app_logs( - session=session, - app_model=app_model, - keyword="run-1", - status=WorkflowExecutionStatus.SUCCEEDED, - created_at_before=None, - created_at_after=None, - page=1, - limit=20, - detail=True, - ) - - # Assert - assert result["total"] == 1 - assert len(result["data"]) == 1 - assert result["data"][0].details == {"trigger_metadata": {"type": "trigger_plugin", "icon": "url"}} - mock_handle.assert_called_once() - - -def test_get_paginate_workflow_app_logs_should_raise_when_account_filter_email_not_found( - service: WorkflowAppService, - app_model: App, -) -> None: - # Arrange - session = MagicMock() - session.scalar.return_value = None - - # Act + Assert - with pytest.raises(ValueError, match="Account not found: account@example.com"): - service.get_paginate_workflow_app_logs( - session=session, - app_model=app_model, - created_by_account="account@example.com", - ) - - -def test_get_paginate_workflow_app_logs_should_filter_by_account_when_account_exists( - service: WorkflowAppService, - app_model: App, -) -> None: - # Arrange - session = MagicMock() - session.scalar.side_effect = [SimpleNamespace(id="account-1"), 0] - session.scalars.return_value.all.return_value = [] - - # Act - result = service.get_paginate_workflow_app_logs( - session=session, - app_model=app_model, - created_by_account="account@example.com", - ) - - # Assert - assert result["total"] == 0 - assert result["data"] == [] - - -def test_get_paginate_workflow_archive_logs_should_return_paginated_archive_items( - service: WorkflowAppService, - app_model: App, -) -> None: - # Arrange - session = MagicMock() - log_account = SimpleNamespace( - id="log-1", - created_by="acc-1", - created_by_role=CreatorUserRole.ACCOUNT, - workflow_run_summary={"run": "1"}, - trigger_metadata='{"type":"trigger-webhook"}', - log_created_at="2026-01-01", - ) - log_end_user = SimpleNamespace( - id="log-2", - created_by="end-1", - created_by_role=CreatorUserRole.END_USER, - workflow_run_summary={"run": "2"}, - trigger_metadata='{"type":"trigger-webhook"}', - log_created_at="2026-01-02", - ) - log_unknown = SimpleNamespace( - id="log-3", - created_by="other", - created_by_role="system", - workflow_run_summary={"run": "3"}, - trigger_metadata='{"type":"trigger-webhook"}', - log_created_at="2026-01-03", - ) - session.scalar.return_value = 3 - session.scalars.side_effect = [ - SimpleNamespace(all=lambda: [log_account, log_end_user, log_unknown]), - SimpleNamespace(all=lambda: [SimpleNamespace(id="acc-1", email="a@example.com")]), - SimpleNamespace(all=lambda: [SimpleNamespace(id="end-1", session_id="session-1")]), - ] - - # Act - result = service.get_paginate_workflow_archive_logs( - session=session, - app_model=app_model, - page=1, - limit=20, - ) - - # Assert - assert result["total"] == 3 - assert len(result["data"]) == 3 - assert result["data"][0]["created_by_account"].id == "acc-1" - assert result["data"][1]["created_by_end_user"].id == "end-1" - assert result["data"][2]["created_by_account"] is None - assert result["data"][2]["created_by_end_user"] is None - - -def test_handle_trigger_metadata_should_return_empty_dict_when_metadata_missing( - service: WorkflowAppService, -) -> None: - # Arrange - # Act - result = service.handle_trigger_metadata("tenant-1", None) - - # Assert - assert result == {} - - -def test_handle_trigger_metadata_should_enrich_plugin_icons_for_trigger_plugin( - service: WorkflowAppService, - mocker: MockerFixture, -) -> None: - # Arrange - meta = { - "type": AppTriggerType.TRIGGER_PLUGIN.value, - "icon_filename": "light.png", - "icon_dark_filename": "dark.png", - } - mock_icon = mocker.patch( - "services.workflow_app_service.PluginService.get_plugin_icon_url", - side_effect=["https://cdn/light.png", "https://cdn/dark.png"], - ) - - # Act - result = service.handle_trigger_metadata("tenant-1", json.dumps(meta)) - - # Assert - assert result["icon"] == "https://cdn/light.png" - assert result["icon_dark"] == "https://cdn/dark.png" - assert mock_icon.call_count == 2 - - -def test_handle_trigger_metadata_should_return_non_plugin_metadata_without_icon_lookup( - service: WorkflowAppService, - mocker: MockerFixture, -) -> None: - # Arrange - meta = {"type": AppTriggerType.TRIGGER_WEBHOOK.value} - mock_icon = mocker.patch("services.workflow_app_service.PluginService.get_plugin_icon_url") - - # Act - result = service.handle_trigger_metadata("tenant-1", json.dumps(meta)) - - # Assert - assert result["type"] == AppTriggerType.TRIGGER_WEBHOOK.value - mock_icon.assert_not_called() - - -@pytest.mark.parametrize( - ("value", "expected"), - [ - (None, None), - ("", None), - ('{"k":"v"}', {"k": "v"}), - ("not-json", None), - ({"raw": True}, {"raw": True}), - ], -) -def test_safe_json_loads_should_handle_various_inputs( - value: object, - expected: object, - service: WorkflowAppService, -) -> None: - # Arrange - # Act - result = service._safe_json_loads(value) - - # Assert - assert result == expected - - -def test_safe_parse_uuid_should_return_none_for_short_or_invalid_values(service: WorkflowAppService) -> None: - # Arrange - # Act - short_result = service._safe_parse_uuid("short") - invalid_result = service._safe_parse_uuid("x" * 40) - - # Assert - assert short_result is None - assert invalid_result is None - - -def test_safe_parse_uuid_should_return_uuid_for_valid_uuid_string(service: WorkflowAppService) -> None: - # Arrange - raw_uuid = str(uuid.uuid4()) - - # Act - result = service._safe_parse_uuid(raw_uuid) - - # Assert - assert result is not None - assert str(result) == raw_uuid diff --git a/api/tests/unit_tests/services/tools/test_tools_transform_service.py b/api/tests/unit_tests/services/tools/test_tools_transform_service.py deleted file mode 100644 index 9616d2f102..0000000000 --- a/api/tests/unit_tests/services/tools/test_tools_transform_service.py +++ /dev/null @@ -1,452 +0,0 @@ -from unittest.mock import Mock - -from core.tools.__base.tool import Tool -from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity -from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolParameter, ToolProviderType -from services.tools.tools_transform_service import ToolTransformService - - -class TestToolTransformService: - """Test cases for ToolTransformService.convert_tool_entity_to_api_entity method""" - - def test_convert_tool_with_parameter_override(self): - """Test that runtime parameters correctly override base parameters""" - # Create mock base parameters - base_param1 = Mock(spec=ToolParameter) - base_param1.name = "param1" - base_param1.form = ToolParameter.ToolParameterForm.FORM - base_param1.type = "string" - base_param1.label = "Base Param 1" - - base_param2 = Mock(spec=ToolParameter) - base_param2.name = "param2" - base_param2.form = ToolParameter.ToolParameterForm.FORM - base_param2.type = "string" - base_param2.label = "Base Param 2" - - # Create mock runtime parameters that override base parameters - runtime_param1 = Mock(spec=ToolParameter) - runtime_param1.name = "param1" - runtime_param1.form = ToolParameter.ToolParameterForm.FORM - runtime_param1.type = "string" - runtime_param1.label = "Runtime Param 1" # Different label to verify override - - # Create mock tool - mock_tool = Mock(spec=Tool) - mock_tool.entity = Mock() - mock_tool.entity.parameters = [base_param1, base_param2] - mock_tool.entity.identity = Mock() - mock_tool.entity.identity.author = "test_author" - mock_tool.entity.identity.name = "test_tool" - mock_tool.entity.identity.label = I18nObject(en_US="Test Tool") - mock_tool.entity.description = Mock() - mock_tool.entity.description.human = I18nObject(en_US="Test description") - mock_tool.entity.output_schema = {} - mock_tool.get_runtime_parameters.return_value = [runtime_param1] - - # Mock fork_tool_runtime to return the same tool - mock_tool.fork_tool_runtime.return_value = mock_tool - - # Call the method - result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None) - - # Verify the result - assert isinstance(result, ToolApiEntity) - assert result.author == "test_author" - assert result.name == "test_tool" - assert result.parameters is not None - assert len(result.parameters) == 2 - - # Find the overridden parameter - overridden_param = next((p for p in result.parameters if p.name == "param1"), None) - assert overridden_param is not None - assert overridden_param.label == "Runtime Param 1" # Should be runtime version - - # Find the non-overridden parameter - original_param = next((p for p in result.parameters if p.name == "param2"), None) - assert original_param is not None - assert original_param.label == "Base Param 2" # Should be base version - - def test_convert_tool_with_additional_runtime_parameters(self): - """Test that additional runtime parameters are added to the final list""" - # Create mock base parameters - base_param1 = Mock(spec=ToolParameter) - base_param1.name = "param1" - base_param1.form = ToolParameter.ToolParameterForm.FORM - base_param1.type = "string" - base_param1.label = "Base Param 1" - - # Create mock runtime parameters - one that overrides and one that's new - runtime_param1 = Mock(spec=ToolParameter) - runtime_param1.name = "param1" - runtime_param1.form = ToolParameter.ToolParameterForm.FORM - runtime_param1.type = "string" - runtime_param1.label = "Runtime Param 1" - - runtime_param2 = Mock(spec=ToolParameter) - runtime_param2.name = "runtime_only" - runtime_param2.form = ToolParameter.ToolParameterForm.FORM - runtime_param2.type = "string" - runtime_param2.label = "Runtime Only Param" - - # Create mock tool - mock_tool = Mock(spec=Tool) - mock_tool.entity = Mock() - mock_tool.entity.parameters = [base_param1] - mock_tool.entity.identity = Mock() - mock_tool.entity.identity.author = "test_author" - mock_tool.entity.identity.name = "test_tool" - mock_tool.entity.identity.label = I18nObject(en_US="Test Tool") - mock_tool.entity.description = Mock() - mock_tool.entity.description.human = I18nObject(en_US="Test description") - mock_tool.entity.output_schema = {} - mock_tool.get_runtime_parameters.return_value = [runtime_param1, runtime_param2] - - # Mock fork_tool_runtime to return the same tool - mock_tool.fork_tool_runtime.return_value = mock_tool - - # Call the method - result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None) - - # Verify the result - assert isinstance(result, ToolApiEntity) - assert result.parameters is not None - assert len(result.parameters) == 2 - - # Check that both parameters are present - param_names = [p.name for p in result.parameters] - assert "param1" in param_names - assert "runtime_only" in param_names - - # Verify the overridden parameter has runtime version - overridden_param = next((p for p in result.parameters if p.name == "param1"), None) - assert overridden_param is not None - assert overridden_param.label == "Runtime Param 1" - - # Verify the new runtime parameter is included - new_param = next((p for p in result.parameters if p.name == "runtime_only"), None) - assert new_param is not None - assert new_param.label == "Runtime Only Param" - - def test_convert_tool_with_non_form_runtime_parameters(self): - """Test that non-FORM runtime parameters are not added as new parameters""" - # Create mock base parameters - base_param1 = Mock(spec=ToolParameter) - base_param1.name = "param1" - base_param1.form = ToolParameter.ToolParameterForm.FORM - base_param1.type = "string" - base_param1.label = "Base Param 1" - - # Create mock runtime parameters with different forms - runtime_param1 = Mock(spec=ToolParameter) - runtime_param1.name = "param1" - runtime_param1.form = ToolParameter.ToolParameterForm.FORM - runtime_param1.type = "string" - runtime_param1.label = "Runtime Param 1" - - runtime_param2 = Mock(spec=ToolParameter) - runtime_param2.name = "llm_param" - runtime_param2.form = ToolParameter.ToolParameterForm.LLM - runtime_param2.type = "string" - runtime_param2.label = "LLM Param" - - # Create mock tool - mock_tool = Mock(spec=Tool) - mock_tool.entity = Mock() - mock_tool.entity.parameters = [base_param1] - mock_tool.entity.identity = Mock() - mock_tool.entity.identity.author = "test_author" - mock_tool.entity.identity.name = "test_tool" - mock_tool.entity.identity.label = I18nObject(en_US="Test Tool") - mock_tool.entity.description = Mock() - mock_tool.entity.description.human = I18nObject(en_US="Test description") - mock_tool.entity.output_schema = {} - mock_tool.get_runtime_parameters.return_value = [runtime_param1, runtime_param2] - - # Mock fork_tool_runtime to return the same tool - mock_tool.fork_tool_runtime.return_value = mock_tool - - # Call the method - result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None) - - # Verify the result - assert isinstance(result, ToolApiEntity) - assert result.parameters is not None - assert len(result.parameters) == 1 # Only the FORM parameter should be present - - # Check that only the FORM parameter is present - param_names = [p.name for p in result.parameters] - assert "param1" in param_names - assert "llm_param" not in param_names - - def test_convert_tool_with_empty_parameters(self): - """Test conversion with empty base and runtime parameters""" - # Create mock tool with no parameters - mock_tool = Mock(spec=Tool) - mock_tool.entity = Mock() - mock_tool.entity.parameters = [] - mock_tool.entity.identity = Mock() - mock_tool.entity.identity.author = "test_author" - mock_tool.entity.identity.name = "test_tool" - mock_tool.entity.identity.label = I18nObject(en_US="Test Tool") - mock_tool.entity.description = Mock() - mock_tool.entity.description.human = I18nObject(en_US="Test description") - mock_tool.entity.output_schema = {} - mock_tool.get_runtime_parameters.return_value = [] - - # Mock fork_tool_runtime to return the same tool - mock_tool.fork_tool_runtime.return_value = mock_tool - - # Call the method - result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None) - - # Verify the result - assert isinstance(result, ToolApiEntity) - assert result.parameters is not None - assert len(result.parameters) == 0 - - def test_convert_tool_with_none_parameters(self): - """Test conversion when base parameters is None""" - # Create mock tool with None parameters - mock_tool = Mock(spec=Tool) - mock_tool.entity = Mock() - mock_tool.entity.parameters = None - mock_tool.entity.identity = Mock() - mock_tool.entity.identity.author = "test_author" - mock_tool.entity.identity.name = "test_tool" - mock_tool.entity.identity.label = I18nObject(en_US="Test Tool") - mock_tool.entity.description = Mock() - mock_tool.entity.description.human = I18nObject(en_US="Test description") - mock_tool.entity.output_schema = {} - mock_tool.get_runtime_parameters.return_value = [] - - # Mock fork_tool_runtime to return the same tool - mock_tool.fork_tool_runtime.return_value = mock_tool - - # Call the method - result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None) - - # Verify the result - assert isinstance(result, ToolApiEntity) - assert result.parameters is not None - assert len(result.parameters) == 0 - - def test_convert_tool_parameter_order_preserved(self): - """Test that parameter order is preserved correctly""" - # Create mock base parameters in specific order - base_param1 = Mock(spec=ToolParameter) - base_param1.name = "param1" - base_param1.form = ToolParameter.ToolParameterForm.FORM - base_param1.type = "string" - base_param1.label = "Base Param 1" - - base_param2 = Mock(spec=ToolParameter) - base_param2.name = "param2" - base_param2.form = ToolParameter.ToolParameterForm.FORM - base_param2.type = "string" - base_param2.label = "Base Param 2" - - base_param3 = Mock(spec=ToolParameter) - base_param3.name = "param3" - base_param3.form = ToolParameter.ToolParameterForm.FORM - base_param3.type = "string" - base_param3.label = "Base Param 3" - - # Create runtime parameter that overrides middle parameter - runtime_param2 = Mock(spec=ToolParameter) - runtime_param2.name = "param2" - runtime_param2.form = ToolParameter.ToolParameterForm.FORM - runtime_param2.type = "string" - runtime_param2.label = "Runtime Param 2" - - # Create new runtime parameter - runtime_param4 = Mock(spec=ToolParameter) - runtime_param4.name = "param4" - runtime_param4.form = ToolParameter.ToolParameterForm.FORM - runtime_param4.type = "string" - runtime_param4.label = "Runtime Param 4" - - # Create mock tool - mock_tool = Mock(spec=Tool) - mock_tool.entity = Mock() - mock_tool.entity.parameters = [base_param1, base_param2, base_param3] - mock_tool.entity.identity = Mock() - mock_tool.entity.identity.author = "test_author" - mock_tool.entity.identity.name = "test_tool" - mock_tool.entity.identity.label = I18nObject(en_US="Test Tool") - mock_tool.entity.description = Mock() - mock_tool.entity.description.human = I18nObject(en_US="Test description") - mock_tool.entity.output_schema = {} - mock_tool.get_runtime_parameters.return_value = [runtime_param2, runtime_param4] - - # Mock fork_tool_runtime to return the same tool - mock_tool.fork_tool_runtime.return_value = mock_tool - - # Call the method - result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None) - - # Verify the result - assert isinstance(result, ToolApiEntity) - assert result.parameters is not None - assert len(result.parameters) == 4 - - # Check that order is maintained: base parameters first, then new runtime parameters - param_names = [p.name for p in result.parameters] - assert param_names == ["param1", "param2", "param3", "param4"] - - # Verify that param2 was overridden with runtime version - param2 = result.parameters[1] - assert param2.name == "param2" - assert param2.label == "Runtime Param 2" - - -class TestWorkflowProviderToUserProvider: - """Test cases for ToolTransformService.workflow_provider_to_user_provider method""" - - def test_workflow_provider_to_user_provider_with_workflow_app_id(self): - """Test that workflow_provider_to_user_provider correctly sets workflow_app_id.""" - from core.tools.workflow_as_tool.provider import WorkflowToolProviderController - - # Create mock workflow tool provider controller - workflow_app_id = "app_123" - provider_id = "provider_123" - mock_controller = Mock(spec=WorkflowToolProviderController) - mock_controller.provider_id = provider_id - mock_controller.entity = Mock() - mock_controller.entity.identity = Mock() - mock_controller.entity.identity.author = "test_author" - mock_controller.entity.identity.name = "test_workflow_tool" - mock_controller.entity.identity.description = I18nObject(en_US="Test description") - mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"} - mock_controller.entity.identity.icon_dark = None - mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool") - - # Call the method - result = ToolTransformService.workflow_provider_to_user_provider( - provider_controller=mock_controller, - labels=["label1", "label2"], - workflow_app_id=workflow_app_id, - ) - - # Verify the result - assert isinstance(result, ToolProviderApiEntity) - assert result.id == provider_id - assert result.author == "test_author" - assert result.name == "test_workflow_tool" - assert result.type == ToolProviderType.WORKFLOW - assert result.workflow_app_id == workflow_app_id - assert result.labels == ["label1", "label2"] - assert result.is_team_authorization is True - assert result.plugin_id is None - assert result.plugin_unique_identifier is None - assert result.tools == [] - - def test_workflow_provider_to_user_provider_without_workflow_app_id(self): - """Test that workflow_provider_to_user_provider works when workflow_app_id is not provided.""" - from core.tools.workflow_as_tool.provider import WorkflowToolProviderController - - # Create mock workflow tool provider controller - provider_id = "provider_123" - mock_controller = Mock(spec=WorkflowToolProviderController) - mock_controller.provider_id = provider_id - mock_controller.entity = Mock() - mock_controller.entity.identity = Mock() - mock_controller.entity.identity.author = "test_author" - mock_controller.entity.identity.name = "test_workflow_tool" - mock_controller.entity.identity.description = I18nObject(en_US="Test description") - mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"} - mock_controller.entity.identity.icon_dark = None - mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool") - - # Call the method without workflow_app_id - result = ToolTransformService.workflow_provider_to_user_provider( - provider_controller=mock_controller, - labels=["label1"], - ) - - # Verify the result - assert isinstance(result, ToolProviderApiEntity) - assert result.id == provider_id - assert result.workflow_app_id is None - assert result.labels == ["label1"] - - def test_workflow_provider_to_user_provider_workflow_app_id_none(self): - """Test that workflow_provider_to_user_provider handles None workflow_app_id explicitly.""" - from core.tools.workflow_as_tool.provider import WorkflowToolProviderController - - # Create mock workflow tool provider controller - provider_id = "provider_123" - mock_controller = Mock(spec=WorkflowToolProviderController) - mock_controller.provider_id = provider_id - mock_controller.entity = Mock() - mock_controller.entity.identity = Mock() - mock_controller.entity.identity.author = "test_author" - mock_controller.entity.identity.name = "test_workflow_tool" - mock_controller.entity.identity.description = I18nObject(en_US="Test description") - mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"} - mock_controller.entity.identity.icon_dark = None - mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool") - - # Call the method with explicit None values - result = ToolTransformService.workflow_provider_to_user_provider( - provider_controller=mock_controller, - labels=None, - workflow_app_id=None, - ) - - # Verify the result - assert isinstance(result, ToolProviderApiEntity) - assert result.id == provider_id - assert result.workflow_app_id is None - assert result.labels == [] - - def test_workflow_provider_to_user_provider_preserves_other_fields(self): - """Test that workflow_provider_to_user_provider preserves all other entity fields.""" - from core.tools.workflow_as_tool.provider import WorkflowToolProviderController - - # Create mock workflow tool provider controller with various fields - workflow_app_id = "app_456" - provider_id = "provider_456" - mock_controller = Mock(spec=WorkflowToolProviderController) - mock_controller.provider_id = provider_id - mock_controller.entity = Mock() - mock_controller.entity.identity = Mock() - mock_controller.entity.identity.author = "another_author" - mock_controller.entity.identity.name = "another_workflow_tool" - mock_controller.entity.identity.description = I18nObject( - en_US="Another description", zh_Hans="Another description" - ) - mock_controller.entity.identity.icon = {"type": "emoji", "content": "⚙️"} - mock_controller.entity.identity.icon_dark = {"type": "emoji", "content": "🔧"} - mock_controller.entity.identity.label = I18nObject( - en_US="Another Workflow Tool", zh_Hans="Another Workflow Tool" - ) - - # Call the method - result = ToolTransformService.workflow_provider_to_user_provider( - provider_controller=mock_controller, - labels=["automation", "workflow"], - workflow_app_id=workflow_app_id, - ) - - # Verify all fields are preserved correctly - assert isinstance(result, ToolProviderApiEntity) - assert result.id == provider_id - assert result.author == "another_author" - assert result.name == "another_workflow_tool" - assert result.description.en_US == "Another description" - assert result.description.zh_Hans == "Another description" - assert result.icon == {"type": "emoji", "content": "⚙️"} - assert result.icon_dark == {"type": "emoji", "content": "🔧"} - assert result.label.en_US == "Another Workflow Tool" - assert result.label.zh_Hans == "Another Workflow Tool" - assert result.type == ToolProviderType.WORKFLOW - assert result.workflow_app_id == workflow_app_id - assert result.labels == ["automation", "workflow"] - assert result.masked_credentials == {} - assert result.is_team_authorization is True - assert result.allow_delete is True - assert result.plugin_id is None - assert result.plugin_unique_identifier is None - assert result.tools == [] diff --git a/api/tests/unit_tests/services/vector_service.py b/api/tests/unit_tests/services/vector_service.py index e180063041..33a5607ef4 100644 --- a/api/tests/unit_tests/services/vector_service.py +++ b/api/tests/unit_tests/services/vector_service.py @@ -121,7 +121,7 @@ import pytest from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_type import VectorType -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.models.document import Document from models.dataset import ChildChunk, Dataset, DatasetDocument, DatasetProcessRule, DocumentSegment from services.vector_service import VectorService @@ -153,7 +153,7 @@ class VectorServiceTestDataFactory: dataset_id: str = "dataset-123", tenant_id: str = "tenant-123", doc_form: str = IndexStructureType.PARAGRAPH_INDEX, - indexing_technique: str = "high_quality", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, embedding_model_provider: str = "openai", embedding_model: str = "text-embedding-ada-002", index_struct_dict: dict | None = None, @@ -494,7 +494,7 @@ class TestVectorService: """ # Arrange dataset = VectorServiceTestDataFactory.create_dataset_mock( - doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_technique="high_quality" + doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_technique=IndexTechniqueType.HIGH_QUALITY ) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -535,7 +535,7 @@ class TestVectorService: """ # Arrange dataset = VectorServiceTestDataFactory.create_dataset_mock( - doc_form="parent_child_model", indexing_technique="high_quality" + doc_form="parent_child_model", indexing_technique=IndexTechniqueType.HIGH_QUALITY ) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -568,7 +568,7 @@ class TestVectorService: """ # Arrange dataset = VectorServiceTestDataFactory.create_dataset_mock( - doc_form="parent_child_model", indexing_technique="high_quality" + doc_form="parent_child_model", indexing_technique=IndexTechniqueType.HIGH_QUALITY ) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -591,7 +591,7 @@ class TestVectorService: """ # Arrange dataset = VectorServiceTestDataFactory.create_dataset_mock( - doc_form="parent_child_model", indexing_technique="high_quality" + doc_form="parent_child_model", indexing_technique=IndexTechniqueType.HIGH_QUALITY ) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -616,7 +616,7 @@ class TestVectorService: """ # Arrange dataset = VectorServiceTestDataFactory.create_dataset_mock( - doc_form="parent_child_model", indexing_technique="economy" + doc_form="parent_child_model", indexing_technique=IndexTechniqueType.ECONOMY ) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -669,7 +669,7 @@ class TestVectorService: store when using high_quality indexing. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -695,7 +695,7 @@ class TestVectorService: index when using economy indexing with keywords. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -731,7 +731,7 @@ class TestVectorService: index when using economy indexing without keywords. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -895,7 +895,7 @@ class TestVectorService: when using high_quality indexing. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) child_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() @@ -923,7 +923,7 @@ class TestVectorService: using economy indexing. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) child_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() @@ -951,7 +951,7 @@ class TestVectorService: when there are new chunks, updated chunks, and deleted chunks. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) new_chunk = VectorServiceTestDataFactory.create_child_chunk_mock(chunk_id="new-chunk-1") @@ -993,7 +993,7 @@ class TestVectorService: add_texts is called, not delete_by_ids. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) new_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() @@ -1019,7 +1019,7 @@ class TestVectorService: delete_by_ids is called, not add_texts. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) delete_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() @@ -1045,7 +1045,7 @@ class TestVectorService: using economy indexing. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) new_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() @@ -1075,7 +1075,7 @@ class TestVectorService: when using high_quality indexing. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) child_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() @@ -1099,7 +1099,7 @@ class TestVectorService: using economy indexing. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) child_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() diff --git a/api/tests/unit_tests/tasks/test_clean_dataset_task.py b/api/tests/unit_tests/tasks/test_clean_dataset_task.py index c0a4d2f113..936a10d6c5 100644 --- a/api/tests/unit_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/unit_tests/tasks/test_clean_dataset_task.py @@ -16,7 +16,7 @@ from unittest.mock import MagicMock, patch import pytest -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models.enums import DataSourceType from tasks.clean_dataset_task import clean_dataset_task @@ -184,7 +184,7 @@ class TestErrorHandling: clean_dataset_task( dataset_id=dataset_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, doc_form=IndexStructureType.PARAGRAPH_INDEX, @@ -229,7 +229,7 @@ class TestPipelineAndWorkflowDeletion: clean_dataset_task( dataset_id=dataset_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, doc_form=IndexStructureType.PARAGRAPH_INDEX, @@ -265,7 +265,7 @@ class TestPipelineAndWorkflowDeletion: clean_dataset_task( dataset_id=dataset_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, doc_form=IndexStructureType.PARAGRAPH_INDEX, @@ -321,7 +321,7 @@ class TestSegmentAttachmentCleanup: clean_dataset_task( dataset_id=dataset_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, doc_form=IndexStructureType.PARAGRAPH_INDEX, @@ -366,7 +366,7 @@ class TestSegmentAttachmentCleanup: clean_dataset_task( dataset_id=dataset_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, doc_form=IndexStructureType.PARAGRAPH_INDEX, @@ -408,7 +408,7 @@ class TestEdgeCases: clean_dataset_task( dataset_id=dataset_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, doc_form=IndexStructureType.PARAGRAPH_INDEX, @@ -445,7 +445,7 @@ class TestIndexProcessorParameters: - Dataset object with correct attributes is passed """ # Arrange - indexing_technique = "high_quality" + indexing_technique = IndexTechniqueType.HIGH_QUALITY index_struct = '{"type": "paragraph"}' # Act diff --git a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py index 027cd3b1ec..0b189ebae2 100644 --- a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py @@ -15,7 +15,7 @@ from unittest.mock import MagicMock, Mock, patch import pytest from core.indexing_runner import DocumentIsPausedError -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan from extensions.ext_redis import redis_client @@ -209,7 +209,7 @@ def mock_dataset(dataset_id, tenant_id): dataset = Mock(spec=Dataset) dataset.id = dataset_id dataset.tenant_id = tenant_id - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.embedding_model_provider = "openai" dataset.embedding_model = "text-embedding-ada-002" return dataset diff --git a/web/app/components/tools/workflow-tool/__tests__/configure-button.spec.tsx b/web/app/components/tools/workflow-tool/__tests__/configure-button.spec.tsx index 9cd66e37ea..5deed8174d 100644 --- a/web/app/components/tools/workflow-tool/__tests__/configure-button.spec.tsx +++ b/web/app/components/tools/workflow-tool/__tests__/configure-button.spec.tsx @@ -49,9 +49,12 @@ vi.mock('@/service/use-tools', () => ({ // Mock Toast - need to verify notification calls const mockToastNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ - default: { - notify: (options: { type: string, message: string }) => mockToastNotify(options), +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + success: (message: string) => mockToastNotify({ type: 'success', message }), + error: (message: string) => mockToastNotify({ type: 'error', message }), + warning: (message: string) => mockToastNotify({ type: 'warning', message }), + info: (message: string) => mockToastNotify({ type: 'info', message }), }, })) diff --git a/web/app/components/tools/workflow-tool/hooks/__tests__/use-configure-button.spec.ts b/web/app/components/tools/workflow-tool/hooks/__tests__/use-configure-button.spec.ts index ad0dd2eff2..ac61872a18 100644 --- a/web/app/components/tools/workflow-tool/hooks/__tests__/use-configure-button.spec.ts +++ b/web/app/components/tools/workflow-tool/hooks/__tests__/use-configure-button.spec.ts @@ -33,9 +33,12 @@ vi.mock('@/service/use-tools', () => ({ })) const mockToastNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ - default: { - notify: (options: { type: string, message: string }) => mockToastNotify(options), +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + success: (message: string) => mockToastNotify({ type: 'success', message }), + error: (message: string) => mockToastNotify({ type: 'error', message }), + warning: (message: string) => mockToastNotify({ type: 'warning', message }), + info: (message: string) => mockToastNotify({ type: 'info', message }), }, })) diff --git a/web/app/components/tools/workflow-tool/hooks/use-configure-button.ts b/web/app/components/tools/workflow-tool/hooks/use-configure-button.ts index 701ae8fd01..142b0c2397 100644 --- a/web/app/components/tools/workflow-tool/hooks/use-configure-button.ts +++ b/web/app/components/tools/workflow-tool/hooks/use-configure-button.ts @@ -3,7 +3,7 @@ import type { InputVar, Variable } from '@/app/components/workflow/types' import type { PublishWorkflowParams } from '@/types/workflow' import { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { useAppContext } from '@/context/app-context' import { useRouter } from '@/next/navigation' import { createWorkflowToolProvider, saveWorkflowToolProvider } from '@/service/tools' @@ -188,14 +188,11 @@ export function useConfigureButton(options: UseConfigureButtonOptions) { invalidateAllWorkflowTools() onRefreshData?.() invalidateDetail(workflowAppId) - Toast.notify({ - type: 'success', - message: t('api.actionSuccess', { ns: 'common' }), - }) + toast.success(t('api.actionSuccess', { ns: 'common' })) setShowModal(false) } catch (e) { - Toast.notify({ type: 'error', message: (e as Error).message }) + toast.error((e as Error).message) } } @@ -209,14 +206,11 @@ export function useConfigureButton(options: UseConfigureButtonOptions) { onRefreshData?.() invalidateAllWorkflowTools() invalidateDetail(workflowAppId) - Toast.notify({ - type: 'success', - message: t('api.actionSuccess', { ns: 'common' }), - }) + toast.success(t('api.actionSuccess', { ns: 'common' })) setShowModal(false) } catch (e) { - Toast.notify({ type: 'error', message: (e as Error).message }) + toast.error((e as Error).message) } } diff --git a/web/app/components/tools/workflow-tool/index.tsx b/web/app/components/tools/workflow-tool/index.tsx index 78375857ea..06aeb1ba79 100644 --- a/web/app/components/tools/workflow-tool/index.tsx +++ b/web/app/components/tools/workflow-tool/index.tsx @@ -12,8 +12,8 @@ import Drawer from '@/app/components/base/drawer-plus' import EmojiPicker from '@/app/components/base/emoji-picker' import Input from '@/app/components/base/input' import Textarea from '@/app/components/base/textarea' -import Toast from '@/app/components/base/toast' import Tooltip from '@/app/components/base/tooltip' +import { toast } from '@/app/components/base/ui/toast' import LabelSelector from '@/app/components/tools/labels/selector' import ConfirmModal from '@/app/components/tools/workflow-tool/confirm-modal' import MethodSelector from '@/app/components/tools/workflow-tool/method-selector' @@ -129,10 +129,7 @@ const WorkflowToolAsModal: FC = ({ errorMessage = t('createTool.nameForToolCall', { ns: 'tools' }) + t('createTool.nameForToolCallTip', { ns: 'tools' }) if (errorMessage) { - Toast.notify({ - type: 'error', - message: errorMessage, - }) + toast.error(errorMessage) return } diff --git a/web/app/components/workflow/__tests__/candidate-node-main.spec.tsx b/web/app/components/workflow/__tests__/candidate-node-main.spec.tsx new file mode 100644 index 0000000000..61e5410aac --- /dev/null +++ b/web/app/components/workflow/__tests__/candidate-node-main.spec.tsx @@ -0,0 +1,260 @@ +import { render, screen } from '@testing-library/react' +import CandidateNodeMain from '../candidate-node-main' +import { CUSTOM_NODE } from '../constants' +import { CUSTOM_NOTE_NODE } from '../note-node/constants' +import { BlockEnum } from '../types' +import { createNode } from './fixtures' + +const mockUseEventListener = vi.hoisted(() => vi.fn()) +const mockUseStoreApi = vi.hoisted(() => vi.fn()) +const mockUseReactFlow = vi.hoisted(() => vi.fn()) +const mockUseViewport = vi.hoisted(() => vi.fn()) +const mockUseStore = vi.hoisted(() => vi.fn()) +const mockUseWorkflowStore = vi.hoisted(() => vi.fn()) +const mockUseHooks = vi.hoisted(() => vi.fn()) +const mockCustomNode = vi.hoisted(() => vi.fn()) +const mockCustomNoteNode = vi.hoisted(() => vi.fn()) +const mockGetIterationStartNode = vi.hoisted(() => vi.fn()) +const mockGetLoopStartNode = vi.hoisted(() => vi.fn()) + +vi.mock('ahooks', () => ({ + useEventListener: (...args: unknown[]) => mockUseEventListener(...args), +})) + +vi.mock('reactflow', () => ({ + useStoreApi: () => mockUseStoreApi(), + useReactFlow: () => mockUseReactFlow(), + useViewport: () => mockUseViewport(), + Position: { + Left: 'left', + Right: 'right', + }, +})) + +vi.mock('@/app/components/workflow/store', () => ({ + useStore: (selector: (state: { mousePosition: { + pageX: number + pageY: number + elementX: number + elementY: number + } }) => unknown) => mockUseStore(selector), + useWorkflowStore: () => mockUseWorkflowStore(), +})) + +vi.mock('@/app/components/workflow/hooks', () => ({ + useNodesInteractions: () => mockUseHooks().useNodesInteractions(), + useNodesSyncDraft: () => mockUseHooks().useNodesSyncDraft(), + useWorkflowHistory: () => mockUseHooks().useWorkflowHistory(), + useAutoGenerateWebhookUrl: () => mockUseHooks().useAutoGenerateWebhookUrl(), + WorkflowHistoryEvent: { + NodeAdd: 'NodeAdd', + NoteAdd: 'NoteAdd', + }, +})) + +vi.mock('@/app/components/workflow/nodes', () => ({ + __esModule: true, + default: (props: { id: string }) => { + mockCustomNode(props) + return
{props.id}
+ }, +})) + +vi.mock('@/app/components/workflow/note-node', () => ({ + __esModule: true, + default: (props: { id: string }) => { + mockCustomNoteNode(props) + return
{props.id}
+ }, +})) + +vi.mock('@/app/components/workflow/utils', () => ({ + getIterationStartNode: (...args: unknown[]) => mockGetIterationStartNode(...args), + getLoopStartNode: (...args: unknown[]) => mockGetLoopStartNode(...args), +})) + +describe('CandidateNodeMain', () => { + const mockSetNodes = vi.fn() + const mockHandleNodeSelect = vi.fn() + const mockSaveStateToHistory = vi.fn() + const mockHandleSyncWorkflowDraft = vi.fn() + const mockAutoGenerateWebhookUrl = vi.fn() + const mockWorkflowStoreSetState = vi.fn() + const createNodesInteractions = () => ({ + handleNodeSelect: mockHandleNodeSelect, + }) + const createWorkflowHistory = () => ({ + saveStateToHistory: mockSaveStateToHistory, + }) + const createNodesSyncDraft = () => ({ + handleSyncWorkflowDraft: mockHandleSyncWorkflowDraft, + }) + const createAutoGenerateWebhookUrl = () => mockAutoGenerateWebhookUrl + const eventHandlers: Partial void }) => void>> = {} + let nodes = [createNode({ id: 'existing-node' })] + + beforeEach(() => { + vi.clearAllMocks() + nodes = [createNode({ id: 'existing-node' })] + eventHandlers.click = undefined + eventHandlers.contextmenu = undefined + + mockUseEventListener.mockImplementation((event: 'click' | 'contextmenu', handler: (event: { preventDefault: () => void }) => void) => { + eventHandlers[event] = handler + }) + mockUseStoreApi.mockReturnValue({ + getState: () => ({ + getNodes: () => nodes, + setNodes: mockSetNodes, + }), + }) + mockUseReactFlow.mockReturnValue({ + screenToFlowPosition: ({ x, y }: { x: number, y: number }) => ({ x: x + 10, y: y + 20 }), + }) + mockUseViewport.mockReturnValue({ zoom: 1.5 }) + mockUseStore.mockImplementation((selector: (state: { mousePosition: { + pageX: number + pageY: number + elementX: number + elementY: number + } }) => unknown) => selector({ + mousePosition: { + pageX: 100, + pageY: 200, + elementX: 30, + elementY: 40, + }, + })) + mockUseWorkflowStore.mockReturnValue({ + setState: mockWorkflowStoreSetState, + }) + mockUseHooks.mockReturnValue({ + useNodesInteractions: createNodesInteractions, + useWorkflowHistory: createWorkflowHistory, + useNodesSyncDraft: createNodesSyncDraft, + useAutoGenerateWebhookUrl: createAutoGenerateWebhookUrl, + }) + mockHandleSyncWorkflowDraft.mockImplementation((_isSync: boolean, _force: boolean, options?: { onSuccess?: () => void }) => { + options?.onSuccess?.() + }) + mockGetIterationStartNode.mockReturnValue(createNode({ id: 'iteration-start' })) + mockGetLoopStartNode.mockReturnValue(createNode({ id: 'loop-start' })) + }) + + it('should render the candidate node and commit a webhook node on click', () => { + const candidateNode = createNode({ + id: 'candidate-webhook', + type: CUSTOM_NODE, + data: { + type: BlockEnum.TriggerWebhook, + title: 'Webhook Candidate', + _isCandidate: true, + }, + }) + + const { container } = render() + + expect(screen.getByTestId('candidate-custom-node')).toHaveTextContent('candidate-webhook') + expect(container.firstChild).toHaveStyle({ + left: '30px', + top: '40px', + transform: 'scale(1.5)', + }) + + eventHandlers.click?.({ preventDefault: vi.fn() }) + + expect(mockSetNodes).toHaveBeenCalledWith(expect.arrayContaining([ + expect.objectContaining({ id: 'existing-node' }), + expect.objectContaining({ + id: 'candidate-webhook', + position: { x: 110, y: 220 }, + data: expect.objectContaining({ _isCandidate: false }), + }), + ])) + expect(mockSaveStateToHistory).toHaveBeenCalledWith('NodeAdd', { nodeId: 'candidate-webhook' }) + expect(mockWorkflowStoreSetState).toHaveBeenCalledWith({ candidateNode: undefined }) + expect(mockHandleSyncWorkflowDraft).toHaveBeenCalledWith(true, true, expect.objectContaining({ + onSuccess: expect.any(Function), + })) + expect(mockAutoGenerateWebhookUrl).toHaveBeenCalledWith('candidate-webhook') + expect(mockHandleNodeSelect).not.toHaveBeenCalled() + }) + + it('should save note candidates as notes and select the inserted note', () => { + const candidateNode = createNode({ + id: 'candidate-note', + type: CUSTOM_NOTE_NODE, + data: { + type: BlockEnum.Code, + title: 'Note Candidate', + _isCandidate: true, + }, + }) + + render() + + expect(screen.getByTestId('candidate-note-node')).toHaveTextContent('candidate-note') + + eventHandlers.click?.({ preventDefault: vi.fn() }) + + expect(mockSaveStateToHistory).toHaveBeenCalledWith('NoteAdd', { nodeId: 'candidate-note' }) + expect(mockHandleNodeSelect).toHaveBeenCalledWith('candidate-note') + }) + + it('should append iteration and loop start helper nodes for control-flow candidates', () => { + const iterationNode = createNode({ + id: 'candidate-iteration', + type: CUSTOM_NODE, + data: { + type: BlockEnum.Iteration, + title: 'Iteration Candidate', + _isCandidate: true, + }, + }) + const loopNode = createNode({ + id: 'candidate-loop', + type: CUSTOM_NODE, + data: { + type: BlockEnum.Loop, + title: 'Loop Candidate', + _isCandidate: true, + }, + }) + + const { rerender } = render() + + eventHandlers.click?.({ preventDefault: vi.fn() }) + expect(mockGetIterationStartNode).toHaveBeenCalledWith('candidate-iteration') + expect(mockSetNodes.mock.calls[0][0]).toEqual(expect.arrayContaining([ + expect.objectContaining({ id: 'candidate-iteration' }), + expect.objectContaining({ id: 'iteration-start' }), + ])) + + rerender() + eventHandlers.click?.({ preventDefault: vi.fn() }) + + expect(mockGetLoopStartNode).toHaveBeenCalledWith('candidate-loop') + expect(mockSetNodes.mock.calls[1][0]).toEqual(expect.arrayContaining([ + expect.objectContaining({ id: 'candidate-loop' }), + expect.objectContaining({ id: 'loop-start' }), + ])) + }) + + it('should clear the candidate node on contextmenu', () => { + const candidateNode = createNode({ + id: 'candidate-context', + type: CUSTOM_NODE, + data: { + type: BlockEnum.Code, + title: 'Context Candidate', + _isCandidate: true, + }, + }) + + render() + + eventHandlers.contextmenu?.({ preventDefault: vi.fn() }) + + expect(mockWorkflowStoreSetState).toHaveBeenCalledWith({ candidateNode: undefined }) + }) +}) diff --git a/web/app/components/workflow/__tests__/custom-edge.spec.tsx b/web/app/components/workflow/__tests__/custom-edge.spec.tsx new file mode 100644 index 0000000000..f8ff9a1a0e --- /dev/null +++ b/web/app/components/workflow/__tests__/custom-edge.spec.tsx @@ -0,0 +1,235 @@ +import type { ReactNode } from 'react' +import { fireEvent, render, screen } from '@testing-library/react' +import { Position } from 'reactflow' +import { ErrorHandleTypeEnum } from '@/app/components/workflow/nodes/_base/components/error-handle/types' +import CustomEdge from '../custom-edge' +import { BlockEnum, NodeRunningStatus } from '../types' + +const mockUseAvailableBlocks = vi.hoisted(() => vi.fn()) +const mockUseNodesInteractions = vi.hoisted(() => vi.fn()) +const mockBlockSelector = vi.hoisted(() => vi.fn()) +const mockGradientRender = vi.hoisted(() => vi.fn()) + +vi.mock('reactflow', () => ({ + BaseEdge: (props: { + id: string + path: string + style: { + stroke: string + strokeWidth: number + opacity: number + strokeDasharray?: string + } + }) => ( +
+ ), + EdgeLabelRenderer: ({ children }: { children?: ReactNode }) =>
{children}
, + getBezierPath: () => ['M 0 0', 24, 48], + Position: { + Right: 'right', + Left: 'left', + }, +})) + +vi.mock('@/app/components/workflow/hooks', () => ({ + useAvailableBlocks: (...args: unknown[]) => mockUseAvailableBlocks(...args), + useNodesInteractions: () => mockUseNodesInteractions(), +})) + +vi.mock('@/app/components/workflow/block-selector', () => ({ + __esModule: true, + default: (props: { + open: boolean + onOpenChange: (open: boolean) => void + onSelect: (nodeType: string, pluginDefaultValue?: Record) => void + availableBlocksTypes: string[] + triggerClassName?: () => string + }) => { + mockBlockSelector(props) + return ( + + ) + }, +})) + +vi.mock('@/app/components/workflow/custom-edge-linear-gradient-render', () => ({ + __esModule: true, + default: (props: { + id: string + startColor: string + stopColor: string + }) => { + mockGradientRender(props) + return
{props.id}
+ }, +})) + +describe('CustomEdge', () => { + const mockHandleNodeAdd = vi.fn() + + beforeEach(() => { + vi.clearAllMocks() + mockUseNodesInteractions.mockReturnValue({ + handleNodeAdd: mockHandleNodeAdd, + }) + mockUseAvailableBlocks.mockImplementation((nodeType: BlockEnum) => { + if (nodeType === BlockEnum.Code) + return { availablePrevBlocks: ['code', 'llm'] } + + return { availableNextBlocks: ['llm', 'tool'] } + }) + }) + + it('should render a gradient edge and insert a node between the source and target', () => { + render( + , + ) + + expect(screen.getByTestId('edge-gradient')).toHaveTextContent('edge-1') + expect(mockGradientRender).toHaveBeenCalledWith(expect.objectContaining({ + id: 'edge-1', + startColor: 'var(--color-workflow-link-line-success-handle)', + stopColor: 'var(--color-workflow-link-line-error-handle)', + })) + expect(screen.getByTestId('base-edge')).toHaveAttribute('data-stroke', 'url(#edge-1)') + expect(screen.getByTestId('base-edge')).toHaveAttribute('data-opacity', '0.3') + expect(screen.getByTestId('base-edge')).toHaveAttribute('data-dasharray', '8 8') + expect(screen.getByTestId('block-selector')).toHaveTextContent('llm') + expect(screen.getByTestId('block-selector').parentElement).toHaveStyle({ + transform: 'translate(-50%, -50%) translate(24px, 48px)', + opacity: '0.7', + }) + + fireEvent.click(screen.getByTestId('block-selector')) + + expect(mockHandleNodeAdd).toHaveBeenCalledWith( + { + nodeType: 'llm', + pluginDefaultValue: { provider: 'openai' }, + }, + { + prevNodeId: 'source-node', + prevNodeSourceHandle: 'source', + nextNodeId: 'target-node', + nextNodeTargetHandle: 'target', + }, + ) + }) + + it('should prefer the running stroke color when the edge is selected', () => { + render( + , + ) + + expect(screen.getByTestId('base-edge')).toHaveAttribute('data-stroke', 'var(--color-workflow-link-line-handle)') + }) + + it('should use the fail-branch running color while the connected node is hovering', () => { + render( + , + ) + + expect(screen.getByTestId('base-edge')).toHaveAttribute('data-stroke', 'var(--color-workflow-link-line-failure-handle)') + }) + + it('should fall back to the default edge color when no highlight state is active', () => { + render( + , + ) + + expect(screen.getByTestId('base-edge')).toHaveAttribute('data-stroke', 'var(--color-workflow-link-line-normal)') + expect(screen.getByTestId('block-selector')).toHaveAttribute('data-trigger-class', 'hover:scale-150 transition-all') + }) +}) diff --git a/web/app/components/workflow/__tests__/node-contextmenu.spec.tsx b/web/app/components/workflow/__tests__/node-contextmenu.spec.tsx new file mode 100644 index 0000000000..7418b7f313 --- /dev/null +++ b/web/app/components/workflow/__tests__/node-contextmenu.spec.tsx @@ -0,0 +1,114 @@ +import type { Node } from '../types' +import { fireEvent, render, screen } from '@testing-library/react' +import NodeContextmenu from '../node-contextmenu' + +const mockUseClickAway = vi.hoisted(() => vi.fn()) +const mockUseNodes = vi.hoisted(() => vi.fn()) +const mockUsePanelInteractions = vi.hoisted(() => vi.fn()) +const mockUseStore = vi.hoisted(() => vi.fn()) +const mockPanelOperatorPopup = vi.hoisted(() => vi.fn()) + +vi.mock('ahooks', () => ({ + useClickAway: (...args: unknown[]) => mockUseClickAway(...args), +})) + +vi.mock('@/app/components/workflow/store/workflow/use-nodes', () => ({ + __esModule: true, + default: () => mockUseNodes(), +})) + +vi.mock('@/app/components/workflow/hooks', () => ({ + usePanelInteractions: () => mockUsePanelInteractions(), +})) + +vi.mock('@/app/components/workflow/store', () => ({ + useStore: (selector: (state: { nodeMenu?: { nodeId: string, left: number, top: number } }) => unknown) => mockUseStore(selector), +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/panel-operator/panel-operator-popup', () => ({ + __esModule: true, + default: (props: { + id: string + data: Node['data'] + showHelpLink: boolean + onClosePopup: () => void + }) => { + mockPanelOperatorPopup(props) + return ( + + ) + }, +})) + +describe('NodeContextmenu', () => { + const mockHandleNodeContextmenuCancel = vi.fn() + let nodeMenu: { nodeId: string, left: number, top: number } | undefined + let nodes: Node[] + let clickAwayHandler: (() => void) | undefined + + beforeEach(() => { + vi.clearAllMocks() + nodeMenu = undefined + nodes = [{ + id: 'node-1', + type: 'custom', + position: { x: 0, y: 0 }, + data: { + title: 'Node 1', + desc: '', + type: 'code' as never, + }, + } as Node] + clickAwayHandler = undefined + + mockUseClickAway.mockImplementation((handler: () => void) => { + clickAwayHandler = handler + }) + mockUseNodes.mockImplementation(() => nodes) + mockUsePanelInteractions.mockReturnValue({ + handleNodeContextmenuCancel: mockHandleNodeContextmenuCancel, + }) + mockUseStore.mockImplementation((selector: (state: { nodeMenu?: { nodeId: string, left: number, top: number } }) => unknown) => selector({ nodeMenu })) + }) + + it('should stay hidden when the node menu is absent', () => { + render() + + expect(screen.queryByRole('button')).not.toBeInTheDocument() + expect(mockPanelOperatorPopup).not.toHaveBeenCalled() + }) + + it('should stay hidden when the referenced node cannot be found', () => { + nodeMenu = { nodeId: 'missing-node', left: 80, top: 120 } + + render() + + expect(screen.queryByRole('button')).not.toBeInTheDocument() + expect(mockPanelOperatorPopup).not.toHaveBeenCalled() + }) + + it('should render the popup at the stored position and close on popup/click-away actions', () => { + nodeMenu = { nodeId: 'node-1', left: 80, top: 120 } + const { container } = render() + + expect(screen.getByRole('button')).toHaveTextContent('node-1:Node 1') + expect(mockPanelOperatorPopup).toHaveBeenCalledWith(expect.objectContaining({ + id: 'node-1', + data: expect.objectContaining({ title: 'Node 1' }), + showHelpLink: true, + })) + expect(container.firstChild).toHaveStyle({ + left: '80px', + top: '120px', + }) + + fireEvent.click(screen.getByRole('button')) + clickAwayHandler?.() + + expect(mockHandleNodeContextmenuCancel).toHaveBeenCalledTimes(2) + }) +}) diff --git a/web/app/components/workflow/__tests__/panel-contextmenu.spec.tsx b/web/app/components/workflow/__tests__/panel-contextmenu.spec.tsx new file mode 100644 index 0000000000..914c1be617 --- /dev/null +++ b/web/app/components/workflow/__tests__/panel-contextmenu.spec.tsx @@ -0,0 +1,151 @@ +import type { ReactNode } from 'react' +import { fireEvent, render, screen } from '@testing-library/react' +import PanelContextmenu from '../panel-contextmenu' + +const mockUseClickAway = vi.hoisted(() => vi.fn()) +const mockUseTranslation = vi.hoisted(() => vi.fn()) +const mockUseStore = vi.hoisted(() => vi.fn()) +const mockUseNodesInteractions = vi.hoisted(() => vi.fn()) +const mockUsePanelInteractions = vi.hoisted(() => vi.fn()) +const mockUseWorkflowStartRun = vi.hoisted(() => vi.fn()) +const mockUseOperator = vi.hoisted(() => vi.fn()) +const mockUseDSL = vi.hoisted(() => vi.fn()) + +vi.mock('ahooks', () => ({ + useClickAway: (...args: unknown[]) => mockUseClickAway(...args), +})) + +vi.mock('react-i18next', () => ({ + useTranslation: () => mockUseTranslation(), +})) + +vi.mock('@/app/components/workflow/store', () => ({ + useStore: (selector: (state: { + panelMenu?: { left: number, top: number } + clipboardElements: unknown[] + setShowImportDSLModal: (visible: boolean) => void + }) => unknown) => mockUseStore(selector), +})) + +vi.mock('@/app/components/workflow/hooks', () => ({ + useNodesInteractions: () => mockUseNodesInteractions(), + usePanelInteractions: () => mockUsePanelInteractions(), + useWorkflowStartRun: () => mockUseWorkflowStartRun(), + useDSL: () => mockUseDSL(), +})) + +vi.mock('@/app/components/workflow/operator/hooks', () => ({ + useOperator: () => mockUseOperator(), +})) + +vi.mock('@/app/components/workflow/operator/add-block', () => ({ + __esModule: true, + default: ({ renderTrigger }: { renderTrigger: () => ReactNode }) => ( +
{renderTrigger()}
+ ), +})) + +vi.mock('@/app/components/base/divider', () => ({ + __esModule: true, + default: ({ className }: { className?: string }) =>
, +})) + +vi.mock('@/app/components/workflow/shortcuts-name', () => ({ + __esModule: true, + default: ({ keys }: { keys: string[] }) => {keys.join('+')}, +})) + +describe('PanelContextmenu', () => { + const mockHandleNodesPaste = vi.fn() + const mockHandlePaneContextmenuCancel = vi.fn() + const mockHandleStartWorkflowRun = vi.fn() + const mockHandleAddNote = vi.fn() + const mockExportCheck = vi.fn() + const mockSetShowImportDSLModal = vi.fn() + let panelMenu: { left: number, top: number } | undefined + let clipboardElements: unknown[] + let clickAwayHandler: (() => void) | undefined + + beforeEach(() => { + vi.clearAllMocks() + panelMenu = undefined + clipboardElements = [] + clickAwayHandler = undefined + + mockUseClickAway.mockImplementation((handler: () => void) => { + clickAwayHandler = handler + }) + mockUseTranslation.mockReturnValue({ + t: (key: string) => key, + }) + mockUseStore.mockImplementation((selector: (state: { + panelMenu?: { left: number, top: number } + clipboardElements: unknown[] + setShowImportDSLModal: (visible: boolean) => void + }) => unknown) => selector({ + panelMenu, + clipboardElements, + setShowImportDSLModal: mockSetShowImportDSLModal, + })) + mockUseNodesInteractions.mockReturnValue({ + handleNodesPaste: mockHandleNodesPaste, + }) + mockUsePanelInteractions.mockReturnValue({ + handlePaneContextmenuCancel: mockHandlePaneContextmenuCancel, + }) + mockUseWorkflowStartRun.mockReturnValue({ + handleStartWorkflowRun: mockHandleStartWorkflowRun, + }) + mockUseOperator.mockReturnValue({ + handleAddNote: mockHandleAddNote, + }) + mockUseDSL.mockReturnValue({ + exportCheck: mockExportCheck, + }) + }) + + it('should stay hidden when the panel menu is absent', () => { + render() + + expect(screen.queryByTestId('add-block')).not.toBeInTheDocument() + }) + + it('should keep paste disabled when the clipboard is empty', () => { + panelMenu = { left: 24, top: 48 } + + render() + + fireEvent.click(screen.getByText('common.pasteHere')) + + expect(mockHandleNodesPaste).not.toHaveBeenCalled() + expect(mockHandlePaneContextmenuCancel).not.toHaveBeenCalled() + }) + + it('should render actions, position the menu, and execute each action', () => { + panelMenu = { left: 24, top: 48 } + clipboardElements = [{ id: 'copied-node' }] + const { container } = render() + + expect(screen.getByTestId('add-block')).toHaveTextContent('common.addBlock') + expect(screen.getByTestId('shortcut-alt-r')).toHaveTextContent('alt+r') + expect(screen.getByTestId('shortcut-ctrl-v')).toHaveTextContent('ctrl+v') + expect(container.firstChild).toHaveStyle({ + left: '24px', + top: '48px', + }) + + fireEvent.click(screen.getByText('nodes.note.addNote')) + fireEvent.click(screen.getByText('common.run')) + fireEvent.click(screen.getByText('common.pasteHere')) + fireEvent.click(screen.getByText('export')) + fireEvent.click(screen.getByText('common.importDSL')) + clickAwayHandler?.() + + expect(mockHandleAddNote).toHaveBeenCalledTimes(1) + expect(mockHandleStartWorkflowRun).toHaveBeenCalledTimes(1) + expect(mockHandleNodesPaste).toHaveBeenCalledTimes(1) + expect(mockExportCheck).toHaveBeenCalledTimes(1) + expect(mockSetShowImportDSLModal).toHaveBeenCalledWith(true) + expect(mockHandlePaneContextmenuCancel).toHaveBeenCalledTimes(4) + }) +}) diff --git a/web/app/components/workflow/__tests__/update-dsl-modal.spec.tsx b/web/app/components/workflow/__tests__/update-dsl-modal.spec.tsx index a85291128b..82645f2028 100644 --- a/web/app/components/workflow/__tests__/update-dsl-modal.spec.tsx +++ b/web/app/components/workflow/__tests__/update-dsl-modal.spec.tsx @@ -1,7 +1,7 @@ import type { EventEmitter } from 'ahooks/lib/useEventEmitter' import type { EventEmitterValue } from '@/context/event-emitter' import { fireEvent, render, screen, waitFor } from '@testing-library/react' -import { ToastContext } from '@/app/components/base/toast/context' +import { toast } from '@/app/components/base/ui/toast' import { EventEmitterContext } from '@/context/event-emitter' import { DSLImportStatus } from '@/models/app' import UpdateDSLModal from '../update-dsl-modal' @@ -16,10 +16,17 @@ class MockFileReader { } vi.stubGlobal('FileReader', MockFileReader as unknown as typeof FileReader) - -const mockNotify = vi.fn() const mockEmit = vi.fn() +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + error: vi.fn(), + info: vi.fn(), + success: vi.fn(), + warning: vi.fn(), + }, +})) + const mockImportDSL = vi.fn() const mockImportDSLConfirm = vi.fn() vi.mock('@/service/apps', () => ({ @@ -59,6 +66,7 @@ vi.mock('@/app/components/app/create-from-dsl-modal/uploader', () => ({ })) describe('UpdateDSLModal', () => { + const mockToastError = vi.mocked(toast.error) const defaultProps = { onCancel: vi.fn(), onBackup: vi.fn(), @@ -91,11 +99,9 @@ describe('UpdateDSLModal', () => { const eventEmitter = { emit: mockEmit } as unknown as EventEmitter return render( - - - - - , + + + , ) } @@ -152,9 +158,7 @@ describe('UpdateDSLModal', () => { fireEvent.click(screen.getByRole('button', { name: 'workflow.common.overwriteAndImport' })) await waitFor(() => { - expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ - type: 'error', - })) + expect(mockToastError).toHaveBeenCalled() }) }) @@ -233,9 +237,7 @@ describe('UpdateDSLModal', () => { fireEvent.click(screen.getByRole('button', { name: 'workflow.common.overwriteAndImport' })) await waitFor(() => { - expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ - type: 'error', - })) + expect(mockToastError).toHaveBeenCalled() }) expect(mockImportDSL).not.toHaveBeenCalled() @@ -254,9 +256,7 @@ describe('UpdateDSLModal', () => { fireEvent.click(screen.getByRole('button', { name: 'workflow.common.overwriteAndImport' })) await waitFor(() => { - expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ - type: 'error', - })) + expect(mockToastError).toHaveBeenCalled() }) }) @@ -274,9 +274,7 @@ describe('UpdateDSLModal', () => { fireEvent.click(screen.getByRole('button', { name: 'workflow.common.overwriteAndImport' })) await waitFor(() => { - expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ - type: 'error', - })) + expect(mockToastError).toHaveBeenCalled() }) }) @@ -305,9 +303,7 @@ describe('UpdateDSLModal', () => { fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Confirm' })) await waitFor(() => { - expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ - type: 'error', - })) + expect(mockToastError).toHaveBeenCalled() }) }) @@ -334,9 +330,7 @@ describe('UpdateDSLModal', () => { fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Confirm' })) await waitFor(() => { - expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ - type: 'error', - })) + expect(mockToastError).toHaveBeenCalled() }) }) @@ -365,9 +359,7 @@ describe('UpdateDSLModal', () => { fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Confirm' })) await waitFor(() => { - expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ - type: 'error', - })) + expect(mockToastError).toHaveBeenCalled() }) }) }) diff --git a/web/app/components/workflow/block-selector/__tests__/tool-picker.spec.tsx b/web/app/components/workflow/block-selector/__tests__/tool-picker.spec.tsx index 47ad2fad02..737481601c 100644 --- a/web/app/components/workflow/block-selector/__tests__/tool-picker.spec.tsx +++ b/web/app/components/workflow/block-selector/__tests__/tool-picker.spec.tsx @@ -114,9 +114,12 @@ vi.mock('@/service/use-tools', () => ({ useInvalidateAllMCPTools: vi.fn(), })) -vi.mock('@/app/components/base/toast', () => ({ - default: { - notify: (payload: unknown) => mockNotify(payload), +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + success: (message: string) => mockNotify({ type: 'success', message }), + error: (message: string) => mockNotify({ type: 'error', message }), + warning: (message: string) => mockNotify({ type: 'warning', message }), + info: (message: string) => mockNotify({ type: 'info', message }), }, })) diff --git a/web/app/components/workflow/block-selector/tool-picker.tsx b/web/app/components/workflow/block-selector/tool-picker.tsx index d9ce065dde..cf48488415 100644 --- a/web/app/components/workflow/block-selector/tool-picker.tsx +++ b/web/app/components/workflow/block-selector/tool-picker.tsx @@ -16,7 +16,7 @@ import { PortalToFollowElemContent, PortalToFollowElemTrigger, } from '@/app/components/base/portal-to-follow-elem' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import SearchBox from '@/app/components/plugins/marketplace/search-box' import EditCustomToolModal from '@/app/components/tools/edit-custom-collection-modal' import AllTools from '@/app/components/workflow/block-selector/all-tools' @@ -137,10 +137,7 @@ const ToolPicker: FC = ({ const doCreateCustomToolCollection = async (data: CustomCollectionBackend) => { await createCustomCollection(data) - Toast.notify({ - type: 'success', - message: t('api.actionSuccess', { ns: 'common' }), - }) + toast.success(t('api.actionSuccess', { ns: 'common' })) hideEditCustomCollectionModal() handleAddedCustomTool() } diff --git a/web/app/components/workflow/header/__tests__/header-layouts.spec.tsx b/web/app/components/workflow/header/__tests__/header-layouts.spec.tsx index dc00d61301..d092e769d6 100644 --- a/web/app/components/workflow/header/__tests__/header-layouts.spec.tsx +++ b/web/app/components/workflow/header/__tests__/header-layouts.spec.tsx @@ -60,9 +60,12 @@ vi.mock('@/service/use-workflow', () => ({ }), })) -vi.mock('../../../base/toast', () => ({ - default: { - notify: (payload: unknown) => mockNotify(payload), +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + success: (message: string) => mockNotify({ type: 'success', message }), + error: (message: string) => mockNotify({ type: 'error', message }), + warning: (message: string) => mockNotify({ type: 'warning', message }), + info: (message: string) => mockNotify({ type: 'info', message }), }, })) diff --git a/web/app/components/workflow/header/__tests__/run-mode.spec.tsx b/web/app/components/workflow/header/__tests__/run-mode.spec.tsx index cb5214544a..74dc529a62 100644 --- a/web/app/components/workflow/header/__tests__/run-mode.spec.tsx +++ b/web/app/components/workflow/header/__tests__/run-mode.spec.tsx @@ -46,10 +46,13 @@ vi.mock('../../hooks/use-dynamic-test-run-options', () => ({ useDynamicTestRunOptions: () => mockDynamicOptions, })) -vi.mock('@/app/components/base/toast/context', () => ({ - useToastContext: () => ({ - notify: mockNotify, - }), +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + success: (message: string) => mockNotify({ type: 'success', message }), + error: (message: string) => mockNotify({ type: 'error', message }), + warning: (message: string) => mockNotify({ type: 'warning', message }), + info: (message: string) => mockNotify({ type: 'info', message }), + }, })) vi.mock('@/app/components/base/amplitude', () => ({ diff --git a/web/app/components/workflow/header/header-in-restoring.tsx b/web/app/components/workflow/header/header-in-restoring.tsx index 2c5b4b9f08..d32e2c7fb9 100644 --- a/web/app/components/workflow/header/header-in-restoring.tsx +++ b/web/app/components/workflow/header/header-in-restoring.tsx @@ -4,11 +4,11 @@ import { } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' +import { toast } from '@/app/components/base/ui/toast' import useTheme from '@/hooks/use-theme' import { useInvalidAllLastRun, useRestoreWorkflow } from '@/service/use-workflow' import { getFlowPrefix } from '@/service/utils' import { cn } from '@/utils/classnames' -import Toast from '../../base/toast' import { useWorkflowRefreshDraft, useWorkflowRun, @@ -65,18 +65,12 @@ const HeaderInRestoring = ({ workflowStore.setState({ isRestoring: false }) workflowStore.setState({ backupDraft: undefined }) handleRefreshWorkflowDraft() - Toast.notify({ - type: 'success', - message: t('versionHistory.action.restoreSuccess', { ns: 'workflow' }), - }) + toast.success(t('versionHistory.action.restoreSuccess', { ns: 'workflow' })) deleteAllInspectVars() invalidAllLastRun() } catch { - Toast.notify({ - type: 'error', - message: t('versionHistory.action.restoreFailure', { ns: 'workflow' }), - }) + toast.error(t('versionHistory.action.restoreFailure', { ns: 'workflow' })) } finally { onRestoreSettled?.() diff --git a/web/app/components/workflow/header/run-mode.tsx b/web/app/components/workflow/header/run-mode.tsx index 86f998e0b7..8f802fcec5 100644 --- a/web/app/components/workflow/header/run-mode.tsx +++ b/web/app/components/workflow/header/run-mode.tsx @@ -5,7 +5,7 @@ import { useCallback, useEffect, useRef } from 'react' import { useTranslation } from 'react-i18next' import { trackEvent } from '@/app/components/base/amplitude' import { StopCircle } from '@/app/components/base/icons/src/vender/line/mediaAndDevices' -import { useToastContext } from '@/app/components/base/toast/context' +import { toast } from '@/app/components/base/ui/toast' import { useWorkflowRun, useWorkflowRunValidation, useWorkflowStartRun } from '@/app/components/workflow/hooks' import ShortcutsName from '@/app/components/workflow/shortcuts-name' import { useStore } from '@/app/components/workflow/store' @@ -41,7 +41,6 @@ const RunMode = ({ const dynamicOptions = useDynamicTestRunOptions() const testRunMenuRef = useRef(null) - const { notify } = useToastContext() useEffect(() => { // @ts-expect-error - Dynamic property for backward compatibility with keyboard shortcuts @@ -66,7 +65,7 @@ const RunMode = ({ isValid = false }) if (!isValid) { - notify({ type: 'error', message: t('panel.checklistTip', { ns: 'workflow' }) }) + toast.error(t('panel.checklistTip', { ns: 'workflow' })) return } @@ -98,7 +97,7 @@ const RunMode = ({ // Placeholder for trigger-specific execution logic for schedule, webhook, plugin types console.log('TODO: Handle trigger execution for type:', option.type, 'nodeId:', option.nodeId) } - }, [warningNodes, notify, t, handleWorkflowStartRunInWorkflow, handleWorkflowTriggerScheduleRunInWorkflow, handleWorkflowTriggerWebhookRunInWorkflow, handleWorkflowTriggerPluginRunInWorkflow, handleWorkflowRunAllTriggersInWorkflow]) + }, [warningNodes, t, handleWorkflowStartRunInWorkflow, handleWorkflowTriggerScheduleRunInWorkflow, handleWorkflowTriggerWebhookRunInWorkflow, handleWorkflowTriggerPluginRunInWorkflow, handleWorkflowRunAllTriggersInWorkflow]) const { eventEmitter } = useEventEmitterContextContext() eventEmitter?.useSubscription((v: any) => { diff --git a/web/app/components/workflow/help-line/__tests__/index.spec.tsx b/web/app/components/workflow/help-line/__tests__/index.spec.tsx new file mode 100644 index 0000000000..f58c9c5d02 --- /dev/null +++ b/web/app/components/workflow/help-line/__tests__/index.spec.tsx @@ -0,0 +1,61 @@ +import { render } from '@testing-library/react' +import HelpLine from '../index' + +const mockUseViewport = vi.hoisted(() => vi.fn()) +const mockUseStore = vi.hoisted(() => vi.fn()) + +vi.mock('reactflow', () => ({ + useViewport: () => mockUseViewport(), +})) + +vi.mock('@/app/components/workflow/store', () => ({ + useStore: (selector: (state: { + helpLineHorizontal?: { top: number, left: number, width: number } + helpLineVertical?: { top: number, left: number, height: number } + }) => unknown) => mockUseStore(selector), +})) + +describe('HelpLine', () => { + let helpLineHorizontal: { top: number, left: number, width: number } | undefined + let helpLineVertical: { top: number, left: number, height: number } | undefined + + beforeEach(() => { + vi.clearAllMocks() + helpLineHorizontal = undefined + helpLineVertical = undefined + + mockUseViewport.mockReturnValue({ x: 10, y: 20, zoom: 2 }) + mockUseStore.mockImplementation((selector: (state: { + helpLineHorizontal?: { top: number, left: number, width: number } + helpLineVertical?: { top: number, left: number, height: number } + }) => unknown) => selector({ + helpLineHorizontal, + helpLineVertical, + })) + }) + + it('should render nothing when both help lines are absent', () => { + const { container } = render() + + expect(container).toBeEmptyDOMElement() + }) + + it('should render the horizontal and vertical guide lines using viewport offsets and zoom', () => { + helpLineHorizontal = { top: 30, left: 40, width: 50 } + helpLineVertical = { top: 60, left: 70, height: 80 } + + const { container } = render() + const [horizontal, vertical] = Array.from(container.querySelectorAll('div')) + + expect(horizontal).toHaveStyle({ + top: '80px', + left: '90px', + width: '100px', + }) + expect(vertical).toHaveStyle({ + top: '140px', + left: '150px', + height: '160px', + }) + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-checklist.spec.ts b/web/app/components/workflow/hooks/__tests__/use-checklist.spec.ts index a11fe2c981..891007ff0e 100644 --- a/web/app/components/workflow/hooks/__tests__/use-checklist.spec.ts +++ b/web/app/components/workflow/hooks/__tests__/use-checklist.spec.ts @@ -89,8 +89,13 @@ vi.mock('../index', () => ({ useNodesMetaData: () => ({ nodes: [], nodesMap: mockNodesMap }), })) -vi.mock('@/app/components/base/toast/context', () => ({ - useToastContext: () => ({ notify: vi.fn() }), +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + success: vi.fn(), + error: vi.fn(), + warning: vi.fn(), + info: vi.fn(), + }, })) vi.mock('@/context/i18n', () => ({ diff --git a/web/app/components/workflow/hooks/__tests__/use-config-vision.spec.ts b/web/app/components/workflow/hooks/__tests__/use-config-vision.spec.ts new file mode 100644 index 0000000000..5811f14a60 --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-config-vision.spec.ts @@ -0,0 +1,171 @@ +import type { ModelConfig, VisionSetting } from '@/app/components/workflow/types' +import { act, renderHook } from '@testing-library/react' +import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import { Resolution } from '@/types/app' +import useConfigVision from '../use-config-vision' + +const mockUseTextGenerationCurrentProviderAndModelAndModelList = vi.hoisted(() => vi.fn()) +const mockUseIsChatMode = vi.hoisted(() => vi.fn()) + +vi.mock('@/app/components/header/account-setting/model-provider-page/hooks', () => ({ + useTextGenerationCurrentProviderAndModelAndModelList: (...args: unknown[]) => + mockUseTextGenerationCurrentProviderAndModelAndModelList(...args), +})) + +vi.mock('../use-workflow', () => ({ + useIsChatMode: () => mockUseIsChatMode(), +})) + +const createModel = (overrides: Partial = {}): ModelConfig => ({ + provider: 'openai', + name: 'gpt-4o', + mode: 'chat', + completion_params: [], + ...overrides, +}) + +const createVisionPayload = (overrides: Partial<{ enabled: boolean, configs?: VisionSetting }> = {}) => ({ + enabled: false, + ...overrides, +}) + +describe('useConfigVision', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseIsChatMode.mockReturnValue(false) + mockUseTextGenerationCurrentProviderAndModelAndModelList.mockReturnValue({ + currentModel: { + features: [], + }, + }) + }) + + it('should expose vision capability and enable default chat configs for vision models', () => { + const onChange = vi.fn() + mockUseIsChatMode.mockReturnValue(true) + mockUseTextGenerationCurrentProviderAndModelAndModelList.mockReturnValue({ + currentModel: { + features: [ModelFeatureEnum.vision], + }, + }) + + const { result } = renderHook(() => useConfigVision(createModel(), { + payload: createVisionPayload(), + onChange, + })) + + expect(result.current.isVisionModel).toBe(true) + + act(() => { + result.current.handleVisionResolutionEnabledChange(true) + }) + + expect(onChange).toHaveBeenCalledWith({ + enabled: true, + configs: { + detail: Resolution.high, + variable_selector: ['sys', 'files'], + }, + }) + }) + + it('should clear configs when disabling vision resolution', () => { + const onChange = vi.fn() + + const { result } = renderHook(() => useConfigVision(createModel(), { + payload: createVisionPayload({ + enabled: true, + configs: { + detail: Resolution.low, + variable_selector: ['node', 'files'], + }, + }), + onChange, + })) + + act(() => { + result.current.handleVisionResolutionEnabledChange(false) + }) + + expect(onChange).toHaveBeenCalledWith({ + enabled: false, + }) + }) + + it('should update the resolution config payload directly', () => { + const onChange = vi.fn() + const config: VisionSetting = { + detail: Resolution.low, + variable_selector: ['upstream', 'images'], + } + + const { result } = renderHook(() => useConfigVision(createModel(), { + payload: createVisionPayload({ enabled: true }), + onChange, + })) + + act(() => { + result.current.handleVisionResolutionChange(config) + }) + + expect(onChange).toHaveBeenCalledWith({ + enabled: true, + configs: config, + }) + }) + + it('should disable vision settings when the selected model is no longer a vision model', () => { + const onChange = vi.fn() + + const { result } = renderHook(() => useConfigVision(createModel(), { + payload: createVisionPayload({ + enabled: true, + configs: { + detail: Resolution.high, + variable_selector: ['sys', 'files'], + }, + }), + onChange, + })) + + act(() => { + result.current.handleModelChanged() + }) + + expect(onChange).toHaveBeenCalledWith({ + enabled: false, + }) + }) + + it('should reset enabled vision configs when the model changes but still supports vision', () => { + const onChange = vi.fn() + mockUseTextGenerationCurrentProviderAndModelAndModelList.mockReturnValue({ + currentModel: { + features: [ModelFeatureEnum.vision], + }, + }) + + const { result } = renderHook(() => useConfigVision(createModel(), { + payload: createVisionPayload({ + enabled: true, + configs: { + detail: Resolution.low, + variable_selector: ['old', 'files'], + }, + }), + onChange, + })) + + act(() => { + result.current.handleModelChanged() + }) + + expect(onChange).toHaveBeenCalledWith({ + enabled: true, + configs: { + detail: Resolution.high, + variable_selector: [], + }, + }) + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-dynamic-test-run-options.spec.tsx b/web/app/components/workflow/hooks/__tests__/use-dynamic-test-run-options.spec.tsx new file mode 100644 index 0000000000..d66e3ebe4a --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-dynamic-test-run-options.spec.tsx @@ -0,0 +1,146 @@ +import { renderHook } from '@testing-library/react' +import { BlockEnum } from '../../types' +import { useDynamicTestRunOptions } from '../use-dynamic-test-run-options' + +const mockUseTranslation = vi.hoisted(() => vi.fn()) +const mockUseNodes = vi.hoisted(() => vi.fn()) +const mockUseStore = vi.hoisted(() => vi.fn()) +const mockUseAllTriggerPlugins = vi.hoisted(() => vi.fn()) +const mockGetWorkflowEntryNode = vi.hoisted(() => vi.fn()) + +vi.mock('react-i18next', () => ({ + useTranslation: () => mockUseTranslation(), +})) + +vi.mock('@/app/components/workflow/store/workflow/use-nodes', () => ({ + __esModule: true, + default: () => mockUseNodes(), +})) + +vi.mock('@/app/components/workflow/store', () => ({ + useStore: (selector: (state: { + buildInTools: unknown[] + customTools: unknown[] + workflowTools: unknown[] + mcpTools: unknown[] + }) => unknown) => mockUseStore(selector), +})) + +vi.mock('@/service/use-triggers', () => ({ + useAllTriggerPlugins: () => mockUseAllTriggerPlugins(), +})) + +vi.mock('@/app/components/workflow/utils/workflow-entry', () => ({ + getWorkflowEntryNode: (...args: unknown[]) => mockGetWorkflowEntryNode(...args), +})) + +describe('useDynamicTestRunOptions', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseTranslation.mockReturnValue({ + t: (key: string) => key, + }) + mockUseStore.mockImplementation((selector: (state: { + buildInTools: unknown[] + customTools: unknown[] + workflowTools: unknown[] + mcpTools: unknown[] + }) => unknown) => selector({ + buildInTools: [], + customTools: [], + workflowTools: [], + mcpTools: [], + })) + mockUseAllTriggerPlugins.mockReturnValue({ + data: [{ + name: 'plugin-provider', + icon: '/plugin-icon.png', + }], + }) + }) + + it('should build user input, trigger options, and a run-all option from workflow nodes', () => { + mockUseNodes.mockReturnValue([ + { + id: 'start-1', + data: { type: BlockEnum.Start, title: 'User Input' }, + }, + { + id: 'schedule-1', + data: { type: BlockEnum.TriggerSchedule, title: 'Daily Schedule' }, + }, + { + id: 'webhook-1', + data: { type: BlockEnum.TriggerWebhook, title: 'Webhook Trigger' }, + }, + { + id: 'plugin-1', + data: { + type: BlockEnum.TriggerPlugin, + title: '', + plugin_name: 'Plugin Trigger', + provider_id: 'plugin-provider', + }, + }, + ]) + + const { result } = renderHook(() => useDynamicTestRunOptions()) + + expect(result.current.userInput).toEqual(expect.objectContaining({ + id: 'start-1', + type: 'user_input', + name: 'User Input', + nodeId: 'start-1', + enabled: true, + })) + expect(result.current.triggers).toEqual([ + expect.objectContaining({ + id: 'schedule-1', + type: 'schedule', + name: 'Daily Schedule', + nodeId: 'schedule-1', + }), + expect.objectContaining({ + id: 'webhook-1', + type: 'webhook', + name: 'Webhook Trigger', + nodeId: 'webhook-1', + }), + expect.objectContaining({ + id: 'plugin-1', + type: 'plugin', + name: 'Plugin Trigger', + nodeId: 'plugin-1', + }), + ]) + expect(result.current.runAll).toEqual(expect.objectContaining({ + id: 'run-all', + type: 'all', + relatedNodeIds: ['schedule-1', 'webhook-1', 'plugin-1'], + })) + }) + + it('should fall back to the workflow entry node and omit run-all when only one trigger exists', () => { + mockUseNodes.mockReturnValue([ + { + id: 'webhook-1', + data: { type: BlockEnum.TriggerWebhook, title: 'Webhook Trigger' }, + }, + ]) + mockGetWorkflowEntryNode.mockReturnValue({ + id: 'fallback-start', + data: { type: BlockEnum.Start, title: '' }, + }) + + const { result } = renderHook(() => useDynamicTestRunOptions()) + + expect(result.current.userInput).toEqual(expect.objectContaining({ + id: 'fallback-start', + type: 'user_input', + name: 'blocks.start', + nodeId: 'fallback-start', + })) + expect(result.current.triggers).toHaveLength(1) + expect(result.current.runAll).toBeUndefined() + }) +}) diff --git a/web/app/components/workflow/hooks/use-checklist.ts b/web/app/components/workflow/hooks/use-checklist.ts index 029892c4d1..99536653ce 100644 --- a/web/app/components/workflow/hooks/use-checklist.ts +++ b/web/app/components/workflow/hooks/use-checklist.ts @@ -27,7 +27,7 @@ import { import { useTranslation } from 'react-i18next' import { useEdges, useStoreApi } from 'reactflow' import { useStore as useAppStore } from '@/app/components/app/store' -import { useToastContext } from '@/app/components/base/toast/context' +import { toast } from '@/app/components/base/ui/toast' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { useModelList } from '@/app/components/header/account-setting/model-provider-page/hooks' import useNodes from '@/app/components/workflow/store/workflow/use-nodes' @@ -325,7 +325,6 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => { export const useChecklistBeforePublish = () => { const { t } = useTranslation() const language = useGetLanguage() - const { notify } = useToastContext() const queryClient = useQueryClient() const store = useStoreApi() const { nodesMap: nodesExtraData } = useNodesMetaData() @@ -390,7 +389,7 @@ export const useChecklistBeforePublish = () => { const { validNodes, maxDepth } = getValidTreeNodes(filteredNodes, edges) if (maxDepth > MAX_TREE_DEPTH) { - notify({ type: 'error', message: t('common.maxTreeDepth', { ns: 'workflow', depth: MAX_TREE_DEPTH }) }) + toast.error(t('common.maxTreeDepth', { ns: 'workflow', depth: MAX_TREE_DEPTH })) return false } @@ -488,7 +487,7 @@ export const useChecklistBeforePublish = () => { isModelProviderInstalled: isLLMModelProviderInstalled(modelProvider, installedPluginIds), }) if (modelIssue === LLMModelIssueCode.providerPluginUnavailable) { - notify({ type: 'error', message: `[${node.data.title}] ${t('errorMsg.configureModel', { ns: 'workflow' })}` }) + toast.error(`[${node.data.title}] ${t('errorMsg.configureModel', { ns: 'workflow' })}`) return false } } @@ -497,7 +496,7 @@ export const useChecklistBeforePublish = () => { const { errorMessage } = nodesExtraData![node.data.type as BlockEnum].checkValid(checkData, t, moreDataForCheckValid) if (errorMessage) { - notify({ type: 'error', message: `[${node.data.title}] ${errorMessage}` }) + toast.error(`[${node.data.title}] ${errorMessage}`) return false } @@ -510,12 +509,12 @@ export const useChecklistBeforePublish = () => { if (usedNode) { const usedVar = usedNode.vars.find(v => v.variable === variable?.[1]) if (!usedVar) { - notify({ type: 'error', message: `[${node.data.title}] ${t('errorMsg.invalidVariable', { ns: 'workflow' })}` }) + toast.error(`[${node.data.title}] ${t('errorMsg.invalidVariable', { ns: 'workflow' })}`) return false } } else { - notify({ type: 'error', message: `[${node.data.title}] ${t('errorMsg.invalidVariable', { ns: 'workflow' })}` }) + toast.error(`[${node.data.title}] ${t('errorMsg.invalidVariable', { ns: 'workflow' })}`) return false } } @@ -526,7 +525,7 @@ export const useChecklistBeforePublish = () => { const isUnconnected = !validNodes.some(n => n.id === node.id) if (isUnconnected && !canSkipConnectionCheck) { - notify({ type: 'error', message: `[${node.data.title}] ${t('common.needConnectTip', { ns: 'workflow' })}` }) + toast.error(`[${node.data.title}] ${t('common.needConnectTip', { ns: 'workflow' })}`) return false } } @@ -534,7 +533,7 @@ export const useChecklistBeforePublish = () => { if (shouldCheckStartNode) { const startNodesFiltered = nodes.filter(node => START_NODE_TYPES.includes(node.data.type as BlockEnum)) if (startNodesFiltered.length === 0) { - notify({ type: 'error', message: t('common.needStartNode', { ns: 'workflow' }) }) + toast.error(t('common.needStartNode', { ns: 'workflow' })) return false } } @@ -545,13 +544,13 @@ export const useChecklistBeforePublish = () => { const type = isRequiredNodesType[i] if (!filteredNodes.some(node => node.data.type === type)) { - notify({ type: 'error', message: t('common.needAdd', { ns: 'workflow', node: t(`blocks.${type}` as I18nKeysWithPrefix<'workflow', 'blocks.'>, { ns: 'workflow' }) }) }) + toast.error(t('common.needAdd', { ns: 'workflow', node: t(`blocks.${type}` as I18nKeysWithPrefix<'workflow', 'blocks.'>, { ns: 'workflow' }) })) return false } } return true - }, [store, workflowStore, getNodesAvailableVarList, shouldCheckStartNode, nodesExtraData, notify, t, updateDatasetsDetail, buildInTools, customTools, workflowTools, language, getCheckData, queryClient, strategyProviders, modelProviders]) + }, [store, workflowStore, getNodesAvailableVarList, shouldCheckStartNode, nodesExtraData, t, updateDatasetsDetail, buildInTools, customTools, workflowTools, language, getCheckData, queryClient, strategyProviders, modelProviders]) return { handleCheckBeforePublish, @@ -563,15 +562,14 @@ export const useWorkflowRunValidation = () => { const nodes = useNodes() const edges = useEdges() const needWarningNodes = useChecklist(nodes, edges) - const { notify } = useToastContext() const validateBeforeRun = useCallback(() => { if (needWarningNodes.length > 0) { - notify({ type: 'error', message: t('panel.checklistTip', { ns: 'workflow' }) }) + toast.error(t('panel.checklistTip', { ns: 'workflow' })) return false } return true - }, [needWarningNodes, notify, t]) + }, [needWarningNodes, t]) return { validateBeforeRun, diff --git a/web/app/components/workflow/hooks/use-nodes-interactions.ts b/web/app/components/workflow/hooks/use-nodes-interactions.ts index cd35d2310f..8de86edecb 100644 --- a/web/app/components/workflow/hooks/use-nodes-interactions.ts +++ b/web/app/components/workflow/hooks/use-nodes-interactions.ts @@ -1822,6 +1822,8 @@ export const useNodesInteractions = () => { else { // single node paste const selectedNode = nodes.find(node => node.selected) + let pastedToNestedBlock = false + if (selectedNode) { const commonNestedDisallowPasteNodes = [ // end node only can be placed outermost layer @@ -1849,10 +1851,24 @@ export const useNodesInteractions = () => { } // set position base on parent node newNode.position = getNestedNodePosition(newNode, selectedNode) + // update parent children array like native add parentChildrenToAppend.push({ parentId: selectedNode.id, childId: newNode.id, childType: newNode.data.type }) + + pastedToNestedBlock = true } } + + // Clear loop/iteration metadata when pasting outside nested blocks (fixes #29835) + // This ensures nodes copied from inside Loop/Iteration are properly independent + // when pasted outside + if (!pastedToNestedBlock) { + newNode.data.isInLoop = false + newNode.data.loop_id = undefined + newNode.data.isInIteration = false + newNode.data.iteration_id = undefined + newNode.parentId = undefined + } } idMapping[nodeToPaste.id] = newNode.id diff --git a/web/app/components/workflow/nodes/_base/components/__tests__/file-support.spec.tsx b/web/app/components/workflow/nodes/_base/components/__tests__/file-support.spec.tsx index ffe1e80bb0..b58b045f92 100644 --- a/web/app/components/workflow/nodes/_base/components/__tests__/file-support.spec.tsx +++ b/web/app/components/workflow/nodes/_base/components/__tests__/file-support.spec.tsx @@ -19,11 +19,13 @@ vi.mock('@/app/components/base/file-uploader/hooks', () => ({ useFileSizeLimit: vi.fn(), })) -vi.mock('@/app/components/base/toast/context', () => ({ - useToastContext: () => ({ - notify: vi.fn(), - close: vi.fn(), - }), +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + success: vi.fn(), + error: vi.fn(), + warning: vi.fn(), + info: vi.fn(), + }, })) const createPayload = (overrides: Partial = {}): UploadFileSetting => ({ diff --git a/web/app/components/workflow/nodes/_base/components/before-run-form/__tests__/index.spec.tsx b/web/app/components/workflow/nodes/_base/components/before-run-form/__tests__/index.spec.tsx index 6ed8210721..a8837f6392 100644 --- a/web/app/components/workflow/nodes/_base/components/before-run-form/__tests__/index.spec.tsx +++ b/web/app/components/workflow/nodes/_base/components/before-run-form/__tests__/index.spec.tsx @@ -1,10 +1,16 @@ import type { Props as FormProps } from '../form' import type { BeforeRunFormProps } from '../index' import { fireEvent, render, screen } from '@testing-library/react' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { BlockEnum, InputVarType } from '@/app/components/workflow/types' import BeforeRunForm from '../index' +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + error: vi.fn(), + }, +})) + vi.mock('../form', () => ({ default: ({ values }: { values: Record }) =>
{Object.keys(values).join(',')}
, })) @@ -29,6 +35,8 @@ vi.mock('@/app/components/workflow/nodes/human-input/components/single-run-form' })) describe('BeforeRunForm', () => { + const mockToastError = vi.mocked(toast.error) + const createForm = (form: Partial): FormProps => ({ inputs: [], values: {}, @@ -66,8 +74,6 @@ describe('BeforeRunForm', () => { }) it('should show an error toast when required fields are missing', () => { - const notifySpy = vi.spyOn(Toast, 'notify').mockImplementation(vi.fn()) - render( { fireEvent.click(screen.getByRole('button', { name: 'workflow.singleRun.startRun' })) - expect(notifySpy).toHaveBeenCalledWith(expect.objectContaining({ - type: 'error', - })) + expect(mockToastError).toHaveBeenCalled() }) it('should generate the human input form instead of running immediately', () => { @@ -199,8 +203,6 @@ describe('BeforeRunForm', () => { }) it('should show an error toast when json input is invalid', () => { - const notifySpy = vi.spyOn(Toast, 'notify').mockImplementation(vi.fn()) - render( { fireEvent.click(screen.getByRole('button', { name: 'workflow.singleRun.startRun' })) - expect(notifySpy).toHaveBeenCalledWith(expect.objectContaining({ - type: 'error', - })) + expect(mockToastError).toHaveBeenCalled() }) }) diff --git a/web/app/components/workflow/nodes/_base/components/before-run-form/index.tsx b/web/app/components/workflow/nodes/_base/components/before-run-form/index.tsx index 0e414f70a5..be29fbbc22 100644 --- a/web/app/components/workflow/nodes/_base/components/before-run-form/index.tsx +++ b/web/app/components/workflow/nodes/_base/components/before-run-form/index.tsx @@ -9,7 +9,7 @@ import * as React from 'react' import { useEffect, useRef } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import Split from '@/app/components/workflow/nodes/_base/components/split' import SingleRunForm from '@/app/components/workflow/nodes/human-input/components/single-run-form' import { BlockEnum } from '@/app/components/workflow/types' @@ -71,19 +71,13 @@ const BeforeRunForm: FC = ({ const handleRunOrGenerateForm = () => { const errMsg = getFormErrorMessage(forms, existVarValuesInForms, t) if (errMsg) { - Toast.notify({ - message: errMsg, - type: 'error', - }) + toast.error(errMsg) return } const { submitData, parseErrorJsonField } = buildSubmitData(forms) if (parseErrorJsonField) { - Toast.notify({ - message: t('errorMsg.invalidJson', { ns: 'workflow', field: parseErrorJsonField }), - type: 'error', - }) + toast.error(t('errorMsg.invalidJson', { ns: 'workflow', field: parseErrorJsonField })) return } diff --git a/web/app/components/workflow/nodes/_base/components/workflow-panel/last-run/__tests__/index.spec.tsx b/web/app/components/workflow/nodes/_base/components/workflow-panel/last-run/__tests__/index.spec.tsx new file mode 100644 index 0000000000..91d346abc9 --- /dev/null +++ b/web/app/components/workflow/nodes/_base/components/workflow-panel/last-run/__tests__/index.spec.tsx @@ -0,0 +1,235 @@ +import { act, render, screen } from '@testing-library/react' +import { NodeRunningStatus } from '@/app/components/workflow/types' +import LastRun from '../index' + +const mockUseHooksStore = vi.hoisted(() => vi.fn()) +const mockUseLastRun = vi.hoisted(() => vi.fn()) +const mockResultPanel = vi.hoisted(() => vi.fn()) + +vi.mock('@remixicon/react', () => ({ + RiLoader2Line: () =>
, +})) + +vi.mock('@/app/components/workflow/hooks-store', () => ({ + useHooksStore: (selector: (state: { + configsMap?: { flowType?: string, flowId?: string } + }) => unknown) => mockUseHooksStore(selector), +})) + +vi.mock('@/service/use-workflow', () => ({ + useLastRun: (...args: unknown[]) => mockUseLastRun(...args), +})) + +vi.mock('@/app/components/workflow/run/result-panel', () => ({ + __esModule: true, + default: (props: Record) => { + mockResultPanel(props) + return
{String(props.status)}
+ }, +})) + +vi.mock('../no-data', () => ({ + __esModule: true, + default: ({ onSingleRun }: { onSingleRun: () => void }) => ( + + ), +})) + +describe('LastRun', () => { + const updateNodeRunningStatus = vi.fn() + const onSingleRunClicked = vi.fn() + let visibilityState = 'visible' + + beforeEach(() => { + vi.clearAllMocks() + mockUseHooksStore.mockImplementation((selector: (state: { + configsMap?: { flowType?: string, flowId?: string } + }) => unknown) => selector({ + configsMap: { + flowType: 'appFlow', + flowId: 'flow-1', + }, + })) + mockUseLastRun.mockReturnValue({ + data: undefined, + isFetching: false, + error: undefined, + }) + visibilityState = 'visible' + Object.defineProperty(document, 'visibilityState', { + configurable: true, + get: () => visibilityState, + }) + }) + + it('should show a loader while fetching the last run before any single run starts', () => { + mockUseLastRun.mockReturnValue({ + data: undefined, + isFetching: true, + error: undefined, + }) + + render( + , + ) + + expect(screen.getByTestId('loading-icon')).toBeInTheDocument() + expect(screen.queryByTestId('result-panel')).not.toBeInTheDocument() + }) + + it('should show a running result panel while a single run is still executing', () => { + render( + , + ) + + expect(screen.getByTestId('result-panel')).toHaveTextContent('running') + expect(mockResultPanel).toHaveBeenCalledWith(expect.objectContaining({ + status: 'running', + showSteps: false, + })) + }) + + it('should render the no-data state for 404 last-run responses and forward single-run clicks', () => { + mockUseLastRun.mockReturnValue({ + data: undefined, + isFetching: false, + error: { status: 404 }, + }) + + render( + , + ) + + act(() => { + screen.getByText('no-data').click() + }) + + expect(onSingleRunClicked).toHaveBeenCalledTimes(1) + }) + + it('should render resolved result data and let paused state override the final status', () => { + mockUseLastRun.mockReturnValue({ + data: { + status: NodeRunningStatus.Succeeded, + execution_metadata: { total_tokens: 9 }, + created_by_account: { created_by: 'Alice' }, + }, + isFetching: false, + error: undefined, + }) + + render( + , + ) + + expect(screen.getByTestId('result-panel')).toHaveTextContent(NodeRunningStatus.Stopped) + expect(mockResultPanel).toHaveBeenCalledWith(expect.objectContaining({ + status: NodeRunningStatus.Stopped, + total_tokens: 9, + created_by: 'Alice', + showSteps: false, + })) + }) + + it('should respect stopped and listening one-step statuses', () => { + mockUseLastRun.mockReturnValue({ + data: { + status: NodeRunningStatus.Succeeded, + }, + isFetching: false, + error: undefined, + }) + + const { rerender } = render( + , + ) + + expect(screen.getByTestId('result-panel')).toHaveTextContent(NodeRunningStatus.Stopped) + + rerender( + , + ) + + expect(screen.getByTestId('result-panel')).toHaveTextContent(NodeRunningStatus.Listening) + }) + + it('should react to page visibility changes while keeping the current result rendered', () => { + mockUseLastRun.mockReturnValue({ + data: { + status: NodeRunningStatus.Succeeded, + }, + isFetching: false, + error: undefined, + }) + + render( + , + ) + + act(() => { + visibilityState = 'hidden' + document.dispatchEvent(new Event('visibilitychange')) + visibilityState = 'visible' + document.dispatchEvent(new Event('visibilitychange')) + }) + + expect(screen.getByTestId('result-panel')).toHaveTextContent(NodeRunningStatus.Succeeded) + }) +}) diff --git a/web/app/components/workflow/nodes/_base/components/workflow-panel/last-run/use-last-run.ts b/web/app/components/workflow/nodes/_base/components/workflow-panel/last-run/use-last-run.ts index db7833af2b..d58c787bd8 100644 --- a/web/app/components/workflow/nodes/_base/components/workflow-panel/last-run/use-last-run.ts +++ b/web/app/components/workflow/nodes/_base/components/workflow-panel/last-run/use-last-run.ts @@ -3,7 +3,7 @@ import type { Params as OneStepRunParams } from '@/app/components/workflow/nodes // import import type { CommonNodeType, ValueSelector } from '@/app/components/workflow/types' import { useCallback, useEffect, useState } from 'react' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { useNodesSyncDraft, } from '@/app/components/workflow/hooks' @@ -163,7 +163,7 @@ const useLastRun = ({ return false const message = warningForNode.errorMessages[0] || 'This node has unresolved checklist issues' - Toast.notify({ type: 'error', message }) + toast.error(message) return true }, [warningNodes, id]) diff --git a/web/app/components/workflow/nodes/_base/hooks/use-one-step-run.ts b/web/app/components/workflow/nodes/_base/hooks/use-one-step-run.ts index 06843eacef..c634fd92f4 100644 --- a/web/app/components/workflow/nodes/_base/hooks/use-one-step-run.ts +++ b/web/app/components/workflow/nodes/_base/hooks/use-one-step-run.ts @@ -12,7 +12,7 @@ import { } from 'reactflow' import { trackEvent } from '@/app/components/base/amplitude' import { getInputVars as doGetInputVars } from '@/app/components/base/prompt-editor/constants' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { useIsChatMode, useNodeDataUpdate, @@ -410,14 +410,14 @@ const useOneStepRun = ({ }) if (!response) { - const message = 'Schedule trigger run failed' - Toast.notify({ type: 'error', message }) + const message = t('common.scheduleTriggerRunFailed', { ns: 'workflow' }) + toast.error(message) throw new Error(message) } if (response?.status === 'error') { - const message = response?.message || 'Schedule trigger run failed' - Toast.notify({ type: 'error', message }) + const message = response?.message || t('common.scheduleTriggerRunFailed', { ns: 'workflow' }) + toast.error(message) throw new Error(message) } @@ -442,10 +442,10 @@ const useOneStepRun = ({ _singleRunningStatus: NodeRunningStatus.Failed, }, }) - Toast.notify({ type: 'error', message: 'Schedule trigger run failed' }) + toast.error(t('common.scheduleTriggerRunFailed', { ns: 'workflow' })) throw error } - }, [flowId, id, handleNodeDataUpdate, data]) + }, [flowId, id, handleNodeDataUpdate, data, t]) const runWebhookSingleRun = useCallback(async (): Promise => { const urlPath = `/apps/${flowId}/workflows/draft/nodes/${id}/trigger/run` @@ -467,8 +467,8 @@ const useOneStepRun = ({ return null if (!response) { - const message = response?.message || 'Webhook debug failed' - Toast.notify({ type: 'error', message }) + const message = response?.message || t('common.webhookDebugFailed', { ns: 'workflow' }) + toast.error(message) cancelWebhookSingleRun() throw new Error(message) } @@ -495,8 +495,8 @@ const useOneStepRun = ({ } if (response?.status === 'error') { - const message = response.message || 'Webhook debug failed' - Toast.notify({ type: 'error', message }) + const message = response.message || t('common.webhookDebugFailed', { ns: 'workflow' }) + toast.error(message) cancelWebhookSingleRun() throw new Error(message) } @@ -519,7 +519,7 @@ const useOneStepRun = ({ if (controller.signal.aborted) return null - Toast.notify({ type: 'error', message: 'Webhook debug request failed' }) + toast.error(t('common.webhookDebugRequestFailed', { ns: 'workflow' })) cancelWebhookSingleRun() if (error instanceof Error) throw error @@ -531,7 +531,7 @@ const useOneStepRun = ({ } return null - }, [flowId, id, data, handleNodeDataUpdate, cancelWebhookSingleRun]) + }, [flowId, id, data, handleNodeDataUpdate, cancelWebhookSingleRun, t]) const runPluginSingleRun = useCallback(async (): Promise => { const urlPath = `/apps/${flowId}/workflows/draft/nodes/${id}/trigger/run` @@ -566,14 +566,14 @@ const useOneStepRun = ({ if (controller.signal.aborted) return null - Toast.notify({ type: 'error', message: requestError.message }) + toast.error(requestError.message) cancelPluginSingleRun() throw requestError } if (!response) { const message = 'Plugin debug failed' - Toast.notify({ type: 'error', message }) + toast.error(message) cancelPluginSingleRun() throw new Error(message) } @@ -600,7 +600,7 @@ const useOneStepRun = ({ if (response?.status === 'error') { const message = response.message || 'Plugin debug failed' - Toast.notify({ type: 'error', message }) + toast.error(message) cancelPluginSingleRun() throw new Error(message) } @@ -633,10 +633,8 @@ const useOneStepRun = ({ _isSingleRun: false, }, }) - Toast.notify({ - type: 'error', - message: res.errorMessage || '', - }) + if (res.errorMessage) + toast.error(res.errorMessage) } return res } diff --git a/web/app/components/workflow/nodes/data-source/hooks/__tests__/use-config.spec.ts b/web/app/components/workflow/nodes/data-source/hooks/__tests__/use-config.spec.ts new file mode 100644 index 0000000000..6d009ba60b --- /dev/null +++ b/web/app/components/workflow/nodes/data-source/hooks/__tests__/use-config.spec.ts @@ -0,0 +1,139 @@ +import type { DataSourceNodeType } from '../../types' +import { renderHook } from '@testing-library/react' +import { VarType as VarKindType } from '../../types' +import { useConfig } from '../use-config' + +const mockUseStoreApi = vi.hoisted(() => vi.fn()) +const mockUseNodeDataUpdate = vi.hoisted(() => vi.fn()) + +vi.mock('reactflow', () => ({ + useStoreApi: () => mockUseStoreApi(), +})) + +vi.mock('@/app/components/workflow/hooks', () => ({ + useNodeDataUpdate: () => mockUseNodeDataUpdate(), +})) + +const createNode = (overrides: Partial = {}): { id: string, data: DataSourceNodeType } => ({ + id: 'data-source-node', + data: { + title: 'Datasource', + desc: '', + type: 'data-source', + plugin_id: 'plugin-1', + provider_type: 'local_file', + provider_name: 'provider', + datasource_name: 'source-a', + datasource_label: 'Source A', + datasource_parameters: {}, + datasource_configurations: {}, + _dataSourceStartToAdd: true, + ...overrides, + } as DataSourceNodeType, +}) + +describe('data-source/hooks/use-config', () => { + const mockHandleNodeDataUpdateWithSyncDraft = vi.fn() + let currentNode = createNode() + + beforeEach(() => { + vi.clearAllMocks() + currentNode = createNode() + + mockUseStoreApi.mockReturnValue({ + getState: () => ({ + getNodes: () => [currentNode], + }), + }) + mockUseNodeDataUpdate.mockReturnValue({ + handleNodeDataUpdateWithSyncDraft: mockHandleNodeDataUpdateWithSyncDraft, + }) + }) + + it('should clear the local-file auto-add flag on mount and update datasource payloads', () => { + const { result } = renderHook(() => useConfig('data-source-node')) + + expect(mockHandleNodeDataUpdateWithSyncDraft).toHaveBeenCalledWith({ + id: 'data-source-node', + data: expect.objectContaining({ + _dataSourceStartToAdd: false, + }), + }) + + mockHandleNodeDataUpdateWithSyncDraft.mockClear() + result.current.handleFileExtensionsChange(['pdf', 'csv']) + result.current.handleParametersChange({ + dataset: { + type: VarKindType.constant, + value: 'docs', + }, + }) + + expect(mockHandleNodeDataUpdateWithSyncDraft).toHaveBeenNthCalledWith(1, { + id: 'data-source-node', + data: expect.objectContaining({ + fileExtensions: ['pdf', 'csv'], + }), + }) + expect(mockHandleNodeDataUpdateWithSyncDraft).toHaveBeenNthCalledWith(2, { + id: 'data-source-node', + data: expect.objectContaining({ + datasource_parameters: { + dataset: { + type: VarKindType.constant, + value: 'docs', + }, + }, + }), + }) + }) + + it('should derive output schema metadata and detect object outputs', () => { + const dataSourceList = [{ + plugin_id: 'plugin-1', + tools: [{ + name: 'source-a', + output_schema: { + properties: { + items: { + type: 'array', + items: { type: 'string' }, + description: 'List of items', + }, + metadata: { + type: 'object', + description: 'Object field', + }, + count: { + type: 'number', + description: 'Total count', + }, + }, + }, + }], + }] + + const { result } = renderHook(() => useConfig('data-source-node', dataSourceList)) + + expect(result.current.outputSchema).toEqual([ + { + name: 'items', + type: 'Array[String]', + description: 'List of items', + }, + { + name: 'metadata', + value: { + type: 'object', + description: 'Object field', + }, + }, + { + name: 'count', + type: 'Number', + description: 'Total count', + }, + ]) + expect(result.current.hasObjectOutput).toBe(true) + }) +}) diff --git a/web/app/components/workflow/nodes/http/components/__tests__/curl-panel.spec.tsx b/web/app/components/workflow/nodes/http/components/__tests__/curl-panel.spec.tsx index 1d11b9b882..f42e98f605 100644 --- a/web/app/components/workflow/nodes/http/components/__tests__/curl-panel.spec.tsx +++ b/web/app/components/workflow/nodes/http/components/__tests__/curl-panel.spec.tsx @@ -1,15 +1,16 @@ import { render, screen } from '@testing-library/react' import userEvent from '@testing-library/user-event' +import { toast } from '@/app/components/base/ui/toast' import { BodyPayloadValueType, BodyType } from '../../types' import CurlPanel from '../curl-panel' import * as curlParser from '../curl-parser' const { mockHandleNodeSelect, - mockNotify, + mockToastError, } = vi.hoisted(() => ({ mockHandleNodeSelect: vi.fn(), - mockNotify: vi.fn(), + mockToastError: vi.fn(), })) vi.mock('@/app/components/workflow/hooks', () => ({ @@ -18,9 +19,9 @@ vi.mock('@/app/components/workflow/hooks', () => ({ }), })) -vi.mock('@/app/components/base/toast', () => ({ - default: { - notify: mockNotify, +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + error: mockToastError, }, })) @@ -131,9 +132,7 @@ describe('curl-panel', () => { await user.type(screen.getByRole('textbox'), 'invalid') await user.click(screen.getByRole('button', { name: 'common.operation.save' })) - expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ - type: 'error', - })) + expect(vi.mocked(toast.error)).toHaveBeenCalledWith(expect.stringContaining('Invalid cURL command')) }) it('should keep the panel open when parsing returns no node and no error', async () => { @@ -159,7 +158,7 @@ describe('curl-panel', () => { expect(onHide).not.toHaveBeenCalled() expect(handleCurlImport).not.toHaveBeenCalled() expect(mockHandleNodeSelect).not.toHaveBeenCalled() - expect(mockNotify).not.toHaveBeenCalled() + expect(vi.mocked(toast.error)).not.toHaveBeenCalled() }) }) }) diff --git a/web/app/components/workflow/nodes/http/components/curl-panel.tsx b/web/app/components/workflow/nodes/http/components/curl-panel.tsx index 7b6a26cc29..b08d3f0a7f 100644 --- a/web/app/components/workflow/nodes/http/components/curl-panel.tsx +++ b/web/app/components/workflow/nodes/http/components/curl-panel.tsx @@ -7,7 +7,7 @@ import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Modal from '@/app/components/base/modal' import Textarea from '@/app/components/base/textarea' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { useNodesInteractions } from '@/app/components/workflow/hooks' import { parseCurl } from './curl-parser' @@ -26,10 +26,7 @@ const CurlPanel: FC = ({ nodeId, isShow, onHide, handleCurlImport }) => { const handleSave = useCallback(() => { const { node, error } = parseCurl(inputString) if (error) { - Toast.notify({ - type: 'error', - message: error, - }) + toast.error(error) return } if (!node) diff --git a/web/app/components/workflow/nodes/human-input/components/__tests__/button-style-dropdown.spec.tsx b/web/app/components/workflow/nodes/human-input/components/__tests__/button-style-dropdown.spec.tsx new file mode 100644 index 0000000000..b7b0229424 --- /dev/null +++ b/web/app/components/workflow/nodes/human-input/components/__tests__/button-style-dropdown.spec.tsx @@ -0,0 +1,149 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import * as React from 'react' +import { UserActionButtonType } from '../../types' +import ButtonStyleDropdown from '../button-style-dropdown' + +const mockUseTranslation = vi.hoisted(() => vi.fn()) +const mockButton = vi.hoisted(() => vi.fn()) + +vi.mock('react-i18next', () => ({ + useTranslation: () => mockUseTranslation(), +})) + +vi.mock('@/app/components/base/button', () => ({ + __esModule: true, + default: (props: { + variant?: string + children?: React.ReactNode + className?: string + }) => { + mockButton(props) + return
{props.children}
+ }, +})) + +vi.mock('@/app/components/base/portal-to-follow-elem', () => { + const OpenContext = React.createContext(false) + + return { + PortalToFollowElem: ({ + open, + children, + }: { + open: boolean + children?: React.ReactNode + }) => ( + +
{children}
+
+ ), + PortalToFollowElemTrigger: ({ + children, + onClick, + }: { + children?: React.ReactNode + onClick?: () => void + }) => ( + + ), + PortalToFollowElemContent: ({ + children, + }: { + children?: React.ReactNode + }) => { + const open = React.use(OpenContext) + return open ?
{children}
: null + }, + } +}) + +describe('ButtonStyleDropdown', () => { + const onChange = vi.fn() + + beforeEach(() => { + vi.clearAllMocks() + mockUseTranslation.mockReturnValue({ + t: (key: string) => key, + }) + }) + + it('should map the current style to the trigger button and update the selected style', () => { + render( + , + ) + + expect(mockButton).toHaveBeenCalledWith(expect.objectContaining({ + variant: 'ghost', + })) + expect(screen.getByTestId('portal')).toHaveAttribute('data-open', 'false') + + fireEvent.click(screen.getByTestId('portal-trigger')) + expect(screen.getByTestId('portal')).toHaveAttribute('data-open', 'true') + expect(screen.getByText('nodes.humanInput.userActions.chooseStyle')).toBeInTheDocument() + + fireEvent.click(screen.getByTestId('button-primary').parentElement as HTMLElement) + fireEvent.click(screen.getByTestId('button-secondary').parentElement as HTMLElement) + fireEvent.click(screen.getByTestId('button-secondary-accent').parentElement as HTMLElement) + fireEvent.click(screen.getAllByTestId('button-ghost')[1].parentElement as HTMLElement) + + expect(onChange).toHaveBeenNthCalledWith(1, UserActionButtonType.Primary) + expect(onChange).toHaveBeenNthCalledWith(2, UserActionButtonType.Default) + expect(onChange).toHaveBeenNthCalledWith(3, UserActionButtonType.Accent) + expect(onChange).toHaveBeenNthCalledWith(4, UserActionButtonType.Ghost) + }) + + it('should keep the dropdown closed in readonly mode', () => { + render( + , + ) + + expect(mockButton).toHaveBeenCalledWith(expect.objectContaining({ + variant: 'secondary', + })) + + fireEvent.click(screen.getByTestId('portal-trigger')) + + expect(screen.getByTestId('portal')).toHaveAttribute('data-open', 'false') + expect(screen.queryByTestId('portal-content')).not.toBeInTheDocument() + expect(onChange).not.toHaveBeenCalled() + }) + + it('should map the accent style to the secondary-accent trigger button', () => { + render( + , + ) + + expect(mockButton).toHaveBeenCalledWith(expect.objectContaining({ + variant: 'secondary-accent', + })) + }) + + it('should map the primary style to the primary trigger button', () => { + render( + , + ) + + expect(mockButton).toHaveBeenCalledWith(expect.objectContaining({ + variant: 'primary', + })) + }) +}) diff --git a/web/app/components/workflow/nodes/human-input/components/__tests__/form-content-preview.spec.tsx b/web/app/components/workflow/nodes/human-input/components/__tests__/form-content-preview.spec.tsx new file mode 100644 index 0000000000..e98a74e6b4 --- /dev/null +++ b/web/app/components/workflow/nodes/human-input/components/__tests__/form-content-preview.spec.tsx @@ -0,0 +1,135 @@ +import type { ReactNode } from 'react' +import { fireEvent, render, screen } from '@testing-library/react' +import { UserActionButtonType } from '../../types' +import FormContentPreview from '../form-content-preview' + +const mockUseTranslation = vi.hoisted(() => vi.fn()) +const mockUseStore = vi.hoisted(() => vi.fn()) +const mockUseNodes = vi.hoisted(() => vi.fn()) +const mockGetButtonStyle = vi.hoisted(() => vi.fn()) + +vi.mock('react-i18next', () => ({ + useTranslation: () => mockUseTranslation(), +})) + +vi.mock('@/app/components/workflow/store', () => ({ + useStore: (selector: (state: { panelWidth: number }) => unknown) => mockUseStore(selector), +})) + +vi.mock('@/app/components/workflow/store/workflow/use-nodes', () => ({ + __esModule: true, + default: () => mockUseNodes(), +})) + +vi.mock('@/app/components/base/action-button', () => ({ + __esModule: true, + default: ({ children, onClick }: { children?: ReactNode, onClick?: () => void }) => ( + + ), +})) + +vi.mock('@/app/components/base/badge', () => ({ + __esModule: true, + default: ({ children }: { children?: ReactNode }) =>
{children}
, +})) + +vi.mock('@/app/components/base/button', () => ({ + __esModule: true, + default: ({ children, variant }: { children?: ReactNode, variant?: string }) => ( + + ), +})) + +vi.mock('@/app/components/base/chat/chat/answer/human-input-content/utils', () => ({ + getButtonStyle: (...args: unknown[]) => mockGetButtonStyle(...args), +})) + +vi.mock('@/app/components/base/markdown', () => ({ + Markdown: ({ customComponents }: { + customComponents: { + variable: (props: { node: { properties: { dataPath: string } } }) => ReactNode + section: (props: { node: { properties: { dataName: string } } }) => ReactNode + } + }) => ( +
+ {customComponents.variable({ node: { properties: { dataPath: '#node-1.answer#' } } })} + {customComponents.section({ node: { properties: { dataName: 'field_1' } } })} + {customComponents.section({ node: { properties: { dataName: 'missing_field' } } })} +
+ ), +})) + +vi.mock('../variable-in-markdown', () => ({ + rehypeNotes: vi.fn(), + rehypeVariable: vi.fn(), + Variable: ({ path }: { path: string }) =>
{path}
, + Note: ({ defaultInput, nodeName }: { + defaultInput: { selector: string[] } + nodeName: (nodeId: string) => string + }) =>
{nodeName(defaultInput.selector[0])}
, +})) + +describe('FormContentPreview', () => { + const onClose = vi.fn() + + beforeEach(() => { + vi.clearAllMocks() + mockUseTranslation.mockReturnValue({ + t: (key: string) => key, + }) + mockUseStore.mockImplementation((selector: (state: { panelWidth: number }) => unknown) => selector({ panelWidth: 320 })) + mockUseNodes.mockReturnValue([{ + id: 'node-1', + data: { title: 'Classifier' }, + }]) + mockGetButtonStyle.mockImplementation((style: UserActionButtonType) => style.toLowerCase()) + }) + + it('should render preview content with resolved node names, note fallbacks, and action buttons', () => { + const { container } = render( + , + ) + + expect(container.firstChild).toHaveStyle({ right: '328px' }) + expect(screen.getByTestId('badge')).toHaveTextContent('nodes.humanInput.formContent.preview') + expect(screen.getByTestId('variable-path')).toHaveTextContent('#Classifier.answer#') + expect(screen.getByTestId('note')).toHaveTextContent('Classifier') + expect(screen.getByText(/Can't find note:/)).toHaveTextContent('missing_field') + expect(screen.getByTestId('action-primary')).toHaveTextContent('Approve') + expect(screen.getByText('nodes.humanInput.editor.previewTip')).toBeInTheDocument() + }) + + it('should close the preview when the close action is clicked', () => { + render( + , + ) + + fireEvent.click(screen.getByRole('button', { name: 'close-preview' })) + + expect(onClose).toHaveBeenCalledTimes(1) + }) +}) diff --git a/web/app/components/workflow/nodes/human-input/components/__tests__/form-content.spec.tsx b/web/app/components/workflow/nodes/human-input/components/__tests__/form-content.spec.tsx new file mode 100644 index 0000000000..218da57fbb --- /dev/null +++ b/web/app/components/workflow/nodes/human-input/components/__tests__/form-content.spec.tsx @@ -0,0 +1,258 @@ +import type { ReactNode } from 'react' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import FormContent from '../form-content' + +const mockUseTranslation = vi.hoisted(() => vi.fn()) +const mockUseWorkflowVariableType = vi.hoisted(() => vi.fn()) +const mockIsMac = vi.hoisted(() => vi.fn()) +const mockPromptEditor = vi.hoisted(() => vi.fn()) +const mockAddInputField = vi.hoisted(() => vi.fn()) +const mockOnInsert = vi.hoisted(() => vi.fn()) + +vi.mock('react-i18next', () => ({ + useTranslation: () => mockUseTranslation(), + Trans: ({ + i18nKey, + components, + }: { + i18nKey: string + components?: Record + }) => ( +
+
{i18nKey}
+ {components?.CtrlKey} + {components?.Key} +
+ ), +})) + +vi.mock('@/app/components/workflow/hooks', () => ({ + useWorkflowVariableType: () => mockUseWorkflowVariableType(), +})) + +vi.mock('@/app/components/workflow/utils', () => ({ + isMac: () => mockIsMac(), +})) + +vi.mock('@/app/components/base/prompt-editor', () => ({ + __esModule: true, + default: (props: { + onChange: (value: string) => void + onFocus: () => void + onBlur: () => void + shortcutPopups?: Array<{ + Popup: (props: { onClose: () => void, onInsert: typeof mockOnInsert }) => ReactNode + }> + editable?: boolean + hitlInputBlock: { + workflowNodesMap: Record + } + }) => { + mockPromptEditor(props) + const popup = props.shortcutPopups?.[0] + return ( +
+ + + + {popup && popup.Popup({ onClose: vi.fn(), onInsert: mockOnInsert })} +
+ ) + }, +})) + +vi.mock('../add-input-field', () => ({ + __esModule: true, + default: (props: { + onSave: (payload: { + type: string + output_variable_name: string + default: { + type: string + selector: string[] + value: string + } + }) => void + onCancel: () => void + }) => { + mockAddInputField(props) + return ( +
+ + +
+ ) + }, +})) + +vi.mock('@/app/components/base/prompt-editor/plugins/hitl-input-block', () => ({ + INSERT_HITL_INPUT_BLOCK_COMMAND: 'INSERT_HITL_INPUT_BLOCK_COMMAND', +})) + +describe('FormContent', () => { + const onChange = vi.fn() + const onFormInputsChange = vi.fn() + const onFormInputItemRename = vi.fn() + const onFormInputItemRemove = vi.fn() + + beforeEach(() => { + vi.clearAllMocks() + mockUseTranslation.mockReturnValue({ + t: (key: string) => key, + }) + mockUseWorkflowVariableType.mockReturnValue(() => 'string') + mockIsMac.mockReturnValue(false) + }) + + it('should build workflow node maps, show the hotkey tip on focus, and defer form-input sync until value changes', async () => { + const { rerender } = render( + , + ) + + expect(mockPromptEditor).toHaveBeenCalledWith(expect.objectContaining({ + editable: true, + hitlInputBlock: expect.objectContaining({ + workflowNodesMap: expect.objectContaining({ + 'node-1': expect.objectContaining({ title: 'Start' }), + 'node-2': expect.objectContaining({ title: 'Classifier' }), + 'sys': expect.objectContaining({ title: 'blocks.start' }), + }), + }), + })) + + fireEvent.click(screen.getByText('focus-editor')) + expect(screen.getByText('nodes.humanInput.formContent.hotkeyTip')).toBeInTheDocument() + + fireEvent.click(screen.getByText('save-input')) + expect(mockOnInsert).toHaveBeenCalledWith('INSERT_HITL_INPUT_BLOCK_COMMAND', expect.objectContaining({ + variableName: 'approval', + nodeId: 'node-2', + formInputs: [expect.objectContaining({ output_variable_name: 'approval' })], + onFormInputsChange, + onFormInputItemRename, + onFormInputItemRemove, + })) + expect(onFormInputsChange).not.toHaveBeenCalled() + + rerender( + , + ) + + await waitFor(() => { + expect(onFormInputsChange).toHaveBeenCalledWith([ + expect.objectContaining({ output_variable_name: 'approval' }), + ]) + }) + }) + + it('should disable editing helpers in readonly mode', () => { + const { container } = render( + , + ) + + expect(mockPromptEditor).toHaveBeenCalledWith(expect.objectContaining({ + editable: false, + shortcutPopups: [], + })) + expect(screen.queryByText('save-input')).not.toBeInTheDocument() + expect(container.firstChild).toHaveClass('pointer-events-none') + }) + + it('should render the mac hotkey hint when focused on macOS', () => { + mockIsMac.mockReturnValue(true) + + render( + , + ) + + fireEvent.click(screen.getByText('focus-editor')) + + expect(screen.getByText('⌘')).toBeInTheDocument() + expect(screen.getByText('/')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/nodes/human-input/components/__tests__/timeout.spec.tsx b/web/app/components/workflow/nodes/human-input/components/__tests__/timeout.spec.tsx new file mode 100644 index 0000000000..0424fac72d --- /dev/null +++ b/web/app/components/workflow/nodes/human-input/components/__tests__/timeout.spec.tsx @@ -0,0 +1,77 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import TimeoutInput from '../timeout' + +const mockUseTranslation = vi.hoisted(() => vi.fn()) + +vi.mock('react-i18next', () => ({ + useTranslation: () => mockUseTranslation(), +})) + +vi.mock('@/app/components/base/input', () => ({ + __esModule: true, + default: (props: { + value: number + disabled?: boolean + onChange: (event: { target: { value: string } }) => void + }) => ( + props.onChange({ target: { value: e.target.value } })} + /> + ), +})) + +describe('TimeoutInput', () => { + const onChange = vi.fn() + + beforeEach(() => { + vi.clearAllMocks() + mockUseTranslation.mockReturnValue({ + t: (key: string) => key, + }) + }) + + it('should update the numeric timeout value and switch units', () => { + render( + , + ) + + fireEvent.change(screen.getByTestId('timeout-input'), { target: { value: '12' } }) + fireEvent.click(screen.getByText('nodes.humanInput.timeout.hours')) + + expect(onChange).toHaveBeenNthCalledWith(1, { timeout: 12, unit: 'day' }) + expect(onChange).toHaveBeenNthCalledWith(2, { timeout: 3, unit: 'hour' }) + }) + + it('should fall back to 1 on invalid input and stay read-only when disabled', () => { + const { rerender } = render( + , + ) + + fireEvent.change(screen.getByTestId('timeout-input'), { target: { value: 'abc' } }) + expect(onChange).toHaveBeenCalledWith({ timeout: 1, unit: 'hour' }) + + rerender( + , + ) + + fireEvent.click(screen.getByText('nodes.humanInput.timeout.days')) + expect(onChange).toHaveBeenCalledTimes(1) + expect(screen.getByTestId('timeout-input')).toBeDisabled() + }) +}) diff --git a/web/app/components/workflow/nodes/human-input/components/__tests__/user-action.spec.tsx b/web/app/components/workflow/nodes/human-input/components/__tests__/user-action.spec.tsx new file mode 100644 index 0000000000..af488af817 --- /dev/null +++ b/web/app/components/workflow/nodes/human-input/components/__tests__/user-action.spec.tsx @@ -0,0 +1,146 @@ +import type { ReactNode } from 'react' +import { fireEvent, render, screen } from '@testing-library/react' +import { UserActionButtonType } from '../../types' +import UserActionItem from '../user-action' + +const mockUseTranslation = vi.hoisted(() => vi.fn()) +const mockNotify = vi.hoisted(() => vi.fn()) + +vi.mock('react-i18next', () => ({ + useTranslation: () => mockUseTranslation(), +})) + +vi.mock('@/app/components/base/input', () => ({ + __esModule: true, + default: (props: { + value: string + placeholder?: string + disabled?: boolean + onChange: (event: { target: { value: string } }) => void + }) => ( + props.onChange({ target: { value: e.target.value } })} + /> + ), +})) + +vi.mock('@/app/components/base/button', () => ({ + __esModule: true, + default: (props: { + children?: ReactNode + onClick?: () => void + }) => ( + + ), +})) + +vi.mock('@/app/components/base/ui/toast', () => ({ + __esModule: true, + toast: { + success: (message: string) => mockNotify({ type: 'success', message }), + error: (message: string) => mockNotify({ type: 'error', message }), + warning: (message: string) => mockNotify({ type: 'warning', message }), + info: (message: string) => mockNotify({ type: 'info', message }), + }, +})) + +vi.mock('../button-style-dropdown', () => ({ + __esModule: true, + default: (props: { + onChange: (type: UserActionButtonType) => void + }) => ( + + ), +})) + +describe('UserActionItem', () => { + const onChange = vi.fn() + const onDelete = vi.fn() + const action = { + id: 'approve', + title: 'Approve', + button_style: UserActionButtonType.Primary, + } + + beforeEach(() => { + vi.clearAllMocks() + mockUseTranslation.mockReturnValue({ + t: (key: string) => key, + }) + }) + + it('should sanitize ids, enforce length limits, and update the button text', () => { + render( + , + ) + + fireEvent.change(screen.getByTestId('nodes.humanInput.userActions.actionNamePlaceholder'), { target: { value: 'Approve action' } }) + fireEvent.change(screen.getByTestId('nodes.humanInput.userActions.actionNamePlaceholder'), { target: { value: '1invalid' } }) + fireEvent.change(screen.getByTestId('nodes.humanInput.userActions.actionNamePlaceholder'), { target: { value: 'averyveryveryverylongidentifier' } }) + fireEvent.change(screen.getByTestId('nodes.humanInput.userActions.buttonTextPlaceholder'), { target: { value: 'A very very very long button title' } }) + + expect(onChange).toHaveBeenNthCalledWith(1, expect.objectContaining({ + id: 'Approve_action', + })) + expect(onChange).toHaveBeenNthCalledWith(2, expect.objectContaining({ + id: 'averyveryveryverylon', + })) + expect(onChange).toHaveBeenNthCalledWith(3, expect.objectContaining({ + title: 'A very very very lon', + })) + expect(mockNotify).toHaveBeenNthCalledWith(1, expect.objectContaining({ + type: 'error', + message: 'nodes.humanInput.userActions.actionIdFormatTip', + })) + expect(mockNotify).toHaveBeenNthCalledWith(2, expect.objectContaining({ + type: 'error', + message: 'nodes.humanInput.userActions.actionIdTooLong', + })) + expect(mockNotify).toHaveBeenNthCalledWith(3, expect.objectContaining({ + type: 'error', + message: 'nodes.humanInput.userActions.buttonTextTooLong', + })) + }) + + it('should support clearing ids, updating button style, deleting, and readonly mode', () => { + const { rerender } = render( + , + ) + + fireEvent.change(screen.getByTestId('nodes.humanInput.userActions.actionNamePlaceholder'), { target: { value: ' ' } }) + fireEvent.click(screen.getByText('change-style')) + fireEvent.click(screen.getAllByRole('button')[1]) + + expect(onChange).toHaveBeenNthCalledWith(1, expect.objectContaining({ id: '' })) + expect(onChange).toHaveBeenNthCalledWith(2, expect.objectContaining({ button_style: UserActionButtonType.Ghost })) + expect(onDelete).toHaveBeenCalledWith('approve') + + rerender( + , + ) + + expect(screen.getByTestId('nodes.humanInput.userActions.actionNamePlaceholder')).toBeDisabled() + expect(screen.getByTestId('nodes.humanInput.userActions.buttonTextPlaceholder')).toBeDisabled() + expect(screen.getAllByRole('button')).toHaveLength(1) + }) +}) diff --git a/web/app/components/workflow/nodes/human-input/components/delivery-method/__tests__/index.spec.tsx b/web/app/components/workflow/nodes/human-input/components/delivery-method/__tests__/index.spec.tsx new file mode 100644 index 0000000000..03bc0f2b79 --- /dev/null +++ b/web/app/components/workflow/nodes/human-input/components/delivery-method/__tests__/index.spec.tsx @@ -0,0 +1,150 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import { DeliveryMethodType } from '../../../types' +import DeliveryMethodForm from '../index' + +const mockUseTranslation = vi.hoisted(() => vi.fn()) +const mockUseNodesSyncDraft = vi.hoisted(() => vi.fn()) + +vi.mock('react-i18next', () => ({ + useTranslation: () => mockUseTranslation(), +})) + +vi.mock('@/app/components/base/tooltip', () => ({ + __esModule: true, + default: ({ popupContent }: { popupContent: string }) =>
{popupContent}
, +})) + +vi.mock('@/app/components/workflow/hooks', () => ({ + useNodesSyncDraft: () => mockUseNodesSyncDraft(), +})) + +vi.mock('../method-selector', () => ({ + __esModule: true, + default: (props: { + onAdd: (method: { id: string, type: DeliveryMethodType, enabled: boolean }) => void + onShowUpgradeTip: () => void + }) => ( +
+ + +
+ ), +})) + +vi.mock('../method-item', () => ({ + __esModule: true, + default: (props: { + method: { type: DeliveryMethodType, enabled: boolean } + onChange: (method: { type: DeliveryMethodType, enabled: boolean }) => void + onDelete: (type: DeliveryMethodType) => void + }) => ( +
+ + +
+ ), +})) + +vi.mock('../upgrade-modal', () => ({ + __esModule: true, + default: ({ onClose }: { onClose: () => void }) => ( + + ), +})) + +describe('DeliveryMethodForm', () => { + const onChange = vi.fn() + const mockHandleSyncWorkflowDraft = vi.fn() + + beforeEach(() => { + vi.clearAllMocks() + mockUseTranslation.mockReturnValue({ + t: (key: string) => key, + }) + mockUseNodesSyncDraft.mockReturnValue({ + handleSyncWorkflowDraft: mockHandleSyncWorkflowDraft, + }) + }) + + it('should render the empty state and add methods through the selector', () => { + render( + , + ) + + expect(screen.getByText('nodes.humanInput.deliveryMethod.emptyTip')).toBeInTheDocument() + fireEvent.click(screen.getByText('add-method')) + + expect(onChange).toHaveBeenCalledWith([ + { + id: 'email-1', + type: DeliveryMethodType.Email, + enabled: false, + }, + ]) + expect(mockHandleSyncWorkflowDraft).not.toHaveBeenCalled() + }) + + it('should change and delete methods, syncing the draft after updates', () => { + render( + , + ) + + fireEvent.click(screen.getByText('change-method')) + fireEvent.click(screen.getByText('delete-method')) + + expect(onChange).toHaveBeenNthCalledWith(1, [{ + id: 'email-1', + type: DeliveryMethodType.Email, + enabled: true, + }]) + expect(onChange).toHaveBeenNthCalledWith(2, []) + expect(mockHandleSyncWorkflowDraft).toHaveBeenCalledWith(true, true) + }) + + it('should open and close the upgrade modal', () => { + render( + , + ) + + fireEvent.click(screen.getByText('show-upgrade')) + expect(screen.getByText('upgrade-modal')).toBeInTheDocument() + + fireEvent.click(screen.getByText('upgrade-modal')) + expect(screen.queryByText('upgrade-modal')).not.toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/nodes/human-input/components/delivery-method/email-configure-modal.tsx b/web/app/components/workflow/nodes/human-input/components/delivery-method/email-configure-modal.tsx index fa5cbfd3a2..0aa8b1f640 100644 --- a/web/app/components/workflow/nodes/human-input/components/delivery-method/email-configure-modal.tsx +++ b/web/app/components/workflow/nodes/human-input/components/delivery-method/email-configure-modal.tsx @@ -12,7 +12,7 @@ import Divider from '@/app/components/base/divider' import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' import Switch from '@/app/components/base/switch' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { useSelector as useAppContextWithSelector } from '@/context/app-context' import MailBodyInput from './mail-body-input' import Recipient from './recipient' @@ -45,31 +45,22 @@ const EmailConfigureModal = ({ const checkValidConfig = useCallback(() => { if (!subject.trim()) { - Toast.notify({ - type: 'error', - message: 'subject is required', - }) + toast.error(t(`${i18nPrefix}.deliveryMethod.emailConfigure.subjectRequired`, { ns: 'workflow' })) return false } if (!body.trim()) { - Toast.notify({ - type: 'error', - message: 'body is required', - }) + toast.error(t(`${i18nPrefix}.deliveryMethod.emailConfigure.bodyRequired`, { ns: 'workflow' })) return false } if (!/\{\{#url#\}\}/.test(body.trim())) { - Toast.notify({ - type: 'error', - message: `body must contain one ${t('promptEditor.requestURL.item.title', { ns: 'common' })}`, - }) + toast.error(t(`${i18nPrefix}.deliveryMethod.emailConfigure.bodyMustContainRequestURL`, { + ns: 'workflow', + field: t('promptEditor.requestURL.item.title', { ns: 'common' }), + })) return false } if (!recipients || (recipients.items.length === 0 && !recipients.whole_workspace)) { - Toast.notify({ - type: 'error', - message: 'recipients is required', - }) + toast.error(t(`${i18nPrefix}.deliveryMethod.emailConfigure.recipientsRequired`, { ns: 'workflow' })) return false } return true diff --git a/web/app/components/workflow/nodes/human-input/components/delivery-method/recipient/__tests__/index.spec.tsx b/web/app/components/workflow/nodes/human-input/components/delivery-method/recipient/__tests__/index.spec.tsx new file mode 100644 index 0000000000..96cfc10c23 --- /dev/null +++ b/web/app/components/workflow/nodes/human-input/components/delivery-method/recipient/__tests__/index.spec.tsx @@ -0,0 +1,156 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import Recipient from '../index' + +const mockUseTranslation = vi.hoisted(() => vi.fn()) +const mockUseAppContext = vi.hoisted(() => vi.fn()) +const mockUseMembers = vi.hoisted(() => vi.fn()) + +vi.mock('react-i18next', () => ({ + useTranslation: () => mockUseTranslation(), +})) + +vi.mock('@/context/app-context', () => ({ + useAppContext: () => mockUseAppContext(), +})) + +vi.mock('@/service/use-common', () => ({ + useMembers: () => mockUseMembers(), +})) + +vi.mock('@/app/components/base/switch', () => ({ + __esModule: true, + default: (props: { + value: boolean + onChange: (value: boolean) => void + }) => ( + + ), +})) + +vi.mock('../member-selector', () => ({ + __esModule: true, + default: ({ onSelect }: { onSelect: (id: string) => void }) => ( + + ), +})) + +vi.mock('../email-input', () => ({ + __esModule: true, + default: (props: { + onAdd: (email: string) => void + onSelect: (id: string) => void + onDelete: (recipient: { type: 'member' | 'external', user_id?: string, email?: string }) => void + }) => ( +
+ + + + +
+ ), +})) + +describe('Recipient', () => { + const onChange = vi.fn() + + beforeEach(() => { + vi.clearAllMocks() + mockUseTranslation.mockReturnValue({ + t: (key: string, options?: { workspaceName?: string }) => options?.workspaceName ?? key, + }) + mockUseAppContext.mockReturnValue({ + userProfile: { email: 'owner@example.com' }, + currentWorkspace: { name: 'Dify\'s Lab' }, + }) + mockUseMembers.mockReturnValue({ + data: { + accounts: [ + { id: 'member-1', email: 'member-1@example.com', name: 'Member One' }, + { id: 'member-2', email: 'member-2@example.com', name: 'Member Two' }, + { id: 'member-3', email: 'member-3@example.com', name: 'Member Three' }, + ], + }, + }) + }) + + it('should render workspace details and update recipients through member/email actions', () => { + render( + , + ) + + expect(screen.getByText('D')).toBeInTheDocument() + expect(screen.getByText('Dify’s Lab')).toBeInTheDocument() + + fireEvent.click(screen.getByText('add-member')) + fireEvent.click(screen.getByText('add-email')) + fireEvent.click(screen.getByText('add-email-member')) + fireEvent.click(screen.getByText('delete-member')) + fireEvent.click(screen.getByText('delete-external')) + fireEvent.click(screen.getByText('toggle-workspace')) + + expect(onChange).toHaveBeenNthCalledWith(1, { + whole_workspace: false, + items: [ + { type: 'member', user_id: 'member-1' }, + { type: 'external', email: 'external@example.com' }, + { type: 'member', user_id: 'member-2' }, + ], + }) + expect(onChange).toHaveBeenNthCalledWith(2, { + whole_workspace: false, + items: [ + { type: 'member', user_id: 'member-1' }, + { type: 'external', email: 'external@example.com' }, + { type: 'external', email: 'new@example.com' }, + ], + }) + expect(onChange).toHaveBeenNthCalledWith(3, { + whole_workspace: false, + items: [ + { type: 'member', user_id: 'member-1' }, + { type: 'external', email: 'external@example.com' }, + { type: 'member', user_id: 'member-3' }, + ], + }) + expect(onChange).toHaveBeenNthCalledWith(4, { + whole_workspace: false, + items: [ + { type: 'external', email: 'external@example.com' }, + ], + }) + expect(onChange).toHaveBeenNthCalledWith(5, { + whole_workspace: false, + items: [ + { type: 'member', user_id: 'member-1' }, + ], + }) + expect(onChange).toHaveBeenNthCalledWith(6, { + whole_workspace: true, + items: [ + { type: 'member', user_id: 'member-1' }, + { type: 'external', email: 'external@example.com' }, + ], + }) + }) +}) diff --git a/web/app/components/workflow/nodes/human-input/components/user-action.tsx b/web/app/components/workflow/nodes/human-input/components/user-action.tsx index d124a80051..ca0398eb7f 100644 --- a/web/app/components/workflow/nodes/human-input/components/user-action.tsx +++ b/web/app/components/workflow/nodes/human-input/components/user-action.tsx @@ -7,7 +7,7 @@ import * as React from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import ButtonStyleDropdown from './button-style-dropdown' const i18nPrefix = 'nodes.humanInput' @@ -47,14 +47,14 @@ const UserActionItem: FC = ({ .join('') if (sanitized !== withUnderscores) { - Toast.notify({ type: 'error', message: t(`${i18nPrefix}.userActions.actionIdFormatTip`, { ns: 'workflow' }) }) + toast.error(t(`${i18nPrefix}.userActions.actionIdFormatTip`, { ns: 'workflow' })) return } // Limit to 20 characters if (sanitized.length > ACTION_ID_MAX_LENGTH) { sanitized = sanitized.slice(0, ACTION_ID_MAX_LENGTH) - Toast.notify({ type: 'error', message: t(`${i18nPrefix}.userActions.actionIdTooLong`, { ns: 'workflow', maxLength: ACTION_ID_MAX_LENGTH }) }) + toast.error(t(`${i18nPrefix}.userActions.actionIdTooLong`, { ns: 'workflow', maxLength: ACTION_ID_MAX_LENGTH })) } if (sanitized) @@ -65,7 +65,7 @@ const UserActionItem: FC = ({ let value = e.target.value if (value.length > BUTTON_TEXT_MAX_LENGTH) { value = value.slice(0, BUTTON_TEXT_MAX_LENGTH) - Toast.notify({ type: 'error', message: t(`${i18nPrefix}.userActions.buttonTextTooLong`, { ns: 'workflow', maxLength: BUTTON_TEXT_MAX_LENGTH }) }) + toast.error(t(`${i18nPrefix}.userActions.buttonTextTooLong`, { ns: 'workflow', maxLength: BUTTON_TEXT_MAX_LENGTH })) } onChange({ ...data, title: value }) } diff --git a/web/app/components/workflow/nodes/human-input/hooks/__tests__/use-config.spec.ts b/web/app/components/workflow/nodes/human-input/hooks/__tests__/use-config.spec.ts new file mode 100644 index 0000000000..ce9bdfc295 --- /dev/null +++ b/web/app/components/workflow/nodes/human-input/hooks/__tests__/use-config.spec.ts @@ -0,0 +1,156 @@ +import type { DeliveryMethod, HumanInputNodeType, UserAction } from '../../types' +import { act, renderHook } from '@testing-library/react' +import { BlockEnum } from '@/app/components/workflow/types' +import useConfig from '../use-config' + +const mockUseUpdateNodeInternals = vi.hoisted(() => vi.fn()) +const mockUseNodesReadOnly = vi.hoisted(() => vi.fn()) +const mockUseEdgesInteractions = vi.hoisted(() => vi.fn()) +const mockUseNodeCrud = vi.hoisted(() => vi.fn()) +const mockUseFormContent = vi.hoisted(() => vi.fn()) + +vi.mock('reactflow', () => ({ + useUpdateNodeInternals: () => mockUseUpdateNodeInternals(), +})) + +vi.mock('@/app/components/workflow/hooks', () => ({ + useNodesReadOnly: () => mockUseNodesReadOnly(), +})) + +vi.mock('@/app/components/workflow/hooks/use-edges-interactions', () => ({ + useEdgesInteractions: () => mockUseEdgesInteractions(), +})) + +vi.mock('@/app/components/workflow/nodes/_base/hooks/use-node-crud', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseNodeCrud(...args), +})) + +vi.mock('../use-form-content', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseFormContent(...args), +})) + +const createPayload = (overrides: Partial = {}): HumanInputNodeType => ({ + title: 'Human Input', + desc: '', + type: BlockEnum.HumanInput, + delivery_methods: [{ + id: 'webapp', + type: 'webapp', + enabled: true, + } as DeliveryMethod], + form_content: 'Body', + inputs: [], + user_actions: [{ + id: 'approve', + title: 'Approve', + button_style: 'primary', + } as UserAction], + timeout: 3, + timeout_unit: 'day', + ...overrides, +}) + +describe('human-input/hooks/use-config', () => { + const mockSetInputs = vi.fn() + const mockHandleEdgeDeleteByDeleteBranch = vi.fn() + const mockHandleEdgeSourceHandleChange = vi.fn() + const mockUpdateNodeInternals = vi.fn() + const formContentHook = { + editorKey: 3, + handleFormContentChange: vi.fn(), + handleFormInputsChange: vi.fn(), + handleFormInputItemRename: vi.fn(), + handleFormInputItemRemove: vi.fn(), + } + let currentInputs = createPayload() + + beforeEach(() => { + vi.clearAllMocks() + currentInputs = createPayload() + mockUseUpdateNodeInternals.mockReturnValue(mockUpdateNodeInternals) + mockUseNodesReadOnly.mockReturnValue({ nodesReadOnly: false }) + mockUseEdgesInteractions.mockReturnValue({ + handleEdgeDeleteByDeleteBranch: mockHandleEdgeDeleteByDeleteBranch, + handleEdgeSourceHandleChange: mockHandleEdgeSourceHandleChange, + }) + mockUseNodeCrud.mockImplementation(() => ({ + inputs: currentInputs, + setInputs: mockSetInputs, + })) + mockUseFormContent.mockReturnValue(formContentHook) + }) + + it('should expose form-content helpers and update delivery methods, timeout, and collapsed state', () => { + const { result } = renderHook(() => useConfig('human-input-node', currentInputs)) + const methods = [{ + id: 'email', + type: 'email', + enabled: true, + } as DeliveryMethod] + + expect(result.current.editorKey).toBe(3) + expect(result.current.readOnly).toBe(false) + expect(result.current.structuredOutputCollapsed).toBe(true) + + act(() => { + result.current.handleDeliveryMethodChange(methods) + result.current.handleTimeoutChange({ timeout: 12, unit: 'hour' }) + result.current.setStructuredOutputCollapsed(false) + }) + + expect(mockSetInputs).toHaveBeenNthCalledWith(1, expect.objectContaining({ + delivery_methods: methods, + })) + expect(mockSetInputs).toHaveBeenNthCalledWith(2, expect.objectContaining({ + timeout: 12, + timeout_unit: 'hour', + })) + expect(result.current.structuredOutputCollapsed).toBe(false) + }) + + it('should append and delete user actions while syncing branch-edge cleanup', () => { + const { result } = renderHook(() => useConfig('human-input-node', currentInputs)) + const newAction = { + id: 'reject', + title: 'Reject', + button_style: 'default', + } as UserAction + + act(() => { + result.current.handleUserActionAdd(newAction) + result.current.handleUserActionDelete('approve') + }) + + expect(mockSetInputs).toHaveBeenNthCalledWith(1, expect.objectContaining({ + user_actions: [ + expect.objectContaining({ id: 'approve' }), + newAction, + ], + })) + expect(mockSetInputs).toHaveBeenNthCalledWith(2, expect.objectContaining({ + user_actions: [], + })) + expect(mockHandleEdgeDeleteByDeleteBranch).toHaveBeenCalledWith('human-input-node', 'approve') + }) + + it('should update user action ids and refresh source handles when the branch key changes', () => { + const { result } = renderHook(() => useConfig('human-input-node', currentInputs)) + const renamedAction = { + id: 'approved', + title: 'Approve', + button_style: 'primary', + } as UserAction + + act(() => { + result.current.handleUserActionChange(0, renamedAction) + }) + + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + user_actions: [renamedAction], + })) + expect(mockHandleEdgeSourceHandleChange).toHaveBeenCalledWith('human-input-node', 'approve', 'approved') + expect(mockUpdateNodeInternals).toHaveBeenCalledWith('human-input-node') + }) +}) diff --git a/web/app/components/workflow/nodes/human-input/hooks/__tests__/use-form-content.spec.ts b/web/app/components/workflow/nodes/human-input/hooks/__tests__/use-form-content.spec.ts new file mode 100644 index 0000000000..c809e51595 --- /dev/null +++ b/web/app/components/workflow/nodes/human-input/hooks/__tests__/use-form-content.spec.ts @@ -0,0 +1,112 @@ +import type { FormInputItem, HumanInputNodeType } from '../../types' +import { act, renderHook } from '@testing-library/react' +import { BlockEnum, InputVarType } from '@/app/components/workflow/types' +import useFormContent from '../use-form-content' + +const mockUseWorkflow = vi.hoisted(() => vi.fn()) +const mockUseNodeCrud = vi.hoisted(() => vi.fn()) + +vi.mock('@/app/components/workflow/hooks', () => ({ + useWorkflow: () => mockUseWorkflow(), +})) + +vi.mock('@/app/components/workflow/nodes/_base/hooks/use-node-crud', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseNodeCrud(...args), +})) + +const createFormInput = (overrides: Partial = {}): FormInputItem => ({ + type: InputVarType.textInput, + output_variable_name: 'old_name', + default: { + selector: [], + type: 'constant', + value: '', + }, + ...overrides, +}) + +const createPayload = (overrides: Partial = {}): HumanInputNodeType => ({ + title: 'Human Input', + desc: '', + type: BlockEnum.HumanInput, + delivery_methods: [], + form_content: 'Hello {{#$output.old_name#}}', + inputs: [createFormInput()], + user_actions: [], + timeout: 1, + timeout_unit: 'day', + ...overrides, +}) + +describe('human-input/use-form-content', () => { + const mockSetInputs = vi.fn() + const mockHandleOutVarRenameChange = vi.fn() + let currentInputs = createPayload() + + beforeEach(() => { + vi.clearAllMocks() + currentInputs = createPayload() + mockUseWorkflow.mockReturnValue({ + handleOutVarRenameChange: mockHandleOutVarRenameChange, + }) + mockUseNodeCrud.mockImplementation(() => ({ + inputs: currentInputs, + setInputs: mockSetInputs, + })) + }) + + it('should update raw form content and replace the form input list', () => { + const { result } = renderHook(() => useFormContent('human-input-node', currentInputs)) + const nextInputs = [ + createFormInput({ + output_variable_name: 'approval', + }), + ] + + act(() => { + result.current.handleFormContentChange('Updated body') + result.current.handleFormInputsChange(nextInputs) + }) + + expect(mockSetInputs).toHaveBeenNthCalledWith(1, expect.objectContaining({ + form_content: 'Updated body', + })) + expect(mockSetInputs).toHaveBeenNthCalledWith(2, expect.objectContaining({ + inputs: nextInputs, + })) + expect(result.current.editorKey).toBe(1) + }) + + it('should rename input placeholders inside markdown and notify downstream references', () => { + const { result } = renderHook(() => useFormContent('human-input-node', currentInputs)) + const renamedInput = createFormInput({ + output_variable_name: 'new_name', + }) + + act(() => { + result.current.handleFormInputItemRename(renamedInput, 'old_name') + }) + + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + form_content: 'Hello {{#$output.new_name#}}', + inputs: [renamedInput], + })) + expect(mockHandleOutVarRenameChange).toHaveBeenCalledWith('human-input-node', ['human-input-node', 'old_name'], ['human-input-node', 'new_name']) + expect(result.current.editorKey).toBe(1) + }) + + it('should remove an input placeholder and its form input metadata', () => { + const { result } = renderHook(() => useFormContent('human-input-node', currentInputs)) + + act(() => { + result.current.handleFormInputItemRemove('old_name') + }) + + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + form_content: 'Hello ', + inputs: [], + })) + expect(result.current.editorKey).toBe(1) + }) +}) diff --git a/web/app/components/workflow/nodes/human-input/hooks/__tests__/use-single-run-form-params.spec.ts b/web/app/components/workflow/nodes/human-input/hooks/__tests__/use-single-run-form-params.spec.ts new file mode 100644 index 0000000000..571708e87d --- /dev/null +++ b/web/app/components/workflow/nodes/human-input/hooks/__tests__/use-single-run-form-params.spec.ts @@ -0,0 +1,234 @@ +import type { HumanInputNodeType } from '../../types' +import type { InputVar } from '@/app/components/workflow/types' +import type { HumanInputFormData } from '@/types/workflow' +import { act, renderHook } from '@testing-library/react' +import { BlockEnum, InputVarType } from '@/app/components/workflow/types' +import { AppModeEnum } from '@/types/app' +import useSingleRunFormParams from '../use-single-run-form-params' + +const mockUseTranslation = vi.hoisted(() => vi.fn()) +const mockUseAppStore = vi.hoisted(() => vi.fn()) +const mockFetchHumanInputNodeStepRunForm = vi.hoisted(() => vi.fn()) +const mockSubmitHumanInputNodeStepRunForm = vi.hoisted(() => vi.fn()) +const mockUseNodeCrud = vi.hoisted(() => vi.fn()) + +vi.mock('react-i18next', () => ({ + useTranslation: () => mockUseTranslation(), +})) + +vi.mock('@/app/components/app/store', () => ({ + useStore: (selector: (state: { appDetail?: { id?: string, mode?: AppModeEnum } }) => unknown) => mockUseAppStore(selector), +})) + +vi.mock('@/service/workflow', () => ({ + fetchHumanInputNodeStepRunForm: (...args: unknown[]) => mockFetchHumanInputNodeStepRunForm(...args), + submitHumanInputNodeStepRunForm: (...args: unknown[]) => mockSubmitHumanInputNodeStepRunForm(...args), +})) + +vi.mock('@/app/components/workflow/nodes/_base/hooks/use-node-crud', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseNodeCrud(...args), +})) + +const createPayload = (overrides: Partial = {}): HumanInputNodeType => ({ + title: 'Human Input', + desc: '', + type: BlockEnum.HumanInput, + delivery_methods: [], + form_content: 'Summary: {{#start.topic#}}', + inputs: [{ + type: InputVarType.textInput, + output_variable_name: 'summary', + default: { + type: 'variable', + selector: ['start', 'topic'], + value: '', + }, + }], + user_actions: [], + timeout: 1, + timeout_unit: 'day', + ...overrides, +}) + +const createInputVar = (overrides: Partial = {}): InputVar => ({ + type: InputVarType.textInput, + label: 'Topic', + variable: '#start.topic#', + required: false, + value_selector: ['start', 'topic'], + ...overrides, +}) + +const mockFormData: HumanInputFormData = { + form_id: 'form-1', + node_id: 'node-1', + node_title: 'Human Input', + form_content: 'Rendered content', + inputs: [], + actions: [], + form_token: 'token-1', + resolved_default_values: { + topic: 'AI', + }, + display_in_ui: true, + expiration_time: 1000, +} + +describe('human-input/hooks/use-single-run-form-params', () => { + const mockSetRunInputData = vi.fn() + const getInputVars = vi.fn() + let currentInputs = createPayload() + let appDetail: { id?: string, mode?: AppModeEnum } | undefined + + beforeEach(() => { + vi.clearAllMocks() + currentInputs = createPayload() + appDetail = { + id: 'app-1', + mode: AppModeEnum.WORKFLOW, + } + + mockUseTranslation.mockReturnValue({ + t: (key: string) => key, + }) + mockUseAppStore.mockImplementation((selector: (state: { appDetail?: { id?: string, mode?: AppModeEnum } }) => unknown) => selector({ appDetail })) + mockUseNodeCrud.mockImplementation(() => ({ + inputs: currentInputs, + })) + getInputVars.mockReturnValue([ + createInputVar(), + createInputVar({ + label: 'Output', + variable: '#$output.answer#', + value_selector: ['$output', 'answer'], + }), + { + ...createInputVar({ + label: 'Broken', + }), + variable: undefined, + } as unknown as InputVar, + ]) + mockFetchHumanInputNodeStepRunForm.mockResolvedValue(mockFormData) + mockSubmitHumanInputNodeStepRunForm.mockResolvedValue({}) + }) + + it('should build a single before-run form, filter output vars, and expose dependent vars', () => { + const { result } = renderHook(() => useSingleRunFormParams({ + id: 'node-1', + payload: currentInputs, + runInputData: { topic: 'AI' }, + getInputVars, + setRunInputData: mockSetRunInputData, + })) + + expect(getInputVars).toHaveBeenCalledWith([ + '{{#start.topic#}}', + 'Summary: {{#start.topic#}}', + ]) + expect(result.current.forms).toHaveLength(1) + expect(result.current.forms[0]).toEqual(expect.objectContaining({ + label: 'nodes.humanInput.singleRun.label', + values: { topic: 'AI' }, + inputs: [ + expect.objectContaining({ variable: '#start.topic#' }), + expect.objectContaining({ label: 'Broken' }), + ], + })) + + act(() => { + result.current.forms[0].onChange?.({ topic: 'Updated' }) + }) + + expect(mockSetRunInputData).toHaveBeenCalledWith({ topic: 'Updated' }) + expect(result.current.getDependentVars()).toEqual([ + ['start', 'topic'], + ]) + }) + + it('should fetch and submit generated forms in workflow mode while keeping required inputs', async () => { + const { result } = renderHook(() => useSingleRunFormParams({ + id: 'node-1', + payload: currentInputs, + runInputData: {}, + getInputVars, + setRunInputData: mockSetRunInputData, + })) + + await act(async () => { + await result.current.handleShowGeneratedForm({ + topic: 'AI', + ignored: undefined as unknown as string, + }) + }) + + expect(result.current.showGeneratedForm).toBe(true) + expect(mockFetchHumanInputNodeStepRunForm).toHaveBeenCalledWith( + '/apps/app-1/workflows/draft/human-input/nodes/node-1/form', + { + inputs: { topic: 'AI' }, + }, + ) + expect(result.current.formData).toEqual(mockFormData) + + await act(async () => { + await result.current.handleSubmitHumanInputForm({ + inputs: { answer: 'approved' }, + form_inputs: { ignored: 'value' }, + action: 'approve', + }) + }) + + expect(mockSubmitHumanInputNodeStepRunForm).toHaveBeenCalledWith( + '/apps/app-1/workflows/draft/human-input/nodes/node-1/form', + { + inputs: { topic: 'AI' }, + form_inputs: { answer: 'approved' }, + action: 'approve', + }, + ) + + act(() => { + result.current.handleHideGeneratedForm() + }) + + expect(result.current.showGeneratedForm).toBe(false) + }) + + it('should use the advanced-chat endpoint and skip remote fetches when app detail is missing', async () => { + appDetail = { + id: 'app-2', + mode: AppModeEnum.ADVANCED_CHAT, + } + + const { result, rerender } = renderHook(() => useSingleRunFormParams({ + id: 'node-9', + payload: currentInputs, + runInputData: {}, + getInputVars, + setRunInputData: mockSetRunInputData, + })) + + await act(async () => { + await result.current.handleFetchFormContent({ topic: 'hello' }) + }) + + expect(mockFetchHumanInputNodeStepRunForm).toHaveBeenCalledWith( + '/apps/app-2/advanced-chat/workflows/draft/human-input/nodes/node-9/form', + { + inputs: { topic: 'hello' }, + }, + ) + + appDetail = undefined + rerender() + + await act(async () => { + const data = await result.current.handleFetchFormContent({ topic: 'skip' }) + expect(data).toBeNull() + }) + + expect(mockFetchHumanInputNodeStepRunForm).toHaveBeenCalledTimes(1) + }) +}) diff --git a/web/app/components/workflow/nodes/human-input/panel.tsx b/web/app/components/workflow/nodes/human-input/panel.tsx index 525821d042..c209c6451e 100644 --- a/web/app/components/workflow/nodes/human-input/panel.tsx +++ b/web/app/components/workflow/nodes/human-input/panel.tsx @@ -16,8 +16,8 @@ import { useTranslation } from 'react-i18next' import ActionButton from '@/app/components/base/action-button' import Button from '@/app/components/base/button' import Divider from '@/app/components/base/divider' -import Toast from '@/app/components/base/toast' import Tooltip from '@/app/components/base/tooltip' +import { toast } from '@/app/components/base/ui/toast' import OutputVars, { VarItem } from '@/app/components/workflow/nodes/_base/components/output-vars' import Split from '@/app/components/workflow/nodes/_base/components/split' import useAvailableVarList from '@/app/components/workflow/nodes/_base/hooks/use-available-var-list' @@ -132,7 +132,7 @@ const Panel: FC> = ({ className="flex size-6 cursor-pointer items-center justify-center rounded-md hover:bg-components-button-ghost-bg-hover" onClick={() => { copy(inputs.form_content) - Toast.notify({ type: 'success', message: t('actionMsg.copySuccessfully', { ns: 'common' }) }) + toast.success(t('actionMsg.copySuccessfully', { ns: 'common' })) }} > diff --git a/web/app/components/workflow/nodes/iteration/__tests__/integration.spec.tsx b/web/app/components/workflow/nodes/iteration/__tests__/integration.spec.tsx index 67de8b188b..dc7538144e 100644 --- a/web/app/components/workflow/nodes/iteration/__tests__/integration.spec.tsx +++ b/web/app/components/workflow/nodes/iteration/__tests__/integration.spec.tsx @@ -3,7 +3,7 @@ import type { IterationNodeType } from '../types' import type { PanelProps } from '@/types/workflow' import { fireEvent, render, screen } from '@testing-library/react' import userEvent from '@testing-library/user-event' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { ErrorHandleMode } from '@/app/components/workflow/types' import { BlockEnum, VarType } from '../../../types' import AddBlock from '../add-block' @@ -15,6 +15,15 @@ const mockHandleNodeAdd = vi.fn() const mockHandleNodeIterationRerender = vi.fn() let mockNodesReadOnly = false +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + success: vi.fn(), + error: vi.fn(), + warning: vi.fn(), + info: vi.fn(), + }, +})) + vi.mock('reactflow', async () => { const actual = await vi.importActual('reactflow') return { @@ -102,7 +111,7 @@ vi.mock('../use-config', () => ({ })) const mockUseConfig = vi.mocked(useConfig) -const mockToastNotify = vi.spyOn(Toast, 'notify').mockImplementation(() => ({})) +const mockToastWarning = vi.mocked(toast.warning) const createData = (overrides: Partial = {}): IterationNodeType => ({ title: 'Iteration', @@ -191,11 +200,7 @@ describe('iteration path', () => { expect(screen.getByRole('button', { name: 'select-block' })).toBeInTheDocument() expect(screen.getByTestId('iteration-background-iteration-node')).toBeInTheDocument() expect(mockHandleNodeIterationRerender).toHaveBeenCalledWith('iteration-node') - expect(mockToastNotify).toHaveBeenCalledWith({ - type: 'warning', - message: 'workflow.nodes.iteration.answerNodeWarningDesc', - duration: 5000, - }) + expect(mockToastWarning).toHaveBeenCalledWith('workflow.nodes.iteration.answerNodeWarningDesc') }) it('should wire panel input, output, parallel, numeric, error mode, and flatten actions', async () => { diff --git a/web/app/components/workflow/nodes/iteration/__tests__/use-config.spec.ts b/web/app/components/workflow/nodes/iteration/__tests__/use-config.spec.ts new file mode 100644 index 0000000000..5bef3eb8a6 --- /dev/null +++ b/web/app/components/workflow/nodes/iteration/__tests__/use-config.spec.ts @@ -0,0 +1,173 @@ +import type { IterationNodeType } from '../types' +import type { Item } from '@/app/components/base/select' +import type { Var } from '@/app/components/workflow/types' +import { act, renderHook } from '@testing-library/react' +import { VarType as VarKindType } from '@/app/components/workflow/nodes/tool/types' +import { BlockEnum, ErrorHandleMode, VarType } from '@/app/components/workflow/types' +import useConfig from '../use-config' + +const mockUseInspectVarsCrud = vi.hoisted(() => vi.fn()) +const mockUseNodesReadOnly = vi.hoisted(() => vi.fn()) +const mockUseIsChatMode = vi.hoisted(() => vi.fn()) +const mockUseWorkflow = vi.hoisted(() => vi.fn()) +const mockUseStore = vi.hoisted(() => vi.fn()) +const mockUseNodeCrud = vi.hoisted(() => vi.fn()) +const mockUseAllBuiltInTools = vi.hoisted(() => vi.fn()) +const mockUseAllCustomTools = vi.hoisted(() => vi.fn()) +const mockUseAllWorkflowTools = vi.hoisted(() => vi.fn()) +const mockUseAllMCPTools = vi.hoisted(() => vi.fn()) +const mockToNodeOutputVars = vi.hoisted(() => vi.fn()) + +vi.mock('@/app/components/workflow/hooks/use-inspect-vars-crud', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseInspectVarsCrud(...args), +})) + +vi.mock('@/app/components/workflow/hooks', () => ({ + useNodesReadOnly: () => mockUseNodesReadOnly(), + useIsChatMode: () => mockUseIsChatMode(), + useWorkflow: () => mockUseWorkflow(), +})) + +vi.mock('@/app/components/workflow/store', () => ({ + useStore: (selector: (state: { dataSourceList: unknown[] }) => unknown) => + selector({ dataSourceList: mockUseStore() }), +})) + +vi.mock('@/app/components/workflow/nodes/_base/hooks/use-node-crud', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseNodeCrud(...args), +})) + +vi.mock('@/service/use-tools', () => ({ + useAllBuiltInTools: () => mockUseAllBuiltInTools(), + useAllCustomTools: () => mockUseAllCustomTools(), + useAllWorkflowTools: () => mockUseAllWorkflowTools(), + useAllMCPTools: () => mockUseAllMCPTools(), +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/variable/utils', () => ({ + toNodeOutputVars: (...args: unknown[]) => mockToNodeOutputVars(...args), +})) + +const createPayload = (overrides: Partial = {}): IterationNodeType => ({ + title: 'Iteration', + desc: '', + type: BlockEnum.Iteration, + iterator_selector: ['start', 'items'], + iterator_input_type: VarType.arrayString, + output_selector: ['child', 'result'], + output_type: VarType.arrayString, + is_parallel: false, + parallel_nums: 3, + error_handle_mode: ErrorHandleMode.Terminated, + flatten_output: false, + start_node_id: 'start-node', + _children: [], + _isShowTips: false, + ...overrides, +}) + +const createVar = (type: VarType, variable = 'test.variable'): Var => ({ + variable, + type, +}) + +describe('iteration/use-config', () => { + const mockSetInputs = vi.fn() + const mockDeleteNodeInspectorVars = vi.fn() + let currentInputs = createPayload() + + beforeEach(() => { + vi.clearAllMocks() + currentInputs = createPayload() + + mockUseInspectVarsCrud.mockReturnValue({ + deleteNodeInspectorVars: mockDeleteNodeInspectorVars, + }) + mockUseNodesReadOnly.mockReturnValue({ nodesReadOnly: false }) + mockUseIsChatMode.mockReturnValue(false) + mockUseWorkflow.mockReturnValue({ + getIterationNodeChildren: vi.fn(() => [{ id: 'child-node' }]), + }) + mockUseStore.mockReturnValue([]) + mockUseNodeCrud.mockImplementation(() => ({ + inputs: currentInputs, + setInputs: mockSetInputs, + })) + mockUseAllBuiltInTools.mockReturnValue({ data: [] }) + mockUseAllCustomTools.mockReturnValue({ data: [] }) + mockUseAllWorkflowTools.mockReturnValue({ data: [] }) + mockUseAllMCPTools.mockReturnValue({ data: [] }) + mockToNodeOutputVars.mockReturnValue([{ variable: 'child.result' }]) + }) + + it('should expose iteration children vars and filter only array-like iterator inputs', () => { + const { result } = renderHook(() => useConfig('iteration-node', currentInputs)) + + expect(result.current.readOnly).toBe(false) + expect(result.current.childrenNodeVars).toEqual([{ variable: 'child.result' }]) + expect(result.current.iterationChildrenNodes).toEqual([{ id: 'child-node' }]) + expect(result.current.filterInputVar(createVar(VarType.arrayFile, 'files'))).toBe(true) + expect(result.current.filterInputVar(createVar(VarType.string, 'text'))).toBe(false) + expect(mockToNodeOutputVars).toHaveBeenCalled() + }) + + it('should update iterator input and output selectors and reset inspector vars on output changes', () => { + const { result } = renderHook(() => useConfig('iteration-node', currentInputs)) + + act(() => { + result.current.handleInputChange(['start', 'documents'], VarKindType.variable, createVar(VarType.arrayObject, 'start.documents')) + }) + + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + iterator_selector: ['start', 'documents'], + iterator_input_type: VarType.arrayObject, + })) + + mockSetInputs.mockClear() + + act(() => { + result.current.handleOutputVarChange(['child', 'score'], VarKindType.variable, createVar(VarType.number, 'child.score')) + }) + + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + output_selector: ['child', 'score'], + output_type: VarType.arrayNumber, + })) + expect(mockDeleteNodeInspectorVars).toHaveBeenCalledWith('iteration-node') + + mockSetInputs.mockClear() + + act(() => { + result.current.handleOutputVarChange(['child', 'result'], VarKindType.variable, createVar(VarType.string, 'child.result')) + }) + + expect(mockSetInputs).not.toHaveBeenCalled() + }) + + it('should update parallel, error-mode, and flatten options', () => { + const { result } = renderHook(() => useConfig('iteration-node', currentInputs)) + const item: Item = { name: 'Continue', value: ErrorHandleMode.ContinueOnError } + + act(() => { + result.current.changeParallel(true) + result.current.changeErrorResponseMode(item) + result.current.changeParallelNums(6) + result.current.changeFlattenOutput(true) + }) + + expect(mockSetInputs).toHaveBeenNthCalledWith(1, expect.objectContaining({ + is_parallel: true, + })) + expect(mockSetInputs).toHaveBeenNthCalledWith(2, expect.objectContaining({ + error_handle_mode: ErrorHandleMode.ContinueOnError, + })) + expect(mockSetInputs).toHaveBeenNthCalledWith(3, expect.objectContaining({ + parallel_nums: 6, + })) + expect(mockSetInputs).toHaveBeenNthCalledWith(4, expect.objectContaining({ + flatten_output: true, + })) + }) +}) diff --git a/web/app/components/workflow/nodes/iteration/__tests__/use-single-run-form-params.spec.ts b/web/app/components/workflow/nodes/iteration/__tests__/use-single-run-form-params.spec.ts new file mode 100644 index 0000000000..7313b6945e --- /dev/null +++ b/web/app/components/workflow/nodes/iteration/__tests__/use-single-run-form-params.spec.ts @@ -0,0 +1,168 @@ +import type { InputVar, Node } from '../../../types' +import type { IterationNodeType } from '../types' +import type { NodeTracing } from '@/types/workflow' +import { act, renderHook } from '@testing-library/react' +import { BlockEnum, ErrorHandleMode, InputVarType, VarType } from '@/app/components/workflow/types' +import useSingleRunFormParams from '../use-single-run-form-params' + +const mockUseIsNodeInIteration = vi.hoisted(() => vi.fn()) +const mockUseWorkflow = vi.hoisted(() => vi.fn()) +const mockFormatTracing = vi.hoisted(() => vi.fn()) +const mockGetNodeUsedVars = vi.hoisted(() => vi.fn()) +const mockGetNodeUsedVarPassToServerKey = vi.hoisted(() => vi.fn()) +const mockGetNodeInfoById = vi.hoisted(() => vi.fn()) +const mockIsSystemVar = vi.hoisted(() => vi.fn()) + +vi.mock('@/app/components/workflow/hooks', () => ({ + useIsNodeInIteration: (...args: unknown[]) => mockUseIsNodeInIteration(...args), + useWorkflow: () => mockUseWorkflow(), +})) + +vi.mock('@/app/components/workflow/run/utils/format-log', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockFormatTracing(...args), +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/variable/utils', () => ({ + getNodeUsedVars: (...args: unknown[]) => mockGetNodeUsedVars(...args), + getNodeUsedVarPassToServerKey: (...args: unknown[]) => mockGetNodeUsedVarPassToServerKey(...args), + getNodeInfoById: (...args: unknown[]) => mockGetNodeInfoById(...args), + isSystemVar: (...args: unknown[]) => mockIsSystemVar(...args), +})) + +const createInputVar = (variable: string): InputVar => ({ + type: InputVarType.textInput, + label: variable, + variable, + required: false, +}) + +const createNode = (id: string, title: string, type = BlockEnum.Tool): Node => ({ + id, + position: { x: 0, y: 0 }, + data: { + title, + type, + desc: '', + }, +} as Node) + +const createPayload = (overrides: Partial = {}): IterationNodeType => ({ + title: 'Iteration', + desc: '', + type: BlockEnum.Iteration, + start_node_id: 'start-node', + iterator_selector: ['start-node', 'items'], + iterator_input_type: VarType.arrayString, + output_selector: ['child-node', 'text'], + output_type: VarType.arrayString, + is_parallel: false, + parallel_nums: 2, + error_handle_mode: ErrorHandleMode.Terminated, + flatten_output: false, + _children: [], + _isShowTips: false, + ...overrides, +}) + +describe('iteration/use-single-run-form-params', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseIsNodeInIteration.mockReturnValue({ + isNodeInIteration: (nodeId: string) => nodeId === 'inner-node', + }) + mockUseWorkflow.mockReturnValue({ + getIterationNodeChildren: () => [ + createNode('tool-a', 'Tool A'), + createNode('inner-node', 'Inner Node'), + ], + getBeforeNodesInSameBranch: () => [ + createNode('start-node', 'Start Node', BlockEnum.Start), + ], + }) + mockGetNodeUsedVars.mockImplementation((node: Node) => { + if (node.id === 'tool-a') + return [['start-node', 'answer'], ['inner-node', 'secret'], ['iteration-node', 'item']] + return [] + }) + mockGetNodeUsedVarPassToServerKey.mockReturnValue('passed_key') + mockGetNodeInfoById.mockImplementation((nodes: Node[], id: string) => nodes.find(node => node.id === id)) + mockIsSystemVar.mockReturnValue(false) + mockFormatTracing.mockReturnValue([{ id: 'formatted-node' }]) + }) + + it('should build single-run forms from external vars and keep iterator state in a dedicated form', () => { + const toVarInputs = vi.fn(() => [createInputVar('#start-node.answer#')]) + + const { result } = renderHook(() => useSingleRunFormParams({ + id: 'iteration-node', + payload: createPayload(), + runInputData: { + 'query': 'hello', + 'iteration-node.input_selector': ['start-node', 'items'], + }, + runInputDataRef: { current: {} }, + getInputVars: vi.fn(), + setRunInputData: vi.fn(), + toVarInputs, + iterationRunResult: [], + })) + + expect(toVarInputs).toHaveBeenCalledWith([ + expect.objectContaining({ + variable: 'start-node.answer', + value_selector: ['start-node', 'answer'], + }), + ]) + expect(result.current.forms).toHaveLength(2) + expect(result.current.forms[0].inputs).toEqual([createInputVar('#start-node.answer#')]) + expect(result.current.forms[0].values).toEqual({ + 'query': 'hello', + 'iteration-node.input_selector': ['start-node', 'items'], + }) + expect(result.current.forms[1].values).toEqual({ + 'iteration-node.input_selector': ['start-node', 'items'], + }) + expect(result.current.allVarObject).toEqual({ + 'start-node.answer@@@tool-a@@@0': { + inSingleRunPassedKey: 'passed_key', + }, + }) + expect(result.current.nodeInfo).toEqual({ id: 'formatted-node' }) + }) + + it('should forward form updates and expose iterator dependencies', () => { + const setRunInputData = vi.fn() + + const { result } = renderHook(() => useSingleRunFormParams({ + id: 'iteration-node', + payload: createPayload({ + iterator_selector: ['source-node', 'records'], + }), + runInputData: { + 'query': 'old', + 'iteration-node.input_selector': ['source-node', 'records'], + }, + runInputDataRef: { current: {} }, + getInputVars: vi.fn(), + setRunInputData, + toVarInputs: vi.fn(() => []), + iterationRunResult: [] as NodeTracing[], + })) + + act(() => { + result.current.forms[0].onChange({ query: 'new' }) + result.current.forms[1].onChange({ + 'iteration-node.input_selector': ['source-node', 'next'], + }) + }) + + expect(setRunInputData).toHaveBeenNthCalledWith(1, { query: 'new' }) + expect(setRunInputData).toHaveBeenNthCalledWith(2, { + 'query': 'old', + 'iteration-node.input_selector': ['source-node', 'next'], + }) + expect(result.current.getDependentVars()).toEqual([['source-node', 'records']]) + expect(result.current.getDependentVar('iteration-node.input_selector')).toEqual(['source-node', 'records']) + }) +}) diff --git a/web/app/components/workflow/nodes/iteration/node.tsx b/web/app/components/workflow/nodes/iteration/node.tsx index 476266211a..667c68144f 100644 --- a/web/app/components/workflow/nodes/iteration/node.tsx +++ b/web/app/components/workflow/nodes/iteration/node.tsx @@ -12,7 +12,7 @@ import { useNodesInitialized, useViewport, } from 'reactflow' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { cn } from '@/utils/classnames' import { IterationStartNodeDumb } from '../iteration-start' import AddBlock from './add-block' @@ -34,11 +34,7 @@ const Node: FC> = ({ if (nodesInitialized) handleNodeIterationRerender(id) if (data.is_parallel && showTips) { - Toast.notify({ - type: 'warning', - message: t(`${i18nPrefix}.answerNodeWarningDesc`, { ns: 'workflow' }), - duration: 5000, - }) + toast.warning(t(`${i18nPrefix}.answerNodeWarningDesc`, { ns: 'workflow' })) setShowTips(false) } }, [nodesInitialized, id, handleNodeIterationRerender, data.is_parallel, showTips, t]) diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-config.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-config.tsx index a19dccad78..164e7c3c29 100644 --- a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-config.tsx +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-config.tsx @@ -4,7 +4,7 @@ import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Divider from '@/app/components/base/divider' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { JSON_SCHEMA_MAX_DEPTH } from '@/config' import { cn } from '@/utils/classnames' import { SegmentedControl } from '../../../../../base/segmented-control' @@ -196,10 +196,7 @@ const JsonSchemaConfig: FC = ({ } else if (currentTab === SchemaView.VisualEditor) { if (advancedEditing || isAddingNewField) { - Toast.notify({ - type: 'warning', - message: t('nodes.llm.jsonSchema.warningTips.saveSchema', { ns: 'workflow' }), - }) + toast.warning(t('nodes.llm.jsonSchema.warningTips.saveSchema', { ns: 'workflow' })) return } } diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/index.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/index.tsx index 6a34925275..9e2fba5fa2 100644 --- a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/index.tsx +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/index.tsx @@ -9,7 +9,7 @@ import { PortalToFollowElemContent, PortalToFollowElemTrigger, } from '@/app/components/base/portal-to-follow-elem' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import useTheme from '@/hooks/use-theme' @@ -112,10 +112,7 @@ const JsonSchemaGenerator: FC = ({ const generateSchema = useCallback(async () => { const { output, error } = await generateStructuredOutputRules({ instruction, model_config: model! }) if (error) { - Toast.notify({ - type: 'error', - message: error, - }) + toast.error(error) setSchema(null) setView(GeneratorView.promptEditor) return diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/hooks.ts b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/hooks.ts index 6159028c21..4820b5a9dc 100644 --- a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/hooks.ts +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/hooks.ts @@ -3,7 +3,8 @@ import type { Field } from '../../../types' import type { EditData } from './edit-card' import { noop } from 'es-toolkit/function' import { produce } from 'immer' -import Toast from '@/app/components/base/toast' +import { useTranslation } from 'react-i18next' +import { toast } from '@/app/components/base/ui/toast' import { ArrayType, Type } from '../../../types' import { findPropertyWithPath } from '../../../utils' import { useMittContext } from './context' @@ -22,6 +23,7 @@ type AddEventParams = { export const useSchemaNodeOperations = (props: VisualEditorProps) => { const { schema: jsonSchema, onChange: doOnChange } = props + const { t } = useTranslation() const onChange = doOnChange || noop const backupSchema = useVisualEditorStore(state => state.backupSchema) const setBackupSchema = useVisualEditorStore(state => state.setBackupSchema) @@ -65,10 +67,7 @@ export const useSchemaNodeOperations = (props: VisualEditorProps) => { if (schema.type === Type.object) { const properties = schema.properties || {} if (properties[newName]) { - Toast.notify({ - type: 'error', - message: 'Property name already exists', - }) + toast.error(t('nodes.llm.jsonSchema.fieldNameAlreadyExists', { ns: 'workflow' })) emit('restorePropertyName') return } @@ -92,10 +91,7 @@ export const useSchemaNodeOperations = (props: VisualEditorProps) => { if (schema.type === Type.array && schema.items && schema.items.type === Type.object) { const properties = schema.items.properties || {} if (properties[newName]) { - Toast.notify({ - type: 'error', - message: 'Property name already exists', - }) + toast.error(t('nodes.llm.jsonSchema.fieldNameAlreadyExists', { ns: 'workflow' })) emit('restorePropertyName') return } @@ -267,10 +263,7 @@ export const useSchemaNodeOperations = (props: VisualEditorProps) => { if (oldName !== newName) { const properties = parentSchema.properties if (properties[newName]) { - Toast.notify({ - type: 'error', - message: 'Property name already exists', - }) + toast.error(t('nodes.llm.jsonSchema.fieldNameAlreadyExists', { ns: 'workflow' })) samePropertyNameError = true } @@ -358,10 +351,7 @@ export const useSchemaNodeOperations = (props: VisualEditorProps) => { if (oldName !== newName) { const properties = parentSchema.items.properties || {} if (properties[newName]) { - Toast.notify({ - type: 'error', - message: 'Property name already exists', - }) + toast.error(t('nodes.llm.jsonSchema.fieldNameAlreadyExists', { ns: 'workflow' })) samePropertyNameError = true } diff --git a/web/app/components/workflow/nodes/llm/panel.tsx b/web/app/components/workflow/nodes/llm/panel.tsx index 7a7640948d..c3e6d0fee2 100644 --- a/web/app/components/workflow/nodes/llm/panel.tsx +++ b/web/app/components/workflow/nodes/llm/panel.tsx @@ -7,8 +7,8 @@ import { useCallback } from 'react' import { useTranslation } from 'react-i18next' import AddButton2 from '@/app/components/base/button/add-button' import Switch from '@/app/components/base/switch' -import Toast from '@/app/components/base/toast' import Tooltip from '@/app/components/base/tooltip' +import { toast } from '@/app/components/base/ui/toast' import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal' import Field from '@/app/components/workflow/nodes/_base/components/field' import OutputVars, { VarItem } from '@/app/components/workflow/nodes/_base/components/output-vars' @@ -98,11 +98,11 @@ const Panel: FC> = ({ ) const keys = Object.keys(removedDetails) if (keys.length) - Toast.notify({ type: 'warning', message: `${t('modelProvider.parametersInvalidRemoved', { ns: 'common' })}: ${keys.map(k => `${k} (${removedDetails[k]})`).join(', ')}` }) + toast.warning(`${t('modelProvider.parametersInvalidRemoved', { ns: 'common' })}: ${keys.map(k => `${k} (${removedDetails[k]})`).join(', ')}`) handleCompletionParamsChange(filtered) } catch { - Toast.notify({ type: 'error', message: t('error', { ns: 'common' }) }) + toast.error(t('error', { ns: 'common' })) handleCompletionParamsChange({}) } finally { diff --git a/web/app/components/workflow/nodes/loop/__tests__/integration.spec.tsx b/web/app/components/workflow/nodes/loop/__tests__/integration.spec.tsx index 10b8dad885..99ce377b99 100644 --- a/web/app/components/workflow/nodes/loop/__tests__/integration.spec.tsx +++ b/web/app/components/workflow/nodes/loop/__tests__/integration.spec.tsx @@ -181,10 +181,12 @@ vi.mock('@/app/components/workflow/nodes/_base/components/editor/code-editor', ( ), })) -vi.mock('@/app/components/base/toast', () => ({ - __esModule: true, - default: { - notify: (payload: unknown) => mockToastNotify(payload), +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + success: (message: string) => mockToastNotify({ type: 'success', message }), + error: (message: string) => mockToastNotify({ type: 'error', message }), + warning: (message: string) => mockToastNotify({ type: 'warning', message }), + info: (message: string) => mockToastNotify({ type: 'info', message }), }, })) diff --git a/web/app/components/workflow/nodes/loop/components/loop-variables/item.tsx b/web/app/components/workflow/nodes/loop/components/loop-variables/item.tsx index 9ceb92f432..7bd7e09c1b 100644 --- a/web/app/components/workflow/nodes/loop/components/loop-variables/item.tsx +++ b/web/app/components/workflow/nodes/loop/components/loop-variables/item.tsx @@ -7,7 +7,7 @@ import { useCallback } from 'react' import { useTranslation } from 'react-i18next' import ActionButton from '@/app/components/base/action-button' import Input from '@/app/components/base/input' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { ValueType, VarType } from '@/app/components/workflow/types' import { checkKeys, replaceSpaceWithUnderscoreInVarNameInput } from '@/utils/var' import FormItem from './form-item' @@ -28,10 +28,7 @@ const Item = ({ const checkVariableName = (value: string) => { const { isValid, errorMessageKey } = checkKeys([value], false) if (!isValid) { - Toast.notify({ - type: 'error', - message: t(`varKeyError.${errorMessageKey}`, { ns: 'appDebug', key: t('env.modal.name', { ns: 'workflow' }) }), - }) + toast.error(t(`varKeyError.${errorMessageKey}`, { ns: 'appDebug', key: t('env.modal.name', { ns: 'workflow' }) })) return false } return true diff --git a/web/app/components/workflow/nodes/parameter-extractor/__tests__/integration.spec.tsx b/web/app/components/workflow/nodes/parameter-extractor/__tests__/integration.spec.tsx index 3eeb59e620..60b9d65260 100644 --- a/web/app/components/workflow/nodes/parameter-extractor/__tests__/integration.spec.tsx +++ b/web/app/components/workflow/nodes/parameter-extractor/__tests__/integration.spec.tsx @@ -6,7 +6,7 @@ import type { ToolDefaultValue } from '@/app/components/workflow/block-selector/ import type { PanelProps } from '@/types/workflow' import { fireEvent, render, screen } from '@testing-library/react' import userEvent from '@testing-library/user-event' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { useTextGenerationCurrentProviderAndModelAndModelList, } from '@/app/components/header/account-setting/model-provider-page/hooks' @@ -36,6 +36,15 @@ let mockWorkflowTools: MockToolCollection[] = [] let mockSelectedToolInfo: ToolDefaultValue | undefined let mockBlockSelectorOpen = false +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + success: vi.fn(), + error: vi.fn(), + warning: vi.fn(), + info: vi.fn(), + }, +})) + vi.mock('@/app/components/workflow/block-selector', () => ({ __esModule: true, default: ({ @@ -254,7 +263,7 @@ vi.mock('../use-config', () => ({ const mockUseTextGeneration = vi.mocked(useTextGenerationCurrentProviderAndModelAndModelList) const mockUseConfig = vi.mocked(useConfig) -const mockToastNotify = vi.spyOn(Toast, 'notify').mockImplementation(() => ({})) +const mockToastError = vi.mocked(toast.error) const createToolParameter = (overrides: Partial = {}): ToolParameter => ({ name: 'city', @@ -356,7 +365,7 @@ const panelProps: PanelProps = { describe('parameter-extractor path', () => { beforeEach(() => { vi.clearAllMocks() - mockToastNotify.mockClear() + mockToastError.mockClear() mockBuiltInTools = [] mockCustomTools = [] mockWorkflowTools = [] @@ -582,7 +591,7 @@ describe('parameter-extractor path', () => { await user.click(screen.getByRole('button', { name: 'common.operation.save' })) expect(onSave).not.toHaveBeenCalled() - expect(mockToastNotify).toHaveBeenCalled() + expect(mockToastError).toHaveBeenCalled() }) it('should render the add trigger for new parameters', () => { @@ -614,7 +623,7 @@ describe('parameter-extractor path', () => { const descriptionInput = screen.getByPlaceholderText('workflow.nodes.parameterExtractor.addExtractParameterContent.descriptionPlaceholder') fireEvent.change(nameInput, { target: { value: '1bad' } }) - expect(mockToastNotify).toHaveBeenCalled() + expect(mockToastError).toHaveBeenCalled() expect(nameInput).toHaveValue('') fireEvent.change(nameInput, { target: { value: 'temporary_name' } }) @@ -649,7 +658,7 @@ describe('parameter-extractor path', () => { await user.click(screen.getByRole('button', { name: 'common.operation.save' })) expect(onSave).not.toHaveBeenCalled() - expect(mockToastNotify).toHaveBeenCalled() + expect(mockToastError).toHaveBeenCalled() }) it('should keep rename metadata and updated options when editing a select parameter', async () => { diff --git a/web/app/components/workflow/nodes/parameter-extractor/components/extract-parameter/update.tsx b/web/app/components/workflow/nodes/parameter-extractor/components/extract-parameter/update.tsx index 5a4113848a..e1b9c1574f 100644 --- a/web/app/components/workflow/nodes/parameter-extractor/components/extract-parameter/update.tsx +++ b/web/app/components/workflow/nodes/parameter-extractor/components/extract-parameter/update.tsx @@ -15,7 +15,7 @@ import Modal from '@/app/components/base/modal' import Select from '@/app/components/base/select' import Switch from '@/app/components/base/switch' import Textarea from '@/app/components/base/textarea' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { ChangeType } from '@/app/components/workflow/types' import { checkKeys } from '@/utils/var' import { ParamType } from '../../types' @@ -54,10 +54,7 @@ const AddExtractParameter: FC = ({ if (key === 'name') { const { isValid, errorKey, errorMessageKey } = checkKeys([value], true) if (!isValid) { - Toast.notify({ - type: 'error', - message: t(`varKeyError.${errorMessageKey}`, { ns: 'appDebug', key: errorKey }), - }) + toast.error(t(`varKeyError.${errorMessageKey}`, { ns: 'appDebug', key: errorKey })) return } } @@ -106,10 +103,7 @@ const AddExtractParameter: FC = ({ errMessage = t(`${errorI18nPrefix}.fieldRequired`, { ns: 'workflow', field: t(`${i18nPrefix}.addExtractParameterContent.description`, { ns: 'workflow' }) }) if (errMessage) { - Toast.notify({ - type: 'error', - message: errMessage, - }) + toast.error(errMessage) return false } return true diff --git a/web/app/components/workflow/nodes/start/__tests__/use-config.spec.ts b/web/app/components/workflow/nodes/start/__tests__/use-config.spec.ts new file mode 100644 index 0000000000..330b1ac776 --- /dev/null +++ b/web/app/components/workflow/nodes/start/__tests__/use-config.spec.ts @@ -0,0 +1,245 @@ +import type { StartNodeType } from '../types' +import type { InputVar, ValueSelector } from '@/app/components/workflow/types' +import { act, renderHook } from '@testing-library/react' +import { BlockEnum, ChangeType, InputVarType } from '@/app/components/workflow/types' +import useConfig from '../use-config' + +const mockUseTranslation = vi.hoisted(() => vi.fn()) +const mockUseNodesReadOnly = vi.hoisted(() => vi.fn()) +const mockUseWorkflow = vi.hoisted(() => vi.fn()) +const mockUseIsChatMode = vi.hoisted(() => vi.fn()) +const mockUseNodeCrud = vi.hoisted(() => vi.fn()) +const mockUseInspectVarsCrud = vi.hoisted(() => vi.fn()) +const mockNotify = vi.hoisted(() => vi.fn()) + +vi.mock('react-i18next', () => ({ + useTranslation: () => mockUseTranslation(), +})) + +vi.mock('@/app/components/workflow/hooks', () => ({ + useNodesReadOnly: () => mockUseNodesReadOnly(), + useWorkflow: () => mockUseWorkflow(), + useIsChatMode: () => mockUseIsChatMode(), +})) + +vi.mock('@/app/components/workflow/nodes/_base/hooks/use-node-crud', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseNodeCrud(...args), +})) + +vi.mock('@/app/components/workflow/hooks/use-inspect-vars-crud', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseInspectVarsCrud(...args), +})) + +vi.mock('@/app/components/base/ui/toast', () => ({ + __esModule: true, + toast: { + error: (message: string) => mockNotify({ type: 'error', message }), + }, +})) + +const createInputVar = (overrides: Partial = {}): InputVar => ({ + label: 'Question', + variable: 'query', + type: InputVarType.textInput, + required: true, + ...overrides, +}) + +const createPayload = (overrides: Partial = {}): StartNodeType => ({ + title: 'Start', + desc: '', + type: BlockEnum.Start, + variables: [ + createInputVar(), + createInputVar({ + label: 'Age', + variable: 'age', + type: InputVarType.number, + required: false, + }), + ], + ...overrides, +}) + +describe('start/use-config', () => { + const mockSetInputs = vi.fn() + const mockHandleOutVarRenameChange = vi.fn() + const mockIsVarUsedInNodes = vi.fn() + const mockRemoveUsedVarInNodes = vi.fn() + const mockDeleteNodeInspectorVars = vi.fn() + const mockRenameInspectVarName = vi.fn() + const mockDeleteInspectVar = vi.fn() + const toastSpy = mockNotify + let currentInputs: StartNodeType + + beforeEach(() => { + vi.clearAllMocks() + currentInputs = createPayload() + + mockUseTranslation.mockReturnValue({ + t: (key: string) => key, + }) + mockUseNodesReadOnly.mockReturnValue({ nodesReadOnly: false }) + mockUseWorkflow.mockReturnValue({ + handleOutVarRenameChange: mockHandleOutVarRenameChange, + isVarUsedInNodes: mockIsVarUsedInNodes, + removeUsedVarInNodes: mockRemoveUsedVarInNodes, + }) + mockUseIsChatMode.mockReturnValue(false) + mockUseNodeCrud.mockImplementation(() => ({ + inputs: currentInputs, + setInputs: mockSetInputs, + })) + mockUseInspectVarsCrud.mockReturnValue({ + deleteNodeInspectorVars: mockDeleteNodeInspectorVars, + renameInspectVarName: mockRenameInspectVarName, + nodesWithInspectVars: [{ + nodeId: 'start-node', + vars: [{ id: 'inspect-query', name: 'query' }], + }], + deleteInspectVar: mockDeleteInspectVar, + }) + mockIsVarUsedInNodes.mockReturnValue(false) + }) + + it('should rename variables and sync downstream variable references', () => { + const { result } = renderHook(() => useConfig('start-node', currentInputs)) + const renamedList = [ + createInputVar({ + label: 'Question', + variable: 'prompt', + }), + createInputVar({ + label: 'Age', + variable: 'age', + type: InputVarType.number, + required: false, + }), + ] + + act(() => { + result.current.handleVarListChange(renamedList, { + index: 0, + payload: { + type: ChangeType.changeVarName, + payload: { + beforeKey: 'query', + }, + }, + }) + }) + + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + variables: renamedList, + })) + expect(mockHandleOutVarRenameChange).toHaveBeenCalledWith('start-node', ['start-node', 'query'], ['start-node', 'prompt']) + expect(mockRenameInspectVarName).toHaveBeenCalledWith('start-node', 'query', 'prompt') + expect(result.current.readOnly).toBe(false) + expect(result.current.isChatMode).toBe(false) + }) + + it('should block removal when the variable is still in use and confirm the deletion later', () => { + mockIsVarUsedInNodes.mockReturnValue(true) + const { result } = renderHook(() => useConfig('start-node', currentInputs)) + const nextList = [currentInputs.variables[1]] + + act(() => { + result.current.handleVarListChange(nextList, { + index: 0, + payload: { + type: ChangeType.remove, + payload: { + beforeKey: 'query', + }, + }, + }) + }) + + expect(mockDeleteInspectVar).toHaveBeenCalledWith('start-node', 'inspect-query') + expect(mockSetInputs).not.toHaveBeenCalled() + expect(result.current.isShowRemoveVarConfirm).toBe(true) + + act(() => { + result.current.onRemoveVarConfirm() + }) + + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + variables: [expect.objectContaining({ variable: 'age' })], + })) + expect(mockRemoveUsedVarInNodes).toHaveBeenCalledWith(['start-node', 'query'] as ValueSelector) + expect(result.current.isShowRemoveVarConfirm).toBe(false) + }) + + it('should validate duplicate variables and labels before adding a new variable', () => { + const { result } = renderHook(() => useConfig('start-node', currentInputs)) + + let added = true + act(() => { + added = result.current.handleAddVariable(createInputVar({ + label: 'Different Label', + variable: 'query', + })) + }) + + expect(added).toBe(false) + expect(toastSpy).toHaveBeenCalledWith(expect.objectContaining({ + type: 'error', + message: 'varKeyError.keyAlreadyExists', + })) + + mockSetInputs.mockClear() + let addedUnique = false + act(() => { + addedUnique = result.current.handleAddVariable(createInputVar({ + label: 'Locale', + variable: 'locale', + required: false, + })) + }) + + expect(addedUnique).toBe(true) + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + variables: expect.arrayContaining([ + expect.objectContaining({ variable: 'locale' }), + ]), + })) + }) + + it('should clear inspector vars for non-remove list updates and reject duplicate labels', () => { + const { result } = renderHook(() => useConfig('start-node', currentInputs)) + const typeEditedList = [ + createInputVar({ + label: 'Question', + variable: 'query', + type: InputVarType.paragraph, + }), + currentInputs.variables[1], + ] + + act(() => { + result.current.handleVarListChange(typeEditedList) + }) + + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + variables: typeEditedList, + })) + expect(mockDeleteNodeInspectorVars).toHaveBeenCalledWith('start-node') + + toastSpy.mockClear() + let added = true + act(() => { + added = result.current.handleAddVariable(createInputVar({ + label: 'Age', + variable: 'new_age', + })) + }) + + expect(added).toBe(false) + expect(toastSpy).toHaveBeenCalledWith(expect.objectContaining({ + type: 'error', + message: 'varKeyError.keyAlreadyExists', + })) + }) +}) diff --git a/web/app/components/workflow/nodes/start/components/var-list.tsx b/web/app/components/workflow/nodes/start/components/var-list.tsx index 21fe5e2bb5..a6158864ac 100644 --- a/web/app/components/workflow/nodes/start/components/var-list.tsx +++ b/web/app/components/workflow/nodes/start/components/var-list.tsx @@ -7,7 +7,7 @@ import * as React from 'react' import { useCallback, useMemo } from 'react' import { useTranslation } from 'react-i18next' import { ReactSortable } from 'react-sortablejs' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { ChangeType } from '@/app/components/workflow/types' import { cn } from '@/utils/classnames' import { hasDuplicateStr } from '@/utils/var' @@ -43,10 +43,7 @@ const VarList: FC = ({ } if (errorMsgKey && typeName) { - Toast.notify({ - type: 'error', - message: t(errorMsgKey, { ns: 'appDebug', key: t(typeName, { ns: 'appDebug' }) }), - }) + toast.error(t(errorMsgKey, { ns: 'appDebug', key: t(typeName, { ns: 'appDebug' }) })) return false } onChange(newList, moreInfo ? { index, payload: moreInfo } : undefined) diff --git a/web/app/components/workflow/nodes/start/use-config.ts b/web/app/components/workflow/nodes/start/use-config.ts index 232c788b6d..12ec1575c9 100644 --- a/web/app/components/workflow/nodes/start/use-config.ts +++ b/web/app/components/workflow/nodes/start/use-config.ts @@ -4,7 +4,7 @@ import { useBoolean } from 'ahooks' import { produce } from 'immer' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { useIsChatMode, useNodesReadOnly, @@ -97,10 +97,7 @@ const useConfig = (id: string, payload: StartNodeType) => { } if (errorMsgKey && typeName) { - Toast.notify({ - type: 'error', - message: t(errorMsgKey, { ns: 'appDebug', key: t(typeName, { ns: 'appDebug' }) }), - }) + toast.error(t(errorMsgKey, { ns: 'appDebug', key: t(typeName, { ns: 'appDebug' }) })) return false } setInputs(newInputs) diff --git a/web/app/components/workflow/nodes/tool/hooks/use-config.ts b/web/app/components/workflow/nodes/tool/hooks/use-config.ts index 6ebe7bea26..5e3f928dcb 100644 --- a/web/app/components/workflow/nodes/tool/hooks/use-config.ts +++ b/web/app/components/workflow/nodes/tool/hooks/use-config.ts @@ -5,7 +5,7 @@ import { capitalize } from 'es-toolkit/string' import { produce } from 'immer' import { useCallback, useEffect, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { useLanguage } from '@/app/components/header/account-setting/model-provider-page/hooks' import { CollectionType } from '@/app/components/tools/types' import { @@ -66,10 +66,7 @@ const useConfig = (id: string, payload: ToolNodeType) => { async (value: any) => { await updateBuiltInToolCredential(currCollection?.name as string, value) - Toast.notify({ - type: 'success', - message: t('api.actionSuccess', { ns: 'common' }), - }) + toast.success(t('api.actionSuccess', { ns: 'common' })) invalidToolsByType() hideSetAuthModal() }, diff --git a/web/app/components/workflow/nodes/trigger-webhook/__tests__/use-config.spec.tsx b/web/app/components/workflow/nodes/trigger-webhook/__tests__/use-config.spec.tsx index 46d0490b65..92a0457598 100644 --- a/web/app/components/workflow/nodes/trigger-webhook/__tests__/use-config.spec.tsx +++ b/web/app/components/workflow/nodes/trigger-webhook/__tests__/use-config.spec.tsx @@ -1,7 +1,7 @@ import type { WebhookTriggerNodeType } from '../types' import { renderHook } from '@testing-library/react' import { useStore as useAppStore } from '@/app/components/app/store' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { BlockEnum, VarType } from '@/app/components/workflow/types' import { fetchWebhookUrl } from '@/service/apps' import { createNodeCrudModuleMock } from '../../__tests__/use-config-test-utils' @@ -18,10 +18,10 @@ vi.mock('react-i18next', () => ({ }), })) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/ui/toast', () => ({ __esModule: true, - default: { - notify: vi.fn(), + toast: { + error: vi.fn(), }, })) @@ -42,7 +42,7 @@ vi.mock('@/service/apps', () => ({ })) const mockedFetchWebhookUrl = vi.mocked(fetchWebhookUrl) -const mockedToastNotify = vi.mocked(Toast.notify) +const mockedToastError = vi.mocked(toast.error) const createPayload = (overrides: Partial = {}): WebhookTriggerNodeType => ({ title: 'Webhook', @@ -148,7 +148,7 @@ describe('useConfig', () => { }), ]), })) - expect(mockedToastNotify).toHaveBeenCalledTimes(1) + expect(mockedToastError).toHaveBeenCalledTimes(1) }) it('should generate webhook urls once and fall back to empty url on request failure', async () => { diff --git a/web/app/components/workflow/nodes/trigger-webhook/panel.tsx b/web/app/components/workflow/nodes/trigger-webhook/panel.tsx index 839ca6875f..f600fa516d 100644 --- a/web/app/components/workflow/nodes/trigger-webhook/panel.tsx +++ b/web/app/components/workflow/nodes/trigger-webhook/panel.tsx @@ -8,7 +8,6 @@ import { useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import InputWithCopy from '@/app/components/base/input-with-copy' import { SimpleSelect } from '@/app/components/base/select' -import Toast from '@/app/components/base/toast' import Tooltip from '@/app/components/base/tooltip' import { NumberField, @@ -18,6 +17,7 @@ import { NumberFieldIncrement, NumberFieldInput, } from '@/app/components/base/ui/number-field' +import { toast } from '@/app/components/base/ui/toast' import Field from '@/app/components/workflow/nodes/_base/components/field' import OutputVars from '@/app/components/workflow/nodes/_base/components/output-vars' import Split from '@/app/components/workflow/nodes/_base/components/split' @@ -102,10 +102,7 @@ const Panel: FC> = ({ placeholder={t(`${i18nPrefix}.webhookUrlPlaceholder`, { ns: 'workflow' })} readOnly onCopy={() => { - Toast.notify({ - type: 'success', - message: t(`${i18nPrefix}.urlCopied`, { ns: 'workflow' }), - }) + toast.success(t(`${i18nPrefix}.urlCopied`, { ns: 'workflow' })) }} />
diff --git a/web/app/components/workflow/nodes/trigger-webhook/use-config.ts b/web/app/components/workflow/nodes/trigger-webhook/use-config.ts index 15ebff7736..7924a35ba0 100644 --- a/web/app/components/workflow/nodes/trigger-webhook/use-config.ts +++ b/web/app/components/workflow/nodes/trigger-webhook/use-config.ts @@ -2,7 +2,7 @@ import type { HttpMethod, WebhookHeader, WebhookParameter, WebhookTriggerNodeTyp import { useCallback } from 'react' import { useTranslation } from 'react-i18next' import { useStore as useAppStore } from '@/app/components/app/store' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { useNodesReadOnly, useWorkflow } from '@/app/components/workflow/hooks' import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud' import { fetchWebhookUrl } from '@/service/apps' @@ -33,10 +33,7 @@ export const useConfig = (id: string, payload: WebhookTriggerNodeType) => { ? t(key as never, { ns: 'appDebug', key: fieldLabel }) : t('varKeyError.keyAlreadyExists', { ns: 'appDebug', key: fieldLabel }) - Toast.notify({ - type: 'error', - message, - }) + toast.error(message) }, [t]) const handleMethodChange = useCallback((method: HttpMethod) => { diff --git a/web/app/components/workflow/nodes/variable-assigner/__tests__/hooks.spec.ts b/web/app/components/workflow/nodes/variable-assigner/__tests__/hooks.spec.ts new file mode 100644 index 0000000000..0cbb98c96a --- /dev/null +++ b/web/app/components/workflow/nodes/variable-assigner/__tests__/hooks.spec.ts @@ -0,0 +1,244 @@ +import { act, renderHook } from '@testing-library/react' +import { VarType } from '../../../types' +import { useGetAvailableVars, useVariableAssigner } from '../hooks' + +const mockUseStoreApi = vi.hoisted(() => vi.fn()) +const mockUseNodes = vi.hoisted(() => vi.fn()) +const mockUseNodeDataUpdate = vi.hoisted(() => vi.fn()) +const mockUseWorkflow = vi.hoisted(() => vi.fn()) +const mockUseWorkflowVariables = vi.hoisted(() => vi.fn()) +const mockUseIsChatMode = vi.hoisted(() => vi.fn()) +const mockUseWorkflowStore = vi.hoisted(() => vi.fn()) + +vi.mock('reactflow', () => ({ + useStoreApi: () => mockUseStoreApi(), + useNodes: () => mockUseNodes(), +})) + +vi.mock('@/app/components/workflow/hooks', () => ({ + useNodeDataUpdate: () => mockUseNodeDataUpdate(), + useWorkflow: () => mockUseWorkflow(), + useWorkflowVariables: () => mockUseWorkflowVariables(), + useIsChatMode: () => mockUseIsChatMode(), +})) + +vi.mock('@/app/components/workflow/store', () => ({ + useWorkflowStore: () => mockUseWorkflowStore(), +})) + +describe('variable-assigner/hooks', () => { + const mockHandleNodeDataUpdate = vi.fn() + const mockSetNodes = vi.fn() + const mockSetShowAssignVariablePopup = vi.fn() + const mockSetHoveringAssignVariableGroupId = vi.fn() + const getNodes = vi.fn() + + beforeEach(() => { + vi.clearAllMocks() + getNodes.mockReturnValue([{ + id: 'assigner-1', + data: { + variables: [['start', 'foo']], + output_type: VarType.string, + advanced_settings: { + groups: [{ + groupId: 'group-1', + variables: [], + output_type: VarType.string, + }], + }, + }, + }]) + mockUseStoreApi.mockReturnValue({ + getState: () => ({ + getNodes, + setNodes: mockSetNodes, + }), + }) + mockUseNodeDataUpdate.mockReturnValue({ + handleNodeDataUpdate: mockHandleNodeDataUpdate, + }) + mockUseWorkflowStore.mockReturnValue({ + getState: () => ({ + setShowAssignVariablePopup: mockSetShowAssignVariablePopup, + setHoveringAssignVariableGroupId: mockSetHoveringAssignVariableGroupId, + connectingNodePayload: { id: 'connecting-node' }, + }), + }) + mockUseNodes.mockReturnValue([]) + mockUseWorkflow.mockReturnValue({ + getBeforeNodesInSameBranchIncludeParent: vi.fn(), + }) + mockUseWorkflowVariables.mockReturnValue({ + getNodeAvailableVars: vi.fn(), + }) + mockUseIsChatMode.mockReturnValue(false) + }) + + it('should append target variables, ignore duplicates, and update grouped variables', () => { + const { result } = renderHook(() => useVariableAssigner()) + + act(() => { + result.current.handleAssignVariableValueChange('assigner-1', ['start', 'bar'], { type: VarType.number } as never) + result.current.handleAssignVariableValueChange('assigner-1', ['start', 'foo'], { type: VarType.number } as never) + result.current.handleAssignVariableValueChange('assigner-1', ['start', 'grouped'], { type: VarType.arrayString } as never, 'group-1') + }) + + expect(mockHandleNodeDataUpdate).toHaveBeenNthCalledWith(1, { + id: 'assigner-1', + data: { + variables: [ + ['start', 'foo'], + ['start', 'bar'], + ], + output_type: VarType.number, + }, + }) + expect(mockHandleNodeDataUpdate).toHaveBeenNthCalledWith(2, { + id: 'assigner-1', + data: { + advanced_settings: { + groups: [{ + groupId: 'group-1', + variables: [['start', 'grouped']], + output_type: VarType.arrayString, + }], + }, + }, + }) + expect(mockHandleNodeDataUpdate).toHaveBeenCalledTimes(2) + }) + + it('should close the popup and add variables through the positioned add-variable flow', () => { + getNodes.mockReturnValue([ + { + id: 'source-node', + data: { + _showAddVariablePopup: true, + _holdAddVariablePopup: true, + }, + }, + { + id: 'assigner-1', + data: { + variables: [], + advanced_settings: { + groups: [{ + groupId: 'group-1', + variables: [], + }], + }, + _showAddVariablePopup: true, + _holdAddVariablePopup: true, + }, + }, + ]) + + const { result } = renderHook(() => useVariableAssigner()) + + act(() => { + result.current.handleAddVariableInAddVariablePopupWithPosition( + 'source-node', + 'assigner-1', + 'group-1', + ['start', 'output'], + { type: VarType.object } as never, + ) + }) + + expect(mockSetNodes).toHaveBeenCalledWith([ + expect.objectContaining({ + id: 'source-node', + data: expect.objectContaining({ + _showAddVariablePopup: false, + _holdAddVariablePopup: false, + }), + }), + expect.objectContaining({ + id: 'assigner-1', + data: expect.objectContaining({ + _showAddVariablePopup: false, + _holdAddVariablePopup: false, + }), + }), + ]) + expect(mockSetShowAssignVariablePopup).toHaveBeenCalledWith(undefined) + expect(mockHandleNodeDataUpdate).toHaveBeenCalledWith({ + id: 'assigner-1', + data: { + advanced_settings: { + groups: [{ + groupId: 'group-1', + variables: [['start', 'output']], + output_type: VarType.object, + }], + }, + }, + }) + }) + + it('should update the hovered group state on enter and leave', () => { + const { result } = renderHook(() => useVariableAssigner()) + + act(() => { + result.current.handleGroupItemMouseEnter('group-1') + result.current.handleGroupItemMouseLeave() + }) + + expect(mockSetHoveringAssignVariableGroupId).toHaveBeenNthCalledWith(1, 'group-1') + expect(mockSetHoveringAssignVariableGroupId).toHaveBeenNthCalledWith(2, undefined) + }) + + it('should collect available vars and filter start-node env vars when hideEnv is enabled', () => { + mockUseNodes.mockReturnValue([ + { + id: 'current-node', + parentId: 'parent-node', + }, + { + id: 'before-1', + }, + { + id: 'parent-node', + }, + ]) + const getBeforeNodesInSameBranchIncludeParent = vi.fn(() => [ + { id: 'before-1' }, + { id: 'before-1' }, + ]) + const getNodeAvailableVars = vi.fn() + .mockReturnValueOnce([{ + isStartNode: true, + vars: [ + { variable: 'sys.user_id' }, + { variable: 'foo' }, + ], + }, { + isStartNode: false, + vars: [], + }]) + .mockReturnValueOnce([{ + isStartNode: false, + vars: [{ variable: 'bar' }], + }]) + + mockUseWorkflow.mockReturnValue({ + getBeforeNodesInSameBranchIncludeParent, + }) + mockUseWorkflowVariables.mockReturnValue({ + getNodeAvailableVars, + }) + + const { result } = renderHook(() => useGetAvailableVars()) + + expect(result.current('current-node', 'target', () => true, true)).toEqual([{ + isStartNode: true, + vars: [{ variable: 'foo' }], + }]) + expect(result.current('current-node', 'target', () => true, false)).toEqual([{ + isStartNode: false, + vars: [{ variable: 'bar' }], + }]) + expect(result.current('missing-node', 'target', () => true)).toEqual([]) + }) +}) diff --git a/web/app/components/workflow/nodes/variable-assigner/__tests__/integration.spec.tsx b/web/app/components/workflow/nodes/variable-assigner/__tests__/integration.spec.tsx index 2769e867a5..e3e9661cb9 100644 --- a/web/app/components/workflow/nodes/variable-assigner/__tests__/integration.spec.tsx +++ b/web/app/components/workflow/nodes/variable-assigner/__tests__/integration.spec.tsx @@ -3,7 +3,7 @@ import type { VariableAssignerNodeType } from '../types' import type { PanelProps } from '@/types/workflow' import { fireEvent, render, screen } from '@testing-library/react' import userEvent from '@testing-library/user-event' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { renderWorkflowFlowComponent } from '@/app/components/workflow/__tests__/workflow-test-env' import { BlockEnum, VarType } from '@/app/components/workflow/types' import AddVariable from '../components/add-variable' @@ -19,6 +19,15 @@ const mockHandleGroupItemMouseEnter = vi.fn() const mockHandleGroupItemMouseLeave = vi.fn() const mockGetAvailableVars = vi.fn() +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + success: vi.fn(), + error: vi.fn(), + warning: vi.fn(), + info: vi.fn(), + }, +})) + vi.mock('@/app/components/workflow/nodes/_base/components/add-variable-popup', () => ({ default: ({ onSelect }: any) => ( + ) + }, +})) + +vi.mock('@/app/components/workflow/run/loop-log', () => ({ + LoopLogTrigger: (props: { + onShowLoopResultList: (detail: unknown, durationMap: unknown) => void + nodeInfo: { details?: unknown, loopDurationMap?: unknown } + }) => { + mockLoopLogTrigger(props) + return ( + + ) + }, +})) + +vi.mock('@/app/components/workflow/run/retry-log', () => ({ + RetryLogTrigger: (props: { + onShowRetryResultList: (detail: unknown) => void + nodeInfo: { retryDetail?: unknown } + }) => { + mockRetryLogTrigger(props) + return ( + + ) + }, +})) + +vi.mock('@/app/components/workflow/run/agent-log', () => ({ + AgentLogTrigger: (props: { + onShowAgentOrToolLog: (detail: unknown) => void + nodeInfo: { agentLog?: unknown } + }) => { + mockAgentLogTrigger(props) + return ( + + ) + }, +})) + +vi.mock('@/app/components/workflow/variable-inspect/large-data-alert', () => ({ + __esModule: true, + default: (props: { downloadUrl?: string }) => { + mockLargeDataAlert(props) + return
{props.downloadUrl ?? 'no-download'}
+ }, +})) + +vi.mock('@/app/components/workflow/run/meta', () => ({ + __esModule: true, + default: (props: Record) => { + mockMetaData(props) + return
{JSON.stringify(props)}
+ }, +})) + +vi.mock('@/app/components/workflow/run/status', () => ({ + __esModule: true, + default: (props: Record) => { + mockStatusPanel(props) + return
{JSON.stringify(props)}
+ }, +})) + +const createNodeInfo = (overrides: Partial = {}): NodeTracing => ({ + id: 'trace-node-1', + index: 0, + predecessor_node_id: '', + node_id: 'node-1', + node_type: BlockEnum.Code, + title: 'Code', + inputs: {}, + inputs_truncated: false, + process_data: {}, + process_data_truncated: false, + outputs_truncated: false, + status: NodeRunningStatus.Succeeded, + elapsed_time: 0, + metadata: { + iterator_length: 0, + iterator_index: 0, + loop_length: 0, + loop_index: 0, + }, + created_at: 0, + created_by: { + id: 'user-1', + name: 'User', + email: 'user@example.com', + }, + finished_at: 1, + details: undefined, + retryDetail: undefined, + agentLog: undefined, + iterDurationMap: undefined, + loopDurationMap: undefined, + ...overrides, +}) + +const createLogDetail = (id: string): NodeTracing => createNodeInfo({ + id: `trace-${id}`, + node_id: id, + title: id, +}) + +const createAgentLog = (label: string): AgentLogItemWithChildren => ({ + node_execution_id: `execution-${label}`, + message_id: `message-${label}`, + node_id: `node-${label}`, + parent_id: undefined, + label, + status: 'success', + data: {}, + metadata: {}, + children: [], +}) + +describe('ResultPanel', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseTranslation.mockReturnValue({ + t: (key: string) => key, + }) + }) + + it('should render status, editors, alerts, error strategy tip, and metadata', () => { + render( + , + ) + + expect(screen.getByTestId('status-panel')).toBeInTheDocument() + expect(screen.getByText('COMMON.INPUT')).toBeInTheDocument() + expect(screen.getByText('COMMON.PROCESSDATA')).toBeInTheDocument() + expect(screen.getByText('COMMON.OUTPUT')).toBeInTheDocument() + expect(screen.getAllByTestId('code-editor')).toHaveLength(3) + expect(screen.getAllByTestId('large-data-alert')).toHaveLength(3) + expect(screen.getByTestId('error-handle-tip')).toHaveTextContent('continue-on-error') + expect(screen.getByTestId('meta-data')).toBeInTheDocument() + expect(mockStatusPanel).toHaveBeenCalledWith(expect.objectContaining({ + status: NodeRunningStatus.Succeeded, + time: 2.5, + tokens: 42, + error: 'boom', + exceptionCounts: 1, + isListening: true, + workflowRunId: 'run-1', + })) + expect(mockMetaData).toHaveBeenCalledWith(expect.objectContaining({ + status: NodeRunningStatus.Succeeded, + executor: 'Alice', + startTime: 1710000000, + time: 2.5, + tokens: 42, + steps: 3, + showSteps: true, + })) + expect(mockLargeDataAlert).toHaveBeenLastCalledWith(expect.objectContaining({ + downloadUrl: 'https://example.com/output.json', + })) + }) + + it('should render and invoke iteration and loop triggers only when their handlers are provided', () => { + const handleShowIterationResultList = vi.fn() + const handleShowLoopResultList = vi.fn() + const details = [[createLogDetail('iter-1')]] + + const { rerender } = render( + , + ) + + fireEvent.click(screen.getByRole('button', { name: 'iteration-trigger' })) + expect(handleShowIterationResultList).toHaveBeenCalledWith(details, { 0: 3 }) + + rerender( + , + ) + + fireEvent.click(screen.getByRole('button', { name: 'loop-trigger' })) + expect(handleShowLoopResultList).toHaveBeenCalledWith(details, { 0: 5 }) + }) + + it('should render retry and agent/tool triggers when the node shape supports them', () => { + const onShowRetryDetail = vi.fn() + const handleShowAgentOrToolLog = vi.fn() + const retryDetail = [createLogDetail('retry-1')] + const agentLog = [createAgentLog('tool-call')] + + const { rerender } = render( + , + ) + + fireEvent.click(screen.getByRole('button', { name: 'retry-trigger' })) + expect(onShowRetryDetail).toHaveBeenCalledWith(retryDetail) + + rerender( + , + ) + + fireEvent.click(screen.getByRole('button', { name: 'agent-trigger' })) + expect(handleShowAgentOrToolLog).toHaveBeenCalledWith(agentLog) + + rerender( + , + ) + + fireEvent.click(screen.getByRole('button', { name: 'agent-trigger' })) + expect(handleShowAgentOrToolLog).toHaveBeenLastCalledWith(agentLog) + }) + + it('should still render the output editor while the node is running even without outputs', () => { + render( + , + ) + + expect(screen.getByText('COMMON.OUTPUT')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/run/__tests__/tracing-panel.spec.tsx b/web/app/components/workflow/run/__tests__/tracing-panel.spec.tsx new file mode 100644 index 0000000000..f5445f5f9f --- /dev/null +++ b/web/app/components/workflow/run/__tests__/tracing-panel.spec.tsx @@ -0,0 +1,199 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import { getHoveredParallelId } from '../get-hovered-parallel-id' +import TracingPanel from '../tracing-panel' + +const mockUseTranslation = vi.hoisted(() => vi.fn()) +const mockFormatNodeList = vi.hoisted(() => vi.fn()) +const mockUseLogs = vi.hoisted(() => vi.fn()) +const mockNodePanel = vi.hoisted(() => vi.fn()) +const mockSpecialResultPanel = vi.hoisted(() => vi.fn()) + +vi.mock('react-i18next', () => ({ + useTranslation: () => mockUseTranslation(), +})) + +vi.mock('@/app/components/workflow/run/utils/format-log', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockFormatNodeList(...args), +})) + +vi.mock('../hooks', () => ({ + useLogs: () => mockUseLogs(), +})) + +vi.mock('../node', () => ({ + __esModule: true, + default: (props: { + nodeInfo: { id: string } + }) => { + mockNodePanel(props) + return
{props.nodeInfo.id}
+ }, +})) + +vi.mock('../special-result-panel', () => ({ + __esModule: true, + default: (props: Record) => { + mockSpecialResultPanel(props) + return
special
+ }, +})) + +describe('TracingPanel', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseTranslation.mockReturnValue({ + t: (key: string) => key, + }) + mockUseLogs.mockReturnValue({ + showSpecialResultPanel: false, + showRetryDetail: false, + setShowRetryDetailFalse: vi.fn(), + retryResultList: [], + handleShowRetryResultList: vi.fn(), + showIteratingDetail: false, + setShowIteratingDetailFalse: vi.fn(), + iterationResultList: [], + iterationResultDurationMap: {}, + handleShowIterationResultList: vi.fn(), + showLoopingDetail: false, + setShowLoopingDetailFalse: vi.fn(), + loopResultList: [], + loopResultDurationMap: {}, + loopResultVariableMap: {}, + handleShowLoopResultList: vi.fn(), + agentOrToolLogItemStack: [], + agentOrToolLogListMap: {}, + handleShowAgentOrToolLog: vi.fn(), + }) + }) + + it('should render formatted nodes, preserve branch labels, and collapse parallel groups', () => { + mockFormatNodeList.mockReturnValue([ + { + id: 'parallel-1', + parallelDetail: { + isParallelStartNode: true, + parallelTitle: 'Parallel Group', + children: [{ + id: 'child-1', + title: 'Child Node', + parallelDetail: { + branchTitle: 'Branch A', + }, + }], + }, + }, + { + id: 'node-2', + title: 'Standalone Node', + parallelDetail: { + branchTitle: 'Branch B', + }, + }, + ]) + + const parentClick = vi.fn() + const { container } = render( +
+ +
, + ) + + expect(screen.getByText('Parallel Group')).toBeInTheDocument() + expect(screen.getByText('Branch A')).toBeInTheDocument() + expect(screen.getByText('Branch B')).toBeInTheDocument() + expect(screen.getByTestId('node-child-1')).toBeInTheDocument() + expect(screen.getByTestId('node-node-2')).toBeInTheDocument() + + fireEvent.click(container.querySelector('.py-2') as HTMLElement) + expect(parentClick).not.toHaveBeenCalled() + + const hoverTarget = screen.getByText('Parallel Group').closest('[data-parallel-id="parallel-1"]') as HTMLElement + const nestedParallelTarget = document.createElement('div') + nestedParallelTarget.setAttribute('data-parallel-id', 'parallel-1') + const unrelatedTarget = document.createElement('div') + document.body.appendChild(nestedParallelTarget) + document.body.appendChild(unrelatedTarget) + + fireEvent.mouseEnter(hoverTarget) + const sameParallelOut = new MouseEvent('mouseout', { bubbles: true }) + Object.defineProperty(sameParallelOut, 'relatedTarget', { value: nestedParallelTarget }) + hoverTarget.dispatchEvent(sameParallelOut) + + const differentTargetOut = new MouseEvent('mouseout', { bubbles: true }) + Object.defineProperty(differentTargetOut, 'relatedTarget', { value: unrelatedTarget }) + hoverTarget.dispatchEvent(differentTargetOut) + + fireEvent.mouseLeave(hoverTarget) + + fireEvent.click(screen.getAllByRole('button')[0]) + expect(container.querySelector('[data-parallel-id="parallel-1"] > div:last-child')).toHaveClass('hidden') + fireEvent.click(screen.getAllByRole('button')[0]) + expect(container.querySelector('[data-parallel-id="parallel-1"] > div:last-child')).not.toHaveClass('hidden') + expect(mockNodePanel).toHaveBeenCalledWith(expect.objectContaining({ + hideInfo: true, + hideProcessDetail: true, + })) + + nestedParallelTarget.remove() + unrelatedTarget.remove() + }) + + it('should switch to the special result panel when the log state requests it', () => { + mockUseLogs.mockReturnValue({ + showSpecialResultPanel: true, + showRetryDetail: true, + setShowRetryDetailFalse: vi.fn(), + retryResultList: [{ id: 'retry-1' }], + handleShowRetryResultList: vi.fn(), + showIteratingDetail: true, + setShowIteratingDetailFalse: vi.fn(), + iterationResultList: [[{ id: 'iter-1' }]], + iterationResultDurationMap: { 0: 1 }, + handleShowIterationResultList: vi.fn(), + showLoopingDetail: true, + setShowLoopingDetailFalse: vi.fn(), + loopResultList: [[{ id: 'loop-1' }]], + loopResultDurationMap: { 0: 2 }, + loopResultVariableMap: { 0: {} }, + handleShowLoopResultList: vi.fn(), + agentOrToolLogItemStack: [{ id: 'agent-1' }], + agentOrToolLogListMap: { agent: [] }, + handleShowAgentOrToolLog: vi.fn(), + }) + + render() + + expect(screen.getByTestId('special-result-panel')).toBeInTheDocument() + expect(mockSpecialResultPanel).toHaveBeenCalledWith(expect.objectContaining({ + showRetryDetail: true, + retryResultList: [{ id: 'retry-1' }], + showIteratingDetail: true, + showLoopingDetail: true, + agentOrToolLogItemStack: [{ id: 'agent-1' }], + })) + }) + + it('should resolve hovered parallel ids from related targets', () => { + const sameParallelTarget = document.createElement('div') + sameParallelTarget.setAttribute('data-parallel-id', 'parallel-1') + document.body.appendChild(sameParallelTarget) + + const nestedChild = document.createElement('span') + sameParallelTarget.appendChild(nestedChild) + + const unrelatedTarget = document.createElement('div') + + expect(getHoveredParallelId(nestedChild)).toBe('parallel-1') + expect(getHoveredParallelId(unrelatedTarget)).toBeNull() + expect(getHoveredParallelId(null)).toBeNull() + + sameParallelTarget.remove() + }) +}) diff --git a/web/app/components/workflow/run/get-hovered-parallel-id.ts b/web/app/components/workflow/run/get-hovered-parallel-id.ts new file mode 100644 index 0000000000..cd369d5eb1 --- /dev/null +++ b/web/app/components/workflow/run/get-hovered-parallel-id.ts @@ -0,0 +1,10 @@ +export const getHoveredParallelId = (relatedTarget: EventTarget | null) => { + const element = relatedTarget as Element | null + if (element && 'closest' in element) { + const closestParallel = element.closest('[data-parallel-id]') + if (closestParallel) + return closestParallel.getAttribute('data-parallel-id') + } + + return null +} diff --git a/web/app/components/workflow/run/index.tsx b/web/app/components/workflow/run/index.tsx index b96037c765..0b0467bb09 100644 --- a/web/app/components/workflow/run/index.tsx +++ b/web/app/components/workflow/run/index.tsx @@ -4,9 +4,8 @@ import type { WorkflowRunDetailResponse } from '@/models/log' import type { NodeTracing } from '@/types/workflow' import { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' import Loading from '@/app/components/base/loading' -import { ToastContext } from '@/app/components/base/toast/context' +import { toast } from '@/app/components/base/ui/toast' import { WorkflowRunningStatus } from '@/app/components/workflow/types' import { fetchRunDetail, fetchTracingList } from '@/service/log' import { cn } from '@/utils/classnames' @@ -32,7 +31,6 @@ const RunPanel: FC = ({ tracingListUrl, }) => { const { t } = useTranslation() - const { notify } = useContext(ToastContext) const [currentTab, setCurrentTab] = useState(activeTab) const [loading, setLoading] = useState(true) const [runDetail, setRunDetail] = useState() @@ -55,12 +53,9 @@ const RunPanel: FC = ({ getResultCallback(res) } catch (err) { - notify({ - type: 'error', - message: `${err}`, - }) + toast.error(`${err}`) } - }, [notify, getResultCallback, runDetailUrl]) + }, [getResultCallback, runDetailUrl]) const getTracingList = useCallback(async () => { try { @@ -70,12 +65,9 @@ const RunPanel: FC = ({ setList(nodeList) } catch (err) { - notify({ - type: 'error', - message: `${err}`, - }) + toast.error(`${err}`) } - }, [notify, tracingListUrl]) + }, [tracingListUrl]) const getData = useCallback(async () => { setLoading(true) diff --git a/web/app/components/workflow/run/tracing-panel.tsx b/web/app/components/workflow/run/tracing-panel.tsx index 8931c8f7fe..dba158f0b2 100644 --- a/web/app/components/workflow/run/tracing-panel.tsx +++ b/web/app/components/workflow/run/tracing-panel.tsx @@ -1,10 +1,6 @@ 'use client' import type { FC } from 'react' import type { NodeTracing } from '@/types/workflow' -import { - RiArrowDownSLine, - RiMenu4Line, -} from '@remixicon/react' import * as React from 'react' import { useCallback, @@ -13,6 +9,7 @@ import { import { useTranslation } from 'react-i18next' import formatNodeList from '@/app/components/workflow/run/utils/format-log' import { cn } from '@/utils/classnames' +import { getHoveredParallelId } from './get-hovered-parallel-id' import { useLogs } from './hooks' import NodePanel from './node' import SpecialResultPanel from './special-result-panel' @@ -53,18 +50,7 @@ const TracingPanel: FC = ({ }, []) const handleParallelMouseLeave = useCallback((e: React.MouseEvent) => { - const relatedTarget = e.relatedTarget as Element | null - if (relatedTarget && 'closest' in relatedTarget) { - const closestParallel = relatedTarget.closest('[data-parallel-id]') - if (closestParallel) - setHoveredParallel(closestParallel.getAttribute('data-parallel-id')) - - else - setHoveredParallel(null) - } - else { - setHoveredParallel(null) - } + setHoveredParallel(getHoveredParallelId(e.relatedTarget)) }, []) const { @@ -116,9 +102,11 @@ const TracingPanel: FC = ({ isHovered ? 'rounded border-components-button-primary-border bg-components-button-primary-bg text-text-primary-on-surface' : 'text-text-secondary hover:text-text-primary', )} > - {isHovered ? : } + {isHovered + ? + : } -
+
{parallelDetail.parallelTitle}
= ({ const isHovered = hoveredParallel === node.id return (
-
+
{node?.parallelDetail?.branchTitle}
vi.fn()) +const mockFormatHumanInputNode = vi.hoisted(() => vi.fn()) +const mockFormatRetryNode = vi.hoisted(() => vi.fn()) +const mockAddChildrenToLoopNode = vi.hoisted(() => vi.fn()) +const mockAddChildrenToIterationNode = vi.hoisted(() => vi.fn()) +const mockFormatParallelNode = vi.hoisted(() => vi.fn()) + +vi.mock('../agent', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockFormatAgentNode(...args), +})) + +vi.mock('../human-input', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockFormatHumanInputNode(...args), +})) + +vi.mock('../retry', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockFormatRetryNode(...args), +})) + +vi.mock('../loop', () => ({ + addChildrenToLoopNode: (...args: unknown[]) => mockAddChildrenToLoopNode(...args), +})) + +vi.mock('../iteration', () => ({ + addChildrenToIterationNode: (...args: unknown[]) => mockAddChildrenToIterationNode(...args), +})) + +vi.mock('../parallel', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockFormatParallelNode(...args), +})) + +const createTrace = (overrides: Partial = {}): NodeTracing => ({ + id: overrides.id ?? overrides.node_id ?? 'node-1', + index: overrides.index ?? 0, + predecessor_node_id: '', + node_id: overrides.node_id ?? 'node-1', + node_type: overrides.node_type ?? BlockEnum.Tool, + title: overrides.title ?? 'Node', + inputs: {}, + inputs_truncated: false, + process_data: {}, + process_data_truncated: false, + outputs_truncated: false, + status: overrides.status ?? 'succeeded', + error: overrides.error, + elapsed_time: 1, + execution_metadata: overrides.execution_metadata ?? { + total_tokens: 0, + total_price: 0, + currency: 'USD', + }, + metadata: { + iterator_length: 0, + iterator_index: 0, + loop_length: 0, + loop_index: 0, + }, + created_at: 0, + created_by: { + id: 'user-1', + name: 'User', + email: 'user@example.com', + }, + finished_at: 1, +}) + +const createExecutionMetadata = (overrides: Partial> = {}): NonNullable => ({ + total_tokens: 0, + total_price: 0, + currency: 'USD', + ...overrides, +}) + +describe('formatToTracingNodeList', () => { + beforeEach(() => { + vi.clearAllMocks() + mockFormatAgentNode.mockImplementation((list: NodeTracing[]) => list) + mockFormatHumanInputNode.mockImplementation((list: NodeTracing[]) => list) + mockFormatRetryNode.mockImplementation((list: NodeTracing[]) => list) + mockAddChildrenToLoopNode.mockImplementation((item: NodeTracing, children: NodeTracing[]) => ({ + ...item, + loopChildren: children.map(child => child.node_id), + details: [[{ id: 'loop-detail-row' }]], + })) + mockAddChildrenToIterationNode.mockImplementation((item: NodeTracing, children: NodeTracing[]) => ({ + ...item, + iterationChildren: children.map(child => child.node_id), + details: [[{ id: 'iteration-detail-row' }]], + })) + mockFormatParallelNode.mockImplementation((list: unknown[]) => + list.map(item => ({ + ...(item as Record), + parallelFormatted: true, + }))) + }) + + it('should sort the input by index and run the formatter pipeline in order', () => { + const t = vi.fn((key: string) => key) + const traces = [ + createTrace({ id: 'b', node_id: 'b', title: 'B', index: 2 }), + createTrace({ id: 'a', node_id: 'a', title: 'A', index: 0 }), + createTrace({ id: 'c', node_id: 'c', title: 'C', index: 1 }), + ] + + const result = formatToTracingNodeList(traces, t) + + expect(mockFormatAgentNode).toHaveBeenCalledWith([ + expect.objectContaining({ node_id: 'a' }), + expect.objectContaining({ node_id: 'c' }), + expect.objectContaining({ node_id: 'b' }), + ]) + expect(mockFormatHumanInputNode).toHaveBeenCalledWith(mockFormatAgentNode.mock.results[0].value) + expect(mockFormatRetryNode).toHaveBeenCalledWith(mockFormatHumanInputNode.mock.results[0].value) + expect(mockFormatParallelNode).toHaveBeenLastCalledWith(expect.any(Array), t) + expect(result).toEqual([ + expect.objectContaining({ node_id: 'a', parallelFormatted: true }), + expect.objectContaining({ node_id: 'c', parallelFormatted: true }), + expect.objectContaining({ node_id: 'b', parallelFormatted: true }), + ]) + }) + + it('should collapse loop and iteration children into parent nodes and propagate child failures', () => { + const t = vi.fn((key: string) => key) + const loopParent = createTrace({ + id: 'loop-parent', + node_id: 'loop-parent', + node_type: BlockEnum.Loop, + index: 0, + }) + const loopChild = createTrace({ + id: 'loop-child', + node_id: 'loop-child', + index: 1, + status: 'failed', + error: 'loop child failed', + execution_metadata: createExecutionMetadata({ loop_id: 'loop-parent' }), + }) + const iterationParent = createTrace({ + id: 'iteration-parent', + node_id: 'iteration-parent', + node_type: BlockEnum.Iteration, + index: 2, + }) + const iterationChild = createTrace({ + id: 'iteration-child', + node_id: 'iteration-child', + index: 3, + status: 'failed', + error: 'iteration child failed', + execution_metadata: createExecutionMetadata({ iteration_id: 'iteration-parent' }), + }) + + const result = formatToTracingNodeList([ + loopParent, + loopChild, + iterationParent, + iterationChild, + ], t) + + expect(mockAddChildrenToLoopNode).toHaveBeenCalledWith( + expect.objectContaining({ + node_id: 'loop-parent', + status: 'failed', + error: 'loop child failed', + }), + [expect.objectContaining({ node_id: 'loop-child' })], + ) + expect(mockAddChildrenToIterationNode).toHaveBeenCalledWith( + expect.objectContaining({ + node_id: 'iteration-parent', + status: 'failed', + error: 'iteration child failed', + }), + [expect.objectContaining({ node_id: 'iteration-child' })], + ) + expect(mockFormatParallelNode).toHaveBeenCalledTimes(3) + expect(result).toEqual([ + expect.objectContaining({ + node_id: 'loop-parent', + loopChildren: ['loop-child'], + parallelFormatted: true, + }), + expect.objectContaining({ + node_id: 'iteration-parent', + iterationChildren: ['iteration-child'], + parallelFormatted: true, + }), + ]) + }) +}) diff --git a/web/app/components/workflow/update-dsl-modal.tsx b/web/app/components/workflow/update-dsl-modal.tsx index 13af4bba1d..e4b0d7067d 100644 --- a/web/app/components/workflow/update-dsl-modal.tsx +++ b/web/app/components/workflow/update-dsl-modal.tsx @@ -13,12 +13,11 @@ import { useState, } from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' import Uploader from '@/app/components/app/create-from-dsl-modal/uploader' import { useStore as useAppStore } from '@/app/components/app/store' import Button from '@/app/components/base/button' import Modal from '@/app/components/base/modal' -import { ToastContext } from '@/app/components/base/toast/context' +import { toast } from '@/app/components/base/ui/toast' import { usePluginDependencies } from '@/app/components/workflow/plugin-dependency/hooks' import { useEventEmitterContextContext } from '@/context/event-emitter' import { @@ -54,7 +53,6 @@ const UpdateDSLModal = ({ onImport, }: UpdateDSLModalProps) => { const { t } = useTranslation() - const { notify } = useContext(ToastContext) const appDetail = useAppStore(s => s.appDetail) const [currentFile, setDSLFile] = useState() const [fileContent, setFileContent] = useState() @@ -110,17 +108,18 @@ const UpdateDSLModal = ({ const isCreatingRef = useRef(false) const handleCompletedImport = useCallback(async (status: DSLImportStatus, appId?: string) => { if (!appId) { - notify({ type: 'error', message: t('common.importFailure', { ns: 'workflow' }) }) + toast.error(t('common.importFailure', { ns: 'workflow' })) return } - handleWorkflowUpdate(appId) + await handleWorkflowUpdate(appId) onImport?.() - notify(getImportNotificationPayload(status, t)) + const payload = getImportNotificationPayload(status, t) + toast[payload.type](payload.message, payload.children ? { description: payload.children } : undefined) await handleCheckPluginDependencies(appId) setLoading(false) onCancel() - }, [handleCheckPluginDependencies, handleWorkflowUpdate, notify, onCancel, onImport, t]) + }, [handleCheckPluginDependencies, handleWorkflowUpdate, onCancel, onImport, t]) const handlePendingImport = useCallback((id: string, importedVersion?: string | null, currentVersion?: string | null) => { setShow(false) @@ -138,8 +137,10 @@ const UpdateDSLModal = ({ if (isCreatingRef.current) return isCreatingRef.current = true - if (!currentFile) + if (!currentFile) { + isCreatingRef.current = false return + } try { if (appDetail && fileContent && validateDSLContent(fileContent, appDetail.mode)) { setLoading(true) @@ -154,20 +155,20 @@ const UpdateDSLModal = ({ } else { setLoading(false) - notify({ type: 'error', message: t('common.importFailure', { ns: 'workflow' }) }) + toast.error(t('common.importFailure', { ns: 'workflow' })) } } else if (fileContent) { - notify({ type: 'error', message: t('common.importFailure', { ns: 'workflow' }) }) + toast.error(t('common.importFailure', { ns: 'workflow' })) } } // eslint-disable-next-line unused-imports/no-unused-vars catch (e) { setLoading(false) - notify({ type: 'error', message: t('common.importFailure', { ns: 'workflow' }) }) + toast.error(t('common.importFailure', { ns: 'workflow' })) } isCreatingRef.current = false - }, [currentFile, fileContent, notify, t, appDetail, handleCompletedImport, handlePendingImport]) + }, [currentFile, fileContent, t, appDetail, handleCompletedImport, handlePendingImport]) const onUpdateDSLConfirm: MouseEventHandler = async () => { try { @@ -179,28 +180,18 @@ const UpdateDSLModal = ({ const { status, app_id } = response - if (status === DSLImportStatus.COMPLETED) { - if (!app_id) { - notify({ type: 'error', message: t('common.importFailure', { ns: 'workflow' }) }) - return - } - handleWorkflowUpdate(app_id) - await handleCheckPluginDependencies(app_id) - if (onImport) - onImport() - notify({ type: 'success', message: t('common.importSuccess', { ns: 'workflow' }) }) - setLoading(false) - onCancel() + if (isImportCompleted(status)) { + await handleCompletedImport(status, app_id) } else if (status === DSLImportStatus.FAILED) { setLoading(false) - notify({ type: 'error', message: t('common.importFailure', { ns: 'workflow' }) }) + toast.error(t('common.importFailure', { ns: 'workflow' })) } } // eslint-disable-next-line unused-imports/no-unused-vars catch (e) { setLoading(false) - notify({ type: 'error', message: t('common.importFailure', { ns: 'workflow' }) }) + toast.error(t('common.importFailure', { ns: 'workflow' })) } } diff --git a/web/context/i18n.spec.ts b/web/context/i18n.spec.ts index 616f3bfced..9ebbda825e 100644 --- a/web/context/i18n.spec.ts +++ b/web/context/i18n.spec.ts @@ -184,8 +184,8 @@ describe('useDocLink', () => { vi.mocked(getDocLanguage).mockReturnValue('ja') const { result } = renderHook(() => useDocLink()) - const url = result.current('/api-reference/application/get-application-basic-information') - expect(url).toBe(`${defaultDocBaseUrl}/api-reference/アプリケーション情報/アプリケーションの基本情報を取得`) + const url = result.current('/api-reference/applications/get-app-info') + expect(url).toBe(`${defaultDocBaseUrl}/api-reference/アプリケーション設定/アプリケーションの基本情報を取得`) }) it('should not translate API reference path for English locale', () => { diff --git a/web/eslint-suppressions.json b/web/eslint-suppressions.json index ffce581afd..4a0bb909e2 100644 --- a/web/eslint-suppressions.json +++ b/web/eslint-suppressions.json @@ -6381,14 +6381,9 @@ "count": 1 } }, - "app/components/tools/workflow-tool/hooks/use-configure-button.ts": { - "no-restricted-imports": { - "count": 1 - } - }, "app/components/tools/workflow-tool/index.tsx": { "no-restricted-imports": { - "count": 2 + "count": 1 }, "tailwindcss/enforce-consistent-class-order": { "count": 7 @@ -6591,7 +6586,7 @@ }, "app/components/workflow/block-selector/tool-picker.tsx": { "no-restricted-imports": { - "count": 2 + "count": 1 } }, "app/components/workflow/block-selector/tool/action-item.tsx": { @@ -6686,9 +6681,6 @@ } }, "app/components/workflow/header/header-in-restoring.tsx": { - "no-restricted-imports": { - "count": 1 - }, "tailwindcss/no-unnecessary-whitespace": { "count": 1 } @@ -6702,9 +6694,6 @@ "no-console": { "count": 1 }, - "no-restricted-imports": { - "count": 1 - }, "ts/no-explicit-any": { "count": 1 } @@ -6762,9 +6751,6 @@ } }, "app/components/workflow/hooks/use-checklist.ts": { - "no-restricted-imports": { - "count": 1 - }, "ts/no-empty-object-type": { "count": 1 }, @@ -6875,9 +6861,6 @@ } }, "app/components/workflow/nodes/_base/components/before-run-form/index.tsx": { - "no-restricted-imports": { - "count": 1 - }, "ts/no-explicit-any": { "count": 5 } @@ -7306,9 +7289,6 @@ } }, "app/components/workflow/nodes/_base/components/workflow-panel/last-run/use-last-run.ts": { - "no-restricted-imports": { - "count": 1 - }, "react/set-state-in-effect": { "count": 1 }, @@ -7328,9 +7308,6 @@ } }, "app/components/workflow/nodes/_base/hooks/use-one-step-run.ts": { - "no-restricted-imports": { - "count": 1 - }, "react/set-state-in-effect": { "count": 2 }, @@ -7595,7 +7572,7 @@ }, "app/components/workflow/nodes/http/components/curl-panel.tsx": { "no-restricted-imports": { - "count": 2 + "count": 1 } }, "app/components/workflow/nodes/http/components/key-value/key-value-edit/index.tsx": { @@ -7652,7 +7629,7 @@ }, "app/components/workflow/nodes/human-input/components/delivery-method/email-configure-modal.tsx": { "no-restricted-imports": { - "count": 2 + "count": 1 }, "tailwindcss/enforce-consistent-class-order": { "count": 8 @@ -7764,11 +7741,6 @@ "count": 2 } }, - "app/components/workflow/nodes/human-input/components/user-action.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "app/components/workflow/nodes/human-input/components/variable-in-markdown.tsx": { "react-refresh/only-export-components": { "count": 2 @@ -7781,7 +7753,7 @@ }, "app/components/workflow/nodes/human-input/panel.tsx": { "no-restricted-imports": { - "count": 2 + "count": 1 }, "tailwindcss/enforce-consistent-class-order": { "count": 4 @@ -7881,9 +7853,6 @@ } }, "app/components/workflow/nodes/iteration/node.tsx": { - "no-restricted-imports": { - "count": 1 - }, "react/set-state-in-effect": { "count": 1 } @@ -8191,9 +8160,6 @@ "app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-config.tsx": { "erasable-syntax-only/enums": { "count": 1 - }, - "no-restricted-imports": { - "count": 1 } }, "app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/generated-result.tsx": { @@ -8209,7 +8175,7 @@ "count": 1 }, "no-restricted-imports": { - "count": 2 + "count": 1 }, "react/set-state-in-effect": { "count": 2 @@ -8270,9 +8236,6 @@ } }, "app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/hooks.ts": { - "no-restricted-imports": { - "count": 1 - }, "ts/no-explicit-any": { "count": 1 } @@ -8292,7 +8255,7 @@ }, "app/components/workflow/nodes/llm/panel.tsx": { "no-restricted-imports": { - "count": 2 + "count": 1 } }, "app/components/workflow/nodes/llm/types.ts": { @@ -8402,9 +8365,6 @@ } }, "app/components/workflow/nodes/loop/components/loop-variables/item.tsx": { - "no-restricted-imports": { - "count": 1 - }, "ts/no-explicit-any": { "count": 4 } @@ -8439,7 +8399,7 @@ }, "app/components/workflow/nodes/parameter-extractor/components/extract-parameter/update.tsx": { "no-restricted-imports": { - "count": 3 + "count": 2 }, "ts/no-explicit-any": { "count": 1 @@ -8517,11 +8477,6 @@ "count": 8 } }, - "app/components/workflow/nodes/start/components/var-list.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "app/components/workflow/nodes/start/node.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 2 @@ -8536,9 +8491,6 @@ } }, "app/components/workflow/nodes/start/use-config.ts": { - "no-restricted-imports": { - "count": 1 - }, "ts/no-explicit-any": { "count": 1 } @@ -8617,9 +8569,6 @@ } }, "app/components/workflow/nodes/tool/hooks/use-config.ts": { - "no-restricted-imports": { - "count": 1 - }, "ts/no-explicit-any": { "count": 6 } @@ -8723,12 +8672,7 @@ }, "app/components/workflow/nodes/trigger-webhook/panel.tsx": { "no-restricted-imports": { - "count": 3 - } - }, - "app/components/workflow/nodes/trigger-webhook/use-config.ts": { - "no-restricted-imports": { - "count": 1 + "count": 2 } }, "app/components/workflow/nodes/trigger-webhook/utils/render-output-vars.tsx": { @@ -8752,9 +8696,6 @@ } }, "app/components/workflow/nodes/variable-assigner/components/var-group-item.tsx": { - "no-restricted-imports": { - "count": 1 - }, "tailwindcss/enforce-consistent-class-order": { "count": 2 }, @@ -8879,9 +8820,6 @@ } }, "app/components/workflow/panel/chat-variable-panel/components/object-value-item.tsx": { - "no-restricted-imports": { - "count": 1 - }, "react-refresh/only-export-components": { "count": 1 }, @@ -8910,11 +8848,6 @@ "count": 1 } }, - "app/components/workflow/panel/chat-variable-panel/components/variable-modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "app/components/workflow/panel/chat-variable-panel/components/variable-type-select.tsx": { "no-restricted-imports": { "count": 1 @@ -8953,9 +8886,6 @@ } }, "app/components/workflow/panel/debug-and-preview/hooks.ts": { - "no-restricted-imports": { - "count": 1 - }, "ts/no-explicit-any": { "count": 12 } @@ -8980,7 +8910,7 @@ }, "app/components/workflow/panel/env-panel/variable-modal.tsx": { "no-restricted-imports": { - "count": 2 + "count": 1 }, "react/set-state-in-effect": { "count": 4 @@ -9134,9 +9064,6 @@ } }, "app/components/workflow/run/index.tsx": { - "no-restricted-imports": { - "count": 1 - }, "react/set-state-in-effect": { "count": 2 } @@ -9222,11 +9149,6 @@ "count": 15 } }, - "app/components/workflow/run/tracing-panel.tsx": { - "tailwindcss/enforce-consistent-class-order": { - "count": 2 - } - }, "app/components/workflow/run/utils/format-log/agent/index.ts": { "ts/no-explicit-any": { "count": 11 @@ -9298,7 +9220,7 @@ }, "app/components/workflow/update-dsl-modal.tsx": { "no-restricted-imports": { - "count": 2 + "count": 1 }, "ts/no-explicit-any": { "count": 1 diff --git a/web/hooks/use-api-access-url.ts b/web/hooks/use-api-access-url.ts index 98576e66db..7f63b7754e 100644 --- a/web/hooks/use-api-access-url.ts +++ b/web/hooks/use-api-access-url.ts @@ -3,5 +3,5 @@ import { useDocLink } from '@/context/i18n' export const useDatasetApiAccessUrl = () => { const docLink = useDocLink() - return docLink('/api-reference/datasets/get-knowledge-base-list') + return docLink('/api-reference/knowledge-bases/list-knowledge-bases') } diff --git a/web/i18n/en-US/workflow.json b/web/i18n/en-US/workflow.json index e5049069d6..dd9337ecc0 100644 --- a/web/i18n/en-US/workflow.json +++ b/web/i18n/en-US/workflow.json @@ -96,6 +96,8 @@ "chatVariable.modal.name": "Name", "chatVariable.modal.namePlaceholder": "Variable name", "chatVariable.modal.objectKey": "Key", + "chatVariable.modal.objectKeyPatternError": "Key can only contain letters, numbers, and underscores", + "chatVariable.modal.objectKeyRequired": "Object key cannot be empty", "chatVariable.modal.objectType": "Type", "chatVariable.modal.objectValue": "Default Value", "chatVariable.modal.oneByOne": "Add one by one", @@ -207,6 +209,7 @@ "common.runApp": "Run App", "common.runHistory": "Run History", "common.running": "Running", + "common.scheduleTriggerRunFailed": "Schedule trigger run failed", "common.searchVar": "Search variable", "common.setVarValuePlaceholder": "Set variable", "common.showRunHistory": "Show Run History", @@ -220,6 +223,8 @@ "common.viewDetailInTracingPanel": "View details", "common.viewOnly": "View Only", "common.viewRunHistory": "View run history", + "common.webhookDebugFailed": "Webhook debug failed", + "common.webhookDebugRequestFailed": "Webhook debug request failed", "common.workflowAsTool": "Workflow as Tool", "common.workflowAsToolDisabledHint": "Publish the latest workflow and ensure a connected User Input node before configuring it as a tool.", "common.workflowAsToolTip": "Tool reconfiguration is required after the workflow update.", @@ -293,6 +298,7 @@ "env.modal.type": "Type", "env.modal.value": "Value", "env.modal.valuePlaceholder": "env value", + "env.modal.valueRequired": "Value cannot be empty", "error.operations.addingNodes": "adding nodes", "error.operations.connectingNodes": "connecting nodes", "error.operations.modifyingWorkflow": "modifying workflow", @@ -513,7 +519,9 @@ "nodes.humanInput.deliveryMethod.contactTip2": "Tell us at support@dify.ai.", "nodes.humanInput.deliveryMethod.emailConfigure.allMembers": "All members ({{workspaceName}})", "nodes.humanInput.deliveryMethod.emailConfigure.body": "Body", + "nodes.humanInput.deliveryMethod.emailConfigure.bodyMustContainRequestURL": "Body must contain {{field}}", "nodes.humanInput.deliveryMethod.emailConfigure.bodyPlaceholder": "Enter email body", + "nodes.humanInput.deliveryMethod.emailConfigure.bodyRequired": "Body is required", "nodes.humanInput.deliveryMethod.emailConfigure.debugMode": "Debug Mode", "nodes.humanInput.deliveryMethod.emailConfigure.debugModeTip1": "In debug mode, the email will only be sent to your account email {{email}}.", "nodes.humanInput.deliveryMethod.emailConfigure.debugModeTip2": "The production environment is not affected.", @@ -524,9 +532,11 @@ "nodes.humanInput.deliveryMethod.emailConfigure.memberSelector.title": "Add workspace members or external recipients", "nodes.humanInput.deliveryMethod.emailConfigure.memberSelector.trigger": "Select", "nodes.humanInput.deliveryMethod.emailConfigure.recipient": "Recipient", + "nodes.humanInput.deliveryMethod.emailConfigure.recipientsRequired": "At least one recipient is required", "nodes.humanInput.deliveryMethod.emailConfigure.requestURLTip": "The request URL variable is the trigger entry for human input.", "nodes.humanInput.deliveryMethod.emailConfigure.subject": "Subject", "nodes.humanInput.deliveryMethod.emailConfigure.subjectPlaceholder": "Enter email subject", + "nodes.humanInput.deliveryMethod.emailConfigure.subjectRequired": "Subject is required", "nodes.humanInput.deliveryMethod.emailConfigure.title": "Email Configuration", "nodes.humanInput.deliveryMethod.emailSender.debugDone": "A test email has been sent to {{email}}. Please check your inbox.", "nodes.humanInput.deliveryMethod.emailSender.debugModeTip": "Debug mode is enabled.", @@ -741,6 +751,7 @@ "nodes.llm.jsonSchema.back": "Back", "nodes.llm.jsonSchema.descriptionPlaceholder": "Add description", "nodes.llm.jsonSchema.doc": "Learn more about structured output", + "nodes.llm.jsonSchema.fieldNameAlreadyExists": "Property name already exists", "nodes.llm.jsonSchema.fieldNamePlaceholder": "Field Name", "nodes.llm.jsonSchema.generate": "Generate", "nodes.llm.jsonSchema.generateJsonSchema": "Generate JSON Schema", diff --git a/web/types/doc-paths.ts b/web/types/doc-paths.ts index 8f95249354..9cbad79a2e 100644 --- a/web/types/doc-paths.ts +++ b/web/types/doc-paths.ts @@ -2,7 +2,7 @@ // DON NOT EDIT IT MANUALLY // // Generated from: https://raw.githubusercontent.com/langgenius/dify-docs/refs/heads/main/docs.json -// Generated at: 2026-01-30T09:14:29.304Z +// Generated at: 2026-03-25T03:18:49.626Z // Language prefixes export type DocLanguage = 'en' | 'zh' | 'ja' @@ -61,6 +61,7 @@ export type UseDifyPath = | '/use-dify/nodes/code' | '/use-dify/nodes/doc-extractor' | '/use-dify/nodes/http-request' + | '/use-dify/nodes/human-input' | '/use-dify/nodes/ifelse' | '/use-dify/nodes/iteration' | '/use-dify/nodes/knowledge-retrieval' @@ -82,6 +83,7 @@ export type UseDifyPath = | '/use-dify/publish/README' | '/use-dify/publish/developing-with-apis' | '/use-dify/publish/publish-mcp' + | '/use-dify/publish/publish-to-marketplace' | '/use-dify/publish/webapp/chatflow-webapp' | '/use-dify/publish/webapp/embedding-in-websites' | '/use-dify/publish/webapp/web-app-access' @@ -92,6 +94,16 @@ export type UseDifyPath = | '/use-dify/tutorials/customer-service-bot' | '/use-dify/tutorials/simple-chatbot' | '/use-dify/tutorials/twitter-chatflow' + | '/use-dify/tutorials/workflow-101/lesson-01' + | '/use-dify/tutorials/workflow-101/lesson-02' + | '/use-dify/tutorials/workflow-101/lesson-03' + | '/use-dify/tutorials/workflow-101/lesson-04' + | '/use-dify/tutorials/workflow-101/lesson-05' + | '/use-dify/tutorials/workflow-101/lesson-06' + | '/use-dify/tutorials/workflow-101/lesson-07' + | '/use-dify/tutorials/workflow-101/lesson-08' + | '/use-dify/tutorials/workflow-101/lesson-09' + | '/use-dify/tutorials/workflow-101/lesson-10' | '/use-dify/workspace/api-extension/api-extension' | '/use-dify/workspace/api-extension/cloudflare-worker' | '/use-dify/workspace/api-extension/external-data-tool-api-extension' @@ -167,72 +179,86 @@ export type DevelopPluginPath = // API Reference paths (English, use apiReferencePathTranslations for other languages) export type ApiReferencePath = + | '/api-reference/annotations/configure-annotation-reply' | '/api-reference/annotations/create-annotation' | '/api-reference/annotations/delete-annotation' - | '/api-reference/annotations/get-annotation-list' - | '/api-reference/annotations/initial-annotation-reply-settings' - | '/api-reference/annotations/query-initial-annotation-reply-settings-task-status' + | '/api-reference/annotations/get-annotation-reply-job-status' + | '/api-reference/annotations/list-annotations' | '/api-reference/annotations/update-annotation' - | '/api-reference/application/get-application-basic-information' - | '/api-reference/application/get-application-meta-information' - | '/api-reference/application/get-application-parameters-information' - | '/api-reference/application/get-application-webapp-settings' - | '/api-reference/chat/next-suggested-questions' - | '/api-reference/chat/send-chat-message' - | '/api-reference/chat/stop-chat-message-generation' - | '/api-reference/chatflow/next-suggested-questions' - | '/api-reference/chatflow/send-chat-message' - | '/api-reference/chatflow/stop-advanced-chat-message-generation' - | '/api-reference/chunks/add-chunks-to-a-document' + | '/api-reference/applications/get-app-info' + | '/api-reference/applications/get-app-meta' + | '/api-reference/applications/get-app-parameters' + | '/api-reference/applications/get-app-webapp-settings' + | '/api-reference/chats/get-next-suggested-questions' + | '/api-reference/chats/send-chat-message' + | '/api-reference/chats/stop-chat-message-generation' | '/api-reference/chunks/create-child-chunk' - | '/api-reference/chunks/delete-a-chunk-in-a-document' + | '/api-reference/chunks/create-chunks' | '/api-reference/chunks/delete-child-chunk' - | '/api-reference/chunks/get-a-chunk-details-in-a-document' - | '/api-reference/chunks/get-child-chunks' - | '/api-reference/chunks/get-chunks-from-a-document' - | '/api-reference/chunks/update-a-chunk-in-a-document' + | '/api-reference/chunks/delete-chunk' + | '/api-reference/chunks/get-chunk' + | '/api-reference/chunks/list-child-chunks' + | '/api-reference/chunks/list-chunks' | '/api-reference/chunks/update-child-chunk' - | '/api-reference/completion/create-completion-message' - | '/api-reference/completion/stop-generate' - | '/api-reference/conversations/conversation-rename' + | '/api-reference/chunks/update-chunk' + | '/api-reference/completions/send-completion-message' + | '/api-reference/completions/stop-completion-message-generation' | '/api-reference/conversations/delete-conversation' - | '/api-reference/conversations/get-conversation-history-messages' - | '/api-reference/conversations/get-conversation-variables' - | '/api-reference/conversations/get-conversations' - | '/api-reference/datasets/create-an-empty-knowledge-base' - | '/api-reference/datasets/delete-a-knowledge-base' - | '/api-reference/datasets/get-knowledge-base-details' - | '/api-reference/datasets/get-knowledge-base-list' - | '/api-reference/datasets/retrieve-chunks-from-a-knowledge-base-/-test-retrieval' - | '/api-reference/datasets/update-knowledge-base' - | '/api-reference/documents/create-a-document-from-a-file' - | '/api-reference/documents/create-a-document-from-text' - | '/api-reference/documents/delete-a-document' - | '/api-reference/documents/get-document-detail' - | '/api-reference/documents/get-document-embedding-status-(progress)' - | '/api-reference/documents/get-the-document-list-of-a-knowledge-base' - | '/api-reference/documents/update-a-document-with-a-file' - | '/api-reference/documents/update-a-document-with-text' - | '/api-reference/documents/update-document-status' - | '/api-reference/feedback/get-feedbacks-of-application' - | '/api-reference/feedback/message-feedback' - | '/api-reference/files/file-preview' - | '/api-reference/files/file-upload' - | '/api-reference/files/file-upload-for-workflow' - | '/api-reference/metadata-&-tags/bind-dataset-to-knowledge-base-type-tag' - | '/api-reference/metadata-&-tags/create-new-knowledge-base-type-tag' - | '/api-reference/metadata-&-tags/delete-knowledge-base-type-tag' - | '/api-reference/metadata-&-tags/get-knowledge-base-type-tags' - | '/api-reference/metadata-&-tags/modify-knowledge-base-type-tag-name' - | '/api-reference/metadata-&-tags/query-tags-bound-to-a-dataset' - | '/api-reference/metadata-&-tags/unbind-dataset-and-knowledge-base-type-tag' - | '/api-reference/models/get-available-embedding-models' - | '/api-reference/tts/speech-to-text' - | '/api-reference/tts/text-to-audio' - | '/api-reference/workflow-execution/execute-workflow' - | '/api-reference/workflow-execution/get-workflow-logs' - | '/api-reference/workflow-execution/get-workflow-run-detail' - | '/api-reference/workflow-execution/stop-workflow-task-generation' + | '/api-reference/conversations/list-conversation-messages' + | '/api-reference/conversations/list-conversation-variables' + | '/api-reference/conversations/list-conversations' + | '/api-reference/conversations/rename-conversation' + | '/api-reference/conversations/update-conversation-variable' + | '/api-reference/documents/create-document-by-file' + | '/api-reference/documents/create-document-by-text' + | '/api-reference/documents/delete-document' + | '/api-reference/documents/download-document' + | '/api-reference/documents/download-documents-as-zip' + | '/api-reference/documents/get-document' + | '/api-reference/documents/get-document-indexing-status' + | '/api-reference/documents/list-documents' + | '/api-reference/documents/update-document-by-file' + | '/api-reference/documents/update-document-by-text' + | '/api-reference/documents/update-document-status-in-batch' + | '/api-reference/end-users/get-end-user-info' + | '/api-reference/feedback/list-app-feedbacks' + | '/api-reference/feedback/submit-message-feedback' + | '/api-reference/files/download-file' + | '/api-reference/files/upload-file' + | '/api-reference/knowledge-bases/create-an-empty-knowledge-base' + | '/api-reference/knowledge-bases/delete-knowledge-base' + | '/api-reference/knowledge-bases/get-knowledge-base' + | '/api-reference/knowledge-bases/list-knowledge-bases' + | '/api-reference/knowledge-bases/retrieve-chunks-from-a-knowledge-base-/-test-retrieval' + | '/api-reference/knowledge-bases/update-knowledge-base' + | '/api-reference/knowledge-pipeline/list-datasource-plugins' + | '/api-reference/knowledge-pipeline/run-datasource-node' + | '/api-reference/knowledge-pipeline/run-pipeline' + | '/api-reference/knowledge-pipeline/upload-pipeline-file' + | '/api-reference/metadata/create-metadata-field' + | '/api-reference/metadata/delete-metadata-field' + | '/api-reference/metadata/get-built-in-metadata-fields' + | '/api-reference/metadata/list-metadata-fields' + | '/api-reference/metadata/update-built-in-metadata-field' + | '/api-reference/metadata/update-document-metadata-in-batch' + | '/api-reference/metadata/update-metadata-field' + | '/api-reference/models/get-available-models' + | '/api-reference/tags/create-knowledge-tag' + | '/api-reference/tags/create-tag-binding' + | '/api-reference/tags/delete-knowledge-tag' + | '/api-reference/tags/delete-tag-binding' + | '/api-reference/tags/get-knowledge-base-tags' + | '/api-reference/tags/list-knowledge-tags' + | '/api-reference/tags/update-knowledge-tag' + | '/api-reference/tts/convert-audio-to-text' + | '/api-reference/tts/convert-text-to-audio' + | '/api-reference/workflow-runs/get-workflow-run-detail' + | '/api-reference/workflow-runs/list-workflow-logs' + | '/api-reference/workflows/get-workflow-run-detail' + | '/api-reference/workflows/list-workflow-logs' + | '/api-reference/workflows/run-workflow' + | '/api-reference/workflows/run-workflow-by-id' + | '/api-reference/workflows/stop-workflow-task' // Base path without language prefix export type DocPathWithoutLangBase = @@ -251,70 +277,84 @@ export type DifyDocPath = `${DocLanguage}/${DocPathWithoutLang}` // API Reference path translations (English -> other languages) export const apiReferencePathTranslations: Record = { - '/api-reference/annotations/create-annotation': { zh: '/api-reference/标注管理/创建标注' }, - '/api-reference/annotations/delete-annotation': { zh: '/api-reference/标注管理/删除标注' }, - '/api-reference/annotations/get-annotation-list': { zh: '/api-reference/标注管理/获取标注列表' }, - '/api-reference/annotations/initial-annotation-reply-settings': { zh: '/api-reference/标注管理/标注回复初始设置' }, - '/api-reference/annotations/query-initial-annotation-reply-settings-task-status': { zh: '/api-reference/标注管理/查询标注回复初始设置任务状态' }, - '/api-reference/annotations/update-annotation': { zh: '/api-reference/标注管理/更新标注' }, - '/api-reference/application/get-application-basic-information': { zh: '/api-reference/应用设置/获取应用基本信息', ja: '/api-reference/アプリケーション情報/アプリケーションの基本情報を取得' }, - '/api-reference/application/get-application-meta-information': { zh: '/api-reference/应用配置/获取应用meta信息', ja: '/api-reference/アプリケーション設定/アプリケーションのメタ情報を取得' }, - '/api-reference/application/get-application-parameters-information': { zh: '/api-reference/应用设置/获取应用参数', ja: '/api-reference/アプリケーション情報/アプリケーションのパラメータ情報を取得' }, - '/api-reference/application/get-application-webapp-settings': { zh: '/api-reference/应用设置/获取应用-webapp-设置', ja: '/api-reference/アプリケーション情報/アプリのwebapp設定を取得' }, - '/api-reference/chat/next-suggested-questions': { zh: '/api-reference/对话消息/获取下一轮建议问题列表', ja: '/api-reference/チャットメッセージ/次の推奨質問' }, - '/api-reference/chat/send-chat-message': { zh: '/api-reference/对话消息/发送对话消息', ja: '/api-reference/チャットメッセージ/チャットメッセージを送信' }, - '/api-reference/chat/stop-chat-message-generation': { zh: '/api-reference/对话消息/停止响应', ja: '/api-reference/チャットメッセージ/生成停止' }, - '/api-reference/chatflow/next-suggested-questions': { zh: '/api-reference/对话消息/获取下一轮建议问题列表', ja: '/api-reference/チャットメッセージ/次の推奨質問' }, - '/api-reference/chatflow/send-chat-message': { zh: '/api-reference/对话消息/发送对话消息', ja: '/api-reference/チャットメッセージ/チャットメッセージを送信' }, - '/api-reference/chatflow/stop-advanced-chat-message-generation': { zh: '/api-reference/对话消息/停止响应', ja: '/api-reference/チャットメッセージ/生成を停止' }, - '/api-reference/chunks/add-chunks-to-a-document': { zh: '/api-reference/文档块/向文档添加块', ja: '/api-reference/チャンク/ドキュメントにチャンクを追加' }, - '/api-reference/chunks/create-child-chunk': { zh: '/api-reference/文档块/创建子块', ja: '/api-reference/チャンク/子チャンクを作成' }, - '/api-reference/chunks/delete-a-chunk-in-a-document': { zh: '/api-reference/文档块/删除文档中的块', ja: '/api-reference/チャンク/ドキュメント内のチャンクを削除' }, - '/api-reference/chunks/delete-child-chunk': { zh: '/api-reference/文档块/删除子块', ja: '/api-reference/チャンク/子チャンクを削除' }, - '/api-reference/chunks/get-a-chunk-details-in-a-document': { zh: '/api-reference/文档块/获取文档中的块详情', ja: '/api-reference/チャンク/ドキュメント内のチャンク詳細を取得' }, - '/api-reference/chunks/get-child-chunks': { zh: '/api-reference/文档块/获取子块', ja: '/api-reference/チャンク/子チャンクを取得' }, - '/api-reference/chunks/get-chunks-from-a-document': { zh: '/api-reference/文档块/从文档获取块', ja: '/api-reference/チャンク/ドキュメントからチャンクを取得' }, - '/api-reference/chunks/update-a-chunk-in-a-document': { zh: '/api-reference/文档块/更新文档中的块', ja: '/api-reference/チャンク/ドキュメント内のチャンクを更新' }, - '/api-reference/chunks/update-child-chunk': { zh: '/api-reference/文档块/更新子块', ja: '/api-reference/チャンク/子チャンクを更新' }, - '/api-reference/completion/create-completion-message': { zh: '/api-reference/文本生成/发送消息', ja: '/api-reference/完了メッセージ/完了メッセージの作成' }, - '/api-reference/completion/stop-generate': { zh: '/api-reference/文本生成/停止响应', ja: '/api-reference/完了メッセージ/生成の停止' }, - '/api-reference/conversations/conversation-rename': { zh: '/api-reference/会话管理/会话重命名', ja: '/api-reference/会話管理/会話の名前を変更' }, + '/api-reference/annotations/configure-annotation-reply': { zh: '/api-reference/标注管理/配置标注回复', ja: '/api-reference/アノテーション管理/アノテーション返信を設定' }, + '/api-reference/annotations/create-annotation': { zh: '/api-reference/标注管理/创建标注', ja: '/api-reference/アノテーション管理/アノテーションを作成' }, + '/api-reference/annotations/delete-annotation': { zh: '/api-reference/标注管理/删除标注', ja: '/api-reference/アノテーション管理/アノテーションを削除' }, + '/api-reference/annotations/get-annotation-reply-job-status': { zh: '/api-reference/标注管理/查询标注回复配置任务状态', ja: '/api-reference/アノテーション管理/アノテーション返信の初期設定タスクステータスを取得' }, + '/api-reference/annotations/list-annotations': { zh: '/api-reference/标注管理/获取标注列表', ja: '/api-reference/アノテーション管理/アノテーションリストを取得' }, + '/api-reference/annotations/update-annotation': { zh: '/api-reference/标注管理/更新标注', ja: '/api-reference/アノテーション管理/アノテーションを更新' }, + '/api-reference/applications/get-app-info': { zh: '/api-reference/应用配置/获取应用基本信息', ja: '/api-reference/アプリケーション設定/アプリケーションの基本情報を取得' }, + '/api-reference/applications/get-app-meta': { zh: '/api-reference/应用配置/获取应用元数据', ja: '/api-reference/アプリケーション設定/アプリケーションのメタ情報を取得' }, + '/api-reference/applications/get-app-parameters': { zh: '/api-reference/应用配置/获取应用参数', ja: '/api-reference/アプリケーション設定/アプリケーションのパラメータ情報を取得' }, + '/api-reference/applications/get-app-webapp-settings': { zh: '/api-reference/应用配置/获取应用-webapp-设置', ja: '/api-reference/アプリケーション設定/アプリの-webapp-設定を取得' }, + '/api-reference/chats/get-next-suggested-questions': { zh: '/api-reference/对话消息/获取下一轮建议问题列表', ja: '/api-reference/チャットメッセージ/次の推奨質問を取得' }, + '/api-reference/chats/send-chat-message': { zh: '/api-reference/对话消息/发送对话消息', ja: '/api-reference/チャットメッセージ/チャットメッセージを送信' }, + '/api-reference/chats/stop-chat-message-generation': { zh: '/api-reference/对话消息/停止响应', ja: '/api-reference/チャットメッセージ/生成を停止' }, + '/api-reference/chunks/create-child-chunk': { zh: '/api-reference/分段/创建子分段', ja: '/api-reference/チャンク/子チャンクを作成' }, + '/api-reference/chunks/create-chunks': { zh: '/api-reference/分段/向文档添加分段', ja: '/api-reference/チャンク/ドキュメントにチャンクを追加' }, + '/api-reference/chunks/delete-child-chunk': { zh: '/api-reference/分段/删除子分段', ja: '/api-reference/チャンク/子チャンクを削除' }, + '/api-reference/chunks/delete-chunk': { zh: '/api-reference/分段/删除文档中的分段', ja: '/api-reference/チャンク/ドキュメント内のチャンクを削除' }, + '/api-reference/chunks/get-chunk': { zh: '/api-reference/分段/获取文档中的分段详情', ja: '/api-reference/チャンク/ドキュメント内のチャンク詳細を取得' }, + '/api-reference/chunks/list-child-chunks': { zh: '/api-reference/分段/获取子分段', ja: '/api-reference/チャンク/子チャンク一覧を取得' }, + '/api-reference/chunks/list-chunks': { zh: '/api-reference/分段/从文档获取分段', ja: '/api-reference/チャンク/チャンク一覧を取得' }, + '/api-reference/chunks/update-child-chunk': { zh: '/api-reference/分段/更新子分段', ja: '/api-reference/チャンク/子チャンクを更新' }, + '/api-reference/chunks/update-chunk': { zh: '/api-reference/分段/更新文档中的分段', ja: '/api-reference/チャンク/ドキュメント内のチャンクを更新' }, + '/api-reference/completions/send-completion-message': { zh: '/api-reference/文本生成/发送消息', ja: '/api-reference/完了メッセージ/完了メッセージを送信' }, + '/api-reference/completions/stop-completion-message-generation': { zh: '/api-reference/文本生成/停止响应', ja: '/api-reference/完了メッセージ/生成を停止' }, '/api-reference/conversations/delete-conversation': { zh: '/api-reference/会话管理/删除会话', ja: '/api-reference/会話管理/会話を削除' }, - '/api-reference/conversations/get-conversation-history-messages': { zh: '/api-reference/会话管理/获取会话历史消息', ja: '/api-reference/会話管理/会話履歴メッセージを取得' }, - '/api-reference/conversations/get-conversation-variables': { zh: '/api-reference/会话管理/获取对话变量', ja: '/api-reference/会話管理/会話変数の取得' }, - '/api-reference/conversations/get-conversations': { zh: '/api-reference/会话管理/获取会话列表', ja: '/api-reference/会話管理/会話を取得' }, - '/api-reference/datasets/create-an-empty-knowledge-base': { zh: '/api-reference/数据集/创建空知识库', ja: '/api-reference/データセット/空のナレッジベースを作成' }, - '/api-reference/datasets/delete-a-knowledge-base': { zh: '/api-reference/数据集/删除知识库', ja: '/api-reference/データセット/ナレッジベースを削除' }, - '/api-reference/datasets/get-knowledge-base-details': { zh: '/api-reference/数据集/获取知识库详情', ja: '/api-reference/データセット/ナレッジベース詳細を取得' }, - '/api-reference/datasets/get-knowledge-base-list': { zh: '/api-reference/数据集/获取知识库列表', ja: '/api-reference/データセット/ナレッジベースリストを取得' }, - '/api-reference/datasets/retrieve-chunks-from-a-knowledge-base-/-test-retrieval': { zh: '/api-reference/数据集/从知识库检索块-/-测试检索', ja: '/api-reference/データセット/ナレッジベースからチャンクを取得-/-テスト検索' }, - '/api-reference/datasets/update-knowledge-base': { zh: '/api-reference/数据集/更新知识库', ja: '/api-reference/データセット/ナレッジベースを更新' }, - '/api-reference/documents/create-a-document-from-a-file': { zh: '/api-reference/文档/从文件创建文档', ja: '/api-reference/ドキュメント/ファイルからドキュメントを作成' }, - '/api-reference/documents/create-a-document-from-text': { zh: '/api-reference/文档/从文本创建文档', ja: '/api-reference/ドキュメント/テキストからドキュメントを作成' }, - '/api-reference/documents/delete-a-document': { zh: '/api-reference/文档/删除文档', ja: '/api-reference/ドキュメント/ドキュメントを削除' }, - '/api-reference/documents/get-document-detail': { zh: '/api-reference/文档/获取文档详情', ja: '/api-reference/ドキュメント/ドキュメント詳細を取得' }, - '/api-reference/documents/get-document-embedding-status-(progress)': { zh: '/api-reference/文档/获取文档嵌入状态(进度)', ja: '/api-reference/ドキュメント/ドキュメント埋め込みステータス(進捗)を取得' }, - '/api-reference/documents/get-the-document-list-of-a-knowledge-base': { zh: '/api-reference/文档/获取知识库的文档列表', ja: '/api-reference/ドキュメント/ナレッジベースのドキュメントリストを取得' }, - '/api-reference/documents/update-a-document-with-a-file': { zh: '/api-reference/文档/用文件更新文档', ja: '/api-reference/ドキュメント/ファイルでドキュメントを更新' }, - '/api-reference/documents/update-a-document-with-text': { zh: '/api-reference/文档/用文本更新文档', ja: '/api-reference/ドキュメント/テキストでドキュメントを更新' }, - '/api-reference/documents/update-document-status': { zh: '/api-reference/文档/更新文档状态', ja: '/api-reference/ドキュメント/ドキュメントステータスを更新' }, - '/api-reference/feedback/get-feedbacks-of-application': { zh: '/api-reference/反馈/获取应用反馈列表', ja: '/api-reference/メッセージフィードバック/アプリのメッセージの「いいね」とフィードバックを取得' }, - '/api-reference/feedback/message-feedback': { zh: '/api-reference/反馈/消息反馈(点赞)', ja: '/api-reference/メッセージフィードバック/メッセージフィードバック' }, - '/api-reference/files/file-preview': { zh: '/api-reference/文件操作/文件预览', ja: '/api-reference/ファイル操作/ファイルプレビュー' }, - '/api-reference/files/file-upload': { zh: '/api-reference/文件管理/上传文件', ja: '/api-reference/ファイル操作/ファイルアップロード' }, - '/api-reference/files/file-upload-for-workflow': { zh: '/api-reference/文件操作-(workflow)/上传文件-(workflow)', ja: '/api-reference/ファイル操作-(ワークフロー)/ファイルアップロード-(ワークフロー用)' }, - '/api-reference/metadata-&-tags/bind-dataset-to-knowledge-base-type-tag': { zh: '/api-reference/元数据和标签/将数据集绑定到知识库类型标签', ja: '/api-reference/メタデータ・タグ/データセットをナレッジベースタイプタグにバインド' }, - '/api-reference/metadata-&-tags/create-new-knowledge-base-type-tag': { zh: '/api-reference/元数据和标签/创建新的知识库类型标签', ja: '/api-reference/メタデータ・タグ/新しいナレッジベースタイプタグを作成' }, - '/api-reference/metadata-&-tags/delete-knowledge-base-type-tag': { zh: '/api-reference/元数据和标签/删除知识库类型标签', ja: '/api-reference/メタデータ・タグ/ナレッジベースタイプタグを削除' }, - '/api-reference/metadata-&-tags/get-knowledge-base-type-tags': { zh: '/api-reference/元数据和标签/获取知识库类型标签', ja: '/api-reference/メタデータ・タグ/ナレッジベースタイプタグを取得' }, - '/api-reference/metadata-&-tags/modify-knowledge-base-type-tag-name': { zh: '/api-reference/元数据和标签/修改知识库类型标签名称', ja: '/api-reference/メタデータ・タグ/ナレッジベースタイプタグ名を変更' }, - '/api-reference/metadata-&-tags/query-tags-bound-to-a-dataset': { zh: '/api-reference/元数据和标签/查询绑定到数据集的标签', ja: '/api-reference/メタデータ・タグ/データセットにバインドされたタグをクエリ' }, - '/api-reference/metadata-&-tags/unbind-dataset-and-knowledge-base-type-tag': { zh: '/api-reference/元数据和标签/解绑数据集和知识库类型标签', ja: '/api-reference/メタデータ・タグ/データセットとナレッジベースタイプタグのバインドを解除' }, - '/api-reference/models/get-available-embedding-models': { zh: '/api-reference/模型/获取可用的嵌入模型', ja: '/api-reference/モデル/利用可能な埋め込みモデルを取得' }, - '/api-reference/tts/speech-to-text': { zh: '/api-reference/语音与文字转换/语音转文字', ja: '/api-reference/音声・テキスト変換/音声からテキストへ' }, - '/api-reference/tts/text-to-audio': { zh: '/api-reference/语音服务/文字转语音', ja: '/api-reference/音声変換/テキストから音声' }, - '/api-reference/workflow-execution/execute-workflow': { zh: '/api-reference/工作流执行/执行-workflow', ja: '/api-reference/ワークフロー実行/ワークフローを実行' }, - '/api-reference/workflow-execution/get-workflow-logs': { zh: '/api-reference/工作流执行/获取-workflow-日志', ja: '/api-reference/ワークフロー実行/ワークフローログを取得' }, - '/api-reference/workflow-execution/get-workflow-run-detail': { zh: '/api-reference/工作流执行/获取workflow执行情况', ja: '/api-reference/ワークフロー実行/ワークフロー実行詳細を取得' }, - '/api-reference/workflow-execution/stop-workflow-task-generation': { zh: '/api-reference/工作流执行/停止响应-(workflow-task)', ja: '/api-reference/ワークフロー実行/生成を停止-(ワークフロータスク)' }, + '/api-reference/conversations/list-conversation-messages': { zh: '/api-reference/会话管理/获取会话历史消息', ja: '/api-reference/会話管理/会話履歴メッセージ一覧を取得' }, + '/api-reference/conversations/list-conversation-variables': { zh: '/api-reference/会话管理/获取对话变量', ja: '/api-reference/会話管理/会話変数の取得' }, + '/api-reference/conversations/list-conversations': { zh: '/api-reference/会话管理/获取会话列表', ja: '/api-reference/会話管理/会話一覧を取得' }, + '/api-reference/conversations/rename-conversation': { zh: '/api-reference/会话管理/重命名会话', ja: '/api-reference/会話管理/会話の名前を変更' }, + '/api-reference/conversations/update-conversation-variable': { zh: '/api-reference/会话管理/更新对话变量', ja: '/api-reference/会話管理/会話変数を更新' }, + '/api-reference/documents/create-document-by-file': { zh: '/api-reference/文档/从文件创建文档', ja: '/api-reference/ドキュメント/ファイルからドキュメントを作成' }, + '/api-reference/documents/create-document-by-text': { zh: '/api-reference/文档/从文本创建文档', ja: '/api-reference/ドキュメント/テキストからドキュメントを作成' }, + '/api-reference/documents/delete-document': { zh: '/api-reference/文档/删除文档', ja: '/api-reference/ドキュメント/ドキュメントを削除' }, + '/api-reference/documents/download-document': { zh: '/api-reference/文档/下载文档', ja: '/api-reference/ドキュメント/ドキュメントをダウンロード' }, + '/api-reference/documents/download-documents-as-zip': { zh: '/api-reference/文档/批量下载文档(zip)', ja: '/api-reference/ドキュメント/ドキュメントを一括ダウンロード(zip)' }, + '/api-reference/documents/get-document': { zh: '/api-reference/文档/获取文档详情', ja: '/api-reference/ドキュメント/ドキュメント詳細を取得' }, + '/api-reference/documents/get-document-indexing-status': { zh: '/api-reference/文档/获取文档嵌入状态(进度)', ja: '/api-reference/ドキュメント/ドキュメント埋め込みステータス(進捗)を取得' }, + '/api-reference/documents/list-documents': { zh: '/api-reference/文档/获取知识库的文档列表', ja: '/api-reference/ドキュメント/ナレッジベースのドキュメントリストを取得' }, + '/api-reference/documents/update-document-by-file': { zh: '/api-reference/文档/用文件更新文档', ja: '/api-reference/ドキュメント/ファイルでドキュメントを更新' }, + '/api-reference/documents/update-document-by-text': { zh: '/api-reference/文档/用文本更新文档', ja: '/api-reference/ドキュメント/テキストでドキュメントを更新' }, + '/api-reference/documents/update-document-status-in-batch': { zh: '/api-reference/文档/批量更新文档状态', ja: '/api-reference/ドキュメント/ドキュメントステータスを一括更新' }, + '/api-reference/end-users/get-end-user-info': { zh: '/api-reference/终端用户/获取终端用户', ja: '/api-reference/エンドユーザー/エンドユーザー取得' }, + '/api-reference/feedback/list-app-feedbacks': { zh: '/api-reference/消息反馈/获取应用的消息反馈', ja: '/api-reference/メッセージフィードバック/アプリのフィードバック一覧を取得' }, + '/api-reference/feedback/submit-message-feedback': { zh: '/api-reference/消息反馈/提交消息反馈', ja: '/api-reference/メッセージフィードバック/メッセージフィードバックを送信' }, + '/api-reference/files/download-file': { zh: '/api-reference/文件操作/下载文件', ja: '/api-reference/ファイル操作/ファイルをダウンロード' }, + '/api-reference/files/upload-file': { zh: '/api-reference/文件操作/上传文件', ja: '/api-reference/ファイル操作/ファイルをアップロード' }, + '/api-reference/knowledge-bases/create-an-empty-knowledge-base': { zh: '/api-reference/知识库/创建空知识库', ja: '/api-reference/データセット/空のナレッジベースを作成' }, + '/api-reference/knowledge-bases/delete-knowledge-base': { zh: '/api-reference/知识库/删除知识库', ja: '/api-reference/データセット/ナレッジベースを削除' }, + '/api-reference/knowledge-bases/get-knowledge-base': { zh: '/api-reference/知识库/获取知识库详情', ja: '/api-reference/データセット/ナレッジベース詳細を取得' }, + '/api-reference/knowledge-bases/list-knowledge-bases': { zh: '/api-reference/知识库/获取知识库列表', ja: '/api-reference/データセット/ナレッジベースリストを取得' }, + '/api-reference/knowledge-bases/retrieve-chunks-from-a-knowledge-base-/-test-retrieval': { zh: '/api-reference/知识库/从知识库检索分段-/-测试检索', ja: '/api-reference/データセット/ナレッジベースからチャンクを取得-/-テスト検索' }, + '/api-reference/knowledge-bases/update-knowledge-base': { zh: '/api-reference/知识库/更新知识库', ja: '/api-reference/データセット/ナレッジベースを更新' }, + '/api-reference/knowledge-pipeline/list-datasource-plugins': { zh: '/api-reference/知识流水线/获取数据源插件列表', ja: '/api-reference/ナレッジパイプライン/データソースプラグインリストを取得' }, + '/api-reference/knowledge-pipeline/run-datasource-node': { zh: '/api-reference/知识流水线/执行数据源节点', ja: '/api-reference/ナレッジパイプライン/データソースノードを実行' }, + '/api-reference/knowledge-pipeline/run-pipeline': { zh: '/api-reference/知识流水线/运行流水线', ja: '/api-reference/ナレッジパイプライン/パイプラインを実行' }, + '/api-reference/knowledge-pipeline/upload-pipeline-file': { zh: '/api-reference/知识流水线/上传流水线文件', ja: '/api-reference/ナレッジパイプライン/パイプラインファイルをアップロード' }, + '/api-reference/metadata/create-metadata-field': { zh: '/api-reference/元数据/创建元数据字段', ja: '/api-reference/メタデータ/メタデータフィールドを作成' }, + '/api-reference/metadata/delete-metadata-field': { zh: '/api-reference/元数据/删除元数据字段', ja: '/api-reference/メタデータ/メタデータフィールドを削除' }, + '/api-reference/metadata/get-built-in-metadata-fields': { zh: '/api-reference/元数据/获取内置元数据字段', ja: '/api-reference/メタデータ/組み込みメタデータフィールドを取得' }, + '/api-reference/metadata/list-metadata-fields': { zh: '/api-reference/元数据/获取元数据字段列表', ja: '/api-reference/メタデータ/メタデータフィールドリストを取得' }, + '/api-reference/metadata/update-built-in-metadata-field': { zh: '/api-reference/元数据/更新内置元数据字段', ja: '/api-reference/メタデータ/組み込みメタデータフィールドを更新' }, + '/api-reference/metadata/update-document-metadata-in-batch': { zh: '/api-reference/元数据/批量更新文档元数据', ja: '/api-reference/メタデータ/ドキュメントメタデータを一括更新' }, + '/api-reference/metadata/update-metadata-field': { zh: '/api-reference/元数据/更新元数据字段', ja: '/api-reference/メタデータ/メタデータフィールドを更新' }, + '/api-reference/models/get-available-models': { zh: '/api-reference/模型/获取可用模型', ja: '/api-reference/モデル/利用可能なモデルを取得' }, + '/api-reference/tags/create-knowledge-tag': { zh: '/api-reference/标签/创建知识库标签', ja: '/api-reference/タグ管理/ナレッジベースタグを作成' }, + '/api-reference/tags/create-tag-binding': { zh: '/api-reference/标签/绑定标签到知识库', ja: '/api-reference/タグ管理/タグをデータセットにバインド' }, + '/api-reference/tags/delete-knowledge-tag': { zh: '/api-reference/标签/删除知识库标签', ja: '/api-reference/タグ管理/ナレッジベースタグを削除' }, + '/api-reference/tags/delete-tag-binding': { zh: '/api-reference/标签/解除标签与知识库的绑定', ja: '/api-reference/タグ管理/タグとデータセットのバインドを解除' }, + '/api-reference/tags/get-knowledge-base-tags': { zh: '/api-reference/标签/获取知识库绑定的标签', ja: '/api-reference/タグ管理/ナレッジベースにバインドされたタグを取得' }, + '/api-reference/tags/list-knowledge-tags': { zh: '/api-reference/标签/获取知识库标签列表', ja: '/api-reference/タグ管理/ナレッジベースタグリストを取得' }, + '/api-reference/tags/update-knowledge-tag': { zh: '/api-reference/标签/修改知识库标签', ja: '/api-reference/タグ管理/ナレッジベースタグを変更' }, + '/api-reference/tts/convert-audio-to-text': { zh: '/api-reference/语音与文字转换/语音转文字', ja: '/api-reference/音声・テキスト変換/音声をテキストに変換' }, + '/api-reference/tts/convert-text-to-audio': { zh: '/api-reference/语音与文字转换/文字转语音', ja: '/api-reference/音声・テキスト変換/テキストを音声に変換' }, + '/api-reference/workflow-runs/get-workflow-run-detail': { zh: '/api-reference/工作流执行/获取工作流执行情况', ja: '/api-reference/ワークフロー実行/ワークフロー実行詳細を取得' }, + '/api-reference/workflow-runs/list-workflow-logs': { zh: '/api-reference/工作流执行/获取工作流日志', ja: '/api-reference/ワークフロー実行/ワークフローログ一覧を取得' }, + '/api-reference/workflows/get-workflow-run-detail': { zh: '/api-reference/工作流/获取工作流执行情况', ja: '/api-reference/ワークフロー/ワークフロー実行詳細を取得' }, + '/api-reference/workflows/list-workflow-logs': { zh: '/api-reference/工作流/获取工作流日志', ja: '/api-reference/ワークフロー/ワークフローログ一覧を取得' }, + '/api-reference/workflows/run-workflow': { zh: '/api-reference/工作流/执行工作流', ja: '/api-reference/ワークフロー/ワークフローを実行' }, + '/api-reference/workflows/run-workflow-by-id': { zh: '/api-reference/工作流/按-id-执行工作流', ja: '/api-reference/ワークフロー/id-でワークフローを実行' }, + '/api-reference/workflows/stop-workflow-task': { zh: '/api-reference/工作流/停止工作流任务', ja: '/api-reference/ワークフロー/ワークフロータスクを停止' }, }