From 3cb944f31859b890ce1f78d5926a4dfeea9f2325 Mon Sep 17 00:00:00 2001 From: hj24 Date: Mon, 8 Dec 2025 17:54:57 +0800 Subject: [PATCH] feat: enable tenant isolation on duplicate document indexing tasks (#29080) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/services/dataset_service.py | 8 +- .../document_indexing_proxy/__init__.py | 11 + api/services/document_indexing_proxy/base.py | 111 +++ .../batch_indexing_base.py | 76 ++ .../document_indexing_task_proxy.py | 12 + .../duplicate_document_indexing_task_proxy.py | 15 + api/services/document_indexing_task_proxy.py | 83 -- .../rag_pipeline/rag_pipeline_task_proxy.py | 11 +- api/tasks/document_indexing_task.py | 10 +- api/tasks/duplicate_document_indexing_task.py | 82 ++ .../priority_rag_pipeline_run_task.py | 6 +- .../rag_pipeline/rag_pipeline_run_task.py | 6 +- .../test_duplicate_document_indexing_task.py | 763 ++++++++++++++++++ .../services/document_indexing_task_proxy.py | 70 +- .../test_document_indexing_task_proxy.py | 37 +- ..._duplicate_document_indexing_task_proxy.py | 363 +++++++++ .../tasks/test_dataset_indexing_task.py | 24 +- .../test_duplicate_document_indexing_task.py | 567 +++++++++++++ 18 files changed, 2097 insertions(+), 158 deletions(-) create mode 100644 api/services/document_indexing_proxy/__init__.py create mode 100644 api/services/document_indexing_proxy/base.py create mode 100644 api/services/document_indexing_proxy/batch_indexing_base.py create mode 100644 api/services/document_indexing_proxy/document_indexing_task_proxy.py create mode 100644 api/services/document_indexing_proxy/duplicate_document_indexing_task_proxy.py delete mode 100644 api/services/document_indexing_task_proxy.py create mode 100644 api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py create mode 100644 api/tests/unit_tests/services/test_duplicate_document_indexing_task_proxy.py create mode 100644 api/tests/unit_tests/tasks/test_duplicate_document_indexing_task.py diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 208ebcb018..bb09311349 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -51,7 +51,8 @@ from models.model import UploadFile from models.provider_ids import ModelProviderID from models.source import DataSourceOauthBinding from models.workflow import Workflow -from services.document_indexing_task_proxy import DocumentIndexingTaskProxy +from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy +from services.document_indexing_proxy.duplicate_document_indexing_task_proxy import DuplicateDocumentIndexingTaskProxy from services.entities.knowledge_entities.knowledge_entities import ( ChildChunkUpdateArgs, KnowledgeConfig, @@ -82,7 +83,6 @@ from tasks.delete_segment_from_index_task import delete_segment_from_index_task from tasks.disable_segment_from_index_task import disable_segment_from_index_task from tasks.disable_segments_from_index_task import disable_segments_from_index_task from tasks.document_indexing_update_task import document_indexing_update_task -from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task from tasks.enable_segments_to_index_task import enable_segments_to_index_task from tasks.recover_document_indexing_task import recover_document_indexing_task from tasks.remove_document_from_index_task import remove_document_from_index_task @@ -1761,7 +1761,9 @@ class DocumentService: if document_ids: DocumentIndexingTaskProxy(dataset.tenant_id, dataset.id, document_ids).delay() if duplicate_document_ids: - duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids) + DuplicateDocumentIndexingTaskProxy( + dataset.tenant_id, dataset.id, duplicate_document_ids + ).delay() except LockNotOwnedError: pass diff --git a/api/services/document_indexing_proxy/__init__.py b/api/services/document_indexing_proxy/__init__.py new file mode 100644 index 0000000000..74195adbe1 --- /dev/null +++ b/api/services/document_indexing_proxy/__init__.py @@ -0,0 +1,11 @@ +from .base import DocumentTaskProxyBase +from .batch_indexing_base import BatchDocumentIndexingProxy +from .document_indexing_task_proxy import DocumentIndexingTaskProxy +from .duplicate_document_indexing_task_proxy import DuplicateDocumentIndexingTaskProxy + +__all__ = [ + "BatchDocumentIndexingProxy", + "DocumentIndexingTaskProxy", + "DocumentTaskProxyBase", + "DuplicateDocumentIndexingTaskProxy", +] diff --git a/api/services/document_indexing_proxy/base.py b/api/services/document_indexing_proxy/base.py new file mode 100644 index 0000000000..56e47857c9 --- /dev/null +++ b/api/services/document_indexing_proxy/base.py @@ -0,0 +1,111 @@ +import logging +from abc import ABC, abstractmethod +from collections.abc import Callable +from functools import cached_property +from typing import Any, ClassVar + +from enums.cloud_plan import CloudPlan +from services.feature_service import FeatureService + +logger = logging.getLogger(__name__) + + +class DocumentTaskProxyBase(ABC): + """ + Base proxy for all document processing tasks. + + Handles common logic: + - Feature/billing checks + - Dispatch routing based on plan + + Subclasses must define: + - QUEUE_NAME: Redis queue identifier + - NORMAL_TASK_FUNC: Task function for normal priority + - PRIORITY_TASK_FUNC: Task function for high priority + """ + + QUEUE_NAME: ClassVar[str] + NORMAL_TASK_FUNC: ClassVar[Callable[..., Any]] + PRIORITY_TASK_FUNC: ClassVar[Callable[..., Any]] + + def __init__(self, tenant_id: str, dataset_id: str): + """ + Initialize with minimal required parameters. + + Args: + tenant_id: Tenant identifier for billing/features + dataset_id: Dataset identifier for logging + """ + self._tenant_id = tenant_id + self._dataset_id = dataset_id + + @cached_property + def features(self): + return FeatureService.get_features(self._tenant_id) + + @abstractmethod + def _send_to_direct_queue(self, task_func: Callable[..., Any]): + """ + Send task directly to Celery queue without tenant isolation. + + Subclasses implement this to pass task-specific parameters. + + Args: + task_func: The Celery task function to call + """ + pass + + @abstractmethod + def _send_to_tenant_queue(self, task_func: Callable[..., Any]): + """ + Send task to tenant-isolated queue. + + Subclasses implement this to handle queue management. + + Args: + task_func: The Celery task function to call + """ + pass + + def _send_to_default_tenant_queue(self): + """Route to normal priority with tenant isolation.""" + self._send_to_tenant_queue(self.NORMAL_TASK_FUNC) + + def _send_to_priority_tenant_queue(self): + """Route to priority queue with tenant isolation.""" + self._send_to_tenant_queue(self.PRIORITY_TASK_FUNC) + + def _send_to_priority_direct_queue(self): + """Route to priority queue without tenant isolation.""" + self._send_to_direct_queue(self.PRIORITY_TASK_FUNC) + + def _dispatch(self): + """ + Dispatch task based on billing plan. + + Routing logic: + - Sandbox plan → normal queue + tenant isolation + - Paid plans → priority queue + tenant isolation + - Self-hosted → priority queue, no isolation + """ + logger.info( + "dispatch args: %s - %s - %s", + self._tenant_id, + self.features.billing.enabled, + self.features.billing.subscription.plan, + ) + # dispatch to different indexing queue with tenant isolation when billing enabled + if self.features.billing.enabled: + if self.features.billing.subscription.plan == CloudPlan.SANDBOX: + # dispatch to normal pipeline queue with tenant self sub queue for sandbox plan + self._send_to_default_tenant_queue() + else: + # dispatch to priority pipeline queue with tenant self sub queue for other plans + self._send_to_priority_tenant_queue() + else: + # dispatch to priority queue without tenant isolation for others, e.g.: self-hosted or enterprise + self._send_to_priority_direct_queue() + + def delay(self): + """Public API: Queue the task asynchronously.""" + self._dispatch() diff --git a/api/services/document_indexing_proxy/batch_indexing_base.py b/api/services/document_indexing_proxy/batch_indexing_base.py new file mode 100644 index 0000000000..dd122f34a8 --- /dev/null +++ b/api/services/document_indexing_proxy/batch_indexing_base.py @@ -0,0 +1,76 @@ +import logging +from collections.abc import Callable, Sequence +from dataclasses import asdict +from typing import Any + +from core.entities.document_task import DocumentTask +from core.rag.pipeline.queue import TenantIsolatedTaskQueue + +from .base import DocumentTaskProxyBase + +logger = logging.getLogger(__name__) + + +class BatchDocumentIndexingProxy(DocumentTaskProxyBase): + """ + Base proxy for batch document indexing tasks (document_ids in plural). + + Adds: + - Tenant isolated queue management + - Batch document handling + """ + + def __init__(self, tenant_id: str, dataset_id: str, document_ids: Sequence[str]): + """ + Initialize with batch documents. + + Args: + tenant_id: Tenant identifier + dataset_id: Dataset identifier + document_ids: List of document IDs to process + """ + super().__init__(tenant_id, dataset_id) + self._document_ids = document_ids + self._tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, self.QUEUE_NAME) + + def _send_to_direct_queue(self, task_func: Callable[[str, str, Sequence[str]], Any]): + """ + Send batch task to direct queue. + + Args: + task_func: The Celery task function to call with (tenant_id, dataset_id, document_ids) + """ + logger.info("tenant %s send documents %s to direct queue", self._tenant_id, self._document_ids) + task_func.delay( # type: ignore + tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids + ) + + def _send_to_tenant_queue(self, task_func: Callable[[str, str, Sequence[str]], Any]): + """ + Send batch task to tenant-isolated queue. + + Args: + task_func: The Celery task function to call with (tenant_id, dataset_id, document_ids) + """ + logger.info( + "tenant %s send documents %s to tenant queue %s", self._tenant_id, self._document_ids, self.QUEUE_NAME + ) + if self._tenant_isolated_task_queue.get_task_key(): + # Add to waiting queue using List operations (lpush) + self._tenant_isolated_task_queue.push_tasks( + [ + asdict( + DocumentTask( + tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids + ) + ) + ] + ) + logger.info("tenant %s push tasks: %s - %s", self._tenant_id, self._dataset_id, self._document_ids) + else: + # Set flag and execute task + self._tenant_isolated_task_queue.set_task_waiting_time() + task_func.delay( # type: ignore + tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids + ) + logger.info("tenant %s init tasks: %s - %s", self._tenant_id, self._dataset_id, self._document_ids) diff --git a/api/services/document_indexing_proxy/document_indexing_task_proxy.py b/api/services/document_indexing_proxy/document_indexing_task_proxy.py new file mode 100644 index 0000000000..fce79a8387 --- /dev/null +++ b/api/services/document_indexing_proxy/document_indexing_task_proxy.py @@ -0,0 +1,12 @@ +from typing import ClassVar + +from services.document_indexing_proxy.batch_indexing_base import BatchDocumentIndexingProxy +from tasks.document_indexing_task import normal_document_indexing_task, priority_document_indexing_task + + +class DocumentIndexingTaskProxy(BatchDocumentIndexingProxy): + """Proxy for document indexing tasks.""" + + QUEUE_NAME: ClassVar[str] = "document_indexing" + NORMAL_TASK_FUNC = normal_document_indexing_task + PRIORITY_TASK_FUNC = priority_document_indexing_task diff --git a/api/services/document_indexing_proxy/duplicate_document_indexing_task_proxy.py b/api/services/document_indexing_proxy/duplicate_document_indexing_task_proxy.py new file mode 100644 index 0000000000..277cfbdcf1 --- /dev/null +++ b/api/services/document_indexing_proxy/duplicate_document_indexing_task_proxy.py @@ -0,0 +1,15 @@ +from typing import ClassVar + +from services.document_indexing_proxy.batch_indexing_base import BatchDocumentIndexingProxy +from tasks.duplicate_document_indexing_task import ( + normal_duplicate_document_indexing_task, + priority_duplicate_document_indexing_task, +) + + +class DuplicateDocumentIndexingTaskProxy(BatchDocumentIndexingProxy): + """Proxy for duplicate document indexing tasks.""" + + QUEUE_NAME: ClassVar[str] = "duplicate_document_indexing" + NORMAL_TASK_FUNC = normal_duplicate_document_indexing_task + PRIORITY_TASK_FUNC = priority_duplicate_document_indexing_task diff --git a/api/services/document_indexing_task_proxy.py b/api/services/document_indexing_task_proxy.py deleted file mode 100644 index 861c84b586..0000000000 --- a/api/services/document_indexing_task_proxy.py +++ /dev/null @@ -1,83 +0,0 @@ -import logging -from collections.abc import Callable, Sequence -from dataclasses import asdict -from functools import cached_property - -from core.entities.document_task import DocumentTask -from core.rag.pipeline.queue import TenantIsolatedTaskQueue -from enums.cloud_plan import CloudPlan -from services.feature_service import FeatureService -from tasks.document_indexing_task import normal_document_indexing_task, priority_document_indexing_task - -logger = logging.getLogger(__name__) - - -class DocumentIndexingTaskProxy: - def __init__(self, tenant_id: str, dataset_id: str, document_ids: Sequence[str]): - self._tenant_id = tenant_id - self._dataset_id = dataset_id - self._document_ids = document_ids - self._tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing") - - @cached_property - def features(self): - return FeatureService.get_features(self._tenant_id) - - def _send_to_direct_queue(self, task_func: Callable[[str, str, Sequence[str]], None]): - logger.info("send dataset %s to direct queue", self._dataset_id) - task_func.delay( # type: ignore - tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids - ) - - def _send_to_tenant_queue(self, task_func: Callable[[str, str, Sequence[str]], None]): - logger.info("send dataset %s to tenant queue", self._dataset_id) - if self._tenant_isolated_task_queue.get_task_key(): - # Add to waiting queue using List operations (lpush) - self._tenant_isolated_task_queue.push_tasks( - [ - asdict( - DocumentTask( - tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids - ) - ) - ] - ) - logger.info("push tasks: %s - %s", self._dataset_id, self._document_ids) - else: - # Set flag and execute task - self._tenant_isolated_task_queue.set_task_waiting_time() - task_func.delay( # type: ignore - tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids - ) - logger.info("init tasks: %s - %s", self._dataset_id, self._document_ids) - - def _send_to_default_tenant_queue(self): - self._send_to_tenant_queue(normal_document_indexing_task) - - def _send_to_priority_tenant_queue(self): - self._send_to_tenant_queue(priority_document_indexing_task) - - def _send_to_priority_direct_queue(self): - self._send_to_direct_queue(priority_document_indexing_task) - - def _dispatch(self): - logger.info( - "dispatch args: %s - %s - %s", - self._tenant_id, - self.features.billing.enabled, - self.features.billing.subscription.plan, - ) - # dispatch to different indexing queue with tenant isolation when billing enabled - if self.features.billing.enabled: - if self.features.billing.subscription.plan == CloudPlan.SANDBOX: - # dispatch to normal pipeline queue with tenant self sub queue for sandbox plan - self._send_to_default_tenant_queue() - else: - # dispatch to priority pipeline queue with tenant self sub queue for other plans - self._send_to_priority_tenant_queue() - else: - # dispatch to priority queue without tenant isolation for others, e.g.: self-hosted or enterprise - self._send_to_priority_direct_queue() - - def delay(self): - self._dispatch() diff --git a/api/services/rag_pipeline/rag_pipeline_task_proxy.py b/api/services/rag_pipeline/rag_pipeline_task_proxy.py index 94dd7941da..1a7b104a70 100644 --- a/api/services/rag_pipeline/rag_pipeline_task_proxy.py +++ b/api/services/rag_pipeline/rag_pipeline_task_proxy.py @@ -38,21 +38,24 @@ class RagPipelineTaskProxy: upload_file = FileService(db.engine).upload_text( json_text, self._RAG_PIPELINE_INVOKE_ENTITIES_FILE_NAME, self._user_id, self._dataset_tenant_id ) + logger.info( + "tenant %s upload %d invoke entities", self._dataset_tenant_id, len(self._rag_pipeline_invoke_entities) + ) return upload_file.id def _send_to_direct_queue(self, upload_file_id: str, task_func: Callable[[str, str], None]): - logger.info("send file %s to direct queue", upload_file_id) + logger.info("tenant %s send file %s to direct queue", self._dataset_tenant_id, upload_file_id) task_func.delay( # type: ignore rag_pipeline_invoke_entities_file_id=upload_file_id, tenant_id=self._dataset_tenant_id, ) def _send_to_tenant_queue(self, upload_file_id: str, task_func: Callable[[str, str], None]): - logger.info("send file %s to tenant queue", upload_file_id) + logger.info("tenant %s send file %s to tenant queue", self._dataset_tenant_id, upload_file_id) if self._tenant_isolated_task_queue.get_task_key(): # Add to waiting queue using List operations (lpush) self._tenant_isolated_task_queue.push_tasks([upload_file_id]) - logger.info("push tasks: %s", upload_file_id) + logger.info("tenant %s push tasks: %s", self._dataset_tenant_id, upload_file_id) else: # Set flag and execute task self._tenant_isolated_task_queue.set_task_waiting_time() @@ -60,7 +63,7 @@ class RagPipelineTaskProxy: rag_pipeline_invoke_entities_file_id=upload_file_id, tenant_id=self._dataset_tenant_id, ) - logger.info("init tasks: %s", upload_file_id) + logger.info("tenant %s init tasks: %s", self._dataset_tenant_id, upload_file_id) def _send_to_default_tenant_queue(self, upload_file_id: str): self._send_to_tenant_queue(upload_file_id, rag_pipeline_run_task) diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index fee4430612..acbdab631b 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -114,7 +114,13 @@ def _document_indexing_with_tenant_queue( try: _document_indexing(dataset_id, document_ids) except Exception: - logger.exception("Error processing document indexing %s for tenant %s: %s", dataset_id, tenant_id) + logger.exception( + "Error processing document indexing %s for tenant %s: %s", + dataset_id, + tenant_id, + document_ids, + exc_info=True, + ) finally: tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing") @@ -122,7 +128,7 @@ def _document_indexing_with_tenant_queue( # Use rpop to get the next task from the queue (FIFO order) next_tasks = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY) - logger.info("document indexing tenant isolation queue next tasks: %s", next_tasks) + logger.info("document indexing tenant isolation queue %s next tasks: %s", tenant_id, next_tasks) if next_tasks: for next_task in next_tasks: diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py index 6492e356a3..4078c8910e 100644 --- a/api/tasks/duplicate_document_indexing_task.py +++ b/api/tasks/duplicate_document_indexing_task.py @@ -1,13 +1,16 @@ import logging import time +from collections.abc import Callable, Sequence import click from celery import shared_task from sqlalchemy import select from configs import dify_config +from core.entities.document_task import DocumentTask from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan from extensions.ext_database import db from libs.datetime_utils import naive_utc_now @@ -24,8 +27,55 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): :param dataset_id: :param document_ids: + .. warning:: TO BE DEPRECATED + This function will be deprecated and removed in a future version. + Use normal_duplicate_document_indexing_task or priority_duplicate_document_indexing_task instead. + Usage: duplicate_document_indexing_task.delay(dataset_id, document_ids) """ + logger.warning("duplicate document indexing task received: %s - %s", dataset_id, document_ids) + _duplicate_document_indexing_task(dataset_id, document_ids) + + +def _duplicate_document_indexing_task_with_tenant_queue( + tenant_id: str, dataset_id: str, document_ids: Sequence[str], task_func: Callable[[str, str, Sequence[str]], None] +): + try: + _duplicate_document_indexing_task(dataset_id, document_ids) + except Exception: + logger.exception( + "Error processing duplicate document indexing %s for tenant %s: %s", + dataset_id, + tenant_id, + document_ids, + exc_info=True, + ) + finally: + tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "duplicate_document_indexing") + + # Check if there are waiting tasks in the queue + # Use rpop to get the next task from the queue (FIFO order) + next_tasks = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY) + + logger.info("duplicate document indexing tenant isolation queue %s next tasks: %s", tenant_id, next_tasks) + + if next_tasks: + for next_task in next_tasks: + document_task = DocumentTask(**next_task) + # Process the next waiting task + # Keep the flag set to indicate a task is running + tenant_isolated_task_queue.set_task_waiting_time() + task_func.delay( # type: ignore + tenant_id=document_task.tenant_id, + dataset_id=document_task.dataset_id, + document_ids=document_task.document_ids, + ) + else: + # No more waiting tasks, clear the flag + tenant_isolated_task_queue.delete_task_key() + + +def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[str]): documents = [] start_at = time.perf_counter() @@ -110,3 +160,35 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): logger.exception("duplicate_document_indexing_task failed, dataset_id: %s", dataset_id) finally: db.session.close() + + +@shared_task(queue="dataset") +def normal_duplicate_document_indexing_task(tenant_id: str, dataset_id: str, document_ids: Sequence[str]): + """ + Async process duplicate documents + :param tenant_id: + :param dataset_id: + :param document_ids: + + Usage: normal_duplicate_document_indexing_task.delay(tenant_id, dataset_id, document_ids) + """ + logger.info("normal duplicate document indexing task received: %s - %s - %s", tenant_id, dataset_id, document_ids) + _duplicate_document_indexing_task_with_tenant_queue( + tenant_id, dataset_id, document_ids, normal_duplicate_document_indexing_task + ) + + +@shared_task(queue="priority_dataset") +def priority_duplicate_document_indexing_task(tenant_id: str, dataset_id: str, document_ids: Sequence[str]): + """ + Async process duplicate documents + :param tenant_id: + :param dataset_id: + :param document_ids: + + Usage: priority_duplicate_document_indexing_task.delay(tenant_id, dataset_id, document_ids) + """ + logger.info("priority duplicate document indexing task received: %s - %s - %s", tenant_id, dataset_id, document_ids) + _duplicate_document_indexing_task_with_tenant_queue( + tenant_id, dataset_id, document_ids, priority_duplicate_document_indexing_task + ) diff --git a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py index a7f61d9811..1eef361a92 100644 --- a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py @@ -47,6 +47,8 @@ def priority_rag_pipeline_run_task( ) rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content) + logger.info("tenant %s received %d rag pipeline invoke entities", tenant_id, len(rag_pipeline_invoke_entities)) + # Get Flask app object for thread context flask_app = current_app._get_current_object() # type: ignore @@ -66,7 +68,7 @@ def priority_rag_pipeline_run_task( end_at = time.perf_counter() logging.info( click.style( - f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green" + f"tenant_id: {tenant_id}, Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green" ) ) except Exception: @@ -78,7 +80,7 @@ def priority_rag_pipeline_run_task( # Check if there are waiting tasks in the queue # Use rpop to get the next task from the queue (FIFO order) next_file_ids = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY) - logger.info("priority rag pipeline tenant isolation queue next files: %s", next_file_ids) + logger.info("priority rag pipeline tenant isolation queue %s next files: %s", tenant_id, next_file_ids) if next_file_ids: for next_file_id in next_file_ids: diff --git a/api/tasks/rag_pipeline/rag_pipeline_run_task.py b/api/tasks/rag_pipeline/rag_pipeline_run_task.py index 92f1dfb73d..275f5abe6e 100644 --- a/api/tasks/rag_pipeline/rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/rag_pipeline_run_task.py @@ -47,6 +47,8 @@ def rag_pipeline_run_task( ) rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content) + logger.info("tenant %s received %d rag pipeline invoke entities", tenant_id, len(rag_pipeline_invoke_entities)) + # Get Flask app object for thread context flask_app = current_app._get_current_object() # type: ignore @@ -66,7 +68,7 @@ def rag_pipeline_run_task( end_at = time.perf_counter() logging.info( click.style( - f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green" + f"tenant_id: {tenant_id}, Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green" ) ) except Exception: @@ -78,7 +80,7 @@ def rag_pipeline_run_task( # Check if there are waiting tasks in the queue # Use rpop to get the next task from the queue (FIFO order) next_file_ids = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY) - logger.info("rag pipeline tenant isolation queue next files: %s", next_file_ids) + logger.info("rag pipeline tenant isolation queue %s next files: %s", tenant_id, next_file_ids) if next_file_ids: for next_file_id in next_file_ids: diff --git a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py new file mode 100644 index 0000000000..aca4be1ffd --- /dev/null +++ b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py @@ -0,0 +1,763 @@ +from unittest.mock import MagicMock, patch + +import pytest +from faker import Faker + +from enums.cloud_plan import CloudPlan +from extensions.ext_database import db +from models import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.dataset import Dataset, Document, DocumentSegment +from tasks.duplicate_document_indexing_task import ( + _duplicate_document_indexing_task, # Core function + _duplicate_document_indexing_task_with_tenant_queue, # Tenant queue wrapper function + duplicate_document_indexing_task, # Deprecated old interface + normal_duplicate_document_indexing_task, # New normal task + priority_duplicate_document_indexing_task, # New priority task +) + + +class TestDuplicateDocumentIndexingTasks: + """Integration tests for duplicate document indexing tasks using testcontainers. + + This test class covers: + - Core _duplicate_document_indexing_task function + - Deprecated duplicate_document_indexing_task function + - New normal_duplicate_document_indexing_task function + - New priority_duplicate_document_indexing_task function + - Tenant queue wrapper _duplicate_document_indexing_task_with_tenant_queue function + - Document segment cleanup logic + """ + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("tasks.duplicate_document_indexing_task.IndexingRunner") as mock_indexing_runner, + patch("tasks.duplicate_document_indexing_task.FeatureService") as mock_feature_service, + patch("tasks.duplicate_document_indexing_task.IndexProcessorFactory") as mock_index_processor_factory, + ): + # Setup mock indexing runner + mock_runner_instance = MagicMock() + mock_indexing_runner.return_value = mock_runner_instance + + # Setup mock feature service + mock_features = MagicMock() + mock_features.billing.enabled = False + mock_feature_service.get_features.return_value = mock_features + + # Setup mock index processor factory + mock_processor = MagicMock() + mock_processor.clean = MagicMock() + mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor + + yield { + "indexing_runner": mock_indexing_runner, + "indexing_runner_instance": mock_runner_instance, + "feature_service": mock_feature_service, + "features": mock_features, + "index_processor_factory": mock_index_processor_factory, + "index_processor": mock_processor, + } + + def _create_test_dataset_and_documents( + self, db_session_with_containers, mock_external_service_dependencies, document_count=3 + ): + """ + Helper method to create a test dataset and documents for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + document_count: Number of documents to create + + Returns: + tuple: (dataset, documents) - Created dataset and document instances + """ + fake = Faker() + + # Create account and tenant + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + db.session.add(account) + db.session.commit() + + tenant = Tenant( + name=fake.company(), + status="normal", + ) + db.session.add(tenant) + db.session.commit() + + # Create tenant-account join + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + db.session.add(join) + db.session.commit() + + # Create dataset + dataset = Dataset( + id=fake.uuid4(), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="upload_file", + indexing_technique="high_quality", + created_by=account.id, + ) + db.session.add(dataset) + db.session.commit() + + # Create documents + documents = [] + for i in range(document_count): + document = Document( + id=fake.uuid4(), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=i, + data_source_type="upload_file", + batch="test_batch", + name=fake.file_name(), + created_from="upload_file", + created_by=account.id, + indexing_status="waiting", + enabled=True, + doc_form="text_model", + ) + db.session.add(document) + documents.append(document) + + db.session.commit() + + # Refresh dataset to ensure it's properly loaded + db.session.refresh(dataset) + + return dataset, documents + + def _create_test_dataset_with_segments( + self, db_session_with_containers, mock_external_service_dependencies, document_count=3, segments_per_doc=2 + ): + """ + Helper method to create a test dataset with documents and segments. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + document_count: Number of documents to create + segments_per_doc: Number of segments per document + + Returns: + tuple: (dataset, documents, segments) - Created dataset, documents and segments + """ + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count + ) + + fake = Faker() + segments = [] + + # Create segments for each document + for document in documents: + for i in range(segments_per_doc): + segment = DocumentSegment( + id=fake.uuid4(), + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + document_id=document.id, + position=i, + index_node_id=f"{document.id}-node-{i}", + index_node_hash=fake.sha256(), + content=fake.text(max_nb_chars=200), + word_count=50, + tokens=100, + status="completed", + enabled=True, + indexing_at=fake.date_time_this_year(), + created_by=dataset.created_by, # Add required field + ) + db.session.add(segment) + segments.append(segment) + + db.session.commit() + + # Refresh to ensure all relationships are loaded + for document in documents: + db.session.refresh(document) + + return dataset, documents, segments + + def _create_test_dataset_with_billing_features( + self, db_session_with_containers, mock_external_service_dependencies, billing_enabled=True + ): + """ + Helper method to create a test dataset with billing features configured. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + billing_enabled: Whether billing is enabled + + Returns: + tuple: (dataset, documents) - Created dataset and document instances + """ + fake = Faker() + + # Create account and tenant + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + db.session.add(account) + db.session.commit() + + tenant = Tenant( + name=fake.company(), + status="normal", + ) + db.session.add(tenant) + db.session.commit() + + # Create tenant-account join + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + db.session.add(join) + db.session.commit() + + # Create dataset + dataset = Dataset( + id=fake.uuid4(), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="upload_file", + indexing_technique="high_quality", + created_by=account.id, + ) + db.session.add(dataset) + db.session.commit() + + # Create documents + documents = [] + for i in range(3): + document = Document( + id=fake.uuid4(), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=i, + data_source_type="upload_file", + batch="test_batch", + name=fake.file_name(), + created_from="upload_file", + created_by=account.id, + indexing_status="waiting", + enabled=True, + doc_form="text_model", + ) + db.session.add(document) + documents.append(document) + + db.session.commit() + + # Configure billing features + mock_external_service_dependencies["features"].billing.enabled = billing_enabled + if billing_enabled: + mock_external_service_dependencies["features"].billing.subscription.plan = CloudPlan.SANDBOX + mock_external_service_dependencies["features"].vector_space.limit = 100 + mock_external_service_dependencies["features"].vector_space.size = 50 + + # Refresh dataset to ensure it's properly loaded + db.session.refresh(dataset) + + return dataset, documents + + def test_duplicate_document_indexing_task_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful duplicate document indexing with multiple documents. + + This test verifies: + - Proper dataset retrieval from database + - Correct document processing and status updates + - IndexingRunner integration + - Database state updates + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=3 + ) + document_ids = [doc.id for doc in documents] + + # Act: Execute the task + _duplicate_document_indexing_task(dataset.id, document_ids) + + # Assert: Verify the expected outcomes + # Verify indexing runner was called correctly + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Verify documents were updated to parsing status + # Re-query documents from database since _duplicate_document_indexing_task uses a different session + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + assert updated_document.processing_started_at is not None + + # Verify the run method was called with correct documents + call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args + assert call_args is not None + processed_documents = call_args[0][0] # First argument should be documents list + assert len(processed_documents) == 3 + + def test_duplicate_document_indexing_task_with_segment_cleanup( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test duplicate document indexing with existing segments that need cleanup. + + This test verifies: + - Old segments are identified and cleaned + - Index processor clean method is called + - Segments are deleted from database + - New indexing proceeds after cleanup + """ + # Arrange: Create test data with existing segments + dataset, documents, segments = self._create_test_dataset_with_segments( + db_session_with_containers, mock_external_service_dependencies, document_count=2, segments_per_doc=3 + ) + document_ids = [doc.id for doc in documents] + + # Act: Execute the task + _duplicate_document_indexing_task(dataset.id, document_ids) + + # Assert: Verify segment cleanup + # Verify index processor clean was called for each document with segments + assert mock_external_service_dependencies["index_processor"].clean.call_count == len(documents) + + # Verify segments were deleted from database + # Re-query segments from database since _duplicate_document_indexing_task uses a different session + for segment in segments: + deleted_segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment.id).first() + assert deleted_segment is None + + # Verify documents were updated to parsing status + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + assert updated_document.processing_started_at is not None + + # Verify indexing runner was called + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + def test_duplicate_document_indexing_task_dataset_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of non-existent dataset. + + This test verifies: + - Proper error handling for missing datasets + - Early return without processing + - Database session cleanup + - No unnecessary indexing runner calls + """ + # Arrange: Use non-existent dataset ID + fake = Faker() + non_existent_dataset_id = fake.uuid4() + document_ids = [fake.uuid4() for _ in range(3)] + + # Act: Execute the task with non-existent dataset + _duplicate_document_indexing_task(non_existent_dataset_id, document_ids) + + # Assert: Verify no processing occurred + mock_external_service_dependencies["indexing_runner"].assert_not_called() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called() + mock_external_service_dependencies["index_processor"].clean.assert_not_called() + + def test_duplicate_document_indexing_task_document_not_found_in_dataset( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling when some documents don't exist in the dataset. + + This test verifies: + - Only existing documents are processed + - Non-existent documents are ignored + - Indexing runner receives only valid documents + - Database state updates correctly + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=2 + ) + + # Mix existing and non-existent document IDs + fake = Faker() + existing_document_ids = [doc.id for doc in documents] + non_existent_document_ids = [fake.uuid4() for _ in range(2)] + all_document_ids = existing_document_ids + non_existent_document_ids + + # Act: Execute the task with mixed document IDs + _duplicate_document_indexing_task(dataset.id, all_document_ids) + + # Assert: Verify only existing documents were processed + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Verify only existing documents were updated + # Re-query documents from database since _duplicate_document_indexing_task uses a different session + for doc_id in existing_document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + assert updated_document.processing_started_at is not None + + # Verify the run method was called with only existing documents + call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args + assert call_args is not None + processed_documents = call_args[0][0] # First argument should be documents list + assert len(processed_documents) == 2 # Only existing documents + + def test_duplicate_document_indexing_task_indexing_runner_exception( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of IndexingRunner exceptions. + + This test verifies: + - Exceptions from IndexingRunner are properly caught + - Task completes without raising exceptions + - Database session is properly closed + - Error logging occurs + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=2 + ) + document_ids = [doc.id for doc in documents] + + # Mock IndexingRunner to raise an exception + mock_external_service_dependencies["indexing_runner_instance"].run.side_effect = Exception( + "Indexing runner failed" + ) + + # Act: Execute the task + _duplicate_document_indexing_task(dataset.id, document_ids) + + # Assert: Verify exception was handled gracefully + # The task should complete without raising exceptions + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Verify documents were still updated to parsing status before the exception + # Re-query documents from database since _duplicate_document_indexing_task close the session + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + assert updated_document.processing_started_at is not None + + def test_duplicate_document_indexing_task_billing_sandbox_plan_batch_limit( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test billing validation for sandbox plan batch upload limit. + + This test verifies: + - Sandbox plan batch upload limit enforcement + - Error handling for batch upload limit exceeded + - Document status updates to error state + - Proper error message recording + """ + # Arrange: Create test data with billing enabled + dataset, documents = self._create_test_dataset_with_billing_features( + db_session_with_containers, mock_external_service_dependencies, billing_enabled=True + ) + + # Configure sandbox plan with batch limit + mock_external_service_dependencies["features"].billing.subscription.plan = CloudPlan.SANDBOX + + # Create more documents than sandbox plan allows (limit is 1) + fake = Faker() + extra_documents = [] + for i in range(2): # Total will be 5 documents (3 existing + 2 new) + document = Document( + id=fake.uuid4(), + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + position=i + 3, + data_source_type="upload_file", + batch="test_batch", + name=fake.file_name(), + created_from="upload_file", + created_by=dataset.created_by, + indexing_status="waiting", + enabled=True, + doc_form="text_model", + ) + db.session.add(document) + extra_documents.append(document) + + db.session.commit() + all_documents = documents + extra_documents + document_ids = [doc.id for doc in all_documents] + + # Act: Execute the task with too many documents for sandbox plan + _duplicate_document_indexing_task(dataset.id, document_ids) + + # Assert: Verify error handling + # Re-query documents from database since _duplicate_document_indexing_task uses a different session + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "error" + assert updated_document.error is not None + assert "batch upload" in updated_document.error.lower() + assert updated_document.stopped_at is not None + + # Verify indexing runner was not called due to early validation error + mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called() + + def test_duplicate_document_indexing_task_billing_vector_space_limit_exceeded( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test billing validation for vector space limit. + + This test verifies: + - Vector space limit enforcement + - Error handling for vector space limit exceeded + - Document status updates to error state + - Proper error message recording + """ + # Arrange: Create test data with billing enabled + dataset, documents = self._create_test_dataset_with_billing_features( + db_session_with_containers, mock_external_service_dependencies, billing_enabled=True + ) + + # Configure TEAM plan with vector space limit exceeded + mock_external_service_dependencies["features"].billing.subscription.plan = CloudPlan.TEAM + mock_external_service_dependencies["features"].vector_space.limit = 100 + mock_external_service_dependencies["features"].vector_space.size = 98 # Almost at limit + + document_ids = [doc.id for doc in documents] # 3 documents will exceed limit + + # Act: Execute the task with documents that will exceed vector space limit + _duplicate_document_indexing_task(dataset.id, document_ids) + + # Assert: Verify error handling + # Re-query documents from database since _duplicate_document_indexing_task uses a different session + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "error" + assert updated_document.error is not None + assert "limit" in updated_document.error.lower() + assert updated_document.stopped_at is not None + + # Verify indexing runner was not called due to early validation error + mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called() + + def test_duplicate_document_indexing_task_with_empty_document_list( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of empty document list. + + This test verifies: + - Empty document list is handled gracefully + - No processing occurs + - No errors are raised + - Database session is properly closed + """ + # Arrange: Create test dataset + dataset, _ = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=0 + ) + document_ids = [] + + # Act: Execute the task with empty document list + _duplicate_document_indexing_task(dataset.id, document_ids) + + # Assert: Verify IndexingRunner was called with empty list + # Note: The actual implementation does call run([]) with empty list + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once_with([]) + + def test_deprecated_duplicate_document_indexing_task_delegates_to_core( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test that deprecated duplicate_document_indexing_task delegates to core function. + + This test verifies: + - Deprecated function calls core _duplicate_document_indexing_task + - Proper parameter passing + - Backward compatibility + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=2 + ) + document_ids = [doc.id for doc in documents] + + # Act: Execute the deprecated task + duplicate_document_indexing_task(dataset.id, document_ids) + + # Assert: Verify core function was executed + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Clear session cache to see database updates from task's session + db.session.expire_all() + + # Verify documents were processed + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + + @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue") + def test_normal_duplicate_document_indexing_task_with_tenant_queue( + self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test normal_duplicate_document_indexing_task with tenant isolation queue. + + This test verifies: + - Task uses tenant isolation queue correctly + - Core processing function is called + - Queue management (pull tasks, delete key) works properly + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=2 + ) + document_ids = [doc.id for doc in documents] + + # Mock tenant isolated queue to return no next tasks + mock_queue = MagicMock() + mock_queue.pull_tasks.return_value = [] + mock_queue_class.return_value = mock_queue + + # Act: Execute the normal task + normal_duplicate_document_indexing_task(dataset.tenant_id, dataset.id, document_ids) + + # Assert: Verify processing occurred + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Verify tenant queue was used + mock_queue_class.assert_called_with(dataset.tenant_id, "duplicate_document_indexing") + mock_queue.pull_tasks.assert_called_once() + mock_queue.delete_task_key.assert_called_once() + + # Clear session cache to see database updates from task's session + db.session.expire_all() + + # Verify documents were processed + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + + @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue") + def test_priority_duplicate_document_indexing_task_with_tenant_queue( + self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test priority_duplicate_document_indexing_task with tenant isolation queue. + + This test verifies: + - Task uses tenant isolation queue correctly + - Core processing function is called + - Queue management works properly + - Same behavior as normal task with different queue assignment + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=2 + ) + document_ids = [doc.id for doc in documents] + + # Mock tenant isolated queue to return no next tasks + mock_queue = MagicMock() + mock_queue.pull_tasks.return_value = [] + mock_queue_class.return_value = mock_queue + + # Act: Execute the priority task + priority_duplicate_document_indexing_task(dataset.tenant_id, dataset.id, document_ids) + + # Assert: Verify processing occurred + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Verify tenant queue was used + mock_queue_class.assert_called_with(dataset.tenant_id, "duplicate_document_indexing") + mock_queue.pull_tasks.assert_called_once() + mock_queue.delete_task_key.assert_called_once() + + # Clear session cache to see database updates from task's session + db.session.expire_all() + + # Verify documents were processed + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + + @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue") + def test_tenant_queue_wrapper_processes_next_tasks( + self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test tenant queue wrapper processes next queued tasks. + + This test verifies: + - After completing current task, next tasks are pulled from queue + - Next tasks are executed correctly + - Task waiting time is set for next tasks + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=2 + ) + document_ids = [doc.id for doc in documents] + + # Extract values before session detachment + tenant_id = dataset.tenant_id + dataset_id = dataset.id + + # Mock tenant isolated queue to return next task + mock_queue = MagicMock() + next_task = { + "tenant_id": tenant_id, + "dataset_id": dataset_id, + "document_ids": document_ids, + } + mock_queue.pull_tasks.return_value = [next_task] + mock_queue_class.return_value = mock_queue + + # Mock the task function to track calls + mock_task_func = MagicMock() + + # Act: Execute the wrapper function + _duplicate_document_indexing_task_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task_func) + + # Assert: Verify next task was scheduled + mock_queue.pull_tasks.assert_called_once() + mock_queue.set_task_waiting_time.assert_called_once() + mock_task_func.delay.assert_called_once_with( + tenant_id=tenant_id, + dataset_id=dataset_id, + document_ids=document_ids, + ) + mock_queue.delete_task_key.assert_not_called() diff --git a/api/tests/unit_tests/services/document_indexing_task_proxy.py b/api/tests/unit_tests/services/document_indexing_task_proxy.py index 765c4b5e32..ff243b8dc3 100644 --- a/api/tests/unit_tests/services/document_indexing_task_proxy.py +++ b/api/tests/unit_tests/services/document_indexing_task_proxy.py @@ -117,7 +117,7 @@ import pytest from core.entities.document_task import DocumentTask from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan -from services.document_indexing_task_proxy import DocumentIndexingTaskProxy +from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy # ============================================================================ # Test Data Factory @@ -370,7 +370,7 @@ class TestDocumentIndexingTaskProxy: # Features Property Tests # ======================================================================== - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_features_property(self, mock_feature_service): """ Test cached_property features. @@ -400,7 +400,7 @@ class TestDocumentIndexingTaskProxy: mock_feature_service.get_features.assert_called_once_with("tenant-123") - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_features_property_with_different_tenants(self, mock_feature_service): """ Test features property with different tenant IDs. @@ -438,7 +438,7 @@ class TestDocumentIndexingTaskProxy: # Direct Queue Routing Tests # ======================================================================== - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_direct_queue(self, mock_task): """ Test _send_to_direct_queue method. @@ -460,7 +460,7 @@ class TestDocumentIndexingTaskProxy: # Assert mock_task.delay.assert_called_once_with(tenant_id=tenant_id, dataset_id=dataset_id, document_ids=document_ids) - @patch("services.document_indexing_task_proxy.priority_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task") def test_send_to_direct_queue_with_priority_task(self, mock_task): """ Test _send_to_direct_queue with priority task function. @@ -481,7 +481,7 @@ class TestDocumentIndexingTaskProxy: tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] ) - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_direct_queue_with_single_document(self, mock_task): """ Test _send_to_direct_queue with single document ID. @@ -502,7 +502,7 @@ class TestDocumentIndexingTaskProxy: tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1"] ) - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_direct_queue_with_empty_documents(self, mock_task): """ Test _send_to_direct_queue with empty document_ids list. @@ -525,7 +525,7 @@ class TestDocumentIndexingTaskProxy: # Tenant Queue Routing Tests # ======================================================================== - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_tenant_queue_with_existing_task_key(self, mock_task): """ Test _send_to_tenant_queue when task key exists. @@ -564,7 +564,7 @@ class TestDocumentIndexingTaskProxy: mock_task.delay.assert_not_called() - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_tenant_queue_without_task_key(self, mock_task): """ Test _send_to_tenant_queue when no task key exists. @@ -594,7 +594,7 @@ class TestDocumentIndexingTaskProxy: proxy._tenant_isolated_task_queue.push_tasks.assert_not_called() - @patch("services.document_indexing_task_proxy.priority_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task") def test_send_to_tenant_queue_with_priority_task(self, mock_task): """ Test _send_to_tenant_queue with priority task function. @@ -621,7 +621,7 @@ class TestDocumentIndexingTaskProxy: tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] ) - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_tenant_queue_document_task_serialization(self, mock_task): """ Test DocumentTask serialization in _send_to_tenant_queue. @@ -659,7 +659,7 @@ class TestDocumentIndexingTaskProxy: # Queue Type Selection Tests # ======================================================================== - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_default_tenant_queue(self, mock_task): """ Test _send_to_default_tenant_queue method. @@ -678,7 +678,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_tenant_queue.assert_called_once_with(mock_task) - @patch("services.document_indexing_task_proxy.priority_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task") def test_send_to_priority_tenant_queue(self, mock_task): """ Test _send_to_priority_tenant_queue method. @@ -697,7 +697,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_tenant_queue.assert_called_once_with(mock_task) - @patch("services.document_indexing_task_proxy.priority_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task") def test_send_to_priority_direct_queue(self, mock_task): """ Test _send_to_priority_direct_queue method. @@ -720,7 +720,7 @@ class TestDocumentIndexingTaskProxy: # Dispatch Logic Tests # ======================================================================== - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service): """ Test _dispatch method when billing is enabled with SANDBOX plan. @@ -745,7 +745,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_default_tenant_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_dispatch_with_billing_enabled_team_plan(self, mock_feature_service): """ Test _dispatch method when billing is enabled with TEAM plan. @@ -770,7 +770,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_priority_tenant_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_dispatch_with_billing_enabled_professional_plan(self, mock_feature_service): """ Test _dispatch method when billing is enabled with PROFESSIONAL plan. @@ -795,7 +795,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_priority_tenant_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_dispatch_with_billing_disabled(self, mock_feature_service): """ Test _dispatch method when billing is disabled. @@ -818,7 +818,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_priority_direct_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_dispatch_edge_case_empty_plan(self, mock_feature_service): """ Test _dispatch method with empty plan string. @@ -842,7 +842,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_priority_tenant_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_dispatch_edge_case_none_plan(self, mock_feature_service): """ Test _dispatch method with None plan. @@ -870,7 +870,7 @@ class TestDocumentIndexingTaskProxy: # Delay Method Tests # ======================================================================== - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_delay_method(self, mock_feature_service): """ Test delay method integration. @@ -895,7 +895,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_default_tenant_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_delay_method_with_team_plan(self, mock_feature_service): """ Test delay method with TEAM plan. @@ -920,7 +920,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_priority_tenant_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_delay_method_with_billing_disabled(self, mock_feature_service): """ Test delay method with billing disabled. @@ -1021,7 +1021,7 @@ class TestDocumentIndexingTaskProxy: # Batch Operations Tests # ======================================================================== - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_batch_operation_with_multiple_documents(self, mock_task): """ Test batch operation with multiple documents. @@ -1044,7 +1044,7 @@ class TestDocumentIndexingTaskProxy: tenant_id="tenant-123", dataset_id="dataset-456", document_ids=document_ids ) - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_batch_operation_with_large_batch(self, mock_task): """ Test batch operation with large batch of documents. @@ -1073,7 +1073,7 @@ class TestDocumentIndexingTaskProxy: # Error Handling Tests # ======================================================================== - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_direct_queue_task_delay_failure(self, mock_task): """ Test _send_to_direct_queue when task.delay() raises an exception. @@ -1090,7 +1090,7 @@ class TestDocumentIndexingTaskProxy: with pytest.raises(Exception, match="Task delay failed"): proxy._send_to_direct_queue(mock_task) - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_tenant_queue_push_tasks_failure(self, mock_task): """ Test _send_to_tenant_queue when push_tasks raises an exception. @@ -1111,7 +1111,7 @@ class TestDocumentIndexingTaskProxy: with pytest.raises(Exception, match="Push tasks failed"): proxy._send_to_tenant_queue(mock_task) - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_tenant_queue_set_waiting_time_failure(self, mock_task): """ Test _send_to_tenant_queue when set_task_waiting_time raises an exception. @@ -1132,7 +1132,7 @@ class TestDocumentIndexingTaskProxy: with pytest.raises(Exception, match="Set waiting time failed"): proxy._send_to_tenant_queue(mock_task) - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_dispatch_feature_service_failure(self, mock_feature_service): """ Test _dispatch when FeatureService.get_features raises an exception. @@ -1153,8 +1153,8 @@ class TestDocumentIndexingTaskProxy: # Integration Tests # ======================================================================== - @patch("services.document_indexing_task_proxy.FeatureService") - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_full_flow_sandbox_plan(self, mock_task, mock_feature_service): """ Test full flow for SANDBOX plan with tenant queue. @@ -1187,8 +1187,8 @@ class TestDocumentIndexingTaskProxy: tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] ) - @patch("services.document_indexing_task_proxy.FeatureService") - @patch("services.document_indexing_task_proxy.priority_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task") def test_full_flow_team_plan(self, mock_task, mock_feature_service): """ Test full flow for TEAM plan with priority tenant queue. @@ -1221,8 +1221,8 @@ class TestDocumentIndexingTaskProxy: tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] ) - @patch("services.document_indexing_task_proxy.FeatureService") - @patch("services.document_indexing_task_proxy.priority_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task") def test_full_flow_billing_disabled(self, mock_task, mock_feature_service): """ Test full flow for billing disabled (self-hosted/enterprise). diff --git a/api/tests/unit_tests/services/test_document_indexing_task_proxy.py b/api/tests/unit_tests/services/test_document_indexing_task_proxy.py index d9183be9fb..98c30c3722 100644 --- a/api/tests/unit_tests/services/test_document_indexing_task_proxy.py +++ b/api/tests/unit_tests/services/test_document_indexing_task_proxy.py @@ -3,7 +3,7 @@ from unittest.mock import Mock, patch from core.entities.document_task import DocumentTask from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan -from services.document_indexing_task_proxy import DocumentIndexingTaskProxy +from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy class DocumentIndexingTaskProxyTestDataFactory: @@ -59,7 +59,7 @@ class TestDocumentIndexingTaskProxy: assert proxy._tenant_isolated_task_queue._tenant_id == tenant_id assert proxy._tenant_isolated_task_queue._unique_key == "document_indexing" - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.base.FeatureService") def test_features_property(self, mock_feature_service): """Test cached_property features.""" # Arrange @@ -77,7 +77,7 @@ class TestDocumentIndexingTaskProxy: assert features1 is features2 # Should be the same instance due to caching mock_feature_service.get_features.assert_called_once_with("tenant-123") - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_direct_queue(self, mock_task): """Test _send_to_direct_queue method.""" # Arrange @@ -92,7 +92,7 @@ class TestDocumentIndexingTaskProxy: tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] ) - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_tenant_queue_with_existing_task_key(self, mock_task): """Test _send_to_tenant_queue when task key exists.""" # Arrange @@ -115,7 +115,7 @@ class TestDocumentIndexingTaskProxy: assert pushed_tasks[0]["document_ids"] == ["doc-1", "doc-2", "doc-3"] mock_task.delay.assert_not_called() - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_tenant_queue_without_task_key(self, mock_task): """Test _send_to_tenant_queue when no task key exists.""" # Arrange @@ -135,8 +135,7 @@ class TestDocumentIndexingTaskProxy: ) proxy._tenant_isolated_task_queue.push_tasks.assert_not_called() - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") - def test_send_to_default_tenant_queue(self, mock_task): + def test_send_to_default_tenant_queue(self): """Test _send_to_default_tenant_queue method.""" # Arrange proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() @@ -146,10 +145,9 @@ class TestDocumentIndexingTaskProxy: proxy._send_to_default_tenant_queue() # Assert - proxy._send_to_tenant_queue.assert_called_once_with(mock_task) + proxy._send_to_tenant_queue.assert_called_once_with(proxy.NORMAL_TASK_FUNC) - @patch("services.document_indexing_task_proxy.priority_document_indexing_task") - def test_send_to_priority_tenant_queue(self, mock_task): + def test_send_to_priority_tenant_queue(self): """Test _send_to_priority_tenant_queue method.""" # Arrange proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() @@ -159,10 +157,9 @@ class TestDocumentIndexingTaskProxy: proxy._send_to_priority_tenant_queue() # Assert - proxy._send_to_tenant_queue.assert_called_once_with(mock_task) + proxy._send_to_tenant_queue.assert_called_once_with(proxy.PRIORITY_TASK_FUNC) - @patch("services.document_indexing_task_proxy.priority_document_indexing_task") - def test_send_to_priority_direct_queue(self, mock_task): + def test_send_to_priority_direct_queue(self): """Test _send_to_priority_direct_queue method.""" # Arrange proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() @@ -172,9 +169,9 @@ class TestDocumentIndexingTaskProxy: proxy._send_to_priority_direct_queue() # Assert - proxy._send_to_direct_queue.assert_called_once_with(mock_task) + proxy._send_to_direct_queue.assert_called_once_with(proxy.PRIORITY_TASK_FUNC) - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.base.FeatureService") def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service): """Test _dispatch method when billing is enabled with sandbox plan.""" # Arrange @@ -191,7 +188,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_default_tenant_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.base.FeatureService") def test_dispatch_with_billing_enabled_non_sandbox_plan(self, mock_feature_service): """Test _dispatch method when billing is enabled with non-sandbox plan.""" # Arrange @@ -208,7 +205,7 @@ class TestDocumentIndexingTaskProxy: # If billing enabled with non sandbox plan, should send to priority tenant queue proxy._send_to_priority_tenant_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.base.FeatureService") def test_dispatch_with_billing_disabled(self, mock_feature_service): """Test _dispatch method when billing is disabled.""" # Arrange @@ -223,7 +220,7 @@ class TestDocumentIndexingTaskProxy: # If billing disabled, for example: self-hosted or enterprise, should send to priority direct queue proxy._send_to_priority_direct_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.base.FeatureService") def test_delay_method(self, mock_feature_service): """Test delay method integration.""" # Arrange @@ -256,7 +253,7 @@ class TestDocumentIndexingTaskProxy: assert task.dataset_id == dataset_id assert task.document_ids == document_ids - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.base.FeatureService") def test_dispatch_edge_case_empty_plan(self, mock_feature_service): """Test _dispatch method with empty plan string.""" # Arrange @@ -271,7 +268,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_priority_tenant_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.base.FeatureService") def test_dispatch_edge_case_none_plan(self, mock_feature_service): """Test _dispatch method with None plan.""" # Arrange diff --git a/api/tests/unit_tests/services/test_duplicate_document_indexing_task_proxy.py b/api/tests/unit_tests/services/test_duplicate_document_indexing_task_proxy.py new file mode 100644 index 0000000000..68bafe3d5e --- /dev/null +++ b/api/tests/unit_tests/services/test_duplicate_document_indexing_task_proxy.py @@ -0,0 +1,363 @@ +from unittest.mock import Mock, patch + +from core.entities.document_task import DocumentTask +from core.rag.pipeline.queue import TenantIsolatedTaskQueue +from enums.cloud_plan import CloudPlan +from services.document_indexing_proxy.duplicate_document_indexing_task_proxy import ( + DuplicateDocumentIndexingTaskProxy, +) + + +class DuplicateDocumentIndexingTaskProxyTestDataFactory: + """Factory class for creating test data and mock objects for DuplicateDocumentIndexingTaskProxy tests.""" + + @staticmethod + def create_mock_features(billing_enabled: bool = False, plan: CloudPlan = CloudPlan.SANDBOX) -> Mock: + """Create mock features with billing configuration.""" + features = Mock() + features.billing = Mock() + features.billing.enabled = billing_enabled + features.billing.subscription = Mock() + features.billing.subscription.plan = plan + return features + + @staticmethod + def create_mock_tenant_queue(has_task_key: bool = False) -> Mock: + """Create mock TenantIsolatedTaskQueue.""" + queue = Mock(spec=TenantIsolatedTaskQueue) + queue.get_task_key.return_value = "task_key" if has_task_key else None + queue.push_tasks = Mock() + queue.set_task_waiting_time = Mock() + return queue + + @staticmethod + def create_duplicate_document_task_proxy( + tenant_id: str = "tenant-123", dataset_id: str = "dataset-456", document_ids: list[str] | None = None + ) -> DuplicateDocumentIndexingTaskProxy: + """Create DuplicateDocumentIndexingTaskProxy instance for testing.""" + if document_ids is None: + document_ids = ["doc-1", "doc-2", "doc-3"] + return DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) + + +class TestDuplicateDocumentIndexingTaskProxy: + """Test cases for DuplicateDocumentIndexingTaskProxy class.""" + + def test_initialization(self): + """Test DuplicateDocumentIndexingTaskProxy initialization.""" + # Arrange + tenant_id = "tenant-123" + dataset_id = "dataset-456" + document_ids = ["doc-1", "doc-2", "doc-3"] + + # Act + proxy = DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) + + # Assert + assert proxy._tenant_id == tenant_id + assert proxy._dataset_id == dataset_id + assert proxy._document_ids == document_ids + assert isinstance(proxy._tenant_isolated_task_queue, TenantIsolatedTaskQueue) + assert proxy._tenant_isolated_task_queue._tenant_id == tenant_id + assert proxy._tenant_isolated_task_queue._unique_key == "duplicate_document_indexing" + + def test_queue_name(self): + """Test QUEUE_NAME class variable.""" + # Arrange & Act + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + + # Assert + assert proxy.QUEUE_NAME == "duplicate_document_indexing" + + def test_task_functions(self): + """Test NORMAL_TASK_FUNC and PRIORITY_TASK_FUNC class variables.""" + # Arrange & Act + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + + # Assert + assert proxy.NORMAL_TASK_FUNC.__name__ == "normal_duplicate_document_indexing_task" + assert proxy.PRIORITY_TASK_FUNC.__name__ == "priority_duplicate_document_indexing_task" + + @patch("services.document_indexing_proxy.base.FeatureService") + def test_features_property(self, mock_feature_service): + """Test cached_property features.""" + # Arrange + mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features() + mock_feature_service.get_features.return_value = mock_features + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + + # Act + features1 = proxy.features + features2 = proxy.features # Second call should use cached property + + # Assert + assert features1 == mock_features + assert features2 == mock_features + assert features1 is features2 # Should be the same instance due to caching + mock_feature_service.get_features.assert_called_once_with("tenant-123") + + @patch( + "services.document_indexing_proxy.duplicate_document_indexing_task_proxy.normal_duplicate_document_indexing_task" + ) + def test_send_to_direct_queue(self, mock_task): + """Test _send_to_direct_queue method.""" + # Arrange + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + mock_task.delay = Mock() + + # Act + proxy._send_to_direct_queue(mock_task) + + # Assert + mock_task.delay.assert_called_once_with( + tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] + ) + + @patch( + "services.document_indexing_proxy.duplicate_document_indexing_task_proxy.normal_duplicate_document_indexing_task" + ) + def test_send_to_tenant_queue_with_existing_task_key(self, mock_task): + """Test _send_to_tenant_queue when task key exists.""" + # Arrange + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._tenant_isolated_task_queue = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue( + has_task_key=True + ) + mock_task.delay = Mock() + + # Act + proxy._send_to_tenant_queue(mock_task) + + # Assert + proxy._tenant_isolated_task_queue.push_tasks.assert_called_once() + pushed_tasks = proxy._tenant_isolated_task_queue.push_tasks.call_args[0][0] + assert len(pushed_tasks) == 1 + assert isinstance(DocumentTask(**pushed_tasks[0]), DocumentTask) + assert pushed_tasks[0]["tenant_id"] == "tenant-123" + assert pushed_tasks[0]["dataset_id"] == "dataset-456" + assert pushed_tasks[0]["document_ids"] == ["doc-1", "doc-2", "doc-3"] + mock_task.delay.assert_not_called() + + @patch( + "services.document_indexing_proxy.duplicate_document_indexing_task_proxy.normal_duplicate_document_indexing_task" + ) + def test_send_to_tenant_queue_without_task_key(self, mock_task): + """Test _send_to_tenant_queue when no task key exists.""" + # Arrange + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._tenant_isolated_task_queue = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue( + has_task_key=False + ) + mock_task.delay = Mock() + + # Act + proxy._send_to_tenant_queue(mock_task) + + # Assert + proxy._tenant_isolated_task_queue.set_task_waiting_time.assert_called_once() + mock_task.delay.assert_called_once_with( + tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] + ) + proxy._tenant_isolated_task_queue.push_tasks.assert_not_called() + + def test_send_to_default_tenant_queue(self): + """Test _send_to_default_tenant_queue method.""" + # Arrange + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._send_to_tenant_queue = Mock() + + # Act + proxy._send_to_default_tenant_queue() + + # Assert + proxy._send_to_tenant_queue.assert_called_once_with(proxy.NORMAL_TASK_FUNC) + + def test_send_to_priority_tenant_queue(self): + """Test _send_to_priority_tenant_queue method.""" + # Arrange + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._send_to_tenant_queue = Mock() + + # Act + proxy._send_to_priority_tenant_queue() + + # Assert + proxy._send_to_tenant_queue.assert_called_once_with(proxy.PRIORITY_TASK_FUNC) + + def test_send_to_priority_direct_queue(self): + """Test _send_to_priority_direct_queue method.""" + # Arrange + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._send_to_direct_queue = Mock() + + # Act + proxy._send_to_priority_direct_queue() + + # Assert + proxy._send_to_direct_queue.assert_called_once_with(proxy.PRIORITY_TASK_FUNC) + + @patch("services.document_indexing_proxy.base.FeatureService") + def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service): + """Test _dispatch method when billing is enabled with sandbox plan.""" + # Arrange + mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features( + billing_enabled=True, plan=CloudPlan.SANDBOX + ) + mock_feature_service.get_features.return_value = mock_features + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._send_to_default_tenant_queue = Mock() + + # Act + proxy._dispatch() + + # Assert + proxy._send_to_default_tenant_queue.assert_called_once() + + @patch("services.document_indexing_proxy.base.FeatureService") + def test_dispatch_with_billing_enabled_non_sandbox_plan(self, mock_feature_service): + """Test _dispatch method when billing is enabled with non-sandbox plan.""" + # Arrange + mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features( + billing_enabled=True, plan=CloudPlan.TEAM + ) + mock_feature_service.get_features.return_value = mock_features + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._send_to_priority_tenant_queue = Mock() + + # Act + proxy._dispatch() + + # Assert + # If billing enabled with non sandbox plan, should send to priority tenant queue + proxy._send_to_priority_tenant_queue.assert_called_once() + + @patch("services.document_indexing_proxy.base.FeatureService") + def test_dispatch_with_billing_disabled(self, mock_feature_service): + """Test _dispatch method when billing is disabled.""" + # Arrange + mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=False) + mock_feature_service.get_features.return_value = mock_features + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._send_to_priority_direct_queue = Mock() + + # Act + proxy._dispatch() + + # Assert + # If billing disabled, for example: self-hosted or enterprise, should send to priority direct queue + proxy._send_to_priority_direct_queue.assert_called_once() + + @patch("services.document_indexing_proxy.base.FeatureService") + def test_delay_method(self, mock_feature_service): + """Test delay method integration.""" + # Arrange + mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features( + billing_enabled=True, plan=CloudPlan.SANDBOX + ) + mock_feature_service.get_features.return_value = mock_features + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._send_to_default_tenant_queue = Mock() + + # Act + proxy.delay() + + # Assert + # If billing enabled with sandbox plan, should send to default tenant queue + proxy._send_to_default_tenant_queue.assert_called_once() + + @patch("services.document_indexing_proxy.base.FeatureService") + def test_dispatch_edge_case_empty_plan(self, mock_feature_service): + """Test _dispatch method with empty plan string.""" + # Arrange + mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features( + billing_enabled=True, plan="" + ) + mock_feature_service.get_features.return_value = mock_features + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._send_to_priority_tenant_queue = Mock() + + # Act + proxy._dispatch() + + # Assert + proxy._send_to_priority_tenant_queue.assert_called_once() + + @patch("services.document_indexing_proxy.base.FeatureService") + def test_dispatch_edge_case_none_plan(self, mock_feature_service): + """Test _dispatch method with None plan.""" + # Arrange + mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features( + billing_enabled=True, plan=None + ) + mock_feature_service.get_features.return_value = mock_features + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._send_to_priority_tenant_queue = Mock() + + # Act + proxy._dispatch() + + # Assert + proxy._send_to_priority_tenant_queue.assert_called_once() + + def test_initialization_with_empty_document_ids(self): + """Test initialization with empty document_ids list.""" + # Arrange + tenant_id = "tenant-123" + dataset_id = "dataset-456" + document_ids = [] + + # Act + proxy = DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) + + # Assert + assert proxy._tenant_id == tenant_id + assert proxy._dataset_id == dataset_id + assert proxy._document_ids == document_ids + + def test_initialization_with_single_document_id(self): + """Test initialization with single document_id.""" + # Arrange + tenant_id = "tenant-123" + dataset_id = "dataset-456" + document_ids = ["doc-1"] + + # Act + proxy = DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) + + # Assert + assert proxy._tenant_id == tenant_id + assert proxy._dataset_id == dataset_id + assert proxy._document_ids == document_ids + + def test_initialization_with_large_batch(self): + """Test initialization with large batch of document IDs.""" + # Arrange + tenant_id = "tenant-123" + dataset_id = "dataset-456" + document_ids = [f"doc-{i}" for i in range(100)] + + # Act + proxy = DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) + + # Assert + assert proxy._tenant_id == tenant_id + assert proxy._dataset_id == dataset_id + assert proxy._document_ids == document_ids + assert len(proxy._document_ids) == 100 + + @patch("services.document_indexing_proxy.base.FeatureService") + def test_dispatch_with_professional_plan(self, mock_feature_service): + """Test _dispatch method when billing is enabled with professional plan.""" + # Arrange + mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features( + billing_enabled=True, plan=CloudPlan.PROFESSIONAL + ) + mock_feature_service.get_features.return_value = mock_features + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._send_to_priority_tenant_queue = Mock() + + # Act + proxy._dispatch() + + # Assert + proxy._send_to_priority_tenant_queue.assert_called_once() diff --git a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py index b3b29fbe45..9d7599b8fe 100644 --- a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py @@ -19,7 +19,7 @@ from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan from extensions.ext_redis import redis_client from models.dataset import Dataset, Document -from services.document_indexing_task_proxy import DocumentIndexingTaskProxy +from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy from tasks.document_indexing_task import ( _document_indexing, _document_indexing_with_tenant_queue, @@ -138,7 +138,9 @@ class TestTaskEnqueuing: with patch.object(DocumentIndexingTaskProxy, "features") as mock_features: mock_features.billing.enabled = False - with patch("services.document_indexing_task_proxy.priority_document_indexing_task") as mock_task: + # Mock the class variable directly + mock_task = Mock() + with patch.object(DocumentIndexingTaskProxy, "PRIORITY_TASK_FUNC", mock_task): proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) # Act @@ -163,7 +165,9 @@ class TestTaskEnqueuing: mock_features.billing.enabled = True mock_features.billing.subscription.plan = CloudPlan.SANDBOX - with patch("services.document_indexing_task_proxy.normal_document_indexing_task") as mock_task: + # Mock the class variable directly + mock_task = Mock() + with patch.object(DocumentIndexingTaskProxy, "NORMAL_TASK_FUNC", mock_task): proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) # Act @@ -187,7 +191,9 @@ class TestTaskEnqueuing: mock_features.billing.enabled = True mock_features.billing.subscription.plan = CloudPlan.PROFESSIONAL - with patch("services.document_indexing_task_proxy.priority_document_indexing_task") as mock_task: + # Mock the class variable directly + mock_task = Mock() + with patch.object(DocumentIndexingTaskProxy, "PRIORITY_TASK_FUNC", mock_task): proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) # Act @@ -211,7 +217,9 @@ class TestTaskEnqueuing: mock_features.billing.enabled = True mock_features.billing.subscription.plan = CloudPlan.PROFESSIONAL - with patch("services.document_indexing_task_proxy.priority_document_indexing_task") as mock_task: + # Mock the class variable directly + mock_task = Mock() + with patch.object(DocumentIndexingTaskProxy, "PRIORITY_TASK_FUNC", mock_task): proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) # Act @@ -1493,7 +1501,9 @@ class TestEdgeCases: mock_features.billing.enabled = True mock_features.billing.subscription.plan = CloudPlan.PROFESSIONAL - with patch("services.document_indexing_task_proxy.priority_document_indexing_task") as mock_task: + # Mock the class variable directly + mock_task = Mock() + with patch.object(DocumentIndexingTaskProxy, "PRIORITY_TASK_FUNC", mock_task): # Act - Enqueue multiple tasks rapidly for doc_ids in document_ids_list: proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, doc_ids) @@ -1898,7 +1908,7 @@ class TestRobustness: - Error is propagated appropriately """ # Arrange - with patch("services.document_indexing_task_proxy.FeatureService.get_features") as mock_get_features: + with patch("services.document_indexing_proxy.base.FeatureService.get_features") as mock_get_features: # Simulate FeatureService failure mock_get_features.side_effect = Exception("Feature service unavailable") diff --git a/api/tests/unit_tests/tasks/test_duplicate_document_indexing_task.py b/api/tests/unit_tests/tasks/test_duplicate_document_indexing_task.py new file mode 100644 index 0000000000..0be6ea045e --- /dev/null +++ b/api/tests/unit_tests/tasks/test_duplicate_document_indexing_task.py @@ -0,0 +1,567 @@ +""" +Unit tests for duplicate document indexing tasks. + +This module tests the duplicate document indexing task functionality including: +- Task enqueuing to different queues (normal, priority, tenant-isolated) +- Batch processing of multiple duplicate documents +- Progress tracking through task lifecycle +- Error handling and retry mechanisms +- Cleanup of old document data before re-indexing +""" + +import uuid +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from core.indexing_runner import DocumentIsPausedError, IndexingRunner +from core.rag.pipeline.queue import TenantIsolatedTaskQueue +from enums.cloud_plan import CloudPlan +from models.dataset import Dataset, Document, DocumentSegment +from tasks.duplicate_document_indexing_task import ( + _duplicate_document_indexing_task, + _duplicate_document_indexing_task_with_tenant_queue, + duplicate_document_indexing_task, + normal_duplicate_document_indexing_task, + priority_duplicate_document_indexing_task, +) + +# ============================================================================ +# Fixtures +# ============================================================================ + + +@pytest.fixture +def tenant_id(): + """Generate a unique tenant ID for testing.""" + return str(uuid.uuid4()) + + +@pytest.fixture +def dataset_id(): + """Generate a unique dataset ID for testing.""" + return str(uuid.uuid4()) + + +@pytest.fixture +def document_ids(): + """Generate a list of document IDs for testing.""" + return [str(uuid.uuid4()) for _ in range(3)] + + +@pytest.fixture +def mock_dataset(dataset_id, tenant_id): + """Create a mock Dataset object.""" + dataset = Mock(spec=Dataset) + dataset.id = dataset_id + dataset.tenant_id = tenant_id + dataset.indexing_technique = "high_quality" + dataset.embedding_model_provider = "openai" + dataset.embedding_model = "text-embedding-ada-002" + return dataset + + +@pytest.fixture +def mock_documents(document_ids, dataset_id): + """Create mock Document objects.""" + documents = [] + for doc_id in document_ids: + doc = Mock(spec=Document) + doc.id = doc_id + doc.dataset_id = dataset_id + doc.indexing_status = "waiting" + doc.error = None + doc.stopped_at = None + doc.processing_started_at = None + doc.doc_form = "text_model" + documents.append(doc) + return documents + + +@pytest.fixture +def mock_document_segments(document_ids): + """Create mock DocumentSegment objects.""" + segments = [] + for doc_id in document_ids: + for i in range(3): + segment = Mock(spec=DocumentSegment) + segment.id = str(uuid.uuid4()) + segment.document_id = doc_id + segment.index_node_id = f"node-{doc_id}-{i}" + segments.append(segment) + return segments + + +@pytest.fixture +def mock_db_session(): + """Mock database session.""" + with patch("tasks.duplicate_document_indexing_task.db.session") as mock_session: + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_session.scalars.return_value = MagicMock() + yield mock_session + + +@pytest.fixture +def mock_indexing_runner(): + """Mock IndexingRunner.""" + with patch("tasks.duplicate_document_indexing_task.IndexingRunner") as mock_runner_class: + mock_runner = MagicMock(spec=IndexingRunner) + mock_runner_class.return_value = mock_runner + yield mock_runner + + +@pytest.fixture +def mock_feature_service(): + """Mock FeatureService.""" + with patch("tasks.duplicate_document_indexing_task.FeatureService") as mock_service: + mock_features = Mock() + mock_features.billing = Mock() + mock_features.billing.enabled = False + mock_features.vector_space = Mock() + mock_features.vector_space.size = 0 + mock_features.vector_space.limit = 1000 + mock_service.get_features.return_value = mock_features + yield mock_service + + +@pytest.fixture +def mock_index_processor_factory(): + """Mock IndexProcessorFactory.""" + with patch("tasks.duplicate_document_indexing_task.IndexProcessorFactory") as mock_factory: + mock_processor = MagicMock() + mock_processor.clean = Mock() + mock_factory.return_value.init_index_processor.return_value = mock_processor + yield mock_factory + + +@pytest.fixture +def mock_tenant_isolated_queue(): + """Mock TenantIsolatedTaskQueue.""" + with patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue") as mock_queue_class: + mock_queue = MagicMock(spec=TenantIsolatedTaskQueue) + mock_queue.pull_tasks.return_value = [] + mock_queue.delete_task_key = Mock() + mock_queue.set_task_waiting_time = Mock() + mock_queue_class.return_value = mock_queue + yield mock_queue + + +# ============================================================================ +# Tests for deprecated duplicate_document_indexing_task +# ============================================================================ + + +class TestDuplicateDocumentIndexingTask: + """Tests for the deprecated duplicate_document_indexing_task function.""" + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task") + def test_duplicate_document_indexing_task_calls_core_function(self, mock_core_func, dataset_id, document_ids): + """Test that duplicate_document_indexing_task calls the core _duplicate_document_indexing_task function.""" + # Act + duplicate_document_indexing_task(dataset_id, document_ids) + + # Assert + mock_core_func.assert_called_once_with(dataset_id, document_ids) + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task") + def test_duplicate_document_indexing_task_with_empty_document_ids(self, mock_core_func, dataset_id): + """Test duplicate_document_indexing_task with empty document_ids list.""" + # Arrange + document_ids = [] + + # Act + duplicate_document_indexing_task(dataset_id, document_ids) + + # Assert + mock_core_func.assert_called_once_with(dataset_id, document_ids) + + +# ============================================================================ +# Tests for _duplicate_document_indexing_task core function +# ============================================================================ + + +class TestDuplicateDocumentIndexingTaskCore: + """Tests for the _duplicate_document_indexing_task core function.""" + + def test_successful_duplicate_document_indexing( + self, + mock_db_session, + mock_indexing_runner, + mock_feature_service, + mock_index_processor_factory, + mock_dataset, + mock_documents, + mock_document_segments, + dataset_id, + document_ids, + ): + """Test successful duplicate document indexing flow.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents + mock_db_session.scalars.return_value.all.return_value = mock_document_segments + + # Act + _duplicate_document_indexing_task(dataset_id, document_ids) + + # Assert + # Verify IndexingRunner was called + mock_indexing_runner.run.assert_called_once() + + # Verify all documents were set to parsing status + for doc in mock_documents: + assert doc.indexing_status == "parsing" + assert doc.processing_started_at is not None + + # Verify session operations + assert mock_db_session.commit.called + assert mock_db_session.close.called + + def test_duplicate_document_indexing_dataset_not_found(self, mock_db_session, dataset_id, document_ids): + """Test duplicate document indexing when dataset is not found.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.return_value = None + + # Act + _duplicate_document_indexing_task(dataset_id, document_ids) + + # Assert + # Should close the session at least once + assert mock_db_session.close.called + + def test_duplicate_document_indexing_with_billing_enabled_sandbox_plan( + self, + mock_db_session, + mock_feature_service, + mock_dataset, + dataset_id, + document_ids, + ): + """Test duplicate document indexing with billing enabled and sandbox plan.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_features = mock_feature_service.get_features.return_value + mock_features.billing.enabled = True + mock_features.billing.subscription.plan = CloudPlan.SANDBOX + + # Act + _duplicate_document_indexing_task(dataset_id, document_ids) + + # Assert + # For sandbox plan with multiple documents, should fail + mock_db_session.commit.assert_called() + + def test_duplicate_document_indexing_with_billing_limit_exceeded( + self, + mock_db_session, + mock_feature_service, + mock_dataset, + mock_documents, + dataset_id, + document_ids, + ): + """Test duplicate document indexing when billing limit is exceeded.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents + mock_db_session.scalars.return_value.all.return_value = [] # No segments to clean + mock_features = mock_feature_service.get_features.return_value + mock_features.billing.enabled = True + mock_features.billing.subscription.plan = CloudPlan.TEAM + mock_features.vector_space.size = 990 + mock_features.vector_space.limit = 1000 + + # Act + _duplicate_document_indexing_task(dataset_id, document_ids) + + # Assert + # Should commit the session + assert mock_db_session.commit.called + # Should close the session + assert mock_db_session.close.called + + def test_duplicate_document_indexing_runner_error( + self, + mock_db_session, + mock_indexing_runner, + mock_feature_service, + mock_index_processor_factory, + mock_dataset, + mock_documents, + dataset_id, + document_ids, + ): + """Test duplicate document indexing when IndexingRunner raises an error.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents + mock_db_session.scalars.return_value.all.return_value = [] + mock_indexing_runner.run.side_effect = Exception("Indexing error") + + # Act + _duplicate_document_indexing_task(dataset_id, document_ids) + + # Assert + # Should close the session even after error + mock_db_session.close.assert_called_once() + + def test_duplicate_document_indexing_document_is_paused( + self, + mock_db_session, + mock_indexing_runner, + mock_feature_service, + mock_index_processor_factory, + mock_dataset, + mock_documents, + dataset_id, + document_ids, + ): + """Test duplicate document indexing when document is paused.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents + mock_db_session.scalars.return_value.all.return_value = [] + mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused") + + # Act + _duplicate_document_indexing_task(dataset_id, document_ids) + + # Assert + # Should handle DocumentIsPausedError gracefully + mock_db_session.close.assert_called_once() + + def test_duplicate_document_indexing_cleans_old_segments( + self, + mock_db_session, + mock_indexing_runner, + mock_feature_service, + mock_index_processor_factory, + mock_dataset, + mock_documents, + mock_document_segments, + dataset_id, + document_ids, + ): + """Test that duplicate document indexing cleans old segments.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents + mock_db_session.scalars.return_value.all.return_value = mock_document_segments + mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value + + # Act + _duplicate_document_indexing_task(dataset_id, document_ids) + + # Assert + # Verify clean was called for each document + assert mock_processor.clean.call_count == len(mock_documents) + + # Verify segments were deleted + for segment in mock_document_segments: + mock_db_session.delete.assert_any_call(segment) + + +# ============================================================================ +# Tests for tenant queue wrapper function +# ============================================================================ + + +class TestDuplicateDocumentIndexingTaskWithTenantQueue: + """Tests for _duplicate_document_indexing_task_with_tenant_queue function.""" + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task") + def test_tenant_queue_wrapper_calls_core_function( + self, + mock_core_func, + mock_tenant_isolated_queue, + tenant_id, + dataset_id, + document_ids, + ): + """Test that tenant queue wrapper calls the core function.""" + # Arrange + mock_task_func = Mock() + + # Act + _duplicate_document_indexing_task_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task_func) + + # Assert + mock_core_func.assert_called_once_with(dataset_id, document_ids) + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task") + def test_tenant_queue_wrapper_deletes_key_when_no_tasks( + self, + mock_core_func, + mock_tenant_isolated_queue, + tenant_id, + dataset_id, + document_ids, + ): + """Test that tenant queue wrapper deletes task key when no more tasks.""" + # Arrange + mock_task_func = Mock() + mock_tenant_isolated_queue.pull_tasks.return_value = [] + + # Act + _duplicate_document_indexing_task_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task_func) + + # Assert + mock_tenant_isolated_queue.delete_task_key.assert_called_once() + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task") + def test_tenant_queue_wrapper_processes_next_tasks( + self, + mock_core_func, + mock_tenant_isolated_queue, + tenant_id, + dataset_id, + document_ids, + ): + """Test that tenant queue wrapper processes next tasks from queue.""" + # Arrange + mock_task_func = Mock() + next_task = { + "tenant_id": tenant_id, + "dataset_id": dataset_id, + "document_ids": document_ids, + } + mock_tenant_isolated_queue.pull_tasks.return_value = [next_task] + + # Act + _duplicate_document_indexing_task_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task_func) + + # Assert + mock_tenant_isolated_queue.set_task_waiting_time.assert_called_once() + mock_task_func.delay.assert_called_once_with( + tenant_id=tenant_id, + dataset_id=dataset_id, + document_ids=document_ids, + ) + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task") + def test_tenant_queue_wrapper_handles_core_function_error( + self, + mock_core_func, + mock_tenant_isolated_queue, + tenant_id, + dataset_id, + document_ids, + ): + """Test that tenant queue wrapper handles errors from core function.""" + # Arrange + mock_task_func = Mock() + mock_core_func.side_effect = Exception("Core function error") + + # Act + _duplicate_document_indexing_task_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task_func) + + # Assert + # Should still check for next tasks even after error + mock_tenant_isolated_queue.pull_tasks.assert_called_once() + + +# ============================================================================ +# Tests for normal_duplicate_document_indexing_task +# ============================================================================ + + +class TestNormalDuplicateDocumentIndexingTask: + """Tests for normal_duplicate_document_indexing_task function.""" + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue") + def test_normal_task_calls_tenant_queue_wrapper( + self, + mock_wrapper_func, + tenant_id, + dataset_id, + document_ids, + ): + """Test that normal task calls tenant queue wrapper.""" + # Act + normal_duplicate_document_indexing_task(tenant_id, dataset_id, document_ids) + + # Assert + mock_wrapper_func.assert_called_once_with( + tenant_id, dataset_id, document_ids, normal_duplicate_document_indexing_task + ) + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue") + def test_normal_task_with_empty_document_ids( + self, + mock_wrapper_func, + tenant_id, + dataset_id, + ): + """Test normal task with empty document_ids list.""" + # Arrange + document_ids = [] + + # Act + normal_duplicate_document_indexing_task(tenant_id, dataset_id, document_ids) + + # Assert + mock_wrapper_func.assert_called_once_with( + tenant_id, dataset_id, document_ids, normal_duplicate_document_indexing_task + ) + + +# ============================================================================ +# Tests for priority_duplicate_document_indexing_task +# ============================================================================ + + +class TestPriorityDuplicateDocumentIndexingTask: + """Tests for priority_duplicate_document_indexing_task function.""" + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue") + def test_priority_task_calls_tenant_queue_wrapper( + self, + mock_wrapper_func, + tenant_id, + dataset_id, + document_ids, + ): + """Test that priority task calls tenant queue wrapper.""" + # Act + priority_duplicate_document_indexing_task(tenant_id, dataset_id, document_ids) + + # Assert + mock_wrapper_func.assert_called_once_with( + tenant_id, dataset_id, document_ids, priority_duplicate_document_indexing_task + ) + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue") + def test_priority_task_with_single_document( + self, + mock_wrapper_func, + tenant_id, + dataset_id, + ): + """Test priority task with single document.""" + # Arrange + document_ids = ["doc-1"] + + # Act + priority_duplicate_document_indexing_task(tenant_id, dataset_id, document_ids) + + # Assert + mock_wrapper_func.assert_called_once_with( + tenant_id, dataset_id, document_ids, priority_duplicate_document_indexing_task + ) + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue") + def test_priority_task_with_large_batch( + self, + mock_wrapper_func, + tenant_id, + dataset_id, + ): + """Test priority task with large batch of documents.""" + # Arrange + document_ids = [f"doc-{i}" for i in range(100)] + + # Act + priority_duplicate_document_indexing_task(tenant_id, dataset_id, document_ids) + + # Assert + mock_wrapper_func.assert_called_once_with( + tenant_id, dataset_id, document_ids, priority_duplicate_document_indexing_task + )