diff --git a/.github/workflows/db-migration-test.yml b/.github/workflows/db-migration-test.yml index b1ccf496df..65f0149a74 100644 --- a/.github/workflows/db-migration-test.yml +++ b/.github/workflows/db-migration-test.yml @@ -110,6 +110,28 @@ jobs: sed -i 's/DB_PORT=5432/DB_PORT=3306/' .env sed -i 's/DB_USERNAME=postgres/DB_USERNAME=root/' .env + # hoverkraft-tech/compose-action@v2.6.0 only waits for `docker compose up -d` + # to return (container processes started); it does not wait on healthcheck + # status. mysql:8.0's first-time init takes 15-30s, so without an explicit + # wait the migration runs while InnoDB is still initialising and gets + # killed with "Lost connection during query". Poll a real SELECT until it + # succeeds. + - name: Wait for MySQL to accept queries + run: | + set +e + for i in $(seq 1 60); do + if docker run --rm --network host mysql:8.0 \ + mysql -h 127.0.0.1 -P 3306 -uroot -pdifyai123456 \ + -e 'SELECT 1' >/dev/null 2>&1; then + echo "MySQL ready after ${i}s" + exit 0 + fi + sleep 1 + done + echo "MySQL not ready after 60s; dumping container logs:" + docker compose -f docker/docker-compose.middleware.yaml --profile mysql logs --tail=200 db_mysql + exit 1 + - name: Run DB Migration env: DEBUG: true diff --git a/.github/workflows/web-e2e.yml b/.github/workflows/web-e2e.yml index a634830fef..bdc24887db 100644 --- a/.github/workflows/web-e2e.yml +++ b/.github/workflows/web-e2e.yml @@ -13,7 +13,7 @@ concurrency: jobs: test: name: Web Full-Stack E2E - runs-on: depot-ubuntu-24.04 + runs-on: depot-ubuntu-24.04-4 defaults: run: shell: bash diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index bc28ecb6b7..0b09facf58 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -468,15 +468,98 @@ class DocumentAddByFileApi(DatasetApiResource): return documents_and_batch_fields, 200 +def _update_document_by_file(tenant_id: str, dataset_id: UUID, document_id: UUID) -> tuple[Mapping[str, object], int]: + """Update a document from an uploaded file for canonical and deprecated routes.""" + dataset_id_str = str(dataset_id) + tenant_id_str = str(tenant_id) + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id_str, Dataset.id == dataset_id_str).limit(1) + ) + + if not dataset: + raise ValueError("Dataset does not exist.") + + if dataset.provider == "external": + raise ValueError("External datasets are not supported.") + + args: dict[str, object] = {} + if "data" in request.form: + args = json.loads(request.form["data"]) + if "doc_form" not in args: + args["doc_form"] = dataset.chunk_structure or "text_model" + if "doc_language" not in args: + args["doc_language"] = "English" + + # indexing_technique is already set in dataset since this is an update + args["indexing_technique"] = dataset.indexing_technique + + if "file" in request.files: + # save file info + file = request.files["file"] + + if len(request.files) > 1: + raise TooManyFilesError() + + if not file.filename: + raise FilenameNotExistsError + + if not current_user: + raise ValueError("current_user is required") + + try: + upload_file = FileService(db.engine).upload_file( + filename=file.filename, + content=file.read(), + mimetype=file.mimetype, + user=current_user, + source="datasets", + ) + except services.errors.file.FileTooLargeError as file_too_large_error: + raise FileTooLargeError(file_too_large_error.description) + except services.errors.file.UnsupportedFileTypeError: + raise UnsupportedFileTypeError() + data_source = { + "type": "upload_file", + "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, + } + args["data_source"] = data_source + + # validate args + args["original_document_id"] = str(document_id) + + knowledge_config = KnowledgeConfig.model_validate(args) + DocumentService.document_create_args_validate(knowledge_config) + + try: + documents, _ = DocumentService.save_document_with_dataset_id( + dataset=dataset, + knowledge_config=knowledge_config, + account=dataset.created_by_account, + dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, + created_from="api", + ) + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + document = documents[0] + documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": document.batch} + return documents_and_batch_fields, 200 + + @service_api_ns.route( "/datasets//documents//update_by_file", "/datasets//documents//update-by-file", ) -class DocumentUpdateByFileApi(DatasetApiResource): - """Resource for update documents.""" +class DeprecatedDocumentUpdateByFileApi(DatasetApiResource): + """Deprecated resource aliases for file document updates.""" - @service_api_ns.doc("update_document_by_file") - @service_api_ns.doc(description="Update an existing document by uploading a file") + @service_api_ns.doc("update_document_by_file_deprecated") + @service_api_ns.doc(deprecated=True) + @service_api_ns.doc( + description=( + "Deprecated legacy alias for updating an existing document by uploading a file. " + "Use PATCH /datasets/{dataset_id}/documents/{document_id} instead." + ) + ) @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) @service_api_ns.doc( responses={ @@ -487,82 +570,9 @@ class DocumentUpdateByFileApi(DatasetApiResource): ) @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def post(self, tenant_id, dataset_id, document_id): - """Update document by upload file.""" - dataset = db.session.scalar( - select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) - ) - - if not dataset: - raise ValueError("Dataset does not exist.") - - if dataset.provider == "external": - raise ValueError("External datasets are not supported.") - - args = {} - if "data" in request.form: - args = json.loads(request.form["data"]) - if "doc_form" not in args: - args["doc_form"] = dataset.chunk_structure or "text_model" - if "doc_language" not in args: - args["doc_language"] = "English" - - # get dataset info - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) - - # indexing_technique is already set in dataset since this is an update - args["indexing_technique"] = dataset.indexing_technique - - if "file" in request.files: - # save file info - file = request.files["file"] - - if len(request.files) > 1: - raise TooManyFilesError() - - if not file.filename: - raise FilenameNotExistsError - - if not current_user: - raise ValueError("current_user is required") - - try: - upload_file = FileService(db.engine).upload_file( - filename=file.filename, - content=file.read(), - mimetype=file.mimetype, - user=current_user, - source="datasets", - ) - except services.errors.file.FileTooLargeError as file_too_large_error: - raise FileTooLargeError(file_too_large_error.description) - except services.errors.file.UnsupportedFileTypeError: - raise UnsupportedFileTypeError() - data_source = { - "type": "upload_file", - "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, - } - args["data_source"] = data_source - # validate args - args["original_document_id"] = str(document_id) - - knowledge_config = KnowledgeConfig.model_validate(args) - DocumentService.document_create_args_validate(knowledge_config) - - try: - documents, _ = DocumentService.save_document_with_dataset_id( - dataset=dataset, - knowledge_config=knowledge_config, - account=dataset.created_by_account, - dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, - created_from="api", - ) - except ProviderTokenNotInitError as ex: - raise ProviderNotInitializeError(ex.description) - document = documents[0] - documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": document.batch} - return documents_and_batch_fields, 200 + def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID): + """Update document by file through the deprecated file-update aliases.""" + return _update_document_by_file(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id) @service_api_ns.route("/datasets//documents") @@ -876,6 +886,22 @@ class DocumentApi(DatasetApiResource): return response + @service_api_ns.doc("update_document_by_file") + @service_api_ns.doc(description="Update an existing document by uploading a file") + @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) + @service_api_ns.doc( + responses={ + 200: "Document updated successfully", + 401: "Unauthorized - invalid API token", + 404: "Document not found", + } + ) + @cloud_edition_billing_resource_check("vector_space", "dataset") + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") + def patch(self, tenant_id: str, dataset_id: UUID, document_id: UUID): + """Update document by file on the canonical document resource.""" + return _update_document_by_file(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id) + @service_api_ns.doc("delete_document") @service_api_ns.doc(description="Delete a document") @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index dddd5fc994..fa36c6ee19 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -39,6 +39,58 @@ class AbstractVectorFactory(ABC): return index_struct_dict +class _LazyEmbeddings(Embeddings): + """Lazy proxy that defers materializing the real embedding model. + + Constructing the real embeddings (via ``ModelManager.get_model_instance``) + transitively calls ``FeatureService.get_features`` → ``BillingService`` + HTTP GETs (see ``provider_manager.py``). Cleanup paths + (``delete_by_ids`` / ``delete`` / ``text_exists``) do not need embeddings + at all, so deferring this until an ``embed_*`` method is actually invoked + keeps cleanup tasks resilient to transient billing-API failures and avoids + leaving stranded ``document_segments`` / ``child_chunks`` whenever billing + hiccups. + + Existing callers that perform create / search operations are unaffected: + the first ``embed_*`` call materializes the underlying model and the + behavior is identical from that point on. + """ + + def __init__(self, dataset: Dataset): + self._dataset = dataset + self._real: Embeddings | None = None + + def _ensure(self) -> Embeddings: + if self._real is None: + model_manager = ModelManager.for_tenant(tenant_id=self._dataset.tenant_id) + embedding_model = model_manager.get_model_instance( + tenant_id=self._dataset.tenant_id, + provider=self._dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=self._dataset.embedding_model, + ) + self._real = CacheEmbedding(embedding_model) + return self._real + + def embed_documents(self, texts: list[str]) -> list[list[float]]: + return self._ensure().embed_documents(texts) + + def embed_multimodal_documents(self, multimodel_documents: list[dict[str, Any]]) -> list[list[float]]: + return self._ensure().embed_multimodal_documents(multimodel_documents) + + def embed_query(self, text: str) -> list[float]: + return self._ensure().embed_query(text) + + def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]: + return self._ensure().embed_multimodal_query(multimodel_document) + + async def aembed_documents(self, texts: list[str]) -> list[list[float]]: + return await self._ensure().aembed_documents(texts) + + async def aembed_query(self, text: str) -> list[float]: + return await self._ensure().aembed_query(text) + + class Vector: def __init__(self, dataset: Dataset, attributes: list | None = None): if attributes is None: @@ -60,7 +112,11 @@ class Vector: "original_chunk_id", ] self._dataset = dataset - self._embeddings = self._get_embeddings() + # Use a lazy proxy so cleanup paths (delete_by_ids / delete / text_exists) + # never transitively trigger billing API calls during ``Vector(dataset)`` + # construction. The real embedding model is materialized only when an + # ``embed_*`` method is actually invoked (i.e. create / search paths). + self._embeddings: Embeddings = _LazyEmbeddings(dataset) self._attributes = attributes self._vector_processor = self._init_vector() diff --git a/api/models/model.py b/api/models/model.py index f0f8d60cdc..29ecaf3073 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -2182,7 +2182,7 @@ class ApiToken(Base): # bug: this uses setattr so idk the field. return result -class UploadFile(Base): +class UploadFile(TypeBase): __tablename__ = "upload_files" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="upload_file_pkey"), @@ -2190,9 +2190,12 @@ class UploadFile(Base): ) # NOTE: The `id` field is generated within the application to minimize extra roundtrips - # (especially when generating `source_url`). - # The `server_default` serves as a fallback mechanism. - id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4())) + # (especially when generating `source_url`) and keep model metadata portable across databases. + id: Mapped[str] = mapped_column( + StringUUID, + init=False, + default_factory=lambda: str(uuid4()), + ) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) storage_type: Mapped[StorageType] = mapped_column(EnumText(StorageType, length=255), nullable=False) key: Mapped[str] = mapped_column(String(255), nullable=False) @@ -2200,16 +2203,6 @@ class UploadFile(Base): size: Mapped[int] = mapped_column(sa.Integer, nullable=False) extension: Mapped[str] = mapped_column(String(255), nullable=False) mime_type: Mapped[str] = mapped_column(String(255), nullable=True) - - # The `created_by_role` field indicates whether the file was created by an `Account` or an `EndUser`. - # Its value is derived from the `CreatorUserRole` enumeration. - created_by_role: Mapped[CreatorUserRole] = mapped_column( - EnumText(CreatorUserRole, length=255), - nullable=False, - server_default=sa.text("'account'"), - default=CreatorUserRole.ACCOUNT, - ) - # The `created_by` field stores the ID of the entity that created this upload file. # # If `created_by_role` is `ACCOUNT`, it corresponds to `Account.id`. @@ -2228,10 +2221,18 @@ class UploadFile(Base): # `used` may indicate whether the file has been utilized by another service. used: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) + # The `created_by_role` field indicates whether the file was created by an `Account` or an `EndUser`. + # Its value is derived from the `CreatorUserRole` enumeration. + created_by_role: Mapped[CreatorUserRole] = mapped_column( + EnumText(CreatorUserRole, length=255), + nullable=False, + server_default=sa.text("'account'"), + default=CreatorUserRole.ACCOUNT, + ) # `used_by` may indicate the ID of the user who utilized this file. - used_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True) - used_at: Mapped[datetime | None] = mapped_column(sa.DateTime, nullable=True) - hash: Mapped[str | None] = mapped_column(String(255), nullable=True) + used_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) + used_at: Mapped[datetime | None] = mapped_column(sa.DateTime, nullable=True, default=None) + hash: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) source_url: Mapped[str] = mapped_column(LongText, default="") def __init__( diff --git a/api/models/workflow.py b/api/models/workflow.py index bb4c24380f..cb90b4127d 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -50,7 +50,7 @@ from libs.uuid_utils import uuidv7 from ._workflow_exc import NodeNotFoundError, WorkflowDataError if TYPE_CHECKING: - from .model import AppMode, UploadFile + from .model import AppMode from graphon.variables import SecretVariable, Segment, SegmentType, VariableBase @@ -64,6 +64,10 @@ from .account import Account from .base import Base, DefaultFieldsDCMixin, TypeBase from .engine import db from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType, WorkflowRunTriggeredFrom + +# UploadFile uses TypeBase while workflow execution offload models use Base, so relationships +# must target the class object directly instead of relying on string lookup across registries. +from .model import UploadFile from .types import EnumText, LongText, StringUUID from .utils.file_input_compat import ( build_file_from_mapping_without_lookup, @@ -1145,8 +1149,6 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo @staticmethod def _load_full_content(session: orm.Session, file_id: str, storage: Storage): - from .model import UploadFile - stmt = sa.select(UploadFile).where(UploadFile.id == file_id) file = session.scalars(stmt).first() assert file is not None, f"UploadFile with id {file_id} should exist but not" @@ -1240,10 +1242,11 @@ class WorkflowNodeExecutionOffload(Base): ) file: Mapped[Optional["UploadFile"]] = orm.relationship( + UploadFile, foreign_keys=[file_id], lazy="raise", uselist=False, - primaryjoin="WorkflowNodeExecutionOffload.file_id == UploadFile.id", + primaryjoin=lambda: orm.foreign(WorkflowNodeExecutionOffload.file_id) == UploadFile.id, ) @@ -2017,10 +2020,11 @@ class WorkflowDraftVariableFile(Base): # Relationship to UploadFile upload_file: Mapped["UploadFile"] = orm.relationship( + UploadFile, foreign_keys=[upload_file_id], lazy="raise", uselist=False, - primaryjoin="WorkflowDraftVariableFile.upload_file_id == UploadFile.id", + primaryjoin=lambda: orm.foreign(WorkflowDraftVariableFile.upload_file_id) == UploadFile.id, ) diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index a657cd553a..c8d0e31c06 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -61,13 +61,31 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i # check segment is exist if index_node_ids: - index_processor = IndexProcessorFactory(doc_form).init_index_processor() - with session_factory.create_session() as session: - dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) - if dataset: - index_processor.clean( - dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True - ) + # Wrap vector / keyword index cleanup in try/except so that a transient + # failure here (e.g. billing API hiccup propagated via FeatureService when + # ModelManager is initialized inside ``Vector(dataset)``) does not abort + # the entire task and leave document_segments / child_chunks / image_files + # / metadata bindings stranded in PG. Mirrors the pattern already used in + # ``clean_dataset_task`` so the document row's hard delete (already + # committed by the caller) does not produce orphan PG rows just because + # the vector backend or one of its transitive dependencies was unhappy. + try: + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + with session_factory.create_session() as session: + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) + if dataset: + index_processor.clean( + dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True + ) + except Exception: + logger.exception( + "Failed to clean vector / keyword index in clean_document_task, " + "document_id=%s, dataset_id=%s, index_node_ids_count=%d. " + "Continuing with PG / storage cleanup; vector orphans can be reaped later.", + document_id, + dataset_id, + len(index_node_ids), + ) total_image_files = [] with session_factory.create_session() as session, session.begin(): diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py index e3be24ac74..017d60efac 100644 --- a/api/tasks/clean_notion_document_task.py +++ b/api/tasks/clean_notion_document_task.py @@ -40,12 +40,29 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str): segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() total_index_node_ids.extend([segment.index_node_id for segment in segments]) - with session_factory.create_session() as session: - dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) - if dataset: - index_processor.clean( - dataset, total_index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True - ) + # Wrap vector / keyword index cleanup in try/except so that a transient + # failure here (e.g. billing API hiccup propagated via FeatureService when + # ``ModelManager`` is initialized inside ``Vector(dataset)``) does not abort + # the task and leave the already-deleted documents' segments stranded in PG. + # The Document rows are hard-deleted in the previous session block, so any + # exception escaping this task would produce orphans that no later request + # can reference back. Mirrors the pattern in ``clean_dataset_task``. + try: + with session_factory.create_session() as session: + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) + if dataset: + index_processor.clean( + dataset, total_index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True + ) + except Exception: + logger.exception( + "Failed to clean vector / keyword index in clean_notion_document_task, " + "dataset_id=%s, document_ids=%s, index_node_ids_count=%d. " + "Continuing with segment deletion; vector orphans can be reaped later.", + dataset_id, + document_ids, + len(total_index_node_ids), + ) with session_factory.create_session() as session, session.begin(): segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids)) diff --git a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py index 5d201bd801..48d1774ce3 100644 --- a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py +++ b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py @@ -11,6 +11,7 @@ from core.plugin.entities.plugin import PluginInstallationSource from core.plugin.impl.plugin import PluginInstaller from extensions.ext_redis import redis_client from models.account import TenantPluginAutoUpgradeStrategy +from services.plugin.plugin_service import PluginService logger = logging.getLogger(__name__) @@ -171,14 +172,13 @@ def process_tenant_plugin_autoupgrade_check_task( fg="green", ) ) - _ = manager.upgrade_plugin( + # Use the service that downloads and uploads the package to the daemon + # first; calling manager.upgrade_plugin directly skips that step and the + # daemon fails because the package never reaches its local bucket. + _ = PluginService.upgrade_plugin_with_marketplace( tenant_id, original_unique_identifier, new_unique_identifier, - PluginInstallationSource.Marketplace, - { - "plugin_unique_identifier": new_unique_identifier, - }, ) except Exception as e: click.echo(click.style(f"Error when upgrading plugin: {e}", fg="red")) diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py index 2fb62e0fc0..bf36da242b 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py @@ -655,14 +655,25 @@ class TestCleanNotionDocumentTask: # Note: This test successfully verifies database operations. # IndexProcessor verification would require more sophisticated mocking. - def test_clean_notion_document_task_database_transaction_rollback( + def test_clean_notion_document_task_continues_when_index_processor_fails( self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies ): """ - Test cleanup task behavior when database operations fail. + Index processor failure (e.g. transient billing API error propagated via + ``FeatureService`` when ``Vector(dataset)`` lazily resolves the embedding + model) must NOT abort the cleanup task. The Document rows have already + been hard-deleted in the first session block before vector cleanup runs, + so any uncaught exception escaping the task would strand + ``DocumentSegment`` rows in PG with no parent ``Document``. - This test verifies that the task properly handles database errors - and maintains data consistency. + Contract: the task swallows the index_processor exception, logs it, and + proceeds to delete the segments — leaving PG consistent. (Vector orphans, + if any, can be reaped later by an offline scanner.) + + Regression guard for the production incident where ``clean_document_task`` + / ``clean_notion_document_task`` failed with + ``ValueError("Unable to retrieve billing information...")`` and left + tens of thousands of orphan segments per affected tenant. """ fake = Faker() @@ -725,17 +736,28 @@ class TestCleanNotionDocumentTask: db_session_with_containers.add(segment) db_session_with_containers.commit() - # Mock index processor to raise an exception + # Simulate the production failure mode: index_processor.clean() raises a + # ValueError mirroring ``BillingService._send_request`` returning non-200. mock_index_processor = mock_index_processor_factory.return_value.init_index_processor.return_value - mock_index_processor.clean.side_effect = Exception("Index processor error") + mock_index_processor.clean.side_effect = ValueError( + "Unable to retrieve billing information. Please try again later or contact support." + ) - # Execute cleanup task - current implementation propagates the exception - with pytest.raises(Exception, match="Index processor error"): - clean_notion_document_task([document.id], dataset.id) + # Execute cleanup task — must NOT raise even though clean() raises. + # Before the safety-net wrapper this would have re-raised the ValueError, + # aborting the task and leaving DocumentSegment stranded in PG. + clean_notion_document_task([document.id], dataset.id) - # Note: This test demonstrates the task's error handling capability. - # Even with external service errors, the database operations complete successfully. - # In a production environment, proper error handling would determine transaction rollback behavior. + # Vector cleanup was attempted exactly once. + mock_index_processor.clean.assert_called_once() + + # The crucial assertion: despite the index processor failure, the + # final session block (line 51-52, ``DELETE FROM document_segments``) + # still ran and committed. This is what the wrapper buys us — without + # it the production incident left tens of thousands of orphan segments + # per affected tenant. Aligns with the assertion shape used by the + # happy-path test (``test_clean_notion_document_task_success``). + assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 0 def test_clean_notion_document_task_with_large_number_of_documents( self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py index 1b391e67ec..230c51161f 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py @@ -23,6 +23,7 @@ from werkzeug.exceptions import Forbidden, NotFound from controllers.service_api.dataset.document import ( DeprecatedDocumentAddByTextApi, + DeprecatedDocumentUpdateByFileApi, DeprecatedDocumentUpdateByTextApi, DocumentAddByFileApi, DocumentAddByTextApi, @@ -32,7 +33,6 @@ from controllers.service_api.dataset.document import ( DocumentListQuery, DocumentTextCreatePayload, DocumentTextUpdate, - DocumentUpdateByFileApi, DocumentUpdateByTextApi, InvalidMetadataError, ) @@ -1095,8 +1095,8 @@ class TestArchivedDocumentImmutableError: assert error.code == 403 -class TestDocumentTextRouteDeprecation: - """Test that legacy underscore text routes stay marked deprecated.""" +class TestDocumentRouteDeprecation: + """Test that legacy document routes stay marked deprecated.""" def test_create_by_text_legacy_alias_is_deprecated(self): """Ensure only the legacy create-by-text alias is marked deprecated.""" @@ -1108,10 +1108,15 @@ class TestDocumentTextRouteDeprecation: assert DeprecatedDocumentUpdateByTextApi.post.__apidoc__["deprecated"] is True assert DocumentUpdateByTextApi.post.__apidoc__.get("deprecated") is not True + def test_update_by_file_legacy_aliases_are_deprecated(self): + """Ensure only the legacy file-update aliases are marked deprecated.""" + assert DeprecatedDocumentUpdateByFileApi.post.__apidoc__["deprecated"] is True + assert DocumentApi.patch.__apidoc__.get("deprecated") is not True + # ============================================================================= # Endpoint tests for DocumentUpdateByTextApi, DocumentAddByFileApi, -# DocumentUpdateByFileApi. +# and the canonical/deprecated document file update routes. # # These controllers use ``@cloud_edition_billing_resource_check`` (does NOT # preserve ``__wrapped__``) and ``@cloud_edition_billing_rate_limit_check`` @@ -1359,13 +1364,52 @@ class TestDocumentAddByFileApiPost: api.post(tenant_id=mock_tenant.id, dataset_id=mock_dataset.id) -class TestDocumentUpdateByFileApiPost: - """Test suite for DocumentUpdateByFileApi.post() endpoint. +class TestDocumentUpdateByFileApiPatch: + """Test suite for the canonical document file update endpoint. - ``post`` is wrapped by ``@cloud_edition_billing_resource_check`` and + ``patch`` is wrapped by ``@cloud_edition_billing_resource_check`` and ``@cloud_edition_billing_rate_limit_check``. """ + @pytest.mark.parametrize("route_name", ["update_by_file", "update-by-file"]) + @patch("controllers.service_api.dataset.document._update_document_by_file") + @patch("controllers.service_api.wraps.FeatureService") + @patch("controllers.service_api.wraps.validate_and_get_api_token") + def test_update_by_file_deprecated_aliases_delegate_to_shared_handler( + self, + mock_validate_token, + mock_feature_svc, + mock_update_document_by_file, + route_name, + app, + mock_tenant, + mock_dataset, + ): + """Test legacy POST aliases still dispatch while marked deprecated.""" + _setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) + mock_update_document_by_file.return_value = ({"document": {"id": "doc-1"}, "batch": "batch-1"}, 200) + + doc_id = str(uuid.uuid4()) + with app.test_request_context( + f"/datasets/{mock_dataset.id}/documents/{doc_id}/{route_name}", + method="POST", + headers={"Authorization": "Bearer test_token"}, + ): + api = DeprecatedDocumentUpdateByFileApi() + response, status = api.post( + tenant_id=mock_tenant.id, + dataset_id=mock_dataset.id, + document_id=doc_id, + ) + + assert status == 200 + assert response["batch"] == "batch-1" + mock_update_document_by_file.assert_called_once_with( + tenant_id=mock_tenant.id, + dataset_id=mock_dataset.id, + document_id=doc_id, + ) + @patch("controllers.service_api.dataset.document.db") @patch("controllers.service_api.wraps.FeatureService") @patch("controllers.service_api.wraps.validate_and_get_api_token") @@ -1387,15 +1431,15 @@ class TestDocumentUpdateByFileApiPost: doc_id = str(uuid.uuid4()) data = {"file": (BytesIO(b"content"), "test.pdf", "application/pdf")} with app.test_request_context( - f"/datasets/{mock_dataset.id}/documents/{doc_id}/update_by_file", - method="POST", + f"/datasets/{mock_dataset.id}/documents/{doc_id}", + method="PATCH", content_type="multipart/form-data", data=data, headers={"Authorization": "Bearer test_token"}, ): - api = DocumentUpdateByFileApi() + api = DocumentApi() with pytest.raises(ValueError, match="Dataset does not exist"): - api.post( + api.patch( tenant_id=mock_tenant.id, dataset_id=mock_dataset.id, document_id=doc_id, @@ -1423,15 +1467,15 @@ class TestDocumentUpdateByFileApiPost: doc_id = str(uuid.uuid4()) data = {"file": (BytesIO(b"content"), "test.pdf", "application/pdf")} with app.test_request_context( - f"/datasets/{mock_dataset.id}/documents/{doc_id}/update_by_file", - method="POST", + f"/datasets/{mock_dataset.id}/documents/{doc_id}", + method="PATCH", content_type="multipart/form-data", data=data, headers={"Authorization": "Bearer test_token"}, ): - api = DocumentUpdateByFileApi() + api = DocumentApi() with pytest.raises(ValueError, match="External datasets"): - api.post( + api.patch( tenant_id=mock_tenant.id, dataset_id=mock_dataset.id, document_id=doc_id, @@ -1482,14 +1526,14 @@ class TestDocumentUpdateByFileApiPost: doc_id = str(uuid.uuid4()) data = {"file": (BytesIO(b"file content"), "test.pdf", "application/pdf")} with app.test_request_context( - f"/datasets/{mock_dataset.id}/documents/{doc_id}/update_by_file", - method="POST", + f"/datasets/{mock_dataset.id}/documents/{doc_id}", + method="PATCH", content_type="multipart/form-data", data=data, headers={"Authorization": "Bearer test_token"}, ): - api = DocumentUpdateByFileApi() - response, status = api.post( + api = DocumentApi() + response, status = api.patch( tenant_id=mock_tenant.id, dataset_id=mock_dataset.id, document_id=doc_id, diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py index 9de04c80ba..f84ce2771f 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py @@ -146,10 +146,7 @@ def test_get_vector_factory_entry_point_overrides_builtin(vector_factory_module, def test_vector_init_uses_default_and_custom_attributes(vector_factory_module): dataset = SimpleNamespace(id="dataset-1") - with ( - patch.object(vector_factory_module.Vector, "_get_embeddings", return_value="embeddings"), - patch.object(vector_factory_module.Vector, "_init_vector", return_value="processor"), - ): + with patch.object(vector_factory_module.Vector, "_init_vector", return_value="processor"): default_vector = vector_factory_module.Vector(dataset) custom_vector = vector_factory_module.Vector(dataset, attributes=["doc_id"]) @@ -166,10 +163,57 @@ def test_vector_init_uses_default_and_custom_attributes(vector_factory_module): "original_chunk_id", ] assert custom_vector._attributes == ["doc_id"] - assert default_vector._embeddings == "embeddings" + # ``_embeddings`` is now a lazy proxy that defers materializing the real + # embedding model until ``embed_*`` is invoked, so cleanup paths never + # trigger billing/feature-service calls during ``Vector(dataset)`` + # construction. See ``_LazyEmbeddings``. + assert isinstance(default_vector._embeddings, vector_factory_module._LazyEmbeddings) assert default_vector._vector_processor == "processor" +def test_lazy_embeddings_defer_real_load_until_first_embed_call(vector_factory_module, monkeypatch): + """``Vector(dataset)`` must not transitively call ``ModelManager`` during + construction. The real embedding model should only be materialized on the + first ``embed_*`` call (i.e. create / search paths) so cleanup paths + (``delete_by_ids`` / ``delete``) remain resilient to billing-API failures. + """ + for_tenant_mock = MagicMock(side_effect=AssertionError("ModelManager.for_tenant must not be called eagerly")) + monkeypatch.setattr(vector_factory_module.ModelManager, "for_tenant", for_tenant_mock) + + dataset = SimpleNamespace( + tenant_id="tenant-1", + embedding_model_provider="openai", + embedding_model="text-embedding-3-small", + ) + + proxy = vector_factory_module._LazyEmbeddings(dataset) + + # Construction alone does not trigger ModelManager / FeatureService / BillingService. + for_tenant_mock.assert_not_called() + + # Exercising an embed_* method materializes the real model exactly once. + inner_model = MagicMock() + inner_model.embed_documents.return_value = [[0.1, 0.2]] + cached_embedding_mock = MagicMock(return_value=inner_model) + real_for_tenant = MagicMock() + real_for_tenant.get_model_instance.return_value = "embedding-model-instance" + monkeypatch.setattr(vector_factory_module.ModelManager, "for_tenant", MagicMock(return_value=real_for_tenant)) + monkeypatch.setattr(vector_factory_module, "CacheEmbedding", cached_embedding_mock) + + result = proxy.embed_documents(["hello"]) + + assert result == [[0.1, 0.2]] + cached_embedding_mock.assert_called_once_with("embedding-model-instance") + inner_model.embed_documents.assert_called_once_with(["hello"]) + + # Subsequent calls reuse the materialized model (no re-resolution). + inner_model.embed_documents.reset_mock() + cached_embedding_mock.reset_mock() + proxy.embed_documents(["world"]) + cached_embedding_mock.assert_not_called() + inner_model.embed_documents.assert_called_once_with(["world"]) + + def test_init_vector_prefers_dataset_index_struct(vector_factory_module, monkeypatch): calls = {"vector_type": None, "init_args": None} diff --git a/api/tests/unit_tests/models/test_workflow_models.py b/api/tests/unit_tests/models/test_workflow_models.py index 507e1c8c3a..876b889beb 100644 --- a/api/tests/unit_tests/models/test_workflow_models.py +++ b/api/tests/unit_tests/models/test_workflow_models.py @@ -45,7 +45,7 @@ class TestWorkflowModelValidation: workflow = Workflow.new( tenant_id=tenant_id, app_id=app_id, - type=WorkflowType.WORKFLOW.value, + type=WorkflowType.WORKFLOW, version="draft", graph=graph, features=features, @@ -58,7 +58,7 @@ class TestWorkflowModelValidation: # Assert assert workflow.tenant_id == tenant_id assert workflow.app_id == app_id - assert workflow.type == WorkflowType.WORKFLOW.value + assert workflow.type == WorkflowType.WORKFLOW assert workflow.version == "draft" assert workflow.graph == graph assert workflow.created_by == created_by @@ -68,7 +68,7 @@ class TestWorkflowModelValidation: def test_workflow_type_enum_values(self): """Test WorkflowType enum values.""" # Assert - assert WorkflowType.WORKFLOW.value == "workflow" + assert WorkflowType.WORKFLOW == "workflow" assert WorkflowType.CHAT.value == "chat" assert WorkflowType.RAG_PIPELINE.value == "rag-pipeline" @@ -89,7 +89,7 @@ class TestWorkflowModelValidation: workflow = Workflow.new( tenant_id=str(uuid4()), app_id=str(uuid4()), - type=WorkflowType.WORKFLOW.value, + type=WorkflowType.WORKFLOW, version="draft", graph=json.dumps(graph_data), features="{}", @@ -114,7 +114,7 @@ class TestWorkflowModelValidation: workflow = Workflow.new( tenant_id=str(uuid4()), app_id=str(uuid4()), - type=WorkflowType.WORKFLOW.value, + type=WorkflowType.WORKFLOW, version="draft", graph="{}", features=json.dumps(features_data), @@ -138,7 +138,7 @@ class TestWorkflowModelValidation: workflow = Workflow.new( tenant_id=str(uuid4()), app_id=str(uuid4()), - type=WorkflowType.WORKFLOW.value, + type=WorkflowType.WORKFLOW, version="v1.0", graph="{}", features="{}", @@ -176,11 +176,11 @@ class TestWorkflowRunStateTransitions: tenant_id=tenant_id, app_id=app_id, workflow_id=workflow_id, - type=WorkflowType.WORKFLOW.value, - triggered_from=WorkflowRunTriggeredFrom.DEBUGGING.value, + type=WorkflowType.WORKFLOW, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, version="draft", - status=WorkflowExecutionStatus.RUNNING.value, - created_by_role=CreatorUserRole.ACCOUNT.value, + status=WorkflowExecutionStatus.RUNNING, + created_by_role=CreatorUserRole.ACCOUNT, created_by=created_by, ) @@ -188,9 +188,9 @@ class TestWorkflowRunStateTransitions: assert workflow_run.tenant_id == tenant_id assert workflow_run.app_id == app_id assert workflow_run.workflow_id == workflow_id - assert workflow_run.type == WorkflowType.WORKFLOW.value - assert workflow_run.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value - assert workflow_run.status == WorkflowExecutionStatus.RUNNING.value + assert workflow_run.type == WorkflowType.WORKFLOW + assert workflow_run.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING + assert workflow_run.status == WorkflowExecutionStatus.RUNNING assert workflow_run.created_by == created_by def test_workflow_run_state_transition_running_to_succeeded(self): @@ -200,21 +200,21 @@ class TestWorkflowRunStateTransitions: tenant_id=str(uuid4()), app_id=str(uuid4()), workflow_id=str(uuid4()), - type=WorkflowType.WORKFLOW.value, - triggered_from=WorkflowRunTriggeredFrom.APP_RUN.value, + type=WorkflowType.WORKFLOW, + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, version="v1.0", - status=WorkflowExecutionStatus.RUNNING.value, - created_by_role=CreatorUserRole.END_USER.value, + status=WorkflowExecutionStatus.RUNNING, + created_by_role=CreatorUserRole.END_USER, created_by=str(uuid4()), ) # Act - workflow_run.status = WorkflowExecutionStatus.SUCCEEDED.value + workflow_run.status = WorkflowExecutionStatus.SUCCEEDED workflow_run.finished_at = datetime.now(UTC) workflow_run.elapsed_time = 2.5 # Assert - assert workflow_run.status == WorkflowExecutionStatus.SUCCEEDED.value + assert workflow_run.status == WorkflowExecutionStatus.SUCCEEDED assert workflow_run.finished_at is not None assert workflow_run.elapsed_time == 2.5 @@ -225,21 +225,21 @@ class TestWorkflowRunStateTransitions: tenant_id=str(uuid4()), app_id=str(uuid4()), workflow_id=str(uuid4()), - type=WorkflowType.WORKFLOW.value, - triggered_from=WorkflowRunTriggeredFrom.APP_RUN.value, + type=WorkflowType.WORKFLOW, + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, version="v1.0", - status=WorkflowExecutionStatus.RUNNING.value, - created_by_role=CreatorUserRole.ACCOUNT.value, + status=WorkflowExecutionStatus.RUNNING, + created_by_role=CreatorUserRole.ACCOUNT, created_by=str(uuid4()), ) # Act - workflow_run.status = WorkflowExecutionStatus.FAILED.value + workflow_run.status = WorkflowExecutionStatus.FAILED workflow_run.error = "Node execution failed: Invalid input" workflow_run.finished_at = datetime.now(UTC) # Assert - assert workflow_run.status == WorkflowExecutionStatus.FAILED.value + assert workflow_run.status == WorkflowExecutionStatus.FAILED assert workflow_run.error == "Node execution failed: Invalid input" assert workflow_run.finished_at is not None @@ -250,20 +250,20 @@ class TestWorkflowRunStateTransitions: tenant_id=str(uuid4()), app_id=str(uuid4()), workflow_id=str(uuid4()), - type=WorkflowType.WORKFLOW.value, - triggered_from=WorkflowRunTriggeredFrom.DEBUGGING.value, + type=WorkflowType.WORKFLOW, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, version="draft", - status=WorkflowExecutionStatus.RUNNING.value, - created_by_role=CreatorUserRole.ACCOUNT.value, + status=WorkflowExecutionStatus.RUNNING, + created_by_role=CreatorUserRole.ACCOUNT, created_by=str(uuid4()), ) # Act - workflow_run.status = WorkflowExecutionStatus.STOPPED.value + workflow_run.status = WorkflowExecutionStatus.STOPPED workflow_run.finished_at = datetime.now(UTC) # Assert - assert workflow_run.status == WorkflowExecutionStatus.STOPPED.value + assert workflow_run.status == WorkflowExecutionStatus.STOPPED assert workflow_run.finished_at is not None def test_workflow_run_state_transition_running_to_paused(self): @@ -273,19 +273,19 @@ class TestWorkflowRunStateTransitions: tenant_id=str(uuid4()), app_id=str(uuid4()), workflow_id=str(uuid4()), - type=WorkflowType.WORKFLOW.value, - triggered_from=WorkflowRunTriggeredFrom.APP_RUN.value, + type=WorkflowType.WORKFLOW, + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, version="v1.0", - status=WorkflowExecutionStatus.RUNNING.value, - created_by_role=CreatorUserRole.END_USER.value, + status=WorkflowExecutionStatus.RUNNING, + created_by_role=CreatorUserRole.END_USER, created_by=str(uuid4()), ) # Act - workflow_run.status = WorkflowExecutionStatus.PAUSED.value + workflow_run.status = WorkflowExecutionStatus.PAUSED # Assert - assert workflow_run.status == WorkflowExecutionStatus.PAUSED.value + assert workflow_run.status == WorkflowExecutionStatus.PAUSED assert workflow_run.finished_at is None # Not finished when paused def test_workflow_run_state_transition_paused_to_running(self): @@ -295,19 +295,19 @@ class TestWorkflowRunStateTransitions: tenant_id=str(uuid4()), app_id=str(uuid4()), workflow_id=str(uuid4()), - type=WorkflowType.WORKFLOW.value, - triggered_from=WorkflowRunTriggeredFrom.APP_RUN.value, + type=WorkflowType.WORKFLOW, + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, version="v1.0", - status=WorkflowExecutionStatus.PAUSED.value, - created_by_role=CreatorUserRole.ACCOUNT.value, + status=WorkflowExecutionStatus.PAUSED, + created_by_role=CreatorUserRole.ACCOUNT, created_by=str(uuid4()), ) # Act - workflow_run.status = WorkflowExecutionStatus.RUNNING.value + workflow_run.status = WorkflowExecutionStatus.RUNNING # Assert - assert workflow_run.status == WorkflowExecutionStatus.RUNNING.value + assert workflow_run.status == WorkflowExecutionStatus.RUNNING def test_workflow_run_with_partial_succeeded_status(self): """Test workflow run with partial-succeeded status.""" @@ -316,17 +316,17 @@ class TestWorkflowRunStateTransitions: tenant_id=str(uuid4()), app_id=str(uuid4()), workflow_id=str(uuid4()), - type=WorkflowType.WORKFLOW.value, - triggered_from=WorkflowRunTriggeredFrom.APP_RUN.value, + type=WorkflowType.WORKFLOW, + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, version="v1.0", - status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED.value, - created_by_role=CreatorUserRole.ACCOUNT.value, + status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED, + created_by_role=CreatorUserRole.ACCOUNT, created_by=str(uuid4()), exceptions_count=2, ) # Assert - assert workflow_run.status == WorkflowExecutionStatus.PARTIAL_SUCCEEDED.value + assert workflow_run.status == WorkflowExecutionStatus.PARTIAL_SUCCEEDED assert workflow_run.exceptions_count == 2 def test_workflow_run_with_inputs_and_outputs(self): @@ -340,11 +340,11 @@ class TestWorkflowRunStateTransitions: tenant_id=str(uuid4()), app_id=str(uuid4()), workflow_id=str(uuid4()), - type=WorkflowType.WORKFLOW.value, - triggered_from=WorkflowRunTriggeredFrom.APP_RUN.value, + type=WorkflowType.WORKFLOW, + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, version="v1.0", - status=WorkflowExecutionStatus.SUCCEEDED.value, - created_by_role=CreatorUserRole.END_USER.value, + status=WorkflowExecutionStatus.SUCCEEDED, + created_by_role=CreatorUserRole.END_USER, created_by=str(uuid4()), inputs=json.dumps(inputs), outputs=json.dumps(outputs), @@ -362,11 +362,11 @@ class TestWorkflowRunStateTransitions: tenant_id=str(uuid4()), app_id=str(uuid4()), workflow_id=str(uuid4()), - type=WorkflowType.WORKFLOW.value, - triggered_from=WorkflowRunTriggeredFrom.DEBUGGING.value, + type=WorkflowType.WORKFLOW, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, version="draft", - status=WorkflowExecutionStatus.RUNNING.value, - created_by_role=CreatorUserRole.ACCOUNT.value, + status=WorkflowExecutionStatus.RUNNING, + created_by_role=CreatorUserRole.ACCOUNT, created_by=str(uuid4()), graph=json.dumps(graph), ) @@ -391,11 +391,11 @@ class TestWorkflowRunStateTransitions: tenant_id=tenant_id, app_id=app_id, workflow_id=workflow_id, - type=WorkflowType.WORKFLOW.value, - triggered_from=WorkflowRunTriggeredFrom.APP_RUN.value, + type=WorkflowType.WORKFLOW, + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, version="v1.0", - status=WorkflowExecutionStatus.SUCCEEDED.value, - created_by_role=CreatorUserRole.ACCOUNT.value, + status=WorkflowExecutionStatus.SUCCEEDED, + created_by_role=CreatorUserRole.ACCOUNT, created_by=created_by, total_tokens=1500, total_steps=5, @@ -410,7 +410,7 @@ class TestWorkflowRunStateTransitions: assert result["tenant_id"] == tenant_id assert result["app_id"] == app_id assert result["workflow_id"] == workflow_id - assert result["status"] == WorkflowExecutionStatus.SUCCEEDED.value + assert result["status"] == WorkflowExecutionStatus.SUCCEEDED assert result["total_tokens"] == 1500 assert result["total_steps"] == 5 @@ -422,18 +422,18 @@ class TestWorkflowRunStateTransitions: "tenant_id": str(uuid4()), "app_id": str(uuid4()), "workflow_id": str(uuid4()), - "type": WorkflowType.WORKFLOW.value, - "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, + "type": WorkflowType.WORKFLOW, + "triggered_from": WorkflowRunTriggeredFrom.APP_RUN, "version": "v1.0", "graph": {"nodes": [], "edges": []}, "inputs": {"query": "test"}, - "status": WorkflowExecutionStatus.SUCCEEDED.value, + "status": WorkflowExecutionStatus.SUCCEEDED, "outputs": {"result": "success"}, "error": None, "elapsed_time": 3.5, "total_tokens": 2000, "total_steps": 10, - "created_by_role": CreatorUserRole.ACCOUNT.value, + "created_by_role": CreatorUserRole.ACCOUNT, "created_by": str(uuid4()), "created_at": datetime.now(UTC), "finished_at": datetime.now(UTC), @@ -446,7 +446,7 @@ class TestWorkflowRunStateTransitions: # Assert assert workflow_run.id == data["id"] assert workflow_run.workflow_id == data["workflow_id"] - assert workflow_run.status == WorkflowExecutionStatus.SUCCEEDED.value + assert workflow_run.status == WorkflowExecutionStatus.SUCCEEDED assert workflow_run.total_tokens == 2000 @@ -467,14 +467,14 @@ class TestNodeExecutionRelationships: tenant_id=tenant_id, app_id=app_id, workflow_id=workflow_id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, workflow_run_id=workflow_run_id, index=1, node_id="start", node_type=BuiltinNodeTypes.START, title="Start Node", - status=WorkflowNodeExecutionStatus.SUCCEEDED.value, - created_by_role=CreatorUserRole.ACCOUNT.value, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + created_by_role=CreatorUserRole.ACCOUNT, created_by=created_by, ) @@ -498,15 +498,15 @@ class TestNodeExecutionRelationships: tenant_id=str(uuid4()), app_id=str(uuid4()), workflow_id=str(uuid4()), - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, workflow_run_id=str(uuid4()), index=2, predecessor_node_id=predecessor_node_id, node_id=current_node_id, node_type=BuiltinNodeTypes.LLM, title="LLM Node", - status=WorkflowNodeExecutionStatus.RUNNING.value, - created_by_role=CreatorUserRole.ACCOUNT.value, + status=WorkflowNodeExecutionStatus.RUNNING, + created_by_role=CreatorUserRole.ACCOUNT, created_by=str(uuid4()), ) @@ -528,8 +528,8 @@ class TestNodeExecutionRelationships: node_id="llm_test", node_type=BuiltinNodeTypes.LLM, title="Test LLM", - status=WorkflowNodeExecutionStatus.SUCCEEDED.value, - created_by_role=CreatorUserRole.ACCOUNT.value, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + created_by_role=CreatorUserRole.ACCOUNT, created_by=str(uuid4()), ) @@ -549,14 +549,14 @@ class TestNodeExecutionRelationships: tenant_id=str(uuid4()), app_id=str(uuid4()), workflow_id=str(uuid4()), - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, workflow_run_id=str(uuid4()), index=1, node_id="llm_1", node_type=BuiltinNodeTypes.LLM, title="LLM Node", - status=WorkflowNodeExecutionStatus.SUCCEEDED.value, - created_by_role=CreatorUserRole.ACCOUNT.value, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + created_by_role=CreatorUserRole.ACCOUNT, created_by=str(uuid4()), inputs=json.dumps(inputs), outputs=json.dumps(outputs), @@ -575,24 +575,24 @@ class TestNodeExecutionRelationships: tenant_id=str(uuid4()), app_id=str(uuid4()), workflow_id=str(uuid4()), - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, workflow_run_id=str(uuid4()), index=1, node_id="code_1", node_type=BuiltinNodeTypes.CODE, title="Code Node", - status=WorkflowNodeExecutionStatus.RUNNING.value, - created_by_role=CreatorUserRole.ACCOUNT.value, + status=WorkflowNodeExecutionStatus.RUNNING, + created_by_role=CreatorUserRole.ACCOUNT, created_by=str(uuid4()), ) # Act - transition to succeeded - node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value + node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED node_execution.elapsed_time = 1.2 node_execution.finished_at = datetime.now(UTC) # Assert - assert node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value + assert node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED assert node_execution.elapsed_time == 1.2 assert node_execution.finished_at is not None @@ -606,20 +606,20 @@ class TestNodeExecutionRelationships: tenant_id=str(uuid4()), app_id=str(uuid4()), workflow_id=str(uuid4()), - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, workflow_run_id=str(uuid4()), index=3, node_id="code_1", node_type=BuiltinNodeTypes.CODE, title="Code Node", - status=WorkflowNodeExecutionStatus.FAILED.value, + status=WorkflowNodeExecutionStatus.FAILED, error=error_message, - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=str(uuid4()), ) # Assert - assert node_execution.status == WorkflowNodeExecutionStatus.FAILED.value + assert node_execution.status == WorkflowNodeExecutionStatus.FAILED assert node_execution.error == error_message def test_node_execution_with_metadata(self): @@ -637,14 +637,14 @@ class TestNodeExecutionRelationships: tenant_id=str(uuid4()), app_id=str(uuid4()), workflow_id=str(uuid4()), - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, workflow_run_id=str(uuid4()), index=1, node_id="llm_1", node_type=BuiltinNodeTypes.LLM, title="LLM Node", - status=WorkflowNodeExecutionStatus.SUCCEEDED.value, - created_by_role=CreatorUserRole.ACCOUNT.value, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + created_by_role=CreatorUserRole.ACCOUNT, created_by=str(uuid4()), execution_metadata=json.dumps(metadata), ) @@ -660,14 +660,14 @@ class TestNodeExecutionRelationships: tenant_id=str(uuid4()), app_id=str(uuid4()), workflow_id=str(uuid4()), - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, workflow_run_id=str(uuid4()), index=1, node_id="start", node_type=BuiltinNodeTypes.START, title="Start", - status=WorkflowNodeExecutionStatus.SUCCEEDED.value, - created_by_role=CreatorUserRole.ACCOUNT.value, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + created_by_role=CreatorUserRole.ACCOUNT, created_by=str(uuid4()), execution_metadata=None, ) @@ -696,14 +696,14 @@ class TestNodeExecutionRelationships: tenant_id=str(uuid4()), app_id=str(uuid4()), workflow_id=str(uuid4()), - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, workflow_run_id=str(uuid4()), index=1, node_id=f"{node_type}_1", node_type=node_type, title=title, - status=WorkflowNodeExecutionStatus.SUCCEEDED.value, - created_by_role=CreatorUserRole.ACCOUNT.value, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + created_by_role=CreatorUserRole.ACCOUNT, created_by=str(uuid4()), ) @@ -734,7 +734,7 @@ class TestGraphConfigurationValidation: workflow = Workflow.new( tenant_id=str(uuid4()), app_id=str(uuid4()), - type=WorkflowType.WORKFLOW.value, + type=WorkflowType.WORKFLOW, version="draft", graph=json.dumps(graph_config), features="{}", @@ -761,7 +761,7 @@ class TestGraphConfigurationValidation: workflow = Workflow.new( tenant_id=str(uuid4()), app_id=str(uuid4()), - type=WorkflowType.WORKFLOW.value, + type=WorkflowType.WORKFLOW, version="draft", graph=json.dumps(graph_config), features="{}", @@ -802,7 +802,7 @@ class TestGraphConfigurationValidation: workflow = Workflow.new( tenant_id=str(uuid4()), app_id=str(uuid4()), - type=WorkflowType.WORKFLOW.value, + type=WorkflowType.WORKFLOW, version="draft", graph=json.dumps(graph_config), features="{}", @@ -835,11 +835,11 @@ class TestGraphConfigurationValidation: tenant_id=str(uuid4()), app_id=str(uuid4()), workflow_id=str(uuid4()), - type=WorkflowType.WORKFLOW.value, - triggered_from=WorkflowRunTriggeredFrom.APP_RUN.value, + type=WorkflowType.WORKFLOW, + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, version="v1.0", - status=WorkflowExecutionStatus.RUNNING.value, - created_by_role=CreatorUserRole.ACCOUNT.value, + status=WorkflowExecutionStatus.RUNNING, + created_by_role=CreatorUserRole.ACCOUNT, created_by=str(uuid4()), graph=json.dumps(original_graph), ) @@ -872,7 +872,7 @@ class TestGraphConfigurationValidation: workflow = Workflow.new( tenant_id=str(uuid4()), app_id=str(uuid4()), - type=WorkflowType.WORKFLOW.value, + type=WorkflowType.WORKFLOW, version="draft", graph=json.dumps(graph_config), features="{}", @@ -912,7 +912,7 @@ class TestGraphConfigurationValidation: workflow = Workflow.new( tenant_id=str(uuid4()), app_id=str(uuid4()), - type=WorkflowType.WORKFLOW.value, + type=WorkflowType.WORKFLOW, version="draft", graph=json.dumps(graph_config), features="{}", @@ -933,7 +933,7 @@ class TestGraphConfigurationValidation: workflow = Workflow.new( tenant_id=str(uuid4()), app_id=str(uuid4()), - type=WorkflowType.WORKFLOW.value, + type=WorkflowType.WORKFLOW, version="draft", graph=None, features="{}", @@ -956,11 +956,11 @@ class TestGraphConfigurationValidation: tenant_id=str(uuid4()), app_id=str(uuid4()), workflow_id=str(uuid4()), - type=WorkflowType.WORKFLOW.value, - triggered_from=WorkflowRunTriggeredFrom.APP_RUN.value, + type=WorkflowType.WORKFLOW, + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, version="v1.0", - status=WorkflowExecutionStatus.RUNNING.value, - created_by_role=CreatorUserRole.ACCOUNT.value, + status=WorkflowExecutionStatus.RUNNING, + created_by_role=CreatorUserRole.ACCOUNT, created_by=str(uuid4()), inputs=None, ) @@ -978,11 +978,11 @@ class TestGraphConfigurationValidation: tenant_id=str(uuid4()), app_id=str(uuid4()), workflow_id=str(uuid4()), - type=WorkflowType.WORKFLOW.value, - triggered_from=WorkflowRunTriggeredFrom.APP_RUN.value, + type=WorkflowType.WORKFLOW, + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, version="v1.0", - status=WorkflowExecutionStatus.RUNNING.value, - created_by_role=CreatorUserRole.ACCOUNT.value, + status=WorkflowExecutionStatus.RUNNING, + created_by_role=CreatorUserRole.ACCOUNT, created_by=str(uuid4()), outputs=None, ) @@ -1000,14 +1000,14 @@ class TestGraphConfigurationValidation: tenant_id=str(uuid4()), app_id=str(uuid4()), workflow_id=str(uuid4()), - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, workflow_run_id=str(uuid4()), index=1, node_id="start", node_type=BuiltinNodeTypes.START, title="Start", - status=WorkflowNodeExecutionStatus.SUCCEEDED.value, - created_by_role=CreatorUserRole.ACCOUNT.value, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + created_by_role=CreatorUserRole.ACCOUNT, created_by=str(uuid4()), inputs=None, ) @@ -1025,14 +1025,14 @@ class TestGraphConfigurationValidation: tenant_id=str(uuid4()), app_id=str(uuid4()), workflow_id=str(uuid4()), - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, workflow_run_id=str(uuid4()), index=1, node_id="start", node_type=BuiltinNodeTypes.START, title="Start", - status=WorkflowNodeExecutionStatus.SUCCEEDED.value, - created_by_role=CreatorUserRole.ACCOUNT.value, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + created_by_role=CreatorUserRole.ACCOUNT, created_by=str(uuid4()), outputs=None, ) diff --git a/api/tests/unit_tests/oss/__mock/baidu_obs.py b/api/tests/unit_tests/oss/__mock/baidu_obs.py new file mode 100644 index 0000000000..d70a7c2eaa --- /dev/null +++ b/api/tests/unit_tests/oss/__mock/baidu_obs.py @@ -0,0 +1,70 @@ +import base64 +import hashlib +import os +from io import BytesIO +from types import SimpleNamespace + +import pytest +from _pytest.monkeypatch import MonkeyPatch +from baidubce.services.bos.bos_client import BosClient + +from tests.unit_tests.oss.__mock.base import ( + get_example_bucket, + get_example_data, + get_example_filename, + get_example_filepath, +) + + +class MockBaiduObsClass: + def __init__(self, config=None): + self.bucket_name = get_example_bucket() + self.key = get_example_filename() + self.content = get_example_data() + self.filepath = get_example_filepath() + + def put_object(self, bucket_name, key, data, content_length=None, content_md5=None, **kwargs): + assert bucket_name == self.bucket_name + assert key == self.key + assert data == self.content + assert content_length == len(self.content) + expected_md5 = base64.standard_b64encode(hashlib.md5(self.content).digest()) + assert content_md5 == expected_md5 + + def get_object(self, bucket_name, key, **kwargs): + assert bucket_name == self.bucket_name + assert key == self.key + return SimpleNamespace(data=BytesIO(self.content)) + + def get_object_to_file(self, bucket_name, key, file_name, **kwargs): + assert bucket_name == self.bucket_name + assert key == self.key + assert file_name == self.filepath + + def get_object_meta_data(self, bucket_name, key, **kwargs): + assert bucket_name == self.bucket_name + assert key == self.key + return SimpleNamespace(status=200) + + def delete_object(self, bucket_name, key, **kwargs): + assert bucket_name == self.bucket_name + assert key == self.key + + +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + + +@pytest.fixture +def setup_baidu_obs_mock(monkeypatch: MonkeyPatch): + if MOCK: + monkeypatch.setattr(BosClient, "__init__", MockBaiduObsClass.__init__) + monkeypatch.setattr(BosClient, "put_object", MockBaiduObsClass.put_object) + monkeypatch.setattr(BosClient, "get_object", MockBaiduObsClass.get_object) + monkeypatch.setattr(BosClient, "get_object_to_file", MockBaiduObsClass.get_object_to_file) + monkeypatch.setattr(BosClient, "get_object_meta_data", MockBaiduObsClass.get_object_meta_data) + monkeypatch.setattr(BosClient, "delete_object", MockBaiduObsClass.delete_object) + + yield + + if MOCK: + monkeypatch.undo() diff --git a/api/tests/unit_tests/oss/baidu_obs/__init__.py b/api/tests/unit_tests/oss/baidu_obs/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/api/tests/unit_tests/oss/baidu_obs/__init__.py @@ -0,0 +1 @@ + diff --git a/api/tests/unit_tests/oss/baidu_obs/test_baidu_obs.py b/api/tests/unit_tests/oss/baidu_obs/test_baidu_obs.py new file mode 100644 index 0000000000..053f811264 --- /dev/null +++ b/api/tests/unit_tests/oss/baidu_obs/test_baidu_obs.py @@ -0,0 +1,60 @@ +from unittest.mock import MagicMock, patch + +import pytest +from baidubce.auth.bce_credentials import BceCredentials +from baidubce.bce_client_configuration import BceClientConfiguration + +from extensions.storage.baidu_obs_storage import BaiduObsStorage +from tests.unit_tests.oss.__mock.base import ( + BaseStorageTest, + get_example_bucket, +) + +pytest_plugins = ("tests.unit_tests.oss.__mock.baidu_obs",) + + +class TestBaiduObs(BaseStorageTest): + @pytest.fixture(autouse=True) + def setup_method(self, setup_baidu_obs_mock): + """Executed before each test method.""" + with ( + patch.object(BceCredentials, "__init__", return_value=None), + patch.object(BceClientConfiguration, "__init__", return_value=None), + ): + self.storage = BaiduObsStorage() + self.storage.bucket_name = get_example_bucket() + + +class TestBaiduObsConfiguration: + def test_init_with_config(self): + mock_dify_config = MagicMock() + mock_dify_config.BAIDU_OBS_BUCKET_NAME = "test-bucket" + mock_dify_config.BAIDU_OBS_ACCESS_KEY = "test-access-key" + mock_dify_config.BAIDU_OBS_SECRET_KEY = "test-secret-key" + mock_dify_config.BAIDU_OBS_ENDPOINT = "https://bj.bcebos.com" + + mock_credentials = MagicMock(name="credentials") + mock_config = MagicMock(name="config") + mock_client = MagicMock(name="client") + + with ( + patch("extensions.storage.baidu_obs_storage.dify_config", mock_dify_config), + patch("extensions.storage.baidu_obs_storage.BceCredentials", return_value=mock_credentials) as credentials, + patch( + "extensions.storage.baidu_obs_storage.BceClientConfiguration", return_value=mock_config + ) as configuration, + patch("extensions.storage.baidu_obs_storage.BosClient", return_value=mock_client) as client_cls, + ): + storage = BaiduObsStorage() + + assert storage.bucket_name == "test-bucket" + assert storage.client == mock_client + credentials.assert_called_once_with( + access_key_id="test-access-key", + secret_access_key="test-secret-key", + ) + configuration.assert_called_once_with( + credentials=mock_credentials, + endpoint="https://bj.bcebos.com", + ) + client_cls.assert_called_once_with(config=mock_config) diff --git a/api/tests/unit_tests/tasks/test_clean_document_task.py b/api/tests/unit_tests/tasks/test_clean_document_task.py new file mode 100644 index 0000000000..26d7b3e3b6 --- /dev/null +++ b/api/tests/unit_tests/tasks/test_clean_document_task.py @@ -0,0 +1,291 @@ +""" +Unit tests for clean_document_task. + +Focuses on the resilience contract added by the billing-failure fix: +``index_processor.clean()`` is wrapped in ``try/except`` so that a transient +failure inside the vector / keyword cleanup (e.g. ``ValueError("Unable to +retrieve billing information...")`` raised by ``BillingService._send_request`` +when ``Vector(dataset)`` transitively triggers ``FeatureService.get_features``) +does not abort the entire task and leave PG with stranded ``DocumentSegment`` +/ ``ChildChunk`` / ``UploadFile`` / ``DatasetMetadataBinding`` rows. +""" + +import uuid +from unittest.mock import MagicMock, patch + +import pytest + +from tasks.clean_document_task import clean_document_task + + +@pytest.fixture +def document_id(): + return str(uuid.uuid4()) + + +@pytest.fixture +def dataset_id(): + return str(uuid.uuid4()) + + +@pytest.fixture +def tenant_id(): + return str(uuid.uuid4()) + + +@pytest.fixture +def mock_session_factory(): + """Patch ``session_factory.create_session`` to return per-call mock sessions. + + Each call to ``create_session()`` yields a fresh ``MagicMock`` session so we + can assert ``execute()`` calls across the multiple short-lived transactions + used by ``clean_document_task``. + """ + with patch("tasks.clean_document_task.session_factory", autospec=True) as mock_sf: + sessions: list[MagicMock] = [] + + def _create_session(): + session = MagicMock() + session.scalars.return_value.all.return_value = [] + session.execute.return_value.all.return_value = [] + session.scalar.return_value = None + cm = MagicMock() + cm.__enter__.return_value = session + cm.__exit__.return_value = None + sessions.append(session) + return cm + + mock_sf.create_session.side_effect = _create_session + yield mock_sf, sessions + + +@pytest.fixture +def mock_storage(): + with patch("tasks.clean_document_task.storage", autospec=True) as mock: + mock.delete.return_value = None + yield mock + + +@pytest.fixture +def mock_index_processor_factory(): + """Mock ``IndexProcessorFactory`` so we can inject behavior into ``clean``.""" + with patch("tasks.clean_document_task.IndexProcessorFactory", autospec=True) as factory_cls: + processor = MagicMock() + processor.clean.return_value = None + factory_instance = MagicMock() + factory_instance.init_index_processor.return_value = processor + factory_cls.return_value = factory_instance + + yield { + "factory_cls": factory_cls, + "factory_instance": factory_instance, + "processor": processor, + } + + +def _build_segment(segment_id: str, content: str = "segment content") -> MagicMock: + seg = MagicMock() + seg.id = segment_id + seg.index_node_id = f"node-{segment_id}" + seg.content = content + return seg + + +def _build_dataset(dataset_id: str, tenant_id: str) -> MagicMock: + ds = MagicMock() + ds.id = dataset_id + ds.tenant_id = tenant_id + return ds + + +class TestVectorCleanupResilience: + """Vector / keyword cleanup must not abort the task on transient failure.""" + + def test_billing_failure_during_vector_cleanup_does_not_skip_pg_cleanup( + self, + document_id, + dataset_id, + tenant_id, + mock_session_factory, + mock_storage, + mock_index_processor_factory, + ): + """Reproduces the production incident: + + ``Vector(dataset)`` transitively calls ``FeatureService.get_features`` + which calls ``BillingService._send_request("GET", ...)``. When billing + returns non-200 it raises ``ValueError("Unable to retrieve billing + information...")``. Before the fix this propagated out of + ``clean_document_task`` and left ``DocumentSegment`` / ``ChildChunk`` / + ``UploadFile`` / ``DatasetMetadataBinding`` rows orphaned because the + already-deleted ``Document`` row had been hard-committed by the caller + (``dataset_service.delete_document``) before ``.delay()`` was invoked. + + Contract: a billing failure inside ``index_processor.clean()`` must be + caught, logged, and the rest of the task must continue so PG ends up + consistent with the deleted ``Document`` even if Qdrant retains + orphan vectors that can be reaped later. + """ + mock_sf, sessions = mock_session_factory + + # First create_session(): Step 1 (load segments + attachments). + step1_session = MagicMock() + step1_session.scalars.return_value.all.return_value = [ + _build_segment("seg-1"), + _build_segment("seg-2"), + ] + step1_session.execute.return_value.all.return_value = [] + step1_session.scalar.return_value = _build_dataset(dataset_id, tenant_id) + # Second create_session(): Step 2 (vector cleanup). Returns dataset. + step2_session = MagicMock() + step2_session.scalar.return_value = _build_dataset(dataset_id, tenant_id) + step2_session.scalars.return_value.all.return_value = [] + step2_session.execute.return_value.all.return_value = [] + # Subsequent sessions: Step 3+ (image / segment / file / metadata cleanup). + # Default fixture returns empty results which is fine for these short txns. + cm1, cm2 = MagicMock(), MagicMock() + cm1.__enter__.return_value = step1_session + cm1.__exit__.return_value = None + cm2.__enter__.return_value = step2_session + cm2.__exit__.return_value = None + + def _default_cm(): + session = MagicMock() + session.scalars.return_value.all.return_value = [] + session.execute.return_value.all.return_value = [] + session.scalar.return_value = None + cm = MagicMock() + cm.__enter__.return_value = session + cm.__exit__.return_value = None + sessions.append(session) + return cm + + mock_sf.create_session.side_effect = [cm1, cm2] + [_default_cm() for _ in range(10)] + + # Simulate the production failure: index_processor.clean() raises ValueError + # mirroring BillingService._send_request when billing returns non-200. + mock_index_processor_factory["processor"].clean.side_effect = ValueError( + "Unable to retrieve billing information. Please try again later or contact support." + ) + + # Act — must not raise out of the task even though clean() raises. + clean_document_task( + document_id=document_id, + dataset_id=dataset_id, + doc_form="paragraph", + file_id=None, + ) + + # Assert + # 1. Vector cleanup was attempted. + mock_index_processor_factory["processor"].clean.assert_called_once() + # 2. Despite the failure the task continued: at least one DocumentSegment + # delete was issued. We use the count of session.execute calls across + # later short transactions as a proxy for "Step 3+ executed". + execute_calls = sum(s.execute.call_count for s in sessions) + assert execute_calls > 0, ( + "Step 3+ DB cleanup did not run after vector cleanup failure; " + "this regression would re-introduce the orphan-segment bug." + ) + + def test_vector_cleanup_success_path_remains_unaffected( + self, + document_id, + dataset_id, + tenant_id, + mock_session_factory, + mock_storage, + mock_index_processor_factory, + ): + """Backward-compat: the happy path must still call ``clean()`` exactly + once with the expected arguments and complete without errors. + """ + mock_sf, sessions = mock_session_factory + + step1_session = MagicMock() + step1_session.scalars.return_value.all.return_value = [_build_segment("seg-1")] + step1_session.execute.return_value.all.return_value = [] + step1_session.scalar.return_value = _build_dataset(dataset_id, tenant_id) + step2_session = MagicMock() + step2_session.scalar.return_value = _build_dataset(dataset_id, tenant_id) + step2_session.scalars.return_value.all.return_value = [] + step2_session.execute.return_value.all.return_value = [] + cm1, cm2 = MagicMock(), MagicMock() + cm1.__enter__.return_value = step1_session + cm1.__exit__.return_value = None + cm2.__enter__.return_value = step2_session + cm2.__exit__.return_value = None + + def _default_cm(): + session = MagicMock() + session.scalars.return_value.all.return_value = [] + session.execute.return_value.all.return_value = [] + session.scalar.return_value = None + cm = MagicMock() + cm.__enter__.return_value = session + cm.__exit__.return_value = None + sessions.append(session) + return cm + + mock_sf.create_session.side_effect = [cm1, cm2] + [_default_cm() for _ in range(10)] + + clean_document_task( + document_id=document_id, + dataset_id=dataset_id, + doc_form="paragraph", + file_id=None, + ) + + assert mock_index_processor_factory["processor"].clean.call_count == 1 + # Index cleanup invoked with the expected delete_summaries / delete_child_chunks flags. + _, kwargs = mock_index_processor_factory["processor"].clean.call_args + assert kwargs.get("with_keywords") is True + assert kwargs.get("delete_child_chunks") is True + assert kwargs.get("delete_summaries") is True + + def test_no_segments_skips_vector_cleanup( + self, + document_id, + dataset_id, + tenant_id, + mock_session_factory, + mock_storage, + mock_index_processor_factory, + ): + """When the document has no segments (e.g. indexing failed before + producing any), vector cleanup must not be attempted — and therefore + the new try/except wrapper does not change behavior here. + """ + mock_sf, sessions = mock_session_factory + + step1_session = MagicMock() + step1_session.scalars.return_value.all.return_value = [] # no segments + step1_session.execute.return_value.all.return_value = [] + step1_session.scalar.return_value = _build_dataset(dataset_id, tenant_id) + cm1 = MagicMock() + cm1.__enter__.return_value = step1_session + cm1.__exit__.return_value = None + + def _default_cm(): + session = MagicMock() + session.scalars.return_value.all.return_value = [] + session.execute.return_value.all.return_value = [] + session.scalar.return_value = None + cm = MagicMock() + cm.__enter__.return_value = session + cm.__exit__.return_value = None + sessions.append(session) + return cm + + mock_sf.create_session.side_effect = [cm1] + [_default_cm() for _ in range(10)] + + clean_document_task( + document_id=document_id, + dataset_id=dataset_id, + doc_form="paragraph", + file_id=None, + ) + + # Vector cleanup is gated on ``index_node_ids``; when there are no + # segments the IndexProcessorFactory path is never entered. + mock_index_processor_factory["factory_cls"].assert_not_called() diff --git a/api/tests/unit_tests/tasks/test_process_tenant_plugin_autoupgrade_check_task.py b/api/tests/unit_tests/tasks/test_process_tenant_plugin_autoupgrade_check_task.py new file mode 100644 index 0000000000..75d8b92044 --- /dev/null +++ b/api/tests/unit_tests/tasks/test_process_tenant_plugin_autoupgrade_check_task.py @@ -0,0 +1,289 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +from core.plugin.entities.marketplace import MarketplacePluginSnapshot +from core.plugin.entities.plugin import PluginInstallationSource +from models.account import TenantPluginAutoUpgradeStrategy + +MODULE = "tasks.process_tenant_plugin_autoupgrade_check_task" + + +def _make_plugin(plugin_id: str, version: str, source=PluginInstallationSource.Marketplace): + """Build a minimal stand-in for a PluginInstallation entry returned by manager.list_plugins.""" + return SimpleNamespace( + plugin_id=plugin_id, + version=version, + plugin_unique_identifier=f"{plugin_id}:{version}@deadbeef", + source=source, + ) + + +def _make_manifest(plugin_id: str, latest_version: str) -> MarketplacePluginSnapshot: + org, name = plugin_id.split("/", 1) + return MarketplacePluginSnapshot( + org=org, + name=name, + latest_version=latest_version, + latest_package_identifier=f"{plugin_id}:{latest_version}@cafe1234", + latest_package_url=f"https://marketplace.example/{plugin_id}/{latest_version}.difypkg", + ) + + +def _run_task( + *, + plugins: list, + manifests: list[MarketplacePluginSnapshot], + strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST, + upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL, + exclude_plugins=None, + include_plugins=None, +): + """ + Execute the celery task synchronously with mocks for the plugin manager, + the marketplace cache and PluginService.upgrade_plugin_with_marketplace. + Returns the upgrade-call recorder so each test can assert on it. + """ + fake_manager = MagicMock() + fake_manager.list_plugins.return_value = plugins + + upgrade_calls: list[tuple[str, str, str]] = [] + + def _record_upgrade(tenant_id, original, new): + upgrade_calls.append((tenant_id, original, new)) + + with ( + patch(f"{MODULE}.PluginInstaller", return_value=fake_manager), + patch(f"{MODULE}.marketplace_batch_fetch_plugin_manifests", return_value=manifests), + patch( + f"{MODULE}.PluginService.upgrade_plugin_with_marketplace", + side_effect=_record_upgrade, + ) as upgrade_mock, + ): + from tasks.process_tenant_plugin_autoupgrade_check_task import ( + process_tenant_plugin_autoupgrade_check_task, + ) + + process_tenant_plugin_autoupgrade_check_task( + "tenant-1", + strategy_setting, + 0, + upgrade_mode, + exclude_plugins or [], + include_plugins or [], + ) + + return upgrade_mock, upgrade_calls + + +class TestUpgradeCallsMarketplaceService: + """ + Regression test for the bug where the auto-upgrade task called + manager.upgrade_plugin directly, which skipped downloading the new package + from marketplace and uploading it to the daemon. The daemon then failed with + "package file not found" and the upgrade silently never completed. + """ + + def test_upgrade_routes_through_plugin_service(self): + plugin = _make_plugin("acme/foo", "1.0.0") + manifest = _make_manifest("acme/foo", "1.0.1") + + upgrade_mock, calls = _run_task(plugins=[plugin], manifests=[manifest]) + + upgrade_mock.assert_called_once() + assert calls == [("tenant-1", plugin.plugin_unique_identifier, manifest.latest_package_identifier)] + + def test_does_not_call_manager_upgrade_plugin_directly(self): + """Locks in that we never go back to the broken path that bypassed download/upload.""" + plugin = _make_plugin("acme/foo", "1.0.0") + manifest = _make_manifest("acme/foo", "1.0.1") + + fake_manager = MagicMock() + fake_manager.list_plugins.return_value = [plugin] + + with ( + patch(f"{MODULE}.PluginInstaller", return_value=fake_manager), + patch(f"{MODULE}.marketplace_batch_fetch_plugin_manifests", return_value=[manifest]), + patch(f"{MODULE}.PluginService.upgrade_plugin_with_marketplace"), + ): + from tasks.process_tenant_plugin_autoupgrade_check_task import ( + process_tenant_plugin_autoupgrade_check_task, + ) + + process_tenant_plugin_autoupgrade_check_task( + "tenant-1", + TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST, + 0, + TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL, + [], + [], + ) + + fake_manager.upgrade_plugin.assert_not_called() + + +class TestStrategySetting: + def test_disabled_strategy_skips_everything(self): + upgrade_mock, _ = _run_task( + plugins=[_make_plugin("acme/foo", "1.0.0")], + manifests=[_make_manifest("acme/foo", "1.0.1")], + strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.DISABLED, + ) + upgrade_mock.assert_not_called() + + def test_fix_only_upgrades_patch_version(self): + upgrade_mock, calls = _run_task( + plugins=[_make_plugin("acme/foo", "1.0.0")], + manifests=[_make_manifest("acme/foo", "1.0.5")], + strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY, + ) + upgrade_mock.assert_called_once() + assert calls[0][2].endswith(":1.0.5@cafe1234") + + def test_fix_only_skips_minor_bump(self): + upgrade_mock, _ = _run_task( + plugins=[_make_plugin("acme/foo", "1.0.0")], + manifests=[_make_manifest("acme/foo", "1.1.0")], + strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY, + ) + upgrade_mock.assert_not_called() + + def test_fix_only_skips_major_bump(self): + upgrade_mock, _ = _run_task( + plugins=[_make_plugin("acme/foo", "1.0.0")], + manifests=[_make_manifest("acme/foo", "2.0.0")], + strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY, + ) + upgrade_mock.assert_not_called() + + def test_latest_strategy_skips_when_versions_equal(self): + upgrade_mock, _ = _run_task( + plugins=[_make_plugin("acme/foo", "1.0.0")], + manifests=[_make_manifest("acme/foo", "1.0.0")], + strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST, + ) + upgrade_mock.assert_not_called() + + +class TestUpgradeMode: + def test_mode_all_upgrades_every_marketplace_plugin(self): + plugins = [ + _make_plugin("acme/foo", "1.0.0"), + _make_plugin("acme/bar", "2.0.0"), + ] + manifests = [ + _make_manifest("acme/foo", "1.0.1"), + _make_manifest("acme/bar", "2.0.1"), + ] + + upgrade_mock, calls = _run_task( + plugins=plugins, + manifests=manifests, + upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL, + ) + + assert upgrade_mock.call_count == 2 + upgraded_ids = sorted(c[1] for c in calls) + assert upgraded_ids == sorted(p.plugin_unique_identifier for p in plugins) + + def test_mode_all_skips_non_marketplace_sources(self): + plugins = [ + _make_plugin("acme/foo", "1.0.0"), + _make_plugin("acme/bar", "2.0.0", source=PluginInstallationSource.Github), + ] + manifests = [ + _make_manifest("acme/foo", "1.0.1"), + _make_manifest("acme/bar", "2.0.1"), + ] + + upgrade_mock, calls = _run_task( + plugins=plugins, + manifests=manifests, + upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL, + ) + + assert upgrade_mock.call_count == 1 + assert calls[0][1] == plugins[0].plugin_unique_identifier + + def test_mode_partial_only_upgrades_included_plugins(self): + plugins = [ + _make_plugin("acme/foo", "1.0.0"), + _make_plugin("acme/bar", "2.0.0"), + ] + manifests = [ + _make_manifest("acme/foo", "1.0.1"), + _make_manifest("acme/bar", "2.0.1"), + ] + + upgrade_mock, calls = _run_task( + plugins=plugins, + manifests=manifests, + upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL, + include_plugins=["acme/foo"], + ) + + assert upgrade_mock.call_count == 1 + assert calls[0][1] == plugins[0].plugin_unique_identifier + + def test_mode_exclude_skips_excluded_plugins(self): + plugins = [ + _make_plugin("acme/foo", "1.0.0"), + _make_plugin("acme/bar", "2.0.0"), + ] + manifests = [ + _make_manifest("acme/foo", "1.0.1"), + _make_manifest("acme/bar", "2.0.1"), + ] + + upgrade_mock, calls = _run_task( + plugins=plugins, + manifests=manifests, + upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE, + exclude_plugins=["acme/bar"], + ) + + assert upgrade_mock.call_count == 1 + assert calls[0][1] == plugins[0].plugin_unique_identifier + + +class TestErrorIsolation: + def test_one_plugin_failure_does_not_block_others(self): + plugins = [ + _make_plugin("acme/foo", "1.0.0"), + _make_plugin("acme/bar", "2.0.0"), + ] + manifests = [ + _make_manifest("acme/foo", "1.0.1"), + _make_manifest("acme/bar", "2.0.1"), + ] + fake_manager = MagicMock() + fake_manager.list_plugins.return_value = plugins + + seen: list[str] = [] + + def _upgrade(tenant_id, original, new): + seen.append(original) + if "foo" in original: + raise RuntimeError("boom") + + with ( + patch(f"{MODULE}.PluginInstaller", return_value=fake_manager), + patch(f"{MODULE}.marketplace_batch_fetch_plugin_manifests", return_value=manifests), + patch(f"{MODULE}.PluginService.upgrade_plugin_with_marketplace", side_effect=_upgrade), + ): + from tasks.process_tenant_plugin_autoupgrade_check_task import ( + process_tenant_plugin_autoupgrade_check_task, + ) + + process_tenant_plugin_autoupgrade_check_task( + "tenant-1", + TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST, + 0, + TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL, + [], + [], + ) + + assert any("foo" in s for s in seen) + assert any("bar" in s for s in seen) diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index 911da70a73..af3d54dfb3 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -59,19 +59,25 @@ services: - ${MYSQL_HOST_VOLUME:-./volumes/mysql/data}:/var/lib/mysql ports: - "${EXPOSE_MYSQL_PORT:-3306}:3306" + # mysqladmin ping passes during mysql:8.0's TCP-listening stage even while + # the server is still finalising init, leading to "Lost connection during + # query" on the first real query. Verify with a real SELECT instead. healthcheck: test: [ "CMD", - "mysqladmin", - "ping", - "-u", - "root", + "mysql", + "-h", + "127.0.0.1", + "-uroot", "-p${DB_PASSWORD:-difyai123456}", + "-e", + "SELECT 1", ] interval: 1s timeout: 3s retries: 30 + start_period: 20s # The redis cache. redis: diff --git a/e2e/features/apps/share-app.feature b/e2e/features/apps/share-app.feature index 22f89f7ebb..1c707306ef 100644 --- a/e2e/features/apps/share-app.feature +++ b/e2e/features/apps/share-app.feature @@ -17,3 +17,10 @@ Feature: Share app publicly Given a workflow app has been published and shared via API When I open the shared app URL Then the shared app page should be accessible + + @unauthenticated + Scenario: Run a shared workflow app without authentication + Given a workflow app has been published and shared via API + When I open the shared app URL + And I run the shared workflow app + Then the shared workflow run should succeed diff --git a/e2e/features/step-definitions/apps/share-app.steps.ts b/e2e/features/step-definitions/apps/share-app.steps.ts index 24da05baab..d5742bdaa8 100644 --- a/e2e/features/step-definitions/apps/share-app.steps.ts +++ b/e2e/features/step-definitions/apps/share-app.steps.ts @@ -37,3 +37,15 @@ Then('the shared app page should be accessible', async function (this: DifyWorld await expect(this.getPage()).toHaveURL(/\/(workflow|chat)\/[a-zA-Z0-9]+/, { timeout: 15_000 }) await expect(this.getPage().locator('body')).toBeVisible({ timeout: 10_000 }) }) + +When('I run the shared workflow app', async function (this: DifyWorld) { + const page = this.getPage() + const runButton = page.getByTestId('run-button') + + await expect(runButton).toBeEnabled({ timeout: 15_000 }) + await runButton.click() +}) + +Then('the shared workflow run should succeed', async function (this: DifyWorld) { + await expect(this.getPage().getByTestId('status-icon-success')).toBeVisible({ timeout: 55_000 }) +}) diff --git a/e2e/features/step-definitions/apps/workflow-run.steps.ts b/e2e/features/step-definitions/apps/workflow-run.steps.ts index 584a33e774..84c03bfa8f 100644 --- a/e2e/features/step-definitions/apps/workflow-run.steps.ts +++ b/e2e/features/step-definitions/apps/workflow-run.steps.ts @@ -12,8 +12,10 @@ Given('a minimal runnable workflow draft has been synced', async function (this: When('I run the workflow', async function (this: DifyWorld) { const page = this.getPage() - await page.getByText('Test Run').click() - await expect(page.getByText('Running').first()).toBeVisible({ timeout: 15_000 }) + const testRunButton = page.getByText('Test Run') + + await expect(testRunButton).toBeVisible({ timeout: 15_000 }) + await testRunButton.click() }) Then('the workflow run should succeed', async function (this: DifyWorld) { diff --git a/eslint-suppressions.json b/eslint-suppressions.json index a50a9415ea..beb49b7d5f 100644 --- a/eslint-suppressions.json +++ b/eslint-suppressions.json @@ -3506,11 +3506,6 @@ "count": 1 } }, - "web/app/components/plugins/reference-setting-modal/auto-update-setting/strategy-picker.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/plugins/reference-setting-modal/auto-update-setting/types.ts": { "erasable-syntax-only/enums": { "count": 2 diff --git a/web/app/components/app-sidebar/snippet-info/__tests__/dropdown.spec.tsx b/web/app/components/app-sidebar/snippet-info/__tests__/dropdown.spec.tsx index 22631d35d5..287b431715 100644 --- a/web/app/components/app-sidebar/snippet-info/__tests__/dropdown.spec.tsx +++ b/web/app/components/app-sidebar/snippet-info/__tests__/dropdown.spec.tsx @@ -154,7 +154,6 @@ const mockSnippet: SnippetDetail = { id: 'snippet-1', name: 'Social Media Repurposer', description: 'Turn one blog post into multiple social media variations.', - author: 'Dify', updatedAt: '2026-03-25 10:00', usage: '12', icon: '🤖', diff --git a/web/app/components/app-sidebar/snippet-info/__tests__/index.spec.tsx b/web/app/components/app-sidebar/snippet-info/__tests__/index.spec.tsx index 50754ffd23..c99f09b2e8 100644 --- a/web/app/components/app-sidebar/snippet-info/__tests__/index.spec.tsx +++ b/web/app/components/app-sidebar/snippet-info/__tests__/index.spec.tsx @@ -11,7 +11,6 @@ const mockSnippet: SnippetDetail = { id: 'snippet-1', name: 'Social Media Repurposer', description: 'Turn one blog post into multiple social media variations.', - author: 'Dify', updatedAt: '2026-03-25 10:00', usage: '12', icon: '🤖', diff --git a/web/app/components/apps/__tests__/list.spec.tsx b/web/app/components/apps/__tests__/list.spec.tsx index 932db440da..6fec23cd81 100644 --- a/web/app/components/apps/__tests__/list.spec.tsx +++ b/web/app/components/apps/__tests__/list.spec.tsx @@ -160,8 +160,9 @@ const defaultSnippetData = { icon_url: '', }, created_at: 1704067200, - updated_at: '2024-01-02 10:00', - author: '', + created_by: 'user-1', + updated_at: 1704153600, + updated_by: 'user-2', }, ], total: 1, @@ -321,8 +322,9 @@ describe('List', () => { icon_url: '', }, created_at: 1704067200, - updated_at: '2024-01-02 10:00', - author: '', + created_by: 'user-1', + updated_at: 1704153600, + updated_by: 'user-2', }, ] defaultSnippetData.pages[0]!.total = 1 @@ -678,8 +680,8 @@ describe('List', () => { }) it('should reuse the shared empty state when no snippets are available', () => { - defaultSnippetData.pages[0].data = [] - defaultSnippetData.pages[0].total = 0 + defaultSnippetData.pages[0]!.data = [] + defaultSnippetData.pages[0]!.total = 0 renderList({ pageType: 'snippets' }) diff --git a/web/app/components/base/chat/embedded-chatbot/__tests__/hooks.spec.tsx b/web/app/components/base/chat/embedded-chatbot/__tests__/hooks.spec.tsx index acb27fb137..f7337ce8fe 100644 --- a/web/app/components/base/chat/embedded-chatbot/__tests__/hooks.spec.tsx +++ b/web/app/components/base/chat/embedded-chatbot/__tests__/hooks.spec.tsx @@ -532,6 +532,7 @@ describe('useEmbeddedChatbot', () => { }) it('handleChangeConversation updates current conversation and refetches chat list', async () => { + mockStoreState.embeddedConversationId = null const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp)) act(() => { @@ -548,6 +549,39 @@ describe('useEmbeddedChatbot', () => { expect(result.current.clearChatList).toBe(false) }) + // Scenario: URL-provided conversation_id should take precedence over localStorage value. + it('should prioritize URL conversation_id over localStorage', async () => { + localStorage.setItem(CONVERSATION_ID_INFO, JSON.stringify({ + 'app-1': { 'embedded-user-1': 'stored-conv-id' }, + })) + mockStoreState.embeddedConversationId = 'url-conv-id' + mockGetProcessedSystemVariablesFromUrlParams.mockResolvedValue({ + user_id: 'embedded-user-1', + conversation_id: 'url-conv-id', + }) + + const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp)) + + await waitFor(() => { + expect(result.current.currentConversationId).toBe('url-conv-id') + }) + }) + + // Scenario: When no URL conversation_id is provided, fall back to localStorage. + it('should fall back to localStorage when no URL conversation_id is provided', async () => { + localStorage.setItem(CONVERSATION_ID_INFO, JSON.stringify({ + 'app-1': { DEFAULT: 'stored-conv-id' }, + })) + mockStoreState.embeddedConversationId = null + mockStoreState.embeddedUserId = null + + const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp)) + + await waitFor(() => { + expect(result.current.currentConversationId).toBe('stored-conv-id') + }) + }) + it('handleFeedback invokes updateFeedback service successfully', async () => { const { updateFeedback } = await import('@/service/share') const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp)) diff --git a/web/app/components/base/chat/embedded-chatbot/hooks.tsx b/web/app/components/base/chat/embedded-chatbot/hooks.tsx index 4ce6bc6318..39cbe66fbb 100644 --- a/web/app/components/base/chat/embedded-chatbot/hooks.tsx +++ b/web/app/components/base/chat/embedded-chatbot/hooks.tsx @@ -113,7 +113,7 @@ export const useEmbeddedChatbot = (appSourceType: AppSourceType, tryAppId?: stri }) }, [setConversationIdInfo]) const allowResetChat = !conversationId - const currentConversationId = useMemo(() => conversationIdInfo?.[appId || '']?.[userId || 'DEFAULT'] || conversationId || '', [appId, conversationIdInfo, userId, conversationId]) + const currentConversationId = useMemo(() => conversationId || conversationIdInfo?.[appId || '']?.[userId || 'DEFAULT'] || '', [appId, conversationIdInfo, userId, conversationId]) const handleConversationIdInfoChange = useCallback((changeConversationId: string) => { if (appId) { let prevValue = conversationIdInfo?.[appId || ''] diff --git a/web/app/components/base/date-and-time-picker/common/__tests__/option-list-item.spec.tsx b/web/app/components/base/date-and-time-picker/common/__tests__/option-list-item.spec.tsx index 8ccf8fab73..feca6d2e92 100644 --- a/web/app/components/base/date-and-time-picker/common/__tests__/option-list-item.spec.tsx +++ b/web/app/components/base/date-and-time-picker/common/__tests__/option-list-item.spec.tsx @@ -43,7 +43,7 @@ describe('OptionListItem', () => { , ) - const item = screen.getByRole('listitem') + const item = screen.getByRole('button') expect(item).toHaveClass('bg-components-button-ghost-bg-hover') }) @@ -54,7 +54,7 @@ describe('OptionListItem', () => { , ) - const item = screen.getByRole('listitem') + const item = screen.getByRole('button') expect(item).not.toHaveClass('bg-components-button-ghost-bg-hover') }) }) @@ -100,7 +100,7 @@ describe('OptionListItem', () => { Clickable , ) - fireEvent.click(screen.getByRole('listitem')) + fireEvent.click(screen.getByRole('button')) expect(handleClick).toHaveBeenCalledTimes(1) }) @@ -111,7 +111,7 @@ describe('OptionListItem', () => { Item , ) - fireEvent.click(screen.getByRole('listitem')) + fireEvent.click(screen.getByRole('button')) expect(Element.prototype.scrollIntoView).toHaveBeenCalledWith({ behavior: 'smooth' }) }) @@ -126,7 +126,7 @@ describe('OptionListItem', () => { , ) - const item = screen.getByRole('listitem') + const item = screen.getByRole('button') fireEvent.click(item) fireEvent.click(item) fireEvent.click(item) diff --git a/web/app/components/base/date-and-time-picker/common/__tests__/option-list.spec.tsx b/web/app/components/base/date-and-time-picker/common/__tests__/option-list.spec.tsx new file mode 100644 index 0000000000..9e6a8cace9 --- /dev/null +++ b/web/app/components/base/date-and-time-picker/common/__tests__/option-list.spec.tsx @@ -0,0 +1,28 @@ +import { render, screen } from '@testing-library/react' +import OptionList from '../option-list' + +describe('OptionList', () => { + it('should render a scrollable list with hidden scrollbar styles', () => { + render( + +
  • Item
  • +
    , + ) + + const list = screen.getByRole('list') + + expect(list).toHaveClass('overflow-y-auto') + expect(list).toHaveClass('[scrollbar-width:none]') + expect(list).toHaveClass('[&::-webkit-scrollbar]:hidden') + }) + + it('should append caller className after default classes', () => { + render( + +
  • Item
  • +
    , + ) + + expect(screen.getByRole('list')).toHaveClass('custom-list') + }) +}) diff --git a/web/app/components/base/date-and-time-picker/common/option-list-item.tsx b/web/app/components/base/date-and-time-picker/common/option-list-item.tsx index 31b303df1f..e1bfdde4be 100644 --- a/web/app/components/base/date-and-time-picker/common/option-list-item.tsx +++ b/web/app/components/base/date-and-time-picker/common/option-list-item.tsx @@ -1,4 +1,4 @@ -import type { FC } from 'react' +import type { FC, ReactNode } from 'react' import { cn } from '@langgenius/dify-ui/cn' import * as React from 'react' import { useEffect, useRef } from 'react' @@ -7,7 +7,8 @@ type OptionListItemProps = { isSelected: boolean onClick: () => void noAutoScroll?: boolean -} & React.LiHTMLAttributes + children: ReactNode +} const OptionListItem: FC = ({ isSelected, @@ -25,16 +26,21 @@ const OptionListItem: FC = ({ return (
  • { - listItemRef.current?.scrollIntoView({ behavior: 'smooth' }) - onClick() - }} > - {children} +
  • ) } diff --git a/web/app/components/base/date-and-time-picker/common/option-list.tsx b/web/app/components/base/date-and-time-picker/common/option-list.tsx new file mode 100644 index 0000000000..8fc2407089 --- /dev/null +++ b/web/app/components/base/date-and-time-picker/common/option-list.tsx @@ -0,0 +1,26 @@ +import type { HTMLAttributes, ReactNode } from 'react' +import { cn } from '@langgenius/dify-ui/cn' +import * as React from 'react' + +type OptionListProps = { + children: ReactNode +} & HTMLAttributes + +const optionListClassName = cn( + 'flex h-[208px] flex-col gap-y-0.5 overflow-y-auto pb-[184px]', + '[scrollbar-width:none] [&::-webkit-scrollbar]:hidden', +) + +const OptionList = ({ + children, + className, + ...props +}: OptionListProps) => { + return ( +
      + {children} +
    + ) +} + +export default React.memo(OptionList) diff --git a/web/app/components/base/date-and-time-picker/time-picker/__tests__/options.spec.tsx b/web/app/components/base/date-and-time-picker/time-picker/__tests__/options.spec.tsx index d7fa3be797..1bf3a52a8b 100644 --- a/web/app/components/base/date-and-time-picker/time-picker/__tests__/options.spec.tsx +++ b/web/app/components/base/date-and-time-picker/time-picker/__tests__/options.spec.tsx @@ -64,13 +64,13 @@ describe('TimePickerOptions', () => { it('should render selected hour in the list', () => { const props = createOptionsProps({ selectedTime: dayjs('2024-01-01 05:30:00') }) render() - const selectedHour = screen.getAllByRole('listitem').find(item => item.textContent === '05') + const selectedHour = screen.getAllByRole('button').find(item => item.textContent === '05') expect(selectedHour)!.toHaveClass('bg-components-button-ghost-bg-hover') }) it('should render selected minute in the list', () => { const props = createOptionsProps({ selectedTime: dayjs('2024-01-01 05:30:00') }) render() - const selectedMinute = screen.getAllByRole('listitem').find(item => item.textContent === '30') + const selectedMinute = screen.getAllByRole('button').find(item => item.textContent === '30') expect(selectedMinute)!.toHaveClass('bg-components-button-ghost-bg-hover') }) diff --git a/web/app/components/base/date-and-time-picker/time-picker/options.tsx b/web/app/components/base/date-and-time-picker/time-picker/options.tsx index 6f6e5e9c68..10e94f983d 100644 --- a/web/app/components/base/date-and-time-picker/time-picker/options.tsx +++ b/web/app/components/base/date-and-time-picker/time-picker/options.tsx @@ -1,6 +1,7 @@ import type { FC } from 'react' import type { TimeOptionsProps } from '../types' import * as React from 'react' +import OptionList from '../common/option-list' import OptionListItem from '../common/option-list-item' import { useTimeOptions } from '../hooks' @@ -16,7 +17,7 @@ const Options: FC = ({ return (
    {/* Hour */} -
      + { hourOptions.map((hour) => { const isSelected = selectedTime?.format('hh') === hour @@ -31,9 +32,9 @@ const Options: FC = ({ ) }) } -
    + {/* Minute */} -
      + { (minuteFilter ? minuteFilter(minuteOptions) : minuteOptions).map((minute) => { const isSelected = selectedTime?.format('mm') === minute @@ -48,9 +49,9 @@ const Options: FC = ({ ) }) } -
    + {/* Period */} -
      + { periodOptions.map((period) => { const isSelected = selectedTime?.format('A') === period @@ -66,7 +67,7 @@ const Options: FC = ({ ) }) } -
    +
    ) } diff --git a/web/app/components/base/date-and-time-picker/types.ts b/web/app/components/base/date-and-time-picker/types.ts index 2773fb7bc7..7dda1d013c 100644 --- a/web/app/components/base/date-and-time-picker/types.ts +++ b/web/app/components/base/date-and-time-picker/types.ts @@ -1,4 +1,4 @@ -import type { Placement } from '@floating-ui/react' +import type { Placement } from '@langgenius/dify-ui/popover' import type { Dayjs } from 'dayjs' export enum ViewType { diff --git a/web/app/components/base/date-and-time-picker/year-and-month-picker/options.tsx b/web/app/components/base/date-and-time-picker/year-and-month-picker/options.tsx index 2288925579..2e162472f4 100644 --- a/web/app/components/base/date-and-time-picker/year-and-month-picker/options.tsx +++ b/web/app/components/base/date-and-time-picker/year-and-month-picker/options.tsx @@ -1,6 +1,7 @@ import type { FC } from 'react' import type { YearAndMonthPickerOptionsProps } from '../types' import * as React from 'react' +import OptionList from '../common/option-list' import OptionListItem from '../common/option-list-item' import { useMonths, useYearOptions } from '../hooks' @@ -16,7 +17,7 @@ const Options: FC = ({ return (
    {/* Month Picker */} -
      + { months.map((month, index) => { const isSelected = selectedMonth === index @@ -31,9 +32,9 @@ const Options: FC = ({ ) }) } -
    + {/* Year Picker */} -
      + { yearOptions.map((year) => { const isSelected = selectedYear === year @@ -48,7 +49,7 @@ const Options: FC = ({ ) }) } -
    +
    ) } diff --git a/web/app/components/datasets/create/step-two/language-select/index.tsx b/web/app/components/datasets/create/step-two/language-select/index.tsx index fdef23ff27..bd1eee3df6 100644 --- a/web/app/components/datasets/create/step-two/language-select/index.tsx +++ b/web/app/components/datasets/create/step-two/language-select/index.tsx @@ -42,7 +42,6 @@ const LanguageSelect: FC = ({ placement="bottom-start" sideOffset={4} popupClassName="w-max" - listClassName="no-scrollbar" > {supportedLanguages.map(({ prompt_name }) => ( diff --git a/web/app/components/header/account-setting/model-provider-page/model-selector/__tests__/popup.spec.tsx b/web/app/components/header/account-setting/model-provider-page/model-selector/__tests__/popup.spec.tsx index 318b5bcd73..42232a71c0 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-selector/__tests__/popup.spec.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-selector/__tests__/popup.spec.tsx @@ -55,7 +55,14 @@ vi.mock('../../hooks', async () => { }) vi.mock('../popup-item', () => ({ - default: ({ model }: { model: Model }) =>
    {model.provider}
    , + default: ({ model }: { model: Model }) => ( +
    + {model.provider} + {model.models.map(modelItem => ( + {modelItem.model} + ))} +
    + ), })) vi.mock('@/context/provider-context', () => ({ @@ -207,6 +214,156 @@ describe('Popup', () => { expect((input as HTMLInputElement).value).toBe('') }) + it('should show matching models when searching by model name', () => { + renderPopup( + , + ) + + fireEvent.change( + screen.getByPlaceholderText('datasetSettings.form.searchModel'), + { target: { value: 'claude' } }, + ) + + expect(screen.queryByText('openai')).not.toBeInTheDocument() + expect(screen.getByText('anthropic')).toBeInTheDocument() + expect(screen.getByText('claude-3')).toBeInTheDocument() + expect(screen.queryByText('gpt-4')).not.toBeInTheDocument() + expect(screen.queryByText('No model found for \u201Cclaude\u201D')).not.toBeInTheDocument() + }) + + it('should show empty search placeholder when no provider or model name matches', () => { + renderPopup( + , + ) + + fireEvent.change( + screen.getByPlaceholderText('datasetSettings.form.searchModel'), + { target: { value: 'mistral' } }, + ) + + expect(screen.getByText('No model found for \u201Cmistral\u201D'))!.toBeInTheDocument() + expect(screen.queryByText('openai')).not.toBeInTheDocument() + expect(screen.queryByText('gpt-4')).not.toBeInTheDocument() + }) + + it('should show all models of a provider when searching by provider label', () => { + renderPopup( + , + ) + + fireEvent.change( + screen.getByPlaceholderText('datasetSettings.form.searchModel'), + { target: { value: 'openai' } }, + ) + + expect(screen.getByText('openai'))!.toBeInTheDocument() + expect(screen.getByText('gpt-4'))!.toBeInTheDocument() + expect(screen.getByText('gpt-4o'))!.toBeInTheDocument() + expect(screen.queryByText('anthropic')).not.toBeInTheDocument() + expect(screen.queryByText('claude-3')).not.toBeInTheDocument() + }) + + it('should match by model provider key when model label does not contain the search text', () => { + renderPopup( + , + ) + + fireEvent.change( + screen.getByPlaceholderText('datasetSettings.form.searchModel'), + { target: { value: 'openai' } }, + ) + + expect(screen.getByText('azure_openai'))!.toBeInTheDocument() + expect(screen.getByText('gpt-4'))!.toBeInTheDocument() + }) + + it('should still apply scope features when matching by provider label', () => { + mockSupportFunctionCall.mockReturnValue(false) + + renderPopup( + , + ) + + fireEvent.change( + screen.getByPlaceholderText('datasetSettings.form.searchModel'), + { target: { value: 'openai' } }, + ) + + expect(screen.getByText('No model found for \u201Copenai\u201D'))!.toBeInTheDocument() + expect(screen.queryByText('gpt-4')).not.toBeInTheDocument() + expect(screen.queryByText('gpt-4-tool')).not.toBeInTheDocument() + }) + it('should not show compatible-only helper text when no scope features are applied', () => { renderPopup( { expect(screen.queryByText('common.modelProvider.selector.onlyCompatibleModelsShown')).not.toBeInTheDocument() }) - it('should show compatible-only helper banner when scope features are applied', () => { - const { container } = renderPopup( + it('should show compatible-only helper text when scope features are applied', () => { + renderPopup( { expect(screen.getByTestId('compatible-models-banner'))!.toBeInTheDocument() expect(screen.getByText('common.modelProvider.selector.onlyCompatibleModelsShown'))!.toBeInTheDocument() - expect(container.querySelector('.i-ri-information-2-fill'))!.toBeInTheDocument() + }) + + it('should keep search and footer outside the scrollable model list', () => { + renderPopup( + , + ) + + const scrollRegion = screen.getByRole('region', { name: 'common.modelProvider.models' }) + const searchInput = screen.getByPlaceholderText('datasetSettings.form.searchModel') + const settingsButton = screen.getByRole('button', { name: /common\.modelProvider\.selector\.modelProviderSettings/ }) + + expect(scrollRegion)!.toBeInTheDocument() + expect(scrollRegion).not.toContainElement(searchInput) + expect(scrollRegion).not.toContainElement(settingsButton) + expect(scrollRegion).toContainElement(screen.getByTestId('compatible-models-banner')) }) it('should filter by scope features including toolCall and non-toolCall checks', () => { diff --git a/web/app/components/header/account-setting/model-provider-page/model-selector/index.tsx b/web/app/components/header/account-setting/model-provider-page/model-selector/index.tsx index 9241c592f5..835821fd59 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-selector/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-selector/index.tsx @@ -88,7 +88,7 @@ const ModelSelector: FC = ({ placement="bottom-start" sideOffset={4} className={popupClassName} - popupClassName="overflow-hidden rounded-lg" + popupClassName="overflow-hidden rounded-xl" popupProps={{ style: { minWidth: '320px', width: 'var(--anchor-width, auto)' } }} > void + onInstallPlugin: (key: ModelProviderQuotaGetPaid) => void | Promise +} + +const MarketplaceSection: FC = ({ + marketplaceProviders, + marketplaceCollapsed, + installingProvider, + isMarketplacePluginsLoading, + theme, + onMarketplaceCollapsedChange, + onInstallPlugin, +}) => { + const { t } = useTranslation() + + if (marketplaceProviders.length === 0) + return null + + return ( + <> +
    +
    +
    +
    +
    +
    onMarketplaceCollapsedChange(!marketplaceCollapsed)} + > + {t('modelProvider.selector.fromMarketplace', { ns: 'common' })} + +
    +
    + {!marketplaceCollapsed && ( +
    + {marketplaceProviders.map((key) => { + const Icon = providerIconMap[key] + const isInstalling = installingProvider === key + return ( +
    +
    + + {modelNameMap[key]} +
    + +
    + ) + })} + + + {t('modelProvider.selector.discoverMoreInMarketplace', { ns: 'common' })} + + + +
    + )} +
    + + ) +} + +export default MarketplaceSection diff --git a/web/app/components/header/account-setting/model-provider-page/model-selector/popup-empty-state.tsx b/web/app/components/header/account-setting/model-provider-page/model-selector/popup-empty-state.tsx new file mode 100644 index 0000000000..dafd26387b --- /dev/null +++ b/web/app/components/header/account-setting/model-provider-page/model-selector/popup-empty-state.tsx @@ -0,0 +1,39 @@ +import type { FC } from 'react' +import { Button } from '@langgenius/dify-ui/button' +import { useTranslation } from 'react-i18next' + +type ModelSelectorEmptyStateProps = { + onConfigure: () => void +} + +const ModelSelectorEmptyState: FC = ({ + onConfigure, +}) => { + const { t } = useTranslation() + + return ( +
    +
    + +
    +
    +

    + {t('modelProvider.selector.noProviderConfigured', { ns: 'common' })} +

    +

    + {t('modelProvider.selector.noProviderConfiguredDesc', { ns: 'common' })} +

    +
    + +
    + ) +} + +export default ModelSelectorEmptyState diff --git a/web/app/components/header/account-setting/model-provider-page/model-selector/popup-item.tsx b/web/app/components/header/account-setting/model-provider-page/model-selector/popup-item.tsx index 72c52a9429..ff9e6575bb 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-selector/popup-item.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-selector/popup-item.tsx @@ -107,7 +107,8 @@ const PopupItem: FC = ({ return (
    -
    + {/* Keep the sticky provider header above model rows while the list scrolls. */} +
    setCollapsed(prev => !prev)} diff --git a/web/app/components/header/account-setting/model-provider-page/model-selector/popup-layout.tsx b/web/app/components/header/account-setting/model-provider-page/model-selector/popup-layout.tsx new file mode 100644 index 0000000000..50bd098af1 --- /dev/null +++ b/web/app/components/header/account-setting/model-provider-page/model-selector/popup-layout.tsx @@ -0,0 +1,130 @@ +import type { FC, ReactNode } from 'react' +import { + ScrollAreaContent, + ScrollAreaRoot, + ScrollAreaScrollbar, + ScrollAreaThumb, + ScrollAreaViewport, +} from '@langgenius/dify-ui/scroll-area' +import { useTranslation } from 'react-i18next' + +type ModelSelectorPopupFrameProps = { + children: ReactNode +} + +export const ModelSelectorPopupFrame: FC = ({ + children, +}) => { + return ( +
    + {children} +
    + ) +} + +type ModelSelectorSearchHeaderProps = { + searchText: string + onSearchTextChange: (value: string) => void +} + +export const ModelSelectorSearchHeader: FC = ({ + searchText, + onSearchTextChange, +}) => { + const { t } = useTranslation() + + return ( +
    +
    + + onSearchTextChange(e.target.value)} + /> + { + searchText && ( + onSearchTextChange('')} + /> + ) + } +
    +
    + ) +} + +type ModelSelectorScrollBodyProps = { + children: ReactNode + label: string +} + +export const ModelSelectorScrollBody: FC = ({ + children, + label, +}) => { + return ( + + + + {children} + + + {/* Keep the overlay scrollbar above sticky provider headers inside this scroll area. */} + + + + + ) +} + +export const CompatibleModelsNotice = () => { + const { t } = useTranslation() + + return ( +
    + {t('modelProvider.selector.onlyCompatibleModelsShown', { ns: 'common' })} +
    + ) +} + +type ModelProviderSettingsFooterProps = { + onOpenSettings: () => void +} + +export const ModelProviderSettingsFooter: FC = ({ + onOpenSettings, +}) => { + const { t } = useTranslation() + + return ( +
    + +
    + ) +} diff --git a/web/app/components/header/account-setting/model-provider-page/model-selector/popup.tsx b/web/app/components/header/account-setting/model-provider-page/model-selector/popup.tsx index 47ddb55b6c..86bac84310 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-selector/popup.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-selector/popup.tsx @@ -5,8 +5,6 @@ import type { ModelItem, } from '../declarations' import type { ModelProviderQuotaGetPaid } from '@/types/model-provider' -import { Button } from '@langgenius/dify-ui/button' -import { cn } from '@langgenius/dify-ui/cn' import { useSuspenseQuery } from '@tanstack/react-query' import { useTheme } from 'next-themes' import { useCallback, useMemo, useState } from 'react' @@ -19,7 +17,6 @@ import { useProviderContext } from '@/context/provider-context' import { systemFeaturesQueryOptions } from '@/service/system-features' import { useInstallPackageFromMarketPlace } from '@/service/use-plugins' import { supportFunctionCall } from '@/utils/tool-call' -import { getMarketplaceUrl } from '@/utils/var' import { CustomConfigurationStatusEnum, ModelFeatureEnum, @@ -29,8 +26,17 @@ import { useLanguage, useMarketplaceAllPlugins } from '../hooks' import CreditsExhaustedAlert from '../provider-added-card/model-auth-dropdown/credits-exhausted-alert' import { useTrialCredits } from '../provider-added-card/use-trial-credits' import { providerSupportsCredits } from '../supports-credits' -import { MODEL_PROVIDER_QUOTA_GET_PAID, modelNameMap, providerIconMap, providerKeyToPluginId } from '../utils' +import { MODEL_PROVIDER_QUOTA_GET_PAID, providerKeyToPluginId } from '../utils' +import MarketplaceSection from './marketplace-section' +import ModelSelectorEmptyState from './popup-empty-state' import PopupItem from './popup-item' +import { + CompatibleModelsNotice, + ModelProviderSettingsFooter, + ModelSelectorPopupFrame, + ModelSelectorScrollBody, + ModelSelectorSearchHeader, +} from './popup-layout' type PopupProps = { defaultModel?: DefaultModel @@ -137,18 +143,26 @@ const Popup: FC = ({ }, [aiCreditVisibleProviders, installedProviderMap, modelList]) const filteredModelList = useMemo(() => { + const normalizedSearch = searchText.toLowerCase() + const matchesLabel = (label: Record) => { + if (label[language] !== undefined) + return label[language].toLowerCase().includes(normalizedSearch) + return Object.values(label).some(value => + value.toLowerCase().includes(normalizedSearch), + ) + } + const filtered = installedModelList.map((model) => { - const matchesProviderSearch = !searchText - || model.provider.toLowerCase().includes(searchText.toLowerCase()) - || Object.values(model.label).some(label => label.toLowerCase().includes(searchText.toLowerCase())) + const providerMatched = !!searchText && ( + matchesLabel(model.label) + || model.provider.toLowerCase().includes(normalizedSearch) + ) const filteredModels = model.models .filter((modelItem) => { - if (modelItem.label[language] !== undefined) - return modelItem.label[language].toLowerCase().includes(searchText.toLowerCase()) - return Object.values(modelItem.label).some(label => - label.toLowerCase().includes(searchText.toLowerCase()), - ) + if (!searchText || providerMatched) + return true + return matchesLabel(modelItem.label) }) .filter((modelItem) => { if (scopeFeatures.length === 0) @@ -159,8 +173,12 @@ const Popup: FC = ({ return modelItem.features?.includes(feature) ?? false }) }) - if (!matchesProviderSearch || (filteredModels.length === 0 && !aiCreditVisibleProviders.has(model.provider))) + if ( + (searchText && filteredModels.length === 0) + || (!searchText && filteredModels.length === 0 && !aiCreditVisibleProviders.has(model.provider)) + ) { return null + } return { ...model, models: filteredModels } }).filter((model): model is Model => model !== null) @@ -181,166 +199,59 @@ const Popup: FC = ({ return MODEL_PROVIDER_QUOTA_GET_PAID.filter(key => !installedProviders.has(key)) }, [modelProviders]) + const handleOpenSettings = useCallback(() => { + onHide() + setShowAccountSettingModal({ payload: ACCOUNT_SETTING_TAB.PROVIDER }) + }, [onHide, setShowAccountSettingModal]) + return ( -
    -
    -
    - - setSearchText(e.target.value)} - /> - { - searchText && ( - setSearchText('')} - /> - ) - } -
    - {scopeFeatures.length > 0 && ( -
    - -

    - {t('modelProvider.selector.onlyCompatibleModelsShown', { ns: 'common' })} -

    -
    - )} -
    + + {showCreditsExhaustedAlert && ( )} -
    - { - filteredModelList.map(model => ( - +
    + { + filteredModelList.map(model => ( + + )) + } + {!filteredModelList.length && !installedModelList.length && ( + - )) - } - {!filteredModelList.length && !installedModelList.length && ( -
    -
    - + )} + {!filteredModelList.length && installedModelList.length > 0 && ( +
    + {`No model found for \u201C${searchText}\u201D`}
    -
    -

    - {t('modelProvider.selector.noProviderConfigured', { ns: 'common' })} -

    -

    - {t('modelProvider.selector.noProviderConfiguredDesc', { ns: 'common' })} -

    -
    - -
    - )} - {!filteredModelList.length && installedModelList.length > 0 && ( -
    - {`No model found for \u201C${searchText}\u201D`} -
    - )} - {marketplaceProviders.length > 0 && ( - <> -
    -
    -
    -
    setMarketplaceCollapsed(prev => !prev)} - > - {t('modelProvider.selector.fromMarketplace', { ns: 'common' })} - -
    -
    - {!marketplaceCollapsed && ( - <> - {marketplaceProviders.map((key) => { - const Icon = providerIconMap[key] - const isInstalling = installingProvider === key - return ( -
    -
    - - {modelNameMap[key]} -
    - -
    - ) - })} - - - {t('modelProvider.selector.discoverMoreInMarketplace', { ns: 'common' })} - - - - - )} -
    - - )} -
    -
    { - onHide() - setShowAccountSettingModal({ payload: ACCOUNT_SETTING_TAB.PROVIDER }) - }} - > - - {t('modelProvider.selector.modelProviderSettings', { ns: 'common' })} -
    -
    + )} + {scopeFeatures.length > 0 && ( + + )} + +
    + + + ) } diff --git a/web/app/components/header/app-nav/__tests__/index.spec.tsx b/web/app/components/header/app-nav/__tests__/index.spec.tsx index 03f8edfacf..e7b546a589 100644 --- a/web/app/components/header/app-nav/__tests__/index.spec.tsx +++ b/web/app/components/header/app-nav/__tests__/index.spec.tsx @@ -2,13 +2,16 @@ import { render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' import { useStore as useAppStore } from '@/app/components/app/store' import { useAppContext } from '@/context/app-context' -import { useParams } from '@/next/navigation' +import { useParams, usePathname, useRouter } from '@/next/navigation' import { useInfiniteAppList } from '@/service/use-apps' +import { useCreateSnippetMutation, useInfiniteSnippetList, useSnippetApiDetail } from '@/service/use-snippets' import { AppModeEnum } from '@/types/app' import AppNav from '../index' vi.mock('@/next/navigation', () => ({ useParams: vi.fn(), + usePathname: vi.fn(), + useRouter: vi.fn(), })) vi.mock('react-i18next', () => ({ @@ -29,6 +32,19 @@ vi.mock('@/service/use-apps', () => ({ useInfiniteAppList: vi.fn(), })) +vi.mock('@/service/use-snippets', () => ({ + useCreateSnippetMutation: vi.fn(), + useInfiniteSnippetList: vi.fn(), + useSnippetApiDetail: vi.fn(), +})) + +vi.mock('@langgenius/dify-ui/toast', () => ({ + toast: { + error: vi.fn(), + success: vi.fn(), + }, +})) + vi.mock('@/app/components/app/create-app-dialog', () => ({ default: ({ show, onClose, onSuccess }: { show: boolean, onClose: () => void, onSuccess: () => void }) => show @@ -83,17 +99,67 @@ vi.mock('@/app/components/app/create-from-dsl-modal', () => ({ : null, })) +vi.mock('@/app/components/workflow/create-snippet-dialog', () => ({ + default: ({ + isOpen, + onClose, + onConfirm, + }: { + isOpen: boolean + onClose: () => void + onConfirm: (payload: { + name: string + description: string + icon: { type: 'emoji', icon: string, background: string } + }) => void + }) => + isOpen + ? ( + + ) + : null, +})) + vi.mock('../../nav', () => ({ default: ({ + createText, + curNav, + isApp, + link, onCreate, onLoadMore, navigationItems, }: { + createText: string + curNav?: { id: string, name: string } + isApp?: boolean + link: string onCreate: (state: string) => void onLoadMore?: () => void navigationItems?: Array<{ id: string, name: string, link: string }> }) => (
    +
    {link}
    +
    {String(isApp)}
    +
    {createText}
    +
    {curNav ? `${curNav.id}:${curNav.name}` : ''}
      {(navigationItems ?? []).map(item => (
    • {`${item.name} -> ${item.link}`}
    • @@ -127,10 +193,52 @@ const mockAppData = [ }, ] +const mockSnippetData = [ + { + id: 'snippet-1', + name: 'Snippet 1', + description: '', + icon_info: { + icon_type: 'emoji', + icon: '🧩', + icon_background: '#fff', + icon_url: null, + }, + is_published: true, + use_count: 0, + created_at: 1, + created_by: 'user-1', + updated_at: 1, + updated_by: 'user-1', + }, + { + id: 'snippet-2', + name: 'Snippet 2', + description: '', + icon_info: { + icon_type: 'emoji', + icon: '⚙️', + icon_background: '#000', + icon_url: null, + }, + is_published: true, + use_count: 1, + created_at: 1, + created_by: 'user-1', + updated_at: 1, + updated_by: 'user-1', + }, +] + const mockUseParams = vi.mocked(useParams) +const mockUsePathname = vi.mocked(usePathname) +const mockUseRouter = vi.mocked(useRouter) const mockUseAppContext = vi.mocked(useAppContext) const mockUseAppStore = vi.mocked(useAppStore) const mockUseInfiniteAppList = vi.mocked(useInfiniteAppList) +const mockUseInfiniteSnippetList = vi.mocked(useInfiniteSnippetList) +const mockUseSnippetApiDetail = vi.mocked(useSnippetApiDetail) +const mockUseCreateSnippetMutation = vi.mocked(useCreateSnippetMutation) let mockAppDetail: { id: string, name: string } | null = null const setupDefaultMocks = (options?: { @@ -144,6 +252,8 @@ const setupDefaultMocks = (options?: { const fetchNextPage = options?.fetchNextPage ?? vi.fn() mockUseParams.mockReturnValue({ appId: 'app-1' } as ReturnType) + mockUsePathname.mockReturnValue('/app/app-1/workflow') + mockUseRouter.mockReturnValue({ push: vi.fn() } as unknown as ReturnType) mockUseAppContext.mockReturnValue({ isCurrentWorkspaceEditor: options?.isEditor ?? false } as ReturnType) mockUseAppStore.mockImplementation((selector: unknown) => (selector as (state: { appDetail: { id: string, name: string } | null }) => unknown)({ appDetail: mockAppDetail })) mockUseInfiniteAppList.mockReturnValue({ @@ -153,10 +263,51 @@ const setupDefaultMocks = (options?: { isFetchingNextPage: false, refetch, } as ReturnType) + mockUseInfiniteSnippetList.mockReturnValue({ + data: undefined, + fetchNextPage: vi.fn(), + hasNextPage: false, + isFetchingNextPage: false, + } as unknown as ReturnType) + mockUseSnippetApiDetail.mockReturnValue({ + data: undefined, + } as ReturnType) + mockUseCreateSnippetMutation.mockReturnValue({ + isPending: false, + mutate: vi.fn(), + } as unknown as ReturnType) return { refetch, fetchNextPage } } +const setupSnippetMocks = (options?: { + fetchNextPage?: () => void + hasNextPage?: boolean + mutate?: ReturnType +}) => { + const fetchNextPage = options?.fetchNextPage ?? vi.fn() + const mutate = options?.mutate ?? vi.fn() + + setupDefaultMocks({ isEditor: true }) + mockUseParams.mockReturnValue({ snippetId: 'snippet-1' } as ReturnType) + mockUsePathname.mockReturnValue('/snippets/snippet-1/orchestrate') + mockUseInfiniteSnippetList.mockReturnValue({ + data: { pages: [{ data: mockSnippetData }] }, + fetchNextPage, + hasNextPage: options?.hasNextPage ?? false, + isFetchingNextPage: false, + } as unknown as ReturnType) + mockUseSnippetApiDetail.mockReturnValue({ + data: mockSnippetData[0], + } as ReturnType) + mockUseCreateSnippetMutation.mockReturnValue({ + isPending: false, + mutate, + } as unknown as ReturnType) + + return { fetchNextPage, mutate } +} + describe('AppNav', () => { beforeEach(() => { vi.clearAllMocks() @@ -338,4 +489,67 @@ describe('AppNav', () => { expect(screen.getByText('App 1 -> /app/app-1/configuration')).toBeInTheDocument() }) }) + + it('should switch the main nav to snippet list and render snippet items on snippet detail routes', () => { + setupSnippetMocks() + + render() + + expect(screen.getByTestId('nav-link')).toHaveTextContent('/snippets') + expect(screen.getByTestId('nav-is-app')).toHaveTextContent('false') + expect(screen.getByTestId('nav-current')).toHaveTextContent('snippet-1:Snippet 1') + expect(screen.getByTestId('nav-create-text')).toHaveTextContent('createFromBlank') + expect(screen.getByText('Snippet 1 -> /snippets/snippet-1/orchestrate')).toBeInTheDocument() + expect(screen.getByText('Snippet 2 -> /snippets/snippet-2/orchestrate')).toBeInTheDocument() + }) + + it('should not show stale snippet detail as the current nav while switching snippets', () => { + setupSnippetMocks() + mockUseParams.mockReturnValue({ snippetId: 'snippet-2' } as ReturnType) + mockUseSnippetApiDetail.mockReturnValue({ + data: mockSnippetData[0], + } as ReturnType) + + render() + + expect(screen.getByTestId('nav-current')).toBeEmptyDOMElement() + }) + + it('should load more snippets from the snippet selector when more data is available', async () => { + const user = userEvent.setup() + const { fetchNextPage } = setupSnippetMocks({ hasNextPage: true }) + + render() + + await user.click(screen.getByTestId('load-more')) + expect(fetchNextPage).toHaveBeenCalledTimes(1) + }) + + it('should open the create snippet dialog from snippet nav create action', async () => { + const user = userEvent.setup() + const mutate = vi.fn() + setupSnippetMocks({ mutate }) + + render() + + await user.click(screen.getByTestId('create-blank')) + expect(screen.getByTestId('create-snippet-dialog')).toBeInTheDocument() + + await user.click(screen.getByTestId('create-snippet-dialog')) + expect(mutate).toHaveBeenCalledWith({ + body: { + name: 'Created Snippet', + description: undefined, + icon_info: { + icon: '🤖', + icon_type: 'emoji', + icon_background: '#fff', + icon_url: undefined, + }, + }, + }, expect.objectContaining({ + onSuccess: expect.any(Function), + onError: expect.any(Function), + })) + }) }) diff --git a/web/app/components/header/app-nav/index.tsx b/web/app/components/header/app-nav/index.tsx index d98eaa1b3b..4f3a580ef8 100644 --- a/web/app/components/header/app-nav/index.tsx +++ b/web/app/components/header/app-nav/index.tsx @@ -1,19 +1,22 @@ 'use client' import type { NavItem } from '../nav/nav-selector' -import { - RiRobot2Fill, - RiRobot2Line, -} from '@remixicon/react' +import type { AppIconSelection } from '@/app/components/base/app-icon-picker' +import { toast } from '@langgenius/dify-ui/toast' import { flatten } from 'es-toolkit/compat' -import { produce } from 'immer' -import { useCallback, useEffect, useState } from 'react' +import { useCallback, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' import { useStore as useAppStore } from '@/app/components/app/store' +import CreateSnippetDialog from '@/app/components/workflow/create-snippet-dialog' import { useAppContext } from '@/context/app-context' import dynamic from '@/next/dynamic' -import { useParams } from '@/next/navigation' +import { useParams, usePathname, useRouter } from '@/next/navigation' import { useInfiniteAppList } from '@/service/use-apps' +import { + useCreateSnippetMutation, + useInfiniteSnippetList, + useSnippetApiDetail, +} from '@/service/use-snippets' import { AppModeEnum } from '@/types/app' import Nav from '../nav' @@ -23,13 +26,18 @@ const CreateFromDSLModal = dynamic(() => import('@/app/components/app/create-fro const AppNav = () => { const { t } = useTranslation() - const { appId } = useParams() + const { appId, snippetId } = useParams() + const { push } = useRouter() + const pathname = usePathname() + const isSnippetSegment = pathname === '/snippets' || pathname.startsWith('/snippets/') + const currentSnippetId = typeof snippetId === 'string' ? snippetId : '' const { isCurrentWorkspaceEditor } = useAppContext() const appDetail = useAppStore(state => state.appDetail) const [showNewAppDialog, setShowNewAppDialog] = useState(false) const [showNewAppTemplateDialog, setShowNewAppTemplateDialog] = useState(false) const [showCreateFromDSLModal, setShowCreateFromDSLModal] = useState(false) - const [navItems, setNavItems] = useState([]) + const [showCreateSnippetDialog, setShowCreateSnippetDialog] = useState(false) + const createSnippetMutation = useCreateSnippetMutation() const { data: appsData, @@ -41,14 +49,36 @@ const AppNav = () => { page: 1, limit: 30, name: '', - }, { enabled: !!appId }) + }, { enabled: !!appId && !isSnippetSegment }) + + const { + data: snippetsData, + fetchNextPage: fetchNextSnippetPage, + hasNextPage: hasNextSnippetPage, + isFetchingNextPage: isFetchingNextSnippetPage, + } = useInfiniteSnippetList({ + page: 1, + limit: 30, + }, { enabled: !!currentSnippetId }) + + const { data: snippetDetail } = useSnippetApiDetail(currentSnippetId) const handleLoadMore = useCallback(() => { if (hasNextPage) fetchNextPage() }, [fetchNextPage, hasNextPage]) + const handleLoadMoreSnippet = useCallback(() => { + if (hasNextSnippetPage) + fetchNextSnippetPage() + }, [fetchNextSnippetPage, hasNextSnippetPage]) + const openModal = (state: string) => { + if (isSnippetSegment) { + setShowCreateSnippetDialog(true) + return + } + if (state === 'blank') setShowNewAppDialog(true) if (state === 'template') @@ -57,64 +87,125 @@ const AppNav = () => { setShowCreateFromDSLModal(true) } - useEffect(() => { - if (appsData) { - const appItems = flatten((appsData.pages ?? []).map(appData => appData.data)) - const navItems = appItems.map((app) => { - const link = ((isCurrentWorkspaceEditor, app) => { - if (!isCurrentWorkspaceEditor) { - return `/app/${app.id}/overview` - } - else { - if (app.mode === AppModeEnum.WORKFLOW || app.mode === AppModeEnum.ADVANCED_CHAT) - return `/app/${app.id}/workflow` - else - return `/app/${app.id}/configuration` - } - })(isCurrentWorkspaceEditor, app) - return { - id: app.id, - icon_type: app.icon_type, - icon: app.icon, - icon_background: app.icon_background, - icon_url: app.icon_url, - name: app.name, - mode: app.mode, - link, - } - }) - setNavItems(navItems as any) - } - }, [appsData, isCurrentWorkspaceEditor, setNavItems]) + const appNavItems = useMemo(() => { + if (!appsData) + return [] - // update current app name - useEffect(() => { - if (appDetail) { - const newNavItems = produce(navItems, (draft: NavItem[]) => { - navItems.forEach((app, index) => { - if (app.id === appDetail.id) - draft[index]!.name = appDetail.name - }) - }) - setNavItems(newNavItems) + const appItems = flatten((appsData.pages ?? []).map(appData => appData.data)) + + return appItems.map((app) => { + const link = (() => { + if (!isCurrentWorkspaceEditor) + return `/app/${app.id}/overview` + + if (app.mode === AppModeEnum.WORKFLOW || app.mode === AppModeEnum.ADVANCED_CHAT) + return `/app/${app.id}/workflow` + + return `/app/${app.id}/configuration` + })() + + return { + id: app.id, + icon_type: app.icon_type, + icon: app.icon, + icon_background: app.icon_background, + icon_url: app.icon_url, + name: appDetail?.id === app.id ? appDetail.name : app.name, + mode: app.mode, + link, + } + }) + }, [appDetail?.id, appDetail?.name, appsData, isCurrentWorkspaceEditor]) + + const snippetNavItems = useMemo(() => { + if (!snippetsData) + return [] + + const snippetItems = flatten((snippetsData.pages ?? []).map(snippetData => snippetData.data)) + + return snippetItems.map(snippet => ({ + id: snippet.id, + icon_type: snippet.icon_info.icon_type, + icon: snippet.icon_info.icon, + icon_background: snippet.icon_info.icon_background ?? null, + icon_url: snippet.icon_info.icon_url ?? null, + name: snippet.name, + link: `/snippets/${snippet.id}/orchestrate`, + })) + }, [snippetsData]) + + const currentSnippetNav = useMemo(() => { + if (!snippetDetail) + return + + if (snippetDetail.id !== currentSnippetId) + return + + return { + id: snippetDetail.id, + icon_type: snippetDetail.icon_info.icon_type, + icon: snippetDetail.icon_info.icon, + icon_background: snippetDetail.icon_info.icon_background ?? null, + icon_url: snippetDetail.icon_info.icon_url ?? null, + name: snippetDetail.name, } - }, [appDetail, navItems]) + }, [currentSnippetId, snippetDetail]) + + const handleCreateSnippet = useCallback(({ + name, + description, + icon, + }: { + name: string + description: string + icon: AppIconSelection + }) => { + createSnippetMutation.mutate({ + body: { + name, + description: description || undefined, + icon_info: { + icon: icon.type === 'emoji' ? icon.icon : icon.fileId, + icon_type: icon.type, + icon_background: icon.type === 'emoji' ? icon.background : undefined, + icon_url: icon.type === 'image' ? icon.url : undefined, + }, + }, + }, { + onSuccess: (snippet) => { + toast.success(t('snippet.createSuccess', { ns: 'workflow' })) + setShowCreateSnippetDialog(false) + push(`/snippets/${snippet.id}/orchestrate`) + }, + onError: (error) => { + toast.error(error instanceof Error ? error.message : t('createFailed', { ns: 'snippet' })) + }, + }) + }, [createSnippetMutation, push, t]) + + const currentNav = isSnippetSegment ? currentSnippetNav : appDetail + const currentNavigationItems = isSnippetSegment ? snippetNavItems : appNavItems + const currentCreateText = isSnippetSegment + ? t('createFromBlank', { ns: 'snippet' }) + : t('menus.newApp', { ns: 'common' }) + const currentLoadMore = isSnippetSegment ? handleLoadMoreSnippet : handleLoadMore + const currentIsLoadingMore = isSnippetSegment ? isFetchingNextSnippetPage : isFetchingNextPage return ( <>
    -
    + {errorPlugins.map(plugin => ( = ({ onClear={() => onClearSingle(plugin.taskId, plugin.plugin_unique_identifier)} /> ))} -
    + )}
    diff --git a/web/app/components/plugins/plugin-page/plugin-tasks/index.tsx b/web/app/components/plugins/plugin-page/plugin-tasks/index.tsx index f3102c4909..00fcb7e072 100644 --- a/web/app/components/plugins/plugin-page/plugin-tasks/index.tsx +++ b/web/app/components/plugins/plugin-page/plugin-tasks/index.tsx @@ -117,7 +117,7 @@ const PluginTasks = () => { { }) describe('StrategyPicker (strategy-picker.tsx)', () => { - const defaultProps = { - value: AUTO_UPDATE_STRATEGY.disabled, - onChange: vi.fn(), + const i18nKeyByStrategy: Record = { + [AUTO_UPDATE_STRATEGY.disabled]: 'disabled', + [AUTO_UPDATE_STRATEGY.fixOnly]: 'fixOnly', + [AUTO_UPDATE_STRATEGY.latest]: 'latest', + } + + const triggerName = (strategy: AUTO_UPDATE_STRATEGY) => + new RegExp(`plugin\\.autoUpdate\\.strategy\\.${i18nKeyByStrategy[strategy]}\\.name`, 'i') + + const findOption = async (key: 'disabled' | 'fixOnly' | 'latest') => { + const options = await screen.findAllByRole('menuitemradio') + const option = options.find(item => + item.textContent?.includes(`plugin.autoUpdate.strategy.${key}.name`), + ) + if (!option) + throw new Error(`Strategy option "${key}" not found`) + return option } describe('Rendering', () => { it('should render trigger button with current strategy label', () => { - // Act - render() + render() - // Assert - expect(screen.getByRole('button', { name: /plugin\.autoUpdate\.strategy\.disabled\.name/i })).toBeInTheDocument() + expect(screen.getByRole('button', { name: triggerName(AUTO_UPDATE_STRATEGY.disabled) })).toBeInTheDocument() }) it('should not render dropdown content when closed', () => { - // Act - render() + render() - // Assert - expect(screen.queryByTestId('portal-content')).not.toBeInTheDocument() + expect(screen.queryByRole('menu')).not.toBeInTheDocument() }) - it('should render all strategy options when open', () => { - // Arrange - mockPortalOpen = true + it('should render all strategy options when open', async () => { + const user = userEvent.setup() + render() - // Act - render() - fireEvent.click(screen.getByTestId('portal-trigger')) + await user.click(screen.getByRole('button', { name: triggerName(AUTO_UPDATE_STRATEGY.disabled) })) - // Wait for portal to open - if (mockPortalOpen) { - // Assert all options visible (use getAllByText for strategy name as it appears in both trigger and dropdown) - expect(screen.getAllByText('plugin.autoUpdate.strategy.disabled.name').length).toBeGreaterThanOrEqual(1) - expect(screen.getByText('plugin.autoUpdate.strategy.fixOnly.name')).toBeInTheDocument() - expect(screen.getByText('plugin.autoUpdate.strategy.latest.name')).toBeInTheDocument() - } + const options = await screen.findAllByRole('menuitemradio') + expect(options).toHaveLength(3) + expect(options.some(o => o.textContent?.includes('plugin.autoUpdate.strategy.disabled.name'))).toBe(true) + expect(options.some(o => o.textContent?.includes('plugin.autoUpdate.strategy.fixOnly.name'))).toBe(true) + expect(options.some(o => o.textContent?.includes('plugin.autoUpdate.strategy.latest.name'))).toBe(true) }) }) describe('User Interactions', () => { - it('should toggle dropdown when trigger is clicked', () => { - // Act - render() - - // Assert - initially closed - expect(mockPortalOpen).toBe(false) - - // Act - click trigger - fireEvent.click(screen.getByTestId('portal-trigger')) - - // Assert - portal trigger element should still be in document - expect(screen.getByTestId('portal-trigger')).toBeInTheDocument() - }) - - it('should call onChange with fixOnly when Bug Fixes Only option is clicked', () => { - // Arrange - force portal content to be visible for testing option selection - forcePortalContentVisible = true - const onChange = vi.fn() - - // Act - render() - - // Find and click the "Bug Fixes Only" option - const fixOnlyOption = screen.getByText('plugin.autoUpdate.strategy.fixOnly.name').closest('div[class*="cursor-pointer"]') - expect(fixOnlyOption).toBeInTheDocument() - fireEvent.click(fixOnlyOption!) - - // Assert - expect(onChange).toHaveBeenCalledWith(AUTO_UPDATE_STRATEGY.fixOnly) - }) - - it('should call onChange with latest when Latest Version option is clicked', () => { - // Arrange - force portal content to be visible for testing option selection - forcePortalContentVisible = true - const onChange = vi.fn() - - // Act - render() - - // Find and click the "Latest Version" option - const latestOption = screen.getByText('plugin.autoUpdate.strategy.latest.name').closest('div[class*="cursor-pointer"]') - expect(latestOption).toBeInTheDocument() - fireEvent.click(latestOption!) - - // Assert - expect(onChange).toHaveBeenCalledWith(AUTO_UPDATE_STRATEGY.latest) - }) - - it('should call onChange with disabled when Disabled option is clicked', () => { - // Arrange - force portal content to be visible for testing option selection - forcePortalContentVisible = true - const onChange = vi.fn() - - // Act - render() - - // Find and click the "Disabled" option - need to find the one in the dropdown, not the button - const disabledOptions = screen.getAllByText('plugin.autoUpdate.strategy.disabled.name') - // The second one should be in the dropdown - const dropdownOption = disabledOptions.find(el => el.closest('div[class*="cursor-pointer"]')) - expect(dropdownOption).toBeInTheDocument() - fireEvent.click(dropdownOption!.closest('div[class*="cursor-pointer"]')!) - - // Assert - expect(onChange).toHaveBeenCalledWith(AUTO_UPDATE_STRATEGY.disabled) - }) - - it('should stop event propagation when option is clicked', () => { - // Arrange - force portal content to be visible - forcePortalContentVisible = true - const onChange = vi.fn() - const parentClickHandler = vi.fn() - - // Act - render( -
    - -
    , - ) - - // Click an option - const fixOnlyOption = screen.getByText('plugin.autoUpdate.strategy.fixOnly.name').closest('div[class*="cursor-pointer"]') - fireEvent.click(fixOnlyOption!) - - // Assert - onChange is called but parent click handler should not propagate - expect(onChange).toHaveBeenCalledWith(AUTO_UPDATE_STRATEGY.fixOnly) - }) - - it('should render check icon for currently selected option', () => { - // Arrange - force portal content to be visible - forcePortalContentVisible = true - - // Act - render with fixOnly selected - render() - - // Assert - RiCheckLine should be rendered (check icon) - // Find all "Bug Fixes Only" texts and get the one in the dropdown (has cursor-pointer parent) - const allFixOnlyTexts = screen.getAllByText('plugin.autoUpdate.strategy.fixOnly.name') - const dropdownOption = allFixOnlyTexts.find(el => el.closest('div[class*="cursor-pointer"]')) - const optionContainer = dropdownOption?.closest('div[class*="cursor-pointer"]') - expect(optionContainer).toBeInTheDocument() - // The check icon SVG should exist within the option - expect(optionContainer?.querySelector('svg')).toBeInTheDocument() - }) - - it('should not render check icon for non-selected options', () => { - // Arrange - force portal content to be visible - forcePortalContentVisible = true - - // Act - render with disabled selected + it('should open and close the menu when the trigger is clicked', async () => { + const user = userEvent.setup() render() - // Assert - check the Latest Version option should not have check icon - const latestOption = screen.getByText('plugin.autoUpdate.strategy.latest.name').closest('div[class*="cursor-pointer"]') - // The svg should only be in selected option, not in non-selected - const checkIconContainer = latestOption?.querySelector('div.mr-1') - // Non-selected option should have empty check icon container - expect(checkIconContainer?.querySelector('svg')).toBeNull() + const trigger = screen.getByRole('button', { name: triggerName(AUTO_UPDATE_STRATEGY.disabled) }) + expect(screen.queryByRole('menu')).not.toBeInTheDocument() + + await user.click(trigger) + expect(await screen.findByRole('menu')).toBeInTheDocument() + }) + + it.each<[AUTO_UPDATE_STRATEGY, 'disabled' | 'fixOnly' | 'latest', AUTO_UPDATE_STRATEGY]>([ + [AUTO_UPDATE_STRATEGY.disabled, 'fixOnly', AUTO_UPDATE_STRATEGY.fixOnly], + [AUTO_UPDATE_STRATEGY.disabled, 'latest', AUTO_UPDATE_STRATEGY.latest], + [AUTO_UPDATE_STRATEGY.fixOnly, 'disabled', AUTO_UPDATE_STRATEGY.disabled], + ])('should call onChange with %s -> %s when option is selected', async (initial, optionKey, expected) => { + const user = userEvent.setup() + const onChange = vi.fn() + render() + + await user.click(screen.getByRole('button', { name: triggerName(initial) })) + await user.click(await findOption(optionKey)) + + expect(onChange).toHaveBeenCalledWith(expected) + }) + + it('should mark only the currently selected option with aria-checked', async () => { + const user = userEvent.setup() + render() + + await user.click(screen.getByRole('button', { name: triggerName(AUTO_UPDATE_STRATEGY.fixOnly) })) + + const options = await screen.findAllByRole('menuitemradio') + const checked = options.filter(o => o.getAttribute('aria-checked') === 'true') + + expect(checked).toHaveLength(1) + expect(checked[0]).toHaveTextContent('plugin.autoUpdate.strategy.fixOnly.name') + }) + + it('should render the check indicator inside the selected option only', async () => { + const user = userEvent.setup() + render() + + await user.click(screen.getByRole('button', { name: triggerName(AUTO_UPDATE_STRATEGY.fixOnly) })) + + const fixOnlyOption = await findOption('fixOnly') + const latestOption = await findOption('latest') + + expect(fixOnlyOption.querySelector('.i-ri-check-line')).toBeInTheDocument() + expect(latestOption.querySelector('.i-ri-check-line')).toBeNull() }) }) }) @@ -1280,7 +1219,9 @@ describe('auto-update-setting', () => { render() // Assert - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect( + screen.getByRole('button', { name: /plugin\.autoUpdate\.strategy\.fixOnly\.name/i }), + ).toBeInTheDocument() }) it('should show time picker when strategy is not disabled', () => { @@ -1407,16 +1348,27 @@ describe('auto-update-setting', () => { }) describe('User Interactions', () => { - it('should call onChange with updated strategy when strategy changes', () => { + it('should call onChange with updated strategy when strategy changes', async () => { // Arrange + const user = userEvent.setup() const onChange = vi.fn() - const payload = createMockAutoUpdateConfig() + const payload = createMockAutoUpdateConfig({ strategy_setting: AUTO_UPDATE_STRATEGY.fixOnly }) // Act render() - // Assert - component renders with strategy picker - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + await user.click( + screen.getByRole('button', { name: /plugin\.autoUpdate\.strategy\.fixOnly\.name/i }), + ) + const latestOption = (await screen.findAllByRole('menuitemradio')).find(item => + item.textContent?.includes('plugin.autoUpdate.strategy.latest.name'), + )! + await user.click(latestOption) + + // Assert + expect(onChange).toHaveBeenCalledWith( + expect.objectContaining({ strategy_setting: AUTO_UPDATE_STRATEGY.latest }), + ) }) it('should call onChange with updated time when time changes', () => { diff --git a/web/app/components/plugins/reference-setting-modal/auto-update-setting/__tests__/strategy-picker.spec.tsx b/web/app/components/plugins/reference-setting-modal/auto-update-setting/__tests__/strategy-picker.spec.tsx index e287e48985..9fe089c34e 100644 --- a/web/app/components/plugins/reference-setting-modal/auto-update-setting/__tests__/strategy-picker.spec.tsx +++ b/web/app/components/plugins/reference-setting-modal/auto-update-setting/__tests__/strategy-picker.spec.tsx @@ -1,62 +1,12 @@ -import { fireEvent, render, screen } from '@testing-library/react' -import { beforeEach, describe, expect, it, vi } from 'vitest' +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { describe, expect, it, vi } from 'vitest' import StrategyPicker from '../strategy-picker' import { AUTO_UPDATE_STRATEGY } from '../types' -let portalOpen = false - -vi.mock('@langgenius/dify-ui/button', () => ({ - Button: ({ - children, - }: { - children: React.ReactNode - }) => {children}, -})) - -vi.mock('@/app/components/base/portal-to-follow-elem', async () => { - const _React = await import('react') - return { - PortalToFollowElem: ({ - open, - children, - }: { - open: boolean - children: React.ReactNode - }) => { - portalOpen = open - return
    {children}
    - }, - PortalToFollowElemTrigger: ({ - children, - onClick, - }: { - children: React.ReactNode - onClick: (event: { stopPropagation: () => void, nativeEvent: { stopImmediatePropagation: () => void } }) => void - }) => ( - - ), - PortalToFollowElemContent: ({ - children, - }: { - children: React.ReactNode - }) => portalOpen ?
    {children}
    : null, - } -}) +const triggerName = (key: string) => new RegExp(`plugin\\.autoUpdate\\.strategy\\.${key}\\.name`, 'i') describe('StrategyPicker', () => { - beforeEach(() => { - vi.clearAllMocks() - portalOpen = false - }) - it('renders the selected strategy label in the trigger', () => { render( { />, ) - expect(screen.getByTestId('trigger')).toHaveTextContent('plugin.autoUpdate.strategy.fixOnly.name') + expect(screen.getByRole('button', { name: triggerName('fixOnly') })).toBeInTheDocument() + expect(screen.queryByRole('menu')).not.toBeInTheDocument() }) - it('opens the option list when the trigger is clicked', () => { + it('opens the option list when the trigger is clicked', async () => { + const user = userEvent.setup() render( { />, ) - fireEvent.click(screen.getByTestId('trigger')) + await user.click(screen.getByRole('button', { name: triggerName('disabled') })) - expect(screen.getByTestId('portal-content')).toBeInTheDocument() - expect(screen.getByTestId('portal-content').querySelectorAll('svg')).toHaveLength(1) + const options = await screen.findAllByRole('menuitemradio') + expect(options).toHaveLength(3) expect(screen.getByText('plugin.autoUpdate.strategy.latest.description')).toBeInTheDocument() }) - it('calls onChange when a new strategy is selected', () => { + it('marks only the currently selected strategy as checked', async () => { + const user = userEvent.setup() + render( + , + ) + + await user.click(screen.getByRole('button', { name: triggerName('fixOnly') })) + + const checkedOptions = (await screen.findAllByRole('menuitemradio')) + .filter(item => item.getAttribute('aria-checked') === 'true') + + expect(checkedOptions).toHaveLength(1) + expect(checkedOptions[0]).toHaveTextContent('plugin.autoUpdate.strategy.fixOnly.name') + }) + + it('calls onChange and closes the menu when a new strategy is selected', async () => { + const user = userEvent.setup() const onChange = vi.fn() render( { />, ) - fireEvent.click(screen.getByTestId('trigger')) - fireEvent.click(screen.getByText('plugin.autoUpdate.strategy.latest.name')) + await user.click(screen.getByRole('button', { name: triggerName('disabled') })) + const latestOption = (await screen.findAllByRole('menuitemradio')) + .find(item => item.textContent?.includes('plugin.autoUpdate.strategy.latest.name'))! + await user.click(latestOption) expect(onChange).toHaveBeenCalledWith(AUTO_UPDATE_STRATEGY.latest) + expect(await screen.findByRole('button', { name: triggerName('disabled') })).toBeInTheDocument() }) }) diff --git a/web/app/components/plugins/reference-setting-modal/auto-update-setting/index.tsx b/web/app/components/plugins/reference-setting-modal/auto-update-setting/index.tsx index dc5d376eb3..d7d6fcd35f 100644 --- a/web/app/components/plugins/reference-setting-modal/auto-update-setting/index.tsx +++ b/web/app/components/plugins/reference-setting-modal/auto-update-setting/index.tsx @@ -105,7 +105,7 @@ const AutoUpdateSetting: FC = ({ const renderTimePickerTrigger = useCallback(({ inputElem, onClick, isOpen }: TriggerParams) => { return (
    @@ -137,7 +137,7 @@ const AutoUpdateSetting: FC = ({ <>