From 37903722fe3583cd82addb804e57b1dcdd959e4b Mon Sep 17 00:00:00 2001 From: hj24 Date: Thu, 6 Nov 2025 21:25:50 +0800 Subject: [PATCH] refactor: implement tenant self queue for rag tasks (#27559) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: -LAN- --- api/.env.example | 3 + api/configs/feature/__init__.py | 8 + .../app/apps/pipeline/pipeline_generator.py | 36 +- api/core/entities/document_task.py | 15 + api/core/rag/pipeline/__init__.py | 0 api/core/rag/pipeline/queue.py | 79 ++ api/docker/entrypoint.sh | 2 +- api/services/dataset_service.py | 4 +- api/services/document_indexing_task_proxy.py | 83 ++ .../rag_pipeline/rag_pipeline_task_proxy.py | 106 ++ api/tasks/document_indexing_task.py | 79 ++ .../priority_rag_pipeline_run_task.py | 25 + .../rag_pipeline/rag_pipeline_run_task.py | 34 +- .../core/rag/__init__.py | 1 + .../core/rag/pipeline/__init__.py | 0 .../rag/pipeline/test_queue_integration.py | 595 +++++++++++ .../tasks/test_document_indexing_task.py | 414 +++++++- .../tasks/test_rag_pipeline_run_tasks.py | 936 ++++++++++++++++++ .../core/rag/pipeline/test_queue.py | 301 ++++++ .../test_document_indexing_task_proxy.py | 317 ++++++ .../services/test_rag_pipeline_task_proxy.py | 483 +++++++++ dev/start-worker | 2 +- docker/.env.example | 3 + docker/docker-compose.yaml | 1 + 24 files changed, 3433 insertions(+), 94 deletions(-) create mode 100644 api/core/entities/document_task.py create mode 100644 api/core/rag/pipeline/__init__.py create mode 100644 api/core/rag/pipeline/queue.py create mode 100644 api/services/document_indexing_task_proxy.py create mode 100644 api/services/rag_pipeline/rag_pipeline_task_proxy.py create mode 100644 api/tests/test_containers_integration_tests/core/rag/__init__.py create mode 100644 api/tests/test_containers_integration_tests/core/rag/pipeline/__init__.py create mode 100644 api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py create mode 100644 api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py create mode 100644 api/tests/unit_tests/core/rag/pipeline/test_queue.py create mode 100644 api/tests/unit_tests/services/test_document_indexing_task_proxy.py create mode 100644 api/tests/unit_tests/services/test_rag_pipeline_task_proxy.py diff --git a/api/.env.example b/api/.env.example index f5bfa72254..3120e1cdd6 100644 --- a/api/.env.example +++ b/api/.env.example @@ -615,5 +615,8 @@ SWAGGER_UI_PATH=/swagger-ui.html # Set to false to export dataset IDs as plain text for easier cross-environment import DSL_EXPORT_ENCRYPT_DATASET_ID=true +# Tenant isolated task queue configuration +TENANT_ISOLATED_TASK_CONCURRENCY=1 + # Maximum number of segments for dataset segments API (0 for unlimited) DATASET_MAX_SEGMENTS_PER_REQUEST=0 diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 843e6b6f70..86c37dca25 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -1142,6 +1142,13 @@ class SwaggerUIConfig(BaseSettings): ) +class TenantIsolatedTaskQueueConfig(BaseSettings): + TENANT_ISOLATED_TASK_CONCURRENCY: int = Field( + description="Number of tasks allowed to be delivered concurrently from isolated queue per tenant", + default=1, + ) + + class FeatureConfig( # place the configs in alphabet order AppExecutionConfig, @@ -1166,6 +1173,7 @@ class FeatureConfig( RagEtlConfig, RepositoryConfig, SecurityConfig, + TenantIsolatedTaskQueueConfig, ToolConfig, UpdateConfig, WorkflowConfig, diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index c36d34f571..a1390ad0be 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -40,20 +40,15 @@ from core.workflow.repositories.draft_variable_repository import DraftVariableSa from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader -from enums.cloud_plan import CloudPlan from extensions.ext_database import db -from extensions.ext_redis import redis_client from libs.flask_utils import preserve_flask_contexts from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline from models.enums import WorkflowRunTriggeredFrom from models.model import AppMode from services.datasource_provider_service import DatasourceProviderService -from services.feature_service import FeatureService -from services.file_service import FileService +from services.rag_pipeline.rag_pipeline_task_proxy import RagPipelineTaskProxy from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService -from tasks.rag_pipeline.priority_rag_pipeline_run_task import priority_rag_pipeline_run_task -from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task logger = logging.getLogger(__name__) @@ -249,34 +244,7 @@ class PipelineGenerator(BaseAppGenerator): ) if rag_pipeline_invoke_entities: - # store the rag_pipeline_invoke_entities to object storage - text = [item.model_dump() for item in rag_pipeline_invoke_entities] - name = "rag_pipeline_invoke_entities.json" - # Convert list to proper JSON string - json_text = json.dumps(text) - upload_file = FileService(db.engine).upload_text(json_text, name, user.id, dataset.tenant_id) - features = FeatureService.get_features(dataset.tenant_id) - if features.billing.enabled and features.billing.subscription.plan == CloudPlan.SANDBOX: - tenant_pipeline_task_key = f"tenant_pipeline_task:{dataset.tenant_id}" - tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{dataset.tenant_id}" - - if redis_client.get(tenant_pipeline_task_key): - # Add to waiting queue using List operations (lpush) - redis_client.lpush(tenant_self_pipeline_task_queue, upload_file.id) - else: - # Set flag and execute task - redis_client.set(tenant_pipeline_task_key, 1, ex=60 * 60) - rag_pipeline_run_task.delay( # type: ignore - rag_pipeline_invoke_entities_file_id=upload_file.id, - tenant_id=dataset.tenant_id, - ) - - else: - priority_rag_pipeline_run_task.delay( # type: ignore - rag_pipeline_invoke_entities_file_id=upload_file.id, - tenant_id=dataset.tenant_id, - ) - + RagPipelineTaskProxy(dataset.tenant_id, user.id, rag_pipeline_invoke_entities).delay() # return batch, dataset, documents return { "batch": batch, diff --git a/api/core/entities/document_task.py b/api/core/entities/document_task.py new file mode 100644 index 0000000000..27ab5c84f7 --- /dev/null +++ b/api/core/entities/document_task.py @@ -0,0 +1,15 @@ +from collections.abc import Sequence +from dataclasses import dataclass + + +@dataclass +class DocumentTask: + """Document task entity for document indexing operations. + + This class represents a document indexing task that can be queued + and processed by the document indexing system. + """ + + tenant_id: str + dataset_id: str + document_ids: Sequence[str] diff --git a/api/core/rag/pipeline/__init__.py b/api/core/rag/pipeline/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/rag/pipeline/queue.py b/api/core/rag/pipeline/queue.py new file mode 100644 index 0000000000..3d4d6f588d --- /dev/null +++ b/api/core/rag/pipeline/queue.py @@ -0,0 +1,79 @@ +import json +from collections.abc import Sequence +from typing import Any + +from pydantic import BaseModel, ValidationError + +from extensions.ext_redis import redis_client + +_DEFAULT_TASK_TTL = 60 * 60 # 1 hour + + +class TaskWrapper(BaseModel): + data: Any + + def serialize(self) -> str: + return self.model_dump_json() + + @classmethod + def deserialize(cls, serialized_data: str) -> "TaskWrapper": + return cls.model_validate_json(serialized_data) + + +class TenantIsolatedTaskQueue: + """ + Simple queue for tenant isolated tasks, used for rag related tenant tasks isolation. + It uses Redis list to store tasks, and Redis key to store task waiting flag. + Support tasks that can be serialized by json. + """ + + def __init__(self, tenant_id: str, unique_key: str): + self._tenant_id = tenant_id + self._unique_key = unique_key + self._queue = f"tenant_self_{unique_key}_task_queue:{tenant_id}" + self._task_key = f"tenant_{unique_key}_task:{tenant_id}" + + def get_task_key(self): + return redis_client.get(self._task_key) + + def set_task_waiting_time(self, ttl: int = _DEFAULT_TASK_TTL): + redis_client.setex(self._task_key, ttl, 1) + + def delete_task_key(self): + redis_client.delete(self._task_key) + + def push_tasks(self, tasks: Sequence[Any]): + serialized_tasks = [] + for task in tasks: + # Store str list directly, maintaining full compatibility for pipeline scenarios + if isinstance(task, str): + serialized_tasks.append(task) + else: + # Use TaskWrapper to do JSON serialization for non-string tasks + wrapper = TaskWrapper(data=task) + serialized_data = wrapper.serialize() + serialized_tasks.append(serialized_data) + + redis_client.lpush(self._queue, *serialized_tasks) + + def pull_tasks(self, count: int = 1) -> Sequence[Any]: + if count <= 0: + return [] + + tasks = [] + for _ in range(count): + serialized_task = redis_client.rpop(self._queue) + if not serialized_task: + break + + if isinstance(serialized_task, bytes): + serialized_task = serialized_task.decode("utf-8") + + try: + wrapper = TaskWrapper.deserialize(serialized_task) + tasks.append(wrapper.data) + except (json.JSONDecodeError, ValidationError, TypeError, ValueError): + # Fall back to raw string for legacy format or invalid JSON + tasks.append(serialized_task) + + return tasks diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index 8f6998119e..41b5eb20b5 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -32,7 +32,7 @@ if [[ "${MODE}" == "worker" ]]; then exec celery -A celery_entrypoint.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION \ --max-tasks-per-child ${MAX_TASKS_PER_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \ - -Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline} \ + -Q ${CELERY_QUEUES:-dataset,priority_dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline} \ --prefetch-multiplier=1 elif [[ "${MODE}" == "beat" ]]; then diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 2e255c0a9b..78de76df7e 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -50,6 +50,7 @@ from models.model import UploadFile from models.provider_ids import ModelProviderID from models.source import DataSourceOauthBinding from models.workflow import Workflow +from services.document_indexing_task_proxy import DocumentIndexingTaskProxy from services.entities.knowledge_entities.knowledge_entities import ( ChildChunkUpdateArgs, KnowledgeConfig, @@ -79,7 +80,6 @@ from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task from tasks.delete_segment_from_index_task import delete_segment_from_index_task from tasks.disable_segment_from_index_task import disable_segment_from_index_task from tasks.disable_segments_from_index_task import disable_segments_from_index_task -from tasks.document_indexing_task import document_indexing_task from tasks.document_indexing_update_task import document_indexing_update_task from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task from tasks.enable_segments_to_index_task import enable_segments_to_index_task @@ -1694,7 +1694,7 @@ class DocumentService: # trigger async task if document_ids: - document_indexing_task.delay(dataset.id, document_ids) + DocumentIndexingTaskProxy(dataset.tenant_id, dataset.id, document_ids).delay() if duplicate_document_ids: duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids) diff --git a/api/services/document_indexing_task_proxy.py b/api/services/document_indexing_task_proxy.py new file mode 100644 index 0000000000..861c84b586 --- /dev/null +++ b/api/services/document_indexing_task_proxy.py @@ -0,0 +1,83 @@ +import logging +from collections.abc import Callable, Sequence +from dataclasses import asdict +from functools import cached_property + +from core.entities.document_task import DocumentTask +from core.rag.pipeline.queue import TenantIsolatedTaskQueue +from enums.cloud_plan import CloudPlan +from services.feature_service import FeatureService +from tasks.document_indexing_task import normal_document_indexing_task, priority_document_indexing_task + +logger = logging.getLogger(__name__) + + +class DocumentIndexingTaskProxy: + def __init__(self, tenant_id: str, dataset_id: str, document_ids: Sequence[str]): + self._tenant_id = tenant_id + self._dataset_id = dataset_id + self._document_ids = document_ids + self._tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing") + + @cached_property + def features(self): + return FeatureService.get_features(self._tenant_id) + + def _send_to_direct_queue(self, task_func: Callable[[str, str, Sequence[str]], None]): + logger.info("send dataset %s to direct queue", self._dataset_id) + task_func.delay( # type: ignore + tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids + ) + + def _send_to_tenant_queue(self, task_func: Callable[[str, str, Sequence[str]], None]): + logger.info("send dataset %s to tenant queue", self._dataset_id) + if self._tenant_isolated_task_queue.get_task_key(): + # Add to waiting queue using List operations (lpush) + self._tenant_isolated_task_queue.push_tasks( + [ + asdict( + DocumentTask( + tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids + ) + ) + ] + ) + logger.info("push tasks: %s - %s", self._dataset_id, self._document_ids) + else: + # Set flag and execute task + self._tenant_isolated_task_queue.set_task_waiting_time() + task_func.delay( # type: ignore + tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids + ) + logger.info("init tasks: %s - %s", self._dataset_id, self._document_ids) + + def _send_to_default_tenant_queue(self): + self._send_to_tenant_queue(normal_document_indexing_task) + + def _send_to_priority_tenant_queue(self): + self._send_to_tenant_queue(priority_document_indexing_task) + + def _send_to_priority_direct_queue(self): + self._send_to_direct_queue(priority_document_indexing_task) + + def _dispatch(self): + logger.info( + "dispatch args: %s - %s - %s", + self._tenant_id, + self.features.billing.enabled, + self.features.billing.subscription.plan, + ) + # dispatch to different indexing queue with tenant isolation when billing enabled + if self.features.billing.enabled: + if self.features.billing.subscription.plan == CloudPlan.SANDBOX: + # dispatch to normal pipeline queue with tenant self sub queue for sandbox plan + self._send_to_default_tenant_queue() + else: + # dispatch to priority pipeline queue with tenant self sub queue for other plans + self._send_to_priority_tenant_queue() + else: + # dispatch to priority queue without tenant isolation for others, e.g.: self-hosted or enterprise + self._send_to_priority_direct_queue() + + def delay(self): + self._dispatch() diff --git a/api/services/rag_pipeline/rag_pipeline_task_proxy.py b/api/services/rag_pipeline/rag_pipeline_task_proxy.py new file mode 100644 index 0000000000..94dd7941da --- /dev/null +++ b/api/services/rag_pipeline/rag_pipeline_task_proxy.py @@ -0,0 +1,106 @@ +import json +import logging +from collections.abc import Callable, Sequence +from functools import cached_property + +from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity +from core.rag.pipeline.queue import TenantIsolatedTaskQueue +from enums.cloud_plan import CloudPlan +from extensions.ext_database import db +from services.feature_service import FeatureService +from services.file_service import FileService +from tasks.rag_pipeline.priority_rag_pipeline_run_task import priority_rag_pipeline_run_task +from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task + +logger = logging.getLogger(__name__) + + +class RagPipelineTaskProxy: + # Default uploaded file name for rag pipeline invoke entities + _RAG_PIPELINE_INVOKE_ENTITIES_FILE_NAME = "rag_pipeline_invoke_entities.json" + + def __init__( + self, dataset_tenant_id: str, user_id: str, rag_pipeline_invoke_entities: Sequence[RagPipelineInvokeEntity] + ): + self._dataset_tenant_id = dataset_tenant_id + self._user_id = user_id + self._rag_pipeline_invoke_entities = rag_pipeline_invoke_entities + self._tenant_isolated_task_queue = TenantIsolatedTaskQueue(dataset_tenant_id, "pipeline") + + @cached_property + def features(self): + return FeatureService.get_features(self._dataset_tenant_id) + + def _upload_invoke_entities(self) -> str: + text = [item.model_dump() for item in self._rag_pipeline_invoke_entities] + # Convert list to proper JSON string + json_text = json.dumps(text) + upload_file = FileService(db.engine).upload_text( + json_text, self._RAG_PIPELINE_INVOKE_ENTITIES_FILE_NAME, self._user_id, self._dataset_tenant_id + ) + return upload_file.id + + def _send_to_direct_queue(self, upload_file_id: str, task_func: Callable[[str, str], None]): + logger.info("send file %s to direct queue", upload_file_id) + task_func.delay( # type: ignore + rag_pipeline_invoke_entities_file_id=upload_file_id, + tenant_id=self._dataset_tenant_id, + ) + + def _send_to_tenant_queue(self, upload_file_id: str, task_func: Callable[[str, str], None]): + logger.info("send file %s to tenant queue", upload_file_id) + if self._tenant_isolated_task_queue.get_task_key(): + # Add to waiting queue using List operations (lpush) + self._tenant_isolated_task_queue.push_tasks([upload_file_id]) + logger.info("push tasks: %s", upload_file_id) + else: + # Set flag and execute task + self._tenant_isolated_task_queue.set_task_waiting_time() + task_func.delay( # type: ignore + rag_pipeline_invoke_entities_file_id=upload_file_id, + tenant_id=self._dataset_tenant_id, + ) + logger.info("init tasks: %s", upload_file_id) + + def _send_to_default_tenant_queue(self, upload_file_id: str): + self._send_to_tenant_queue(upload_file_id, rag_pipeline_run_task) + + def _send_to_priority_tenant_queue(self, upload_file_id: str): + self._send_to_tenant_queue(upload_file_id, priority_rag_pipeline_run_task) + + def _send_to_priority_direct_queue(self, upload_file_id: str): + self._send_to_direct_queue(upload_file_id, priority_rag_pipeline_run_task) + + def _dispatch(self): + upload_file_id = self._upload_invoke_entities() + if not upload_file_id: + raise ValueError("upload_file_id is empty") + + logger.info( + "dispatch args: %s - %s - %s", + self._dataset_tenant_id, + self.features.billing.enabled, + self.features.billing.subscription.plan, + ) + + # dispatch to different pipeline queue with tenant isolation when billing enabled + if self.features.billing.enabled: + if self.features.billing.subscription.plan == CloudPlan.SANDBOX: + # dispatch to normal pipeline queue with tenant isolation for sandbox plan + self._send_to_default_tenant_queue(upload_file_id) + else: + # dispatch to priority pipeline queue with tenant isolation for other plans + self._send_to_priority_tenant_queue(upload_file_id) + else: + # dispatch to priority pipeline queue without tenant isolation for others, e.g.: self-hosted or enterprise + self._send_to_priority_direct_queue(upload_file_id) + + def delay(self): + if not self._rag_pipeline_invoke_entities: + logger.warning( + "Received empty rag pipeline invoke entities, no tasks delivered: %s %s", + self._dataset_tenant_id, + self._user_id, + ) + return + self._dispatch() diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index 07f469de0e..fee4430612 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -1,11 +1,14 @@ import logging import time +from collections.abc import Callable, Sequence import click from celery import shared_task from configs import dify_config +from core.entities.document_task import DocumentTask from core.indexing_runner import DocumentIsPausedError, IndexingRunner +from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan from extensions.ext_database import db from libs.datetime_utils import naive_utc_now @@ -22,8 +25,24 @@ def document_indexing_task(dataset_id: str, document_ids: list): :param dataset_id: :param document_ids: + .. warning:: TO BE DEPRECATED + This function will be deprecated and removed in a future version. + Use normal_document_indexing_task or priority_document_indexing_task instead. + Usage: document_indexing_task.delay(dataset_id, document_ids) """ + logger.warning("document indexing legacy mode received: %s - %s", dataset_id, document_ids) + _document_indexing(dataset_id, document_ids) + + +def _document_indexing(dataset_id: str, document_ids: Sequence[str]): + """ + Process document for tasks + :param dataset_id: + :param document_ids: + + Usage: _document_indexing(dataset_id, document_ids) + """ documents = [] start_at = time.perf_counter() @@ -87,3 +106,63 @@ def document_indexing_task(dataset_id: str, document_ids: list): logger.exception("Document indexing task failed, dataset_id: %s", dataset_id) finally: db.session.close() + + +def _document_indexing_with_tenant_queue( + tenant_id: str, dataset_id: str, document_ids: Sequence[str], task_func: Callable[[str, str, Sequence[str]], None] +): + try: + _document_indexing(dataset_id, document_ids) + except Exception: + logger.exception("Error processing document indexing %s for tenant %s: %s", dataset_id, tenant_id) + finally: + tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing") + + # Check if there are waiting tasks in the queue + # Use rpop to get the next task from the queue (FIFO order) + next_tasks = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY) + + logger.info("document indexing tenant isolation queue next tasks: %s", next_tasks) + + if next_tasks: + for next_task in next_tasks: + document_task = DocumentTask(**next_task) + # Process the next waiting task + # Keep the flag set to indicate a task is running + tenant_isolated_task_queue.set_task_waiting_time() + task_func.delay( # type: ignore + tenant_id=document_task.tenant_id, + dataset_id=document_task.dataset_id, + document_ids=document_task.document_ids, + ) + else: + # No more waiting tasks, clear the flag + tenant_isolated_task_queue.delete_task_key() + + +@shared_task(queue="dataset") +def normal_document_indexing_task(tenant_id: str, dataset_id: str, document_ids: Sequence[str]): + """ + Async process document + :param tenant_id: + :param dataset_id: + :param document_ids: + + Usage: normal_document_indexing_task.delay(tenant_id, dataset_id, document_ids) + """ + logger.info("normal document indexing task received: %s - %s - %s", tenant_id, dataset_id, document_ids) + _document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, normal_document_indexing_task) + + +@shared_task(queue="priority_dataset") +def priority_document_indexing_task(tenant_id: str, dataset_id: str, document_ids: Sequence[str]): + """ + Priority async process document + :param tenant_id: + :param dataset_id: + :param document_ids: + + Usage: priority_document_indexing_task.delay(tenant_id, dataset_id, document_ids) + """ + logger.info("priority document indexing task received: %s - %s - %s", tenant_id, dataset_id, document_ids) + _document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, priority_document_indexing_task) diff --git a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py index 6de95a3b85..a7f61d9811 100644 --- a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py @@ -12,8 +12,10 @@ from celery import shared_task # type: ignore from flask import current_app, g from sqlalchemy.orm import Session, sessionmaker +from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity +from core.rag.pipeline.queue import TenantIsolatedTaskQueue from core.repositories.factory import DifyCoreRepositoryFactory from extensions.ext_database import db from models import Account, Tenant @@ -22,6 +24,8 @@ from models.enums import WorkflowRunTriggeredFrom from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom from services.file_service import FileService +logger = logging.getLogger(__name__) + @shared_task(queue="priority_pipeline") def priority_rag_pipeline_run_task( @@ -69,6 +73,27 @@ def priority_rag_pipeline_run_task( logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red")) raise finally: + tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "pipeline") + + # Check if there are waiting tasks in the queue + # Use rpop to get the next task from the queue (FIFO order) + next_file_ids = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY) + logger.info("priority rag pipeline tenant isolation queue next files: %s", next_file_ids) + + if next_file_ids: + for next_file_id in next_file_ids: + # Process the next waiting task + # Keep the flag set to indicate a task is running + tenant_isolated_task_queue.set_task_waiting_time() + priority_rag_pipeline_run_task.delay( # type: ignore + rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8") + if isinstance(next_file_id, bytes) + else next_file_id, + tenant_id=tenant_id, + ) + else: + # No more waiting tasks, clear the flag + tenant_isolated_task_queue.delete_task_key() file_service = FileService(db.engine) file_service.delete_file(rag_pipeline_invoke_entities_file_id) db.session.close() diff --git a/api/tasks/rag_pipeline/rag_pipeline_run_task.py b/api/tasks/rag_pipeline/rag_pipeline_run_task.py index f4a092d97e..92f1dfb73d 100644 --- a/api/tasks/rag_pipeline/rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/rag_pipeline_run_task.py @@ -12,17 +12,20 @@ from celery import shared_task # type: ignore from flask import current_app, g from sqlalchemy.orm import Session, sessionmaker +from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity +from core.rag.pipeline.queue import TenantIsolatedTaskQueue from core.repositories.factory import DifyCoreRepositoryFactory from extensions.ext_database import db -from extensions.ext_redis import redis_client from models import Account, Tenant from models.dataset import Pipeline from models.enums import WorkflowRunTriggeredFrom from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom from services.file_service import FileService +logger = logging.getLogger(__name__) + @shared_task(queue="pipeline") def rag_pipeline_run_task( @@ -70,26 +73,27 @@ def rag_pipeline_run_task( logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red")) raise finally: - tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{tenant_id}" - tenant_pipeline_task_key = f"tenant_pipeline_task:{tenant_id}" + tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "pipeline") # Check if there are waiting tasks in the queue # Use rpop to get the next task from the queue (FIFO order) - next_file_id = redis_client.rpop(tenant_self_pipeline_task_queue) + next_file_ids = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY) + logger.info("rag pipeline tenant isolation queue next files: %s", next_file_ids) - if next_file_id: - # Process the next waiting task - # Keep the flag set to indicate a task is running - redis_client.setex(tenant_pipeline_task_key, 60 * 60, 1) - rag_pipeline_run_task.delay( # type: ignore - rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8") - if isinstance(next_file_id, bytes) - else next_file_id, - tenant_id=tenant_id, - ) + if next_file_ids: + for next_file_id in next_file_ids: + # Process the next waiting task + # Keep the flag set to indicate a task is running + tenant_isolated_task_queue.set_task_waiting_time() + rag_pipeline_run_task.delay( # type: ignore + rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8") + if isinstance(next_file_id, bytes) + else next_file_id, + tenant_id=tenant_id, + ) else: # No more waiting tasks, clear the flag - redis_client.delete(tenant_pipeline_task_key) + tenant_isolated_task_queue.delete_task_key() file_service = FileService(db.engine) file_service.delete_file(rag_pipeline_invoke_entities_file_id) db.session.close() diff --git a/api/tests/test_containers_integration_tests/core/rag/__init__.py b/api/tests/test_containers_integration_tests/core/rag/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/api/tests/test_containers_integration_tests/core/rag/__init__.py @@ -0,0 +1 @@ + diff --git a/api/tests/test_containers_integration_tests/core/rag/pipeline/__init__.py b/api/tests/test_containers_integration_tests/core/rag/pipeline/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py b/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py new file mode 100644 index 0000000000..cdf390b327 --- /dev/null +++ b/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py @@ -0,0 +1,595 @@ +""" +Integration tests for TenantIsolatedTaskQueue using testcontainers. + +These tests verify the Redis-based task queue functionality with real Redis instances, +testing tenant isolation, task serialization, and queue operations in a realistic environment. +Includes compatibility tests for migrating from legacy string-only queues. + +All tests use generic naming to avoid coupling to specific business implementations. +""" + +import time +from dataclasses import dataclass +from typing import Any +from uuid import uuid4 + +import pytest +from faker import Faker + +from core.rag.pipeline.queue import TaskWrapper, TenantIsolatedTaskQueue +from extensions.ext_redis import redis_client +from models import Account, Tenant, TenantAccountJoin, TenantAccountRole + + +@dataclass +class TestTask: + """Test task data structure for testing complex object serialization.""" + + task_id: str + tenant_id: str + data: dict[str, Any] + metadata: dict[str, Any] + + +class TestTenantIsolatedTaskQueueIntegration: + """Integration tests for TenantIsolatedTaskQueue using testcontainers.""" + + @pytest.fixture + def fake(self): + """Faker instance for generating test data.""" + return Faker() + + @pytest.fixture + def test_tenant_and_account(self, db_session_with_containers, fake): + """Create test tenant and account for testing.""" + # Create account + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + db_session_with_containers.add(account) + db_session_with_containers.commit() + + # Create tenant + tenant = Tenant( + name=fake.company(), + status="normal", + ) + db_session_with_containers.add(tenant) + db_session_with_containers.commit() + + # Create tenant-account join + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + return tenant, account + + @pytest.fixture + def test_queue(self, test_tenant_and_account): + """Create a generic test queue for testing.""" + tenant, _ = test_tenant_and_account + return TenantIsolatedTaskQueue(tenant.id, "test_queue") + + @pytest.fixture + def secondary_queue(self, test_tenant_and_account): + """Create a secondary test queue for testing isolation.""" + tenant, _ = test_tenant_and_account + return TenantIsolatedTaskQueue(tenant.id, "secondary_queue") + + def test_queue_initialization(self, test_tenant_and_account): + """Test queue initialization with correct key generation.""" + tenant, _ = test_tenant_and_account + queue = TenantIsolatedTaskQueue(tenant.id, "test-key") + + assert queue._tenant_id == tenant.id + assert queue._unique_key == "test-key" + assert queue._queue == f"tenant_self_test-key_task_queue:{tenant.id}" + assert queue._task_key == f"tenant_test-key_task:{tenant.id}" + + def test_tenant_isolation(self, test_tenant_and_account, db_session_with_containers, fake): + """Test that different tenants have isolated queues.""" + tenant1, _ = test_tenant_and_account + + # Create second tenant + tenant2 = Tenant( + name=fake.company(), + status="normal", + ) + db_session_with_containers.add(tenant2) + db_session_with_containers.commit() + + queue1 = TenantIsolatedTaskQueue(tenant1.id, "same-key") + queue2 = TenantIsolatedTaskQueue(tenant2.id, "same-key") + + assert queue1._queue != queue2._queue + assert queue1._task_key != queue2._task_key + assert queue1._queue == f"tenant_self_same-key_task_queue:{tenant1.id}" + assert queue2._queue == f"tenant_self_same-key_task_queue:{tenant2.id}" + + def test_key_isolation(self, test_tenant_and_account): + """Test that different keys have isolated queues.""" + tenant, _ = test_tenant_and_account + queue1 = TenantIsolatedTaskQueue(tenant.id, "key1") + queue2 = TenantIsolatedTaskQueue(tenant.id, "key2") + + assert queue1._queue != queue2._queue + assert queue1._task_key != queue2._task_key + assert queue1._queue == f"tenant_self_key1_task_queue:{tenant.id}" + assert queue2._queue == f"tenant_self_key2_task_queue:{tenant.id}" + + def test_task_key_operations(self, test_queue): + """Test task key operations (get, set, delete).""" + # Initially no task key should exist + assert test_queue.get_task_key() is None + + # Set task waiting time with default TTL + test_queue.set_task_waiting_time() + task_key = test_queue.get_task_key() + # Redis returns bytes, convert to string for comparison + assert task_key in (b"1", "1") + + # Set task waiting time with custom TTL + custom_ttl = 30 + test_queue.set_task_waiting_time(custom_ttl) + task_key = test_queue.get_task_key() + assert task_key in (b"1", "1") + + # Delete task key + test_queue.delete_task_key() + assert test_queue.get_task_key() is None + + def test_push_and_pull_string_tasks(self, test_queue): + """Test pushing and pulling string tasks.""" + tasks = ["task1", "task2", "task3"] + + # Push tasks + test_queue.push_tasks(tasks) + + # Pull tasks (FIFO order) + pulled_tasks = test_queue.pull_tasks(3) + + # Should get tasks in FIFO order (lpush + rpop = FIFO) + assert pulled_tasks == ["task1", "task2", "task3"] + + def test_push_and_pull_multiple_tasks(self, test_queue): + """Test pushing and pulling multiple tasks at once.""" + tasks = ["task1", "task2", "task3", "task4", "task5"] + + # Push tasks + test_queue.push_tasks(tasks) + + # Pull multiple tasks + pulled_tasks = test_queue.pull_tasks(3) + assert len(pulled_tasks) == 3 + assert pulled_tasks == ["task1", "task2", "task3"] + + # Pull remaining tasks + remaining_tasks = test_queue.pull_tasks(5) + assert len(remaining_tasks) == 2 + assert remaining_tasks == ["task4", "task5"] + + def test_push_and_pull_complex_objects(self, test_queue, fake): + """Test pushing and pulling complex object tasks.""" + # Create complex task objects as dictionaries (not dataclass instances) + tasks = [ + { + "task_id": str(uuid4()), + "tenant_id": test_queue._tenant_id, + "data": { + "file_id": str(uuid4()), + "content": fake.text(), + "metadata": {"size": fake.random_int(1000, 10000)}, + }, + "metadata": {"created_at": fake.iso8601(), "tags": fake.words(3)}, + }, + { + "task_id": str(uuid4()), + "tenant_id": test_queue._tenant_id, + "data": { + "file_id": str(uuid4()), + "content": "测试中文内容", + "metadata": {"size": fake.random_int(1000, 10000)}, + }, + "metadata": {"created_at": fake.iso8601(), "tags": ["中文", "测试", "emoji🚀"]}, + }, + ] + + # Push complex tasks + test_queue.push_tasks(tasks) + + # Pull tasks + pulled_tasks = test_queue.pull_tasks(2) + assert len(pulled_tasks) == 2 + + # Verify deserialized tasks match original (FIFO order) + for i, pulled_task in enumerate(pulled_tasks): + original_task = tasks[i] # FIFO order + assert isinstance(pulled_task, dict) + assert pulled_task["task_id"] == original_task["task_id"] + assert pulled_task["tenant_id"] == original_task["tenant_id"] + assert pulled_task["data"] == original_task["data"] + assert pulled_task["metadata"] == original_task["metadata"] + + def test_mixed_task_types(self, test_queue, fake): + """Test pushing and pulling mixed string and object tasks.""" + string_task = "simple_string_task" + object_task = { + "task_id": str(uuid4()), + "dataset_id": str(uuid4()), + "document_ids": [str(uuid4()) for _ in range(3)], + } + + tasks = [string_task, object_task, "another_string"] + + # Push mixed tasks + test_queue.push_tasks(tasks) + + # Pull all tasks + pulled_tasks = test_queue.pull_tasks(3) + assert len(pulled_tasks) == 3 + + # Verify types and content + assert pulled_tasks[0] == string_task + assert isinstance(pulled_tasks[1], dict) + assert pulled_tasks[1] == object_task + assert pulled_tasks[2] == "another_string" + + def test_empty_queue_operations(self, test_queue): + """Test operations on empty queue.""" + # Pull from empty queue + tasks = test_queue.pull_tasks(5) + assert tasks == [] + + # Pull zero or negative count + assert test_queue.pull_tasks(0) == [] + assert test_queue.pull_tasks(-1) == [] + + def test_task_ttl_expiration(self, test_queue): + """Test task key TTL expiration.""" + # Set task with short TTL + short_ttl = 2 + test_queue.set_task_waiting_time(short_ttl) + + # Verify task key exists + assert test_queue.get_task_key() == b"1" or test_queue.get_task_key() == "1" + + # Wait for TTL to expire + time.sleep(short_ttl + 1) + + # Verify task key has expired + assert test_queue.get_task_key() is None + + def test_large_task_batch(self, test_queue, fake): + """Test handling large batches of tasks.""" + # Create large batch of tasks + large_batch = [] + for i in range(100): + task = { + "task_id": str(uuid4()), + "index": i, + "data": fake.text(max_nb_chars=100), + "metadata": {"batch_id": str(uuid4())}, + } + large_batch.append(task) + + # Push large batch + test_queue.push_tasks(large_batch) + + # Pull all tasks + pulled_tasks = test_queue.pull_tasks(100) + assert len(pulled_tasks) == 100 + + # Verify all tasks were retrieved correctly (FIFO order) + for i, task in enumerate(pulled_tasks): + assert isinstance(task, dict) + assert task["index"] == i # FIFO order + + def test_queue_operations_isolation(self, test_tenant_and_account, fake): + """Test concurrent operations on different queues.""" + tenant, _ = test_tenant_and_account + + # Create multiple queues for the same tenant + queue1 = TenantIsolatedTaskQueue(tenant.id, "queue1") + queue2 = TenantIsolatedTaskQueue(tenant.id, "queue2") + + # Push tasks to different queues + queue1.push_tasks(["task1_queue1", "task2_queue1"]) + queue2.push_tasks(["task1_queue2", "task2_queue2"]) + + # Verify queues are isolated + tasks1 = queue1.pull_tasks(2) + tasks2 = queue2.pull_tasks(2) + + assert tasks1 == ["task1_queue1", "task2_queue1"] + assert tasks2 == ["task1_queue2", "task2_queue2"] + assert tasks1 != tasks2 + + def test_task_wrapper_serialization_roundtrip(self, test_queue, fake): + """Test TaskWrapper serialization and deserialization roundtrip.""" + # Create complex nested data + complex_data = { + "id": str(uuid4()), + "nested": {"deep": {"value": "test", "numbers": [1, 2, 3, 4, 5], "unicode": "测试中文", "emoji": "🚀"}}, + "metadata": {"created_at": fake.iso8601(), "tags": ["tag1", "tag2", "tag3"]}, + } + + # Create wrapper and serialize + wrapper = TaskWrapper(data=complex_data) + serialized = wrapper.serialize() + + # Verify serialization + assert isinstance(serialized, str) + assert "测试中文" in serialized + assert "🚀" in serialized + + # Deserialize and verify + deserialized_wrapper = TaskWrapper.deserialize(serialized) + assert deserialized_wrapper.data == complex_data + + def test_error_handling_invalid_json(self, test_queue): + """Test error handling for invalid JSON in wrapped tasks.""" + # Manually create invalid JSON task (not a valid TaskWrapper JSON) + invalid_json_task = "invalid json data" + + # Push invalid task directly to Redis + redis_client.lpush(test_queue._queue, invalid_json_task) + + # Pull task - should fall back to string since it's not valid JSON + task = test_queue.pull_tasks(1) + assert task[0] == invalid_json_task + + def test_real_world_batch_processing_scenario(self, test_queue, fake): + """Test realistic batch processing scenario.""" + # Simulate batch processing tasks + batch_tasks = [] + for i in range(3): + task = { + "file_id": str(uuid4()), + "tenant_id": test_queue._tenant_id, + "user_id": str(uuid4()), + "processing_config": { + "model": fake.random_element(["model_a", "model_b", "model_c"]), + "temperature": fake.random.uniform(0.1, 1.0), + "max_tokens": fake.random_int(1000, 4000), + }, + "metadata": { + "source": fake.random_element(["upload", "api", "webhook"]), + "priority": fake.random_element(["low", "normal", "high"]), + }, + } + batch_tasks.append(task) + + # Push tasks + test_queue.push_tasks(batch_tasks) + + # Process tasks in batches + batch_size = 2 + processed_tasks = [] + + while True: + batch = test_queue.pull_tasks(batch_size) + if not batch: + break + + processed_tasks.extend(batch) + + # Verify all tasks were processed + assert len(processed_tasks) == 3 + + # Verify task structure + for task in processed_tasks: + assert isinstance(task, dict) + assert "file_id" in task + assert "tenant_id" in task + assert "processing_config" in task + assert "metadata" in task + assert task["tenant_id"] == test_queue._tenant_id + + +class TestTenantIsolatedTaskQueueCompatibility: + """Compatibility tests for migrating from legacy string-only queues.""" + + @pytest.fixture + def fake(self): + """Faker instance for generating test data.""" + return Faker() + + @pytest.fixture + def test_tenant_and_account(self, db_session_with_containers, fake): + """Create test tenant and account for testing.""" + # Create account + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + db_session_with_containers.add(account) + db_session_with_containers.commit() + + # Create tenant + tenant = Tenant( + name=fake.company(), + status="normal", + ) + db_session_with_containers.add(tenant) + db_session_with_containers.commit() + + # Create tenant-account join + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + return tenant, account + + def test_legacy_string_queue_compatibility(self, test_tenant_and_account, fake): + """ + Test compatibility with legacy queues containing only string data. + + This simulates the scenario where Redis queues already contain string data + from the old architecture, and we need to ensure the new code can read them. + """ + tenant, _ = test_tenant_and_account + queue = TenantIsolatedTaskQueue(tenant.id, "legacy_queue") + + # Simulate legacy string data in Redis queue (using old format) + legacy_strings = ["legacy_task_1", "legacy_task_2", "legacy_task_3", "legacy_task_4", "legacy_task_5"] + + # Manually push legacy strings directly to Redis (simulating old system) + for legacy_string in legacy_strings: + redis_client.lpush(queue._queue, legacy_string) + + # Verify new code can read legacy string data + pulled_tasks = queue.pull_tasks(5) + assert len(pulled_tasks) == 5 + + # Verify all tasks are strings (not wrapped) + for task in pulled_tasks: + assert isinstance(task, str) + assert task.startswith("legacy_task_") + + # Verify order (FIFO from Redis list) + expected_order = ["legacy_task_1", "legacy_task_2", "legacy_task_3", "legacy_task_4", "legacy_task_5"] + assert pulled_tasks == expected_order + + def test_legacy_queue_migration_scenario(self, test_tenant_and_account, fake): + """ + Test complete migration scenario from legacy to new system. + + This simulates the real-world scenario where: + 1. Legacy system has string data in Redis + 2. New system starts processing the same queue + 3. Both legacy and new tasks coexist during migration + 4. New system can handle both formats seamlessly + """ + tenant, _ = test_tenant_and_account + queue = TenantIsolatedTaskQueue(tenant.id, "migration_queue") + + # Phase 1: Legacy system has data + legacy_tasks = [f"legacy_resource_{i}" for i in range(1, 6)] + redis_client.lpush(queue._queue, *legacy_tasks) + + # Phase 2: New system starts processing legacy data + processed_legacy = [] + while True: + tasks = queue.pull_tasks(1) + if not tasks: + break + processed_legacy.extend(tasks) + + # Verify legacy data was processed correctly + assert len(processed_legacy) == 5 + for task in processed_legacy: + assert isinstance(task, str) + assert task.startswith("legacy_resource_") + + # Phase 3: New system adds new tasks (mixed types) + new_string_tasks = ["new_resource_1", "new_resource_2"] + new_object_tasks = [ + { + "resource_id": str(uuid4()), + "tenant_id": tenant.id, + "processing_type": "new_system", + "metadata": {"version": "2.0", "features": ["ai", "ml"]}, + }, + { + "resource_id": str(uuid4()), + "tenant_id": tenant.id, + "processing_type": "new_system", + "metadata": {"version": "2.0", "features": ["ai", "ml"]}, + }, + ] + + # Push new tasks using new system + queue.push_tasks(new_string_tasks) + queue.push_tasks(new_object_tasks) + + # Phase 4: Process all new tasks + processed_new = [] + while True: + tasks = queue.pull_tasks(1) + if not tasks: + break + processed_new.extend(tasks) + + # Verify new tasks were processed correctly + assert len(processed_new) == 4 + + string_tasks = [task for task in processed_new if isinstance(task, str)] + object_tasks = [task for task in processed_new if isinstance(task, dict)] + + assert len(string_tasks) == 2 + assert len(object_tasks) == 2 + + # Verify string tasks + for task in string_tasks: + assert task.startswith("new_resource_") + + # Verify object tasks + for task in object_tasks: + assert isinstance(task, dict) + assert "resource_id" in task + assert "tenant_id" in task + assert task["tenant_id"] == tenant.id + assert task["processing_type"] == "new_system" + + def test_legacy_queue_error_recovery(self, test_tenant_and_account, fake): + """ + Test error recovery when legacy queue contains malformed data. + + This ensures the new system can gracefully handle corrupted or + malformed legacy data without crashing. + """ + tenant, _ = test_tenant_and_account + queue = TenantIsolatedTaskQueue(tenant.id, "error_recovery_queue") + + # Create mix of valid and malformed legacy data + mixed_legacy_data = [ + "valid_legacy_task_1", + "valid_legacy_task_2", + "malformed_data_string", # This should be treated as string + "valid_legacy_task_3", + "invalid_json_not_taskwrapper_format", # This should fall back to string (not valid TaskWrapper JSON) + "valid_legacy_task_4", + ] + + # Manually push mixed data directly to Redis + redis_client.lpush(queue._queue, *mixed_legacy_data) + + # Process all tasks + processed_tasks = [] + while True: + tasks = queue.pull_tasks(1) + if not tasks: + break + processed_tasks.extend(tasks) + + # Verify all tasks were processed (no crashes) + assert len(processed_tasks) == 6 + + # Verify all tasks are strings (malformed data falls back to string) + for task in processed_tasks: + assert isinstance(task, str) + + # Verify valid tasks are preserved + valid_tasks = [task for task in processed_tasks if task.startswith("valid_legacy_task_")] + assert len(valid_tasks) == 4 + + # Verify malformed data is handled gracefully + malformed_tasks = [task for task in processed_tasks if not task.startswith("valid_legacy_task_")] + assert len(malformed_tasks) == 2 + assert "malformed_data_string" in malformed_tasks + assert "invalid_json_not_taskwrapper_format" in malformed_tasks diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py index 1329bba082..c015d7ec9c 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py @@ -1,17 +1,33 @@ +from dataclasses import asdict from unittest.mock import MagicMock, patch import pytest from faker import Faker +from core.entities.document_task import DocumentTask from enums.cloud_plan import CloudPlan from extensions.ext_database import db from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document -from tasks.document_indexing_task import document_indexing_task +from tasks.document_indexing_task import ( + _document_indexing, # Core function + _document_indexing_with_tenant_queue, # Tenant queue wrapper function + document_indexing_task, # Deprecated old interface + normal_document_indexing_task, # New normal task + priority_document_indexing_task, # New priority task +) -class TestDocumentIndexingTask: - """Integration tests for document_indexing_task using testcontainers.""" +class TestDocumentIndexingTasks: + """Integration tests for document indexing tasks using testcontainers. + + This test class covers: + - Core _document_indexing function + - Deprecated document_indexing_task function + - New normal_document_indexing_task function + - New priority_document_indexing_task function + - Tenant queue wrapper _document_indexing_with_tenant_queue function + """ @pytest.fixture def mock_external_service_dependencies(self): @@ -224,7 +240,7 @@ class TestDocumentIndexingTask: document_ids = [doc.id for doc in documents] # Act: Execute the task - document_indexing_task(dataset.id, document_ids) + _document_indexing(dataset.id, document_ids) # Assert: Verify the expected outcomes # Verify indexing runner was called correctly @@ -232,10 +248,11 @@ class TestDocumentIndexingTask: mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() # Verify documents were updated to parsing status - for document in documents: - db.session.refresh(document) - assert document.indexing_status == "parsing" - assert document.processing_started_at is not None + # Re-query documents from database since _document_indexing uses a different session + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + assert updated_document.processing_started_at is not None # Verify the run method was called with correct documents call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args @@ -261,7 +278,7 @@ class TestDocumentIndexingTask: document_ids = [fake.uuid4() for _ in range(3)] # Act: Execute the task with non-existent dataset - document_indexing_task(non_existent_dataset_id, document_ids) + _document_indexing(non_existent_dataset_id, document_ids) # Assert: Verify no processing occurred mock_external_service_dependencies["indexing_runner"].assert_not_called() @@ -291,17 +308,18 @@ class TestDocumentIndexingTask: all_document_ids = existing_document_ids + non_existent_document_ids # Act: Execute the task with mixed document IDs - document_indexing_task(dataset.id, all_document_ids) + _document_indexing(dataset.id, all_document_ids) # Assert: Verify only existing documents were processed mock_external_service_dependencies["indexing_runner"].assert_called_once() mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() # Verify only existing documents were updated - for document in documents: - db.session.refresh(document) - assert document.indexing_status == "parsing" - assert document.processing_started_at is not None + # Re-query documents from database since _document_indexing uses a different session + for doc_id in existing_document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + assert updated_document.processing_started_at is not None # Verify the run method was called with only existing documents call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args @@ -333,7 +351,7 @@ class TestDocumentIndexingTask: ) # Act: Execute the task - document_indexing_task(dataset.id, document_ids) + _document_indexing(dataset.id, document_ids) # Assert: Verify exception was handled gracefully # The task should complete without raising exceptions @@ -341,10 +359,11 @@ class TestDocumentIndexingTask: mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() # Verify documents were still updated to parsing status before the exception - for document in documents: - db.session.refresh(document) - assert document.indexing_status == "parsing" - assert document.processing_started_at is not None + # Re-query documents from database since _document_indexing close the session + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + assert updated_document.processing_started_at is not None def test_document_indexing_task_mixed_document_states( self, db_session_with_containers, mock_external_service_dependencies @@ -407,17 +426,18 @@ class TestDocumentIndexingTask: document_ids = [doc.id for doc in all_documents] # Act: Execute the task with mixed document states - document_indexing_task(dataset.id, document_ids) + _document_indexing(dataset.id, document_ids) # Assert: Verify processing mock_external_service_dependencies["indexing_runner"].assert_called_once() mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() # Verify all documents were updated to parsing status - for document in all_documents: - db.session.refresh(document) - assert document.indexing_status == "parsing" - assert document.processing_started_at is not None + # Re-query documents from database since _document_indexing uses a different session + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + assert updated_document.processing_started_at is not None # Verify the run method was called with all documents call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args @@ -470,15 +490,16 @@ class TestDocumentIndexingTask: document_ids = [doc.id for doc in all_documents] # Act: Execute the task with too many documents for sandbox plan - document_indexing_task(dataset.id, document_ids) + _document_indexing(dataset.id, document_ids) # Assert: Verify error handling - for document in all_documents: - db.session.refresh(document) - assert document.indexing_status == "error" - assert document.error is not None - assert "batch upload" in document.error - assert document.stopped_at is not None + # Re-query documents from database since _document_indexing uses a different session + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "error" + assert updated_document.error is not None + assert "batch upload" in updated_document.error + assert updated_document.stopped_at is not None # Verify no indexing runner was called mock_external_service_dependencies["indexing_runner"].assert_not_called() @@ -503,17 +524,18 @@ class TestDocumentIndexingTask: document_ids = [doc.id for doc in documents] # Act: Execute the task with billing disabled - document_indexing_task(dataset.id, document_ids) + _document_indexing(dataset.id, document_ids) # Assert: Verify successful processing mock_external_service_dependencies["indexing_runner"].assert_called_once() mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() # Verify documents were updated to parsing status - for document in documents: - db.session.refresh(document) - assert document.indexing_status == "parsing" - assert document.processing_started_at is not None + # Re-query documents from database since _document_indexing uses a different session + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + assert updated_document.processing_started_at is not None def test_document_indexing_task_document_is_paused_error( self, db_session_with_containers, mock_external_service_dependencies @@ -541,7 +563,7 @@ class TestDocumentIndexingTask: ) # Act: Execute the task - document_indexing_task(dataset.id, document_ids) + _document_indexing(dataset.id, document_ids) # Assert: Verify exception was handled gracefully # The task should complete without raising exceptions @@ -549,7 +571,317 @@ class TestDocumentIndexingTask: mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() # Verify documents were still updated to parsing status before the exception - for document in documents: - db.session.refresh(document) - assert document.indexing_status == "parsing" - assert document.processing_started_at is not None + # Re-query documents from database since _document_indexing uses a different session + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + assert updated_document.processing_started_at is not None + + # ==================== NEW TESTS FOR REFACTORED FUNCTIONS ==================== + def test_old_document_indexing_task_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test document_indexing_task basic functionality. + + This test verifies: + - Task function calls the wrapper correctly + - Basic parameter passing works + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=1 + ) + document_ids = [doc.id for doc in documents] + + # Act: Execute the deprecated task (it only takes 2 parameters) + document_indexing_task(dataset.id, document_ids) + + # Assert: Verify processing occurred (core logic is tested in _document_indexing tests) + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + def test_normal_document_indexing_task_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test normal_document_indexing_task basic functionality. + + This test verifies: + - Task function calls the wrapper correctly + - Basic parameter passing works + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=1 + ) + document_ids = [doc.id for doc in documents] + tenant_id = dataset.tenant_id + + # Act: Execute the new normal task + normal_document_indexing_task(tenant_id, dataset.id, document_ids) + + # Assert: Verify processing occurred (core logic is tested in _document_indexing tests) + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + def test_priority_document_indexing_task_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test priority_document_indexing_task basic functionality. + + This test verifies: + - Task function calls the wrapper correctly + - Basic parameter passing works + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=1 + ) + document_ids = [doc.id for doc in documents] + tenant_id = dataset.tenant_id + + # Act: Execute the new priority task + priority_document_indexing_task(tenant_id, dataset.id, document_ids) + + # Assert: Verify processing occurred (core logic is tested in _document_indexing tests) + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + def test_document_indexing_with_tenant_queue_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test _document_indexing_with_tenant_queue function with no waiting tasks. + + This test verifies: + - Core indexing logic execution (same as _document_indexing) + - Tenant queue cleanup when no waiting tasks + - Task function parameter passing + - Queue management after processing + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=2 + ) + document_ids = [doc.id for doc in documents] + tenant_id = dataset.tenant_id + + # Mock the task function + from unittest.mock import MagicMock + + mock_task_func = MagicMock() + + # Act: Execute the wrapper function + _document_indexing_with_tenant_queue(tenant_id, dataset.id, document_ids, mock_task_func) + + # Assert: Verify core processing occurred (same as _document_indexing) + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Verify documents were updated (same as _document_indexing) + # Re-query documents from database since _document_indexing uses a different session + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + assert updated_document.processing_started_at is not None + + # Verify the run method was called with correct documents + call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args + assert call_args is not None + processed_documents = call_args[0][0] + assert len(processed_documents) == 2 + + # Verify task function was not called (no waiting tasks) + mock_task_func.delay.assert_not_called() + + def test_document_indexing_with_tenant_queue_with_waiting_tasks( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test _document_indexing_with_tenant_queue function with waiting tasks in queue using real Redis. + + This test verifies: + - Core indexing logic execution + - Real Redis-based tenant queue processing of waiting tasks + - Task function calls for waiting tasks + - Queue management with multiple tasks using actual Redis operations + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=1 + ) + document_ids = [doc.id for doc in documents] + tenant_id = dataset.tenant_id + dataset_id = dataset.id + + # Mock the task function + from unittest.mock import MagicMock + + mock_task_func = MagicMock() + + # Use real Redis for TenantIsolatedTaskQueue + from core.rag.pipeline.queue import TenantIsolatedTaskQueue + + # Create real queue instance + queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing") + + # Add waiting tasks to the real Redis queue + waiting_tasks = [ + DocumentTask(tenant_id=tenant_id, dataset_id=dataset.id, document_ids=["waiting-doc-1"]), + DocumentTask(tenant_id=tenant_id, dataset_id=dataset.id, document_ids=["waiting-doc-2"]), + ] + # Convert DocumentTask objects to dictionaries for serialization + waiting_task_dicts = [asdict(task) for task in waiting_tasks] + queue.push_tasks(waiting_task_dicts) + + # Act: Execute the wrapper function + _document_indexing_with_tenant_queue(tenant_id, dataset.id, document_ids, mock_task_func) + + # Assert: Verify core processing occurred + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Verify task function was called for each waiting task + assert mock_task_func.delay.call_count == 1 + + # Verify correct parameters for each call + calls = mock_task_func.delay.call_args_list + assert calls[0][1] == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]} + + # Verify queue is empty after processing (tasks were pulled) + remaining_tasks = queue.pull_tasks(count=10) # Pull more than we added + assert len(remaining_tasks) == 1 + + def test_document_indexing_with_tenant_queue_error_handling( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test error handling in _document_indexing_with_tenant_queue using real Redis. + + This test verifies: + - Exception handling during core processing + - Tenant queue cleanup even on errors using real Redis + - Proper error logging + - Function completes without raising exceptions + - Queue management continues despite core processing errors + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=1 + ) + document_ids = [doc.id for doc in documents] + tenant_id = dataset.tenant_id + dataset_id = dataset.id + + # Mock IndexingRunner to raise an exception + mock_external_service_dependencies["indexing_runner_instance"].run.side_effect = Exception("Test error") + + # Mock the task function + from unittest.mock import MagicMock + + mock_task_func = MagicMock() + + # Use real Redis for TenantIsolatedTaskQueue + from core.rag.pipeline.queue import TenantIsolatedTaskQueue + + # Create real queue instance + queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing") + + # Add waiting task to the real Redis queue + waiting_task = DocumentTask(tenant_id=tenant_id, dataset_id=dataset.id, document_ids=["waiting-doc-1"]) + queue.push_tasks([asdict(waiting_task)]) + + # Act: Execute the wrapper function + _document_indexing_with_tenant_queue(tenant_id, dataset.id, document_ids, mock_task_func) + + # Assert: Verify error was handled gracefully + # The function should not raise exceptions + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Verify documents were still updated to parsing status before the exception + # Re-query documents from database since _document_indexing uses a different session + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + assert updated_document.processing_started_at is not None + + # Verify waiting task was still processed despite core processing error + mock_task_func.delay.assert_called_once() + + # Verify correct parameters for the call + call = mock_task_func.delay.call_args + assert call[1] == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]} + + # Verify queue is empty after processing (task was pulled) + remaining_tasks = queue.pull_tasks(count=10) + assert len(remaining_tasks) == 0 + + def test_document_indexing_with_tenant_queue_tenant_isolation( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test tenant isolation in _document_indexing_with_tenant_queue using real Redis. + + This test verifies: + - Different tenants have isolated queues + - Tasks from one tenant don't affect another tenant's queue + - Queue operations are properly scoped to tenant + """ + # Arrange: Create test data for two different tenants + dataset1, documents1 = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=1 + ) + dataset2, documents2 = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=1 + ) + + tenant1_id = dataset1.tenant_id + tenant2_id = dataset2.tenant_id + dataset1_id = dataset1.id + dataset2_id = dataset2.id + document_ids1 = [doc.id for doc in documents1] + document_ids2 = [doc.id for doc in documents2] + + # Mock the task function + from unittest.mock import MagicMock + + mock_task_func = MagicMock() + + # Use real Redis for TenantIsolatedTaskQueue + from core.rag.pipeline.queue import TenantIsolatedTaskQueue + + # Create queue instances for both tenants + queue1 = TenantIsolatedTaskQueue(tenant1_id, "document_indexing") + queue2 = TenantIsolatedTaskQueue(tenant2_id, "document_indexing") + + # Add waiting tasks to both queues + waiting_task1 = DocumentTask(tenant_id=tenant1_id, dataset_id=dataset1.id, document_ids=["tenant1-doc-1"]) + waiting_task2 = DocumentTask(tenant_id=tenant2_id, dataset_id=dataset2.id, document_ids=["tenant2-doc-1"]) + + queue1.push_tasks([asdict(waiting_task1)]) + queue2.push_tasks([asdict(waiting_task2)]) + + # Act: Execute the wrapper function for tenant1 only + _document_indexing_with_tenant_queue(tenant1_id, dataset1.id, document_ids1, mock_task_func) + + # Assert: Verify core processing occurred for tenant1 + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Verify only tenant1's waiting task was processed + mock_task_func.delay.assert_called_once() + call = mock_task_func.delay.call_args + assert call[1] == {"tenant_id": tenant1_id, "dataset_id": dataset1_id, "document_ids": ["tenant1-doc-1"]} + + # Verify tenant1's queue is empty + remaining_tasks1 = queue1.pull_tasks(count=10) + assert len(remaining_tasks1) == 0 + + # Verify tenant2's queue still has its task (isolation) + remaining_tasks2 = queue2.pull_tasks(count=10) + assert len(remaining_tasks2) == 1 + + # Verify queue keys are different + assert queue1._queue != queue2._queue + assert queue1._task_key != queue2._task_key diff --git a/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py b/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py new file mode 100644 index 0000000000..c82162238c --- /dev/null +++ b/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py @@ -0,0 +1,936 @@ +import json +import uuid +from unittest.mock import patch + +import pytest +from faker import Faker + +from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity +from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity +from core.rag.pipeline.queue import TenantIsolatedTaskQueue +from extensions.ext_database import db +from models import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.dataset import Pipeline +from models.workflow import Workflow +from tasks.rag_pipeline.priority_rag_pipeline_run_task import ( + priority_rag_pipeline_run_task, + run_single_rag_pipeline_task, +) +from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task + + +class TestRagPipelineRunTasks: + """Integration tests for RAG pipeline run tasks using testcontainers. + + This test class covers: + - priority_rag_pipeline_run_task function + - rag_pipeline_run_task function + - run_single_rag_pipeline_task function + - Real Redis-based TenantIsolatedTaskQueue operations + - PipelineGenerator._generate method mocking and parameter validation + - File operations and cleanup + - Error handling and queue management + """ + + @pytest.fixture + def mock_pipeline_generator(self): + """Mock PipelineGenerator._generate method.""" + with patch("core.app.apps.pipeline.pipeline_generator.PipelineGenerator._generate") as mock_generate: + # Mock the _generate method to return a simple response + mock_generate.return_value = {"answer": "Test response", "metadata": {"test": "data"}} + yield mock_generate + + @pytest.fixture + def mock_file_service(self): + """Mock FileService for file operations.""" + with ( + patch("services.file_service.FileService.get_file_content") as mock_get_content, + patch("services.file_service.FileService.delete_file") as mock_delete_file, + ): + yield { + "get_content": mock_get_content, + "delete_file": mock_delete_file, + } + + def _create_test_pipeline_and_workflow(self, db_session_with_containers): + """ + Helper method to create test pipeline and workflow for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + + Returns: + tuple: (account, tenant, pipeline, workflow) - Created entities + """ + fake = Faker() + + # Create account and tenant + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + db.session.add(account) + db.session.commit() + + tenant = Tenant( + name=fake.company(), + status="normal", + ) + db.session.add(tenant) + db.session.commit() + + # Create tenant-account join + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + db.session.add(join) + db.session.commit() + + # Create workflow + workflow = Workflow( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + app_id=str(uuid.uuid4()), + type="workflow", + version="draft", + graph="{}", + features="{}", + marked_name=fake.company(), + marked_comment=fake.text(max_nb_chars=100), + created_by=account.id, + environment_variables=[], + conversation_variables=[], + rag_pipeline_variables=[], + ) + db.session.add(workflow) + db.session.commit() + + # Create pipeline + pipeline = Pipeline( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + workflow_id=workflow.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + created_by=account.id, + ) + db.session.add(pipeline) + db.session.commit() + + # Refresh entities to ensure they're properly loaded + db.session.refresh(account) + db.session.refresh(tenant) + db.session.refresh(workflow) + db.session.refresh(pipeline) + + return account, tenant, pipeline, workflow + + def _create_rag_pipeline_invoke_entities(self, account, tenant, pipeline, workflow, count=2): + """ + Helper method to create RAG pipeline invoke entities for testing. + + Args: + account: Account instance + tenant: Tenant instance + pipeline: Pipeline instance + workflow: Workflow instance + count: Number of entities to create + + Returns: + list: List of RagPipelineInvokeEntity instances + """ + fake = Faker() + entities = [] + + for i in range(count): + # Create application generate entity + app_config = { + "app_id": str(uuid.uuid4()), + "app_name": fake.company(), + "mode": "workflow", + "workflow_id": workflow.id, + "tenant_id": tenant.id, + "app_mode": "workflow", + } + + application_generate_entity = { + "task_id": str(uuid.uuid4()), + "app_config": app_config, + "inputs": {"query": f"Test query {i}"}, + "files": [], + "user_id": account.id, + "stream": False, + "invoke_from": "published", + "workflow_execution_id": str(uuid.uuid4()), + "pipeline_config": { + "app_id": str(uuid.uuid4()), + "app_name": fake.company(), + "mode": "workflow", + "workflow_id": workflow.id, + "tenant_id": tenant.id, + "app_mode": "workflow", + }, + "datasource_type": "upload_file", + "datasource_info": {}, + "dataset_id": str(uuid.uuid4()), + "batch": "test_batch", + } + + entity = RagPipelineInvokeEntity( + pipeline_id=pipeline.id, + application_generate_entity=application_generate_entity, + user_id=account.id, + tenant_id=tenant.id, + workflow_id=workflow.id, + streaming=False, + workflow_execution_id=str(uuid.uuid4()), + workflow_thread_pool_id=str(uuid.uuid4()), + ) + entities.append(entity) + + return entities + + def _create_file_content_for_entities(self, entities): + """ + Helper method to create file content for RAG pipeline invoke entities. + + Args: + entities: List of RagPipelineInvokeEntity instances + + Returns: + str: JSON string containing serialized entities + """ + entities_data = [entity.model_dump() for entity in entities] + return json.dumps(entities_data) + + def test_priority_rag_pipeline_run_task_success( + self, db_session_with_containers, mock_pipeline_generator, mock_file_service + ): + """ + Test successful priority RAG pipeline run task execution. + + This test verifies: + - Task execution with multiple RAG pipeline invoke entities + - File content retrieval and parsing + - PipelineGenerator._generate method calls with correct parameters + - Thread pool execution + - File cleanup after execution + - Queue management with no waiting tasks + """ + # Arrange: Create test data + account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers) + entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=2) + file_content = self._create_file_content_for_entities(entities) + + # Mock file service + file_id = str(uuid.uuid4()) + mock_file_service["get_content"].return_value = file_content + + # Act: Execute the priority task + priority_rag_pipeline_run_task(file_id, tenant.id) + + # Assert: Verify expected outcomes + # Verify file operations + mock_file_service["get_content"].assert_called_once_with(file_id) + mock_file_service["delete_file"].assert_called_once_with(file_id) + + # Verify PipelineGenerator._generate was called for each entity + assert mock_pipeline_generator.call_count == 2 + + # Verify call parameters for each entity + calls = mock_pipeline_generator.call_args_list + for call in calls: + call_kwargs = call[1] # Get keyword arguments + assert call_kwargs["pipeline"].id == pipeline.id + assert call_kwargs["workflow_id"] == workflow.id + assert call_kwargs["user"].id == account.id + assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED + assert call_kwargs["streaming"] == False + assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity) + + def test_rag_pipeline_run_task_success( + self, db_session_with_containers, mock_pipeline_generator, mock_file_service + ): + """ + Test successful regular RAG pipeline run task execution. + + This test verifies: + - Task execution with multiple RAG pipeline invoke entities + - File content retrieval and parsing + - PipelineGenerator._generate method calls with correct parameters + - Thread pool execution + - File cleanup after execution + - Queue management with no waiting tasks + """ + # Arrange: Create test data + account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers) + entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=3) + file_content = self._create_file_content_for_entities(entities) + + # Mock file service + file_id = str(uuid.uuid4()) + mock_file_service["get_content"].return_value = file_content + + # Act: Execute the regular task + rag_pipeline_run_task(file_id, tenant.id) + + # Assert: Verify expected outcomes + # Verify file operations + mock_file_service["get_content"].assert_called_once_with(file_id) + mock_file_service["delete_file"].assert_called_once_with(file_id) + + # Verify PipelineGenerator._generate was called for each entity + assert mock_pipeline_generator.call_count == 3 + + # Verify call parameters for each entity + calls = mock_pipeline_generator.call_args_list + for call in calls: + call_kwargs = call[1] # Get keyword arguments + assert call_kwargs["pipeline"].id == pipeline.id + assert call_kwargs["workflow_id"] == workflow.id + assert call_kwargs["user"].id == account.id + assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED + assert call_kwargs["streaming"] == False + assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity) + + def test_priority_rag_pipeline_run_task_with_waiting_tasks( + self, db_session_with_containers, mock_pipeline_generator, mock_file_service + ): + """ + Test priority RAG pipeline run task with waiting tasks in queue using real Redis. + + This test verifies: + - Core task execution + - Real Redis-based tenant queue processing of waiting tasks + - Task function calls for waiting tasks + - Queue management with multiple tasks using actual Redis operations + """ + # Arrange: Create test data + account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers) + entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1) + file_content = self._create_file_content_for_entities(entities) + + # Mock file service + file_id = str(uuid.uuid4()) + mock_file_service["get_content"].return_value = file_content + + # Use real Redis for TenantIsolatedTaskQueue + queue = TenantIsolatedTaskQueue(tenant.id, "pipeline") + + # Add waiting tasks to the real Redis queue + waiting_file_ids = [str(uuid.uuid4()) for _ in range(2)] + queue.push_tasks(waiting_file_ids) + + # Mock the task function calls + with patch( + "tasks.rag_pipeline.priority_rag_pipeline_run_task.priority_rag_pipeline_run_task.delay" + ) as mock_delay: + # Act: Execute the priority task + priority_rag_pipeline_run_task(file_id, tenant.id) + + # Assert: Verify core processing occurred + mock_file_service["get_content"].assert_called_once_with(file_id) + mock_file_service["delete_file"].assert_called_once_with(file_id) + assert mock_pipeline_generator.call_count == 1 + + # Verify waiting tasks were processed, pull 1 task a time by default + assert mock_delay.call_count == 1 + + # Verify correct parameters for the call + call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {} + assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_ids[0] + assert call_kwargs.get("tenant_id") == tenant.id + + # Verify queue still has remaining tasks (only 1 was pulled) + remaining_tasks = queue.pull_tasks(count=10) + assert len(remaining_tasks) == 1 # 2 original - 1 pulled = 1 remaining + + def test_rag_pipeline_run_task_legacy_compatibility( + self, db_session_with_containers, mock_pipeline_generator, mock_file_service + ): + """ + Test regular RAG pipeline run task with legacy Redis queue format for backward compatibility. + + This test simulates the scenario where: + - Old code writes file IDs directly to Redis list using lpush + - New worker processes these legacy queue entries + - Ensures backward compatibility during deployment transition + + Legacy format: redis_client.lpush(tenant_self_pipeline_task_queue, upload_file.id) + New format: TenantIsolatedTaskQueue.push_tasks([file_id]) + """ + # Arrange: Create test data + account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers) + entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1) + file_content = self._create_file_content_for_entities(entities) + + # Mock file service + file_id = str(uuid.uuid4()) + mock_file_service["get_content"].return_value = file_content + + # Simulate legacy Redis queue format - direct file IDs in Redis list + from extensions.ext_redis import redis_client + + # Legacy queue key format (old code) + legacy_queue_key = f"tenant_self_pipeline_task_queue:{tenant.id}" + legacy_task_key = f"tenant_pipeline_task:{tenant.id}" + + # Add legacy format data to Redis (simulating old code behavior) + legacy_file_ids = [str(uuid.uuid4()) for _ in range(3)] + for file_id_legacy in legacy_file_ids: + redis_client.lpush(legacy_queue_key, file_id_legacy) + + # Set the task key to indicate there are waiting tasks (legacy behavior) + redis_client.set(legacy_task_key, 1, ex=60 * 60) + + # Mock the task function calls + with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay: + # Act: Execute the priority task with new code but legacy queue data + rag_pipeline_run_task(file_id, tenant.id) + + # Assert: Verify core processing occurred + mock_file_service["get_content"].assert_called_once_with(file_id) + mock_file_service["delete_file"].assert_called_once_with(file_id) + assert mock_pipeline_generator.call_count == 1 + + # Verify waiting tasks were processed, pull 1 task a time by default + assert mock_delay.call_count == 1 + + # Verify correct parameters for the call + call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {} + assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == legacy_file_ids[0] + assert call_kwargs.get("tenant_id") == tenant.id + + # Verify that new code can process legacy queue entries + # The new TenantIsolatedTaskQueue should be able to read from the legacy format + queue = TenantIsolatedTaskQueue(tenant.id, "pipeline") + + # Verify queue still has remaining tasks (only 1 was pulled) + remaining_tasks = queue.pull_tasks(count=10) + assert len(remaining_tasks) == 2 # 3 original - 1 pulled = 2 remaining + + # Cleanup: Remove legacy test data + redis_client.delete(legacy_queue_key) + redis_client.delete(legacy_task_key) + + def test_rag_pipeline_run_task_with_waiting_tasks( + self, db_session_with_containers, mock_pipeline_generator, mock_file_service + ): + """ + Test regular RAG pipeline run task with waiting tasks in queue using real Redis. + + This test verifies: + - Core task execution + - Real Redis-based tenant queue processing of waiting tasks + - Task function calls for waiting tasks + - Queue management with multiple tasks using actual Redis operations + """ + # Arrange: Create test data + account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers) + entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1) + file_content = self._create_file_content_for_entities(entities) + + # Mock file service + file_id = str(uuid.uuid4()) + mock_file_service["get_content"].return_value = file_content + + # Use real Redis for TenantIsolatedTaskQueue + queue = TenantIsolatedTaskQueue(tenant.id, "pipeline") + + # Add waiting tasks to the real Redis queue + waiting_file_ids = [str(uuid.uuid4()) for _ in range(3)] + queue.push_tasks(waiting_file_ids) + + # Mock the task function calls + with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay: + # Act: Execute the regular task + rag_pipeline_run_task(file_id, tenant.id) + + # Assert: Verify core processing occurred + mock_file_service["get_content"].assert_called_once_with(file_id) + mock_file_service["delete_file"].assert_called_once_with(file_id) + assert mock_pipeline_generator.call_count == 1 + + # Verify waiting tasks were processed, pull 1 task a time by default + assert mock_delay.call_count == 1 + + # Verify correct parameters for the call + call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {} + assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_ids[0] + assert call_kwargs.get("tenant_id") == tenant.id + + # Verify queue still has remaining tasks (only 1 was pulled) + remaining_tasks = queue.pull_tasks(count=10) + assert len(remaining_tasks) == 2 # 3 original - 1 pulled = 2 remaining + + def test_priority_rag_pipeline_run_task_error_handling( + self, db_session_with_containers, mock_pipeline_generator, mock_file_service + ): + """ + Test error handling in priority RAG pipeline run task using real Redis. + + This test verifies: + - Exception handling during core processing + - Tenant queue cleanup even on errors using real Redis + - Proper error logging + - Function completes without raising exceptions + - Queue management continues despite core processing errors + """ + # Arrange: Create test data + account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers) + entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1) + file_content = self._create_file_content_for_entities(entities) + + # Mock file service + file_id = str(uuid.uuid4()) + mock_file_service["get_content"].return_value = file_content + + # Mock PipelineGenerator to raise an exception + mock_pipeline_generator.side_effect = Exception("Pipeline generation failed") + + # Use real Redis for TenantIsolatedTaskQueue + queue = TenantIsolatedTaskQueue(tenant.id, "pipeline") + + # Add waiting task to the real Redis queue + waiting_file_id = str(uuid.uuid4()) + queue.push_tasks([waiting_file_id]) + + # Mock the task function calls + with patch( + "tasks.rag_pipeline.priority_rag_pipeline_run_task.priority_rag_pipeline_run_task.delay" + ) as mock_delay: + # Act: Execute the priority task (should not raise exception) + priority_rag_pipeline_run_task(file_id, tenant.id) + + # Assert: Verify error was handled gracefully + # The function should not raise exceptions + mock_file_service["get_content"].assert_called_once_with(file_id) + mock_file_service["delete_file"].assert_called_once_with(file_id) + assert mock_pipeline_generator.call_count == 1 + + # Verify waiting task was still processed despite core processing error + mock_delay.assert_called_once() + + # Verify correct parameters for the call + call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {} + assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id + assert call_kwargs.get("tenant_id") == tenant.id + + # Verify queue is empty after processing (task was pulled) + remaining_tasks = queue.pull_tasks(count=10) + assert len(remaining_tasks) == 0 + + def test_rag_pipeline_run_task_error_handling( + self, db_session_with_containers, mock_pipeline_generator, mock_file_service + ): + """ + Test error handling in regular RAG pipeline run task using real Redis. + + This test verifies: + - Exception handling during core processing + - Tenant queue cleanup even on errors using real Redis + - Proper error logging + - Function completes without raising exceptions + - Queue management continues despite core processing errors + """ + # Arrange: Create test data + account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers) + entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1) + file_content = self._create_file_content_for_entities(entities) + + # Mock file service + file_id = str(uuid.uuid4()) + mock_file_service["get_content"].return_value = file_content + + # Mock PipelineGenerator to raise an exception + mock_pipeline_generator.side_effect = Exception("Pipeline generation failed") + + # Use real Redis for TenantIsolatedTaskQueue + queue = TenantIsolatedTaskQueue(tenant.id, "pipeline") + + # Add waiting task to the real Redis queue + waiting_file_id = str(uuid.uuid4()) + queue.push_tasks([waiting_file_id]) + + # Mock the task function calls + with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay: + # Act: Execute the regular task (should not raise exception) + rag_pipeline_run_task(file_id, tenant.id) + + # Assert: Verify error was handled gracefully + # The function should not raise exceptions + mock_file_service["get_content"].assert_called_once_with(file_id) + mock_file_service["delete_file"].assert_called_once_with(file_id) + assert mock_pipeline_generator.call_count == 1 + + # Verify waiting task was still processed despite core processing error + mock_delay.assert_called_once() + + # Verify correct parameters for the call + call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {} + assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id + assert call_kwargs.get("tenant_id") == tenant.id + + # Verify queue is empty after processing (task was pulled) + remaining_tasks = queue.pull_tasks(count=10) + assert len(remaining_tasks) == 0 + + def test_priority_rag_pipeline_run_task_tenant_isolation( + self, db_session_with_containers, mock_pipeline_generator, mock_file_service + ): + """ + Test tenant isolation in priority RAG pipeline run task using real Redis. + + This test verifies: + - Different tenants have isolated queues + - Tasks from one tenant don't affect another tenant's queue + - Queue operations are properly scoped to tenant + """ + # Arrange: Create test data for two different tenants + account1, tenant1, pipeline1, workflow1 = self._create_test_pipeline_and_workflow(db_session_with_containers) + account2, tenant2, pipeline2, workflow2 = self._create_test_pipeline_and_workflow(db_session_with_containers) + + entities1 = self._create_rag_pipeline_invoke_entities(account1, tenant1, pipeline1, workflow1, count=1) + entities2 = self._create_rag_pipeline_invoke_entities(account2, tenant2, pipeline2, workflow2, count=1) + + file_content1 = self._create_file_content_for_entities(entities1) + file_content2 = self._create_file_content_for_entities(entities2) + + # Mock file service + file_id1 = str(uuid.uuid4()) + file_id2 = str(uuid.uuid4()) + mock_file_service["get_content"].side_effect = [file_content1, file_content2] + + # Use real Redis for TenantIsolatedTaskQueue + queue1 = TenantIsolatedTaskQueue(tenant1.id, "pipeline") + queue2 = TenantIsolatedTaskQueue(tenant2.id, "pipeline") + + # Add waiting tasks to both queues + waiting_file_id1 = str(uuid.uuid4()) + waiting_file_id2 = str(uuid.uuid4()) + + queue1.push_tasks([waiting_file_id1]) + queue2.push_tasks([waiting_file_id2]) + + # Mock the task function calls + with patch( + "tasks.rag_pipeline.priority_rag_pipeline_run_task.priority_rag_pipeline_run_task.delay" + ) as mock_delay: + # Act: Execute the priority task for tenant1 only + priority_rag_pipeline_run_task(file_id1, tenant1.id) + + # Assert: Verify core processing occurred for tenant1 + assert mock_file_service["get_content"].call_count == 1 + assert mock_file_service["delete_file"].call_count == 1 + assert mock_pipeline_generator.call_count == 1 + + # Verify only tenant1's waiting task was processed + mock_delay.assert_called_once() + call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {} + assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id1 + assert call_kwargs.get("tenant_id") == tenant1.id + + # Verify tenant1's queue is empty + remaining_tasks1 = queue1.pull_tasks(count=10) + assert len(remaining_tasks1) == 0 + + # Verify tenant2's queue still has its task (isolation) + remaining_tasks2 = queue2.pull_tasks(count=10) + assert len(remaining_tasks2) == 1 + + # Verify queue keys are different + assert queue1._queue != queue2._queue + assert queue1._task_key != queue2._task_key + + def test_rag_pipeline_run_task_tenant_isolation( + self, db_session_with_containers, mock_pipeline_generator, mock_file_service + ): + """ + Test tenant isolation in regular RAG pipeline run task using real Redis. + + This test verifies: + - Different tenants have isolated queues + - Tasks from one tenant don't affect another tenant's queue + - Queue operations are properly scoped to tenant + """ + # Arrange: Create test data for two different tenants + account1, tenant1, pipeline1, workflow1 = self._create_test_pipeline_and_workflow(db_session_with_containers) + account2, tenant2, pipeline2, workflow2 = self._create_test_pipeline_and_workflow(db_session_with_containers) + + entities1 = self._create_rag_pipeline_invoke_entities(account1, tenant1, pipeline1, workflow1, count=1) + entities2 = self._create_rag_pipeline_invoke_entities(account2, tenant2, pipeline2, workflow2, count=1) + + file_content1 = self._create_file_content_for_entities(entities1) + file_content2 = self._create_file_content_for_entities(entities2) + + # Mock file service + file_id1 = str(uuid.uuid4()) + file_id2 = str(uuid.uuid4()) + mock_file_service["get_content"].side_effect = [file_content1, file_content2] + + # Use real Redis for TenantIsolatedTaskQueue + queue1 = TenantIsolatedTaskQueue(tenant1.id, "pipeline") + queue2 = TenantIsolatedTaskQueue(tenant2.id, "pipeline") + + # Add waiting tasks to both queues + waiting_file_id1 = str(uuid.uuid4()) + waiting_file_id2 = str(uuid.uuid4()) + + queue1.push_tasks([waiting_file_id1]) + queue2.push_tasks([waiting_file_id2]) + + # Mock the task function calls + with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay: + # Act: Execute the regular task for tenant1 only + rag_pipeline_run_task(file_id1, tenant1.id) + + # Assert: Verify core processing occurred for tenant1 + assert mock_file_service["get_content"].call_count == 1 + assert mock_file_service["delete_file"].call_count == 1 + assert mock_pipeline_generator.call_count == 1 + + # Verify only tenant1's waiting task was processed + mock_delay.assert_called_once() + call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {} + assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id1 + assert call_kwargs.get("tenant_id") == tenant1.id + + # Verify tenant1's queue is empty + remaining_tasks1 = queue1.pull_tasks(count=10) + assert len(remaining_tasks1) == 0 + + # Verify tenant2's queue still has its task (isolation) + remaining_tasks2 = queue2.pull_tasks(count=10) + assert len(remaining_tasks2) == 1 + + # Verify queue keys are different + assert queue1._queue != queue2._queue + assert queue1._task_key != queue2._task_key + + def test_run_single_rag_pipeline_task_success( + self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers + ): + """ + Test successful run_single_rag_pipeline_task execution. + + This test verifies: + - Single RAG pipeline task execution within Flask app context + - Entity validation and database queries + - PipelineGenerator._generate method call with correct parameters + - Proper Flask context handling + """ + # Arrange: Create test data + account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers) + entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1) + entity_data = entities[0].model_dump() + + # Act: Execute the single task + with flask_app_with_containers.app_context(): + run_single_rag_pipeline_task(entity_data, flask_app_with_containers) + + # Assert: Verify expected outcomes + # Verify PipelineGenerator._generate was called + assert mock_pipeline_generator.call_count == 1 + + # Verify call parameters + call = mock_pipeline_generator.call_args + call_kwargs = call[1] # Get keyword arguments + assert call_kwargs["pipeline"].id == pipeline.id + assert call_kwargs["workflow_id"] == workflow.id + assert call_kwargs["user"].id == account.id + assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED + assert call_kwargs["streaming"] == False + assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity) + + def test_run_single_rag_pipeline_task_entity_validation_error( + self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers + ): + """ + Test run_single_rag_pipeline_task with invalid entity data. + + This test verifies: + - Proper error handling for invalid entity data + - Exception logging + - Function raises ValueError for missing entities + """ + # Arrange: Create entity data with valid UUIDs but non-existent entities + fake = Faker() + invalid_entity_data = { + "pipeline_id": str(uuid.uuid4()), + "application_generate_entity": { + "app_config": { + "app_id": str(uuid.uuid4()), + "app_name": "Test App", + "mode": "workflow", + "workflow_id": str(uuid.uuid4()), + }, + "inputs": {"query": "Test query"}, + "query": "Test query", + "response_mode": "blocking", + "user": str(uuid.uuid4()), + "files": [], + "conversation_id": str(uuid.uuid4()), + }, + "user_id": str(uuid.uuid4()), + "tenant_id": str(uuid.uuid4()), + "workflow_id": str(uuid.uuid4()), + "streaming": False, + "workflow_execution_id": str(uuid.uuid4()), + "workflow_thread_pool_id": str(uuid.uuid4()), + } + + # Act & Assert: Execute the single task with non-existent entities (should raise ValueError) + with flask_app_with_containers.app_context(): + with pytest.raises(ValueError, match="Account .* not found"): + run_single_rag_pipeline_task(invalid_entity_data, flask_app_with_containers) + + # Assert: Pipeline generator should not be called + mock_pipeline_generator.assert_not_called() + + def test_run_single_rag_pipeline_task_database_entity_not_found( + self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers + ): + """ + Test run_single_rag_pipeline_task with non-existent database entities. + + This test verifies: + - Proper error handling for missing database entities + - Exception logging + - Function raises ValueError for missing entities + """ + # Arrange: Create test data with non-existent IDs + fake = Faker() + entity_data = { + "pipeline_id": str(uuid.uuid4()), + "application_generate_entity": { + "app_config": { + "app_id": str(uuid.uuid4()), + "app_name": "Test App", + "mode": "workflow", + "workflow_id": str(uuid.uuid4()), + }, + "inputs": {"query": "Test query"}, + "query": "Test query", + "response_mode": "blocking", + "user": str(uuid.uuid4()), + "files": [], + "conversation_id": str(uuid.uuid4()), + }, + "user_id": str(uuid.uuid4()), + "tenant_id": str(uuid.uuid4()), + "workflow_id": str(uuid.uuid4()), + "streaming": False, + "workflow_execution_id": str(uuid.uuid4()), + "workflow_thread_pool_id": str(uuid.uuid4()), + } + + # Act & Assert: Execute the single task with non-existent entities (should raise ValueError) + with flask_app_with_containers.app_context(): + with pytest.raises(ValueError, match="Account .* not found"): + run_single_rag_pipeline_task(entity_data, flask_app_with_containers) + + # Assert: Pipeline generator should not be called + mock_pipeline_generator.assert_not_called() + + def test_priority_rag_pipeline_run_task_file_not_found( + self, db_session_with_containers, mock_pipeline_generator, mock_file_service + ): + """ + Test priority RAG pipeline run task with non-existent file. + + This test verifies: + - Proper error handling for missing files + - Exception logging + - Function raises Exception for file errors + - Queue management continues despite file errors + """ + # Arrange: Create test data + account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers) + + # Mock file service to raise exception + file_id = str(uuid.uuid4()) + mock_file_service["get_content"].side_effect = Exception("File not found") + + # Use real Redis for TenantIsolatedTaskQueue + queue = TenantIsolatedTaskQueue(tenant.id, "pipeline") + + # Add waiting task to the real Redis queue + waiting_file_id = str(uuid.uuid4()) + queue.push_tasks([waiting_file_id]) + + # Mock the task function calls + with patch( + "tasks.rag_pipeline.priority_rag_pipeline_run_task.priority_rag_pipeline_run_task.delay" + ) as mock_delay: + # Act & Assert: Execute the priority task (should raise Exception) + with pytest.raises(Exception, match="File not found"): + priority_rag_pipeline_run_task(file_id, tenant.id) + + # Assert: Verify error was handled gracefully + mock_file_service["get_content"].assert_called_once_with(file_id) + mock_pipeline_generator.assert_not_called() + + # Verify waiting task was still processed despite file error + mock_delay.assert_called_once() + + # Verify correct parameters for the call + call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {} + assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id + assert call_kwargs.get("tenant_id") == tenant.id + + # Verify queue is empty after processing (task was pulled) + remaining_tasks = queue.pull_tasks(count=10) + assert len(remaining_tasks) == 0 + + def test_rag_pipeline_run_task_file_not_found( + self, db_session_with_containers, mock_pipeline_generator, mock_file_service + ): + """ + Test regular RAG pipeline run task with non-existent file. + + This test verifies: + - Proper error handling for missing files + - Exception logging + - Function raises Exception for file errors + - Queue management continues despite file errors + """ + # Arrange: Create test data + account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers) + + # Mock file service to raise exception + file_id = str(uuid.uuid4()) + mock_file_service["get_content"].side_effect = Exception("File not found") + + # Use real Redis for TenantIsolatedTaskQueue + queue = TenantIsolatedTaskQueue(tenant.id, "pipeline") + + # Add waiting task to the real Redis queue + waiting_file_id = str(uuid.uuid4()) + queue.push_tasks([waiting_file_id]) + + # Mock the task function calls + with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay: + # Act & Assert: Execute the regular task (should raise Exception) + with pytest.raises(Exception, match="File not found"): + rag_pipeline_run_task(file_id, tenant.id) + + # Assert: Verify error was handled gracefully + mock_file_service["get_content"].assert_called_once_with(file_id) + mock_pipeline_generator.assert_not_called() + + # Verify waiting task was still processed despite file error + mock_delay.assert_called_once() + + # Verify correct parameters for the call + call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {} + assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id + assert call_kwargs.get("tenant_id") == tenant.id + + # Verify queue is empty after processing (task was pulled) + remaining_tasks = queue.pull_tasks(count=10) + assert len(remaining_tasks) == 0 diff --git a/api/tests/unit_tests/core/rag/pipeline/test_queue.py b/api/tests/unit_tests/core/rag/pipeline/test_queue.py new file mode 100644 index 0000000000..cfdbb30f8f --- /dev/null +++ b/api/tests/unit_tests/core/rag/pipeline/test_queue.py @@ -0,0 +1,301 @@ +""" +Unit tests for TenantIsolatedTaskQueue. + +These tests verify the Redis-based task queue functionality for tenant-specific +task management with proper serialization and deserialization. +""" + +import json +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest +from pydantic import ValidationError + +from core.rag.pipeline.queue import TaskWrapper, TenantIsolatedTaskQueue + + +class TestTaskWrapper: + """Test cases for TaskWrapper serialization/deserialization.""" + + def test_serialize_simple_data(self): + """Test serialization of simple data types.""" + data = {"key": "value", "number": 42, "list": [1, 2, 3]} + wrapper = TaskWrapper(data=data) + + serialized = wrapper.serialize() + assert isinstance(serialized, str) + + # Verify it's valid JSON + parsed = json.loads(serialized) + assert parsed["data"] == data + + def test_serialize_complex_data(self): + """Test serialization of complex nested data.""" + data = { + "nested": {"deep": {"value": "test", "numbers": [1, 2, 3, 4, 5]}}, + "unicode": "测试中文", + "special_chars": "!@#$%^&*()", + } + wrapper = TaskWrapper(data=data) + + serialized = wrapper.serialize() + parsed = json.loads(serialized) + assert parsed["data"] == data + + def test_deserialize_valid_data(self): + """Test deserialization of valid JSON data.""" + original_data = {"key": "value", "number": 42} + # Serialize using TaskWrapper to get the correct format + wrapper = TaskWrapper(data=original_data) + serialized = wrapper.serialize() + + wrapper = TaskWrapper.deserialize(serialized) + assert wrapper.data == original_data + + def test_deserialize_invalid_json(self): + """Test deserialization handles invalid JSON gracefully.""" + invalid_json = "{invalid json}" + + # Pydantic will raise ValidationError for invalid JSON + with pytest.raises(ValidationError): + TaskWrapper.deserialize(invalid_json) + + def test_serialize_ensure_ascii_false(self): + """Test that serialization preserves Unicode characters.""" + data = {"chinese": "中文测试", "emoji": "🚀"} + wrapper = TaskWrapper(data=data) + + serialized = wrapper.serialize() + assert "中文测试" in serialized + assert "🚀" in serialized + + +class TestTenantIsolatedTaskQueue: + """Test cases for TenantIsolatedTaskQueue functionality.""" + + @pytest.fixture + def mock_redis_client(self): + """Mock Redis client for testing.""" + mock_redis = MagicMock() + return mock_redis + + @pytest.fixture + def sample_queue(self, mock_redis_client): + """Create a sample TenantIsolatedTaskQueue instance.""" + return TenantIsolatedTaskQueue("tenant-123", "test-key") + + def test_initialization(self, sample_queue): + """Test queue initialization with correct key generation.""" + assert sample_queue._tenant_id == "tenant-123" + assert sample_queue._unique_key == "test-key" + assert sample_queue._queue == "tenant_self_test-key_task_queue:tenant-123" + assert sample_queue._task_key == "tenant_test-key_task:tenant-123" + + @patch("core.rag.pipeline.queue.redis_client") + def test_get_task_key_exists(self, mock_redis, sample_queue): + """Test getting task key when it exists.""" + mock_redis.get.return_value = "1" + + result = sample_queue.get_task_key() + + assert result == "1" + mock_redis.get.assert_called_once_with("tenant_test-key_task:tenant-123") + + @patch("core.rag.pipeline.queue.redis_client") + def test_get_task_key_not_exists(self, mock_redis, sample_queue): + """Test getting task key when it doesn't exist.""" + mock_redis.get.return_value = None + + result = sample_queue.get_task_key() + + assert result is None + mock_redis.get.assert_called_once_with("tenant_test-key_task:tenant-123") + + @patch("core.rag.pipeline.queue.redis_client") + def test_set_task_waiting_time_default_ttl(self, mock_redis, sample_queue): + """Test setting task waiting flag with default TTL.""" + sample_queue.set_task_waiting_time() + + mock_redis.setex.assert_called_once_with( + "tenant_test-key_task:tenant-123", + 3600, # DEFAULT_TASK_TTL + 1, + ) + + @patch("core.rag.pipeline.queue.redis_client") + def test_set_task_waiting_time_custom_ttl(self, mock_redis, sample_queue): + """Test setting task waiting flag with custom TTL.""" + custom_ttl = 1800 + sample_queue.set_task_waiting_time(custom_ttl) + + mock_redis.setex.assert_called_once_with("tenant_test-key_task:tenant-123", custom_ttl, 1) + + @patch("core.rag.pipeline.queue.redis_client") + def test_delete_task_key(self, mock_redis, sample_queue): + """Test deleting task key.""" + sample_queue.delete_task_key() + + mock_redis.delete.assert_called_once_with("tenant_test-key_task:tenant-123") + + @patch("core.rag.pipeline.queue.redis_client") + def test_push_tasks_string_list(self, mock_redis, sample_queue): + """Test pushing string tasks directly.""" + tasks = ["task1", "task2", "task3"] + + sample_queue.push_tasks(tasks) + + mock_redis.lpush.assert_called_once_with( + "tenant_self_test-key_task_queue:tenant-123", "task1", "task2", "task3" + ) + + @patch("core.rag.pipeline.queue.redis_client") + def test_push_tasks_mixed_types(self, mock_redis, sample_queue): + """Test pushing mixed string and object tasks.""" + tasks = ["string_task", {"object_task": "data", "id": 123}, "another_string"] + + sample_queue.push_tasks(tasks) + + # Verify lpush was called + mock_redis.lpush.assert_called_once() + call_args = mock_redis.lpush.call_args + + # Check queue name + assert call_args[0][0] == "tenant_self_test-key_task_queue:tenant-123" + + # Check serialized tasks + serialized_tasks = call_args[0][1:] + assert len(serialized_tasks) == 3 + assert serialized_tasks[0] == "string_task" + assert serialized_tasks[2] == "another_string" + + # Check object task is serialized as TaskWrapper JSON (without prefix) + # It should be a valid JSON string that can be deserialized by TaskWrapper + wrapper = TaskWrapper.deserialize(serialized_tasks[1]) + assert wrapper.data == {"object_task": "data", "id": 123} + + @patch("core.rag.pipeline.queue.redis_client") + def test_push_tasks_empty_list(self, mock_redis, sample_queue): + """Test pushing empty task list.""" + sample_queue.push_tasks([]) + + mock_redis.lpush.assert_called_once_with("tenant_self_test-key_task_queue:tenant-123") + + @patch("core.rag.pipeline.queue.redis_client") + def test_pull_tasks_default_count(self, mock_redis, sample_queue): + """Test pulling tasks with default count (1).""" + mock_redis.rpop.side_effect = ["task1", None] + + result = sample_queue.pull_tasks() + + assert result == ["task1"] + assert mock_redis.rpop.call_count == 1 + + @patch("core.rag.pipeline.queue.redis_client") + def test_pull_tasks_custom_count(self, mock_redis, sample_queue): + """Test pulling tasks with custom count.""" + # First test: pull 3 tasks + mock_redis.rpop.side_effect = ["task1", "task2", "task3", None] + + result = sample_queue.pull_tasks(3) + + assert result == ["task1", "task2", "task3"] + assert mock_redis.rpop.call_count == 3 + + # Reset mock for second test + mock_redis.reset_mock() + mock_redis.rpop.side_effect = ["task1", "task2", None] + + result = sample_queue.pull_tasks(3) + + assert result == ["task1", "task2"] + assert mock_redis.rpop.call_count == 3 + + @patch("core.rag.pipeline.queue.redis_client") + def test_pull_tasks_zero_count(self, mock_redis, sample_queue): + """Test pulling tasks with zero count returns empty list.""" + result = sample_queue.pull_tasks(0) + + assert result == [] + mock_redis.rpop.assert_not_called() + + @patch("core.rag.pipeline.queue.redis_client") + def test_pull_tasks_negative_count(self, mock_redis, sample_queue): + """Test pulling tasks with negative count returns empty list.""" + result = sample_queue.pull_tasks(-1) + + assert result == [] + mock_redis.rpop.assert_not_called() + + @patch("core.rag.pipeline.queue.redis_client") + def test_pull_tasks_with_wrapped_objects(self, mock_redis, sample_queue): + """Test pulling tasks that include wrapped objects.""" + # Create a wrapped task + task_data = {"task_id": 123, "data": "test"} + wrapper = TaskWrapper(data=task_data) + wrapped_task = wrapper.serialize() + + mock_redis.rpop.side_effect = [ + "string_task", + wrapped_task.encode("utf-8"), # Simulate bytes from Redis + None, + ] + + result = sample_queue.pull_tasks(2) + + assert len(result) == 2 + assert result[0] == "string_task" + assert result[1] == {"task_id": 123, "data": "test"} + + @patch("core.rag.pipeline.queue.redis_client") + def test_pull_tasks_with_invalid_wrapped_data(self, mock_redis, sample_queue): + """Test pulling tasks with invalid JSON falls back to string.""" + # Invalid JSON string that cannot be deserialized + invalid_json = "invalid json data" + mock_redis.rpop.side_effect = [invalid_json, None] + + result = sample_queue.pull_tasks(1) + + assert result == [invalid_json] + + @patch("core.rag.pipeline.queue.redis_client") + def test_pull_tasks_bytes_decoding(self, mock_redis, sample_queue): + """Test pulling tasks handles bytes from Redis correctly.""" + mock_redis.rpop.side_effect = [ + b"task1", # bytes + "task2", # string + None, + ] + + result = sample_queue.pull_tasks(2) + + assert result == ["task1", "task2"] + + @patch("core.rag.pipeline.queue.redis_client") + def test_complex_object_serialization_roundtrip(self, mock_redis, sample_queue): + """Test complex object serialization and deserialization roundtrip.""" + complex_task = { + "id": uuid4().hex, + "data": {"nested": {"deep": [1, 2, 3], "unicode": "测试中文", "special": "!@#$%^&*()"}}, + "metadata": {"created_at": "2024-01-01T00:00:00Z", "tags": ["tag1", "tag2", "tag3"]}, + } + + # Push the complex task + sample_queue.push_tasks([complex_task]) + + # Verify it was serialized as TaskWrapper JSON + call_args = mock_redis.lpush.call_args + wrapped_task = call_args[0][1] + # Verify it's a valid TaskWrapper JSON (starts with {"data":) + assert wrapped_task.startswith('{"data":') + + # Verify it can be deserialized + wrapper = TaskWrapper.deserialize(wrapped_task) + assert wrapper.data == complex_task + + # Simulate pulling it back + mock_redis.rpop.return_value = wrapped_task + result = sample_queue.pull_tasks(1) + + assert len(result) == 1 + assert result[0] == complex_task diff --git a/api/tests/unit_tests/services/test_document_indexing_task_proxy.py b/api/tests/unit_tests/services/test_document_indexing_task_proxy.py new file mode 100644 index 0000000000..d9183be9fb --- /dev/null +++ b/api/tests/unit_tests/services/test_document_indexing_task_proxy.py @@ -0,0 +1,317 @@ +from unittest.mock import Mock, patch + +from core.entities.document_task import DocumentTask +from core.rag.pipeline.queue import TenantIsolatedTaskQueue +from enums.cloud_plan import CloudPlan +from services.document_indexing_task_proxy import DocumentIndexingTaskProxy + + +class DocumentIndexingTaskProxyTestDataFactory: + """Factory class for creating test data and mock objects for DocumentIndexingTaskProxy tests.""" + + @staticmethod + def create_mock_features(billing_enabled: bool = False, plan: CloudPlan = CloudPlan.SANDBOX) -> Mock: + """Create mock features with billing configuration.""" + features = Mock() + features.billing = Mock() + features.billing.enabled = billing_enabled + features.billing.subscription = Mock() + features.billing.subscription.plan = plan + return features + + @staticmethod + def create_mock_tenant_queue(has_task_key: bool = False) -> Mock: + """Create mock TenantIsolatedTaskQueue.""" + queue = Mock(spec=TenantIsolatedTaskQueue) + queue.get_task_key.return_value = "task_key" if has_task_key else None + queue.push_tasks = Mock() + queue.set_task_waiting_time = Mock() + return queue + + @staticmethod + def create_document_task_proxy( + tenant_id: str = "tenant-123", dataset_id: str = "dataset-456", document_ids: list[str] | None = None + ) -> DocumentIndexingTaskProxy: + """Create DocumentIndexingTaskProxy instance for testing.""" + if document_ids is None: + document_ids = ["doc-1", "doc-2", "doc-3"] + return DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) + + +class TestDocumentIndexingTaskProxy: + """Test cases for DocumentIndexingTaskProxy class.""" + + def test_initialization(self): + """Test DocumentIndexingTaskProxy initialization.""" + # Arrange + tenant_id = "tenant-123" + dataset_id = "dataset-456" + document_ids = ["doc-1", "doc-2", "doc-3"] + + # Act + proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) + + # Assert + assert proxy._tenant_id == tenant_id + assert proxy._dataset_id == dataset_id + assert proxy._document_ids == document_ids + assert isinstance(proxy._tenant_isolated_task_queue, TenantIsolatedTaskQueue) + assert proxy._tenant_isolated_task_queue._tenant_id == tenant_id + assert proxy._tenant_isolated_task_queue._unique_key == "document_indexing" + + @patch("services.document_indexing_task_proxy.FeatureService") + def test_features_property(self, mock_feature_service): + """Test cached_property features.""" + # Arrange + mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features() + mock_feature_service.get_features.return_value = mock_features + proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() + + # Act + features1 = proxy.features + features2 = proxy.features # Second call should use cached property + + # Assert + assert features1 == mock_features + assert features2 == mock_features + assert features1 is features2 # Should be the same instance due to caching + mock_feature_service.get_features.assert_called_once_with("tenant-123") + + @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + def test_send_to_direct_queue(self, mock_task): + """Test _send_to_direct_queue method.""" + # Arrange + proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() + mock_task.delay = Mock() + + # Act + proxy._send_to_direct_queue(mock_task) + + # Assert + mock_task.delay.assert_called_once_with( + tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] + ) + + @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + def test_send_to_tenant_queue_with_existing_task_key(self, mock_task): + """Test _send_to_tenant_queue when task key exists.""" + # Arrange + proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() + proxy._tenant_isolated_task_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue( + has_task_key=True + ) + mock_task.delay = Mock() + + # Act + proxy._send_to_tenant_queue(mock_task) + + # Assert + proxy._tenant_isolated_task_queue.push_tasks.assert_called_once() + pushed_tasks = proxy._tenant_isolated_task_queue.push_tasks.call_args[0][0] + assert len(pushed_tasks) == 1 + assert isinstance(DocumentTask(**pushed_tasks[0]), DocumentTask) + assert pushed_tasks[0]["tenant_id"] == "tenant-123" + assert pushed_tasks[0]["dataset_id"] == "dataset-456" + assert pushed_tasks[0]["document_ids"] == ["doc-1", "doc-2", "doc-3"] + mock_task.delay.assert_not_called() + + @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + def test_send_to_tenant_queue_without_task_key(self, mock_task): + """Test _send_to_tenant_queue when no task key exists.""" + # Arrange + proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() + proxy._tenant_isolated_task_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue( + has_task_key=False + ) + mock_task.delay = Mock() + + # Act + proxy._send_to_tenant_queue(mock_task) + + # Assert + proxy._tenant_isolated_task_queue.set_task_waiting_time.assert_called_once() + mock_task.delay.assert_called_once_with( + tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] + ) + proxy._tenant_isolated_task_queue.push_tasks.assert_not_called() + + @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + def test_send_to_default_tenant_queue(self, mock_task): + """Test _send_to_default_tenant_queue method.""" + # Arrange + proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() + proxy._send_to_tenant_queue = Mock() + + # Act + proxy._send_to_default_tenant_queue() + + # Assert + proxy._send_to_tenant_queue.assert_called_once_with(mock_task) + + @patch("services.document_indexing_task_proxy.priority_document_indexing_task") + def test_send_to_priority_tenant_queue(self, mock_task): + """Test _send_to_priority_tenant_queue method.""" + # Arrange + proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() + proxy._send_to_tenant_queue = Mock() + + # Act + proxy._send_to_priority_tenant_queue() + + # Assert + proxy._send_to_tenant_queue.assert_called_once_with(mock_task) + + @patch("services.document_indexing_task_proxy.priority_document_indexing_task") + def test_send_to_priority_direct_queue(self, mock_task): + """Test _send_to_priority_direct_queue method.""" + # Arrange + proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() + proxy._send_to_direct_queue = Mock() + + # Act + proxy._send_to_priority_direct_queue() + + # Assert + proxy._send_to_direct_queue.assert_called_once_with(mock_task) + + @patch("services.document_indexing_task_proxy.FeatureService") + def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service): + """Test _dispatch method when billing is enabled with sandbox plan.""" + # Arrange + mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features( + billing_enabled=True, plan=CloudPlan.SANDBOX + ) + mock_feature_service.get_features.return_value = mock_features + proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() + proxy._send_to_default_tenant_queue = Mock() + + # Act + proxy._dispatch() + + # Assert + proxy._send_to_default_tenant_queue.assert_called_once() + + @patch("services.document_indexing_task_proxy.FeatureService") + def test_dispatch_with_billing_enabled_non_sandbox_plan(self, mock_feature_service): + """Test _dispatch method when billing is enabled with non-sandbox plan.""" + # Arrange + mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features( + billing_enabled=True, plan=CloudPlan.TEAM + ) + mock_feature_service.get_features.return_value = mock_features + proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() + proxy._send_to_priority_tenant_queue = Mock() + + # Act + proxy._dispatch() + + # If billing enabled with non sandbox plan, should send to priority tenant queue + proxy._send_to_priority_tenant_queue.assert_called_once() + + @patch("services.document_indexing_task_proxy.FeatureService") + def test_dispatch_with_billing_disabled(self, mock_feature_service): + """Test _dispatch method when billing is disabled.""" + # Arrange + mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=False) + mock_feature_service.get_features.return_value = mock_features + proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() + proxy._send_to_priority_direct_queue = Mock() + + # Act + proxy._dispatch() + + # If billing disabled, for example: self-hosted or enterprise, should send to priority direct queue + proxy._send_to_priority_direct_queue.assert_called_once() + + @patch("services.document_indexing_task_proxy.FeatureService") + def test_delay_method(self, mock_feature_service): + """Test delay method integration.""" + # Arrange + mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features( + billing_enabled=True, plan=CloudPlan.SANDBOX + ) + mock_feature_service.get_features.return_value = mock_features + proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() + proxy._send_to_default_tenant_queue = Mock() + + # Act + proxy.delay() + + # Assert + # If billing enabled with sandbox plan, should send to default tenant queue + proxy._send_to_default_tenant_queue.assert_called_once() + + def test_document_task_dataclass(self): + """Test DocumentTask dataclass.""" + # Arrange + tenant_id = "tenant-123" + dataset_id = "dataset-456" + document_ids = ["doc-1", "doc-2"] + + # Act + task = DocumentTask(tenant_id=tenant_id, dataset_id=dataset_id, document_ids=document_ids) + + # Assert + assert task.tenant_id == tenant_id + assert task.dataset_id == dataset_id + assert task.document_ids == document_ids + + @patch("services.document_indexing_task_proxy.FeatureService") + def test_dispatch_edge_case_empty_plan(self, mock_feature_service): + """Test _dispatch method with empty plan string.""" + # Arrange + mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan="") + mock_feature_service.get_features.return_value = mock_features + proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() + proxy._send_to_priority_tenant_queue = Mock() + + # Act + proxy._dispatch() + + # Assert + proxy._send_to_priority_tenant_queue.assert_called_once() + + @patch("services.document_indexing_task_proxy.FeatureService") + def test_dispatch_edge_case_none_plan(self, mock_feature_service): + """Test _dispatch method with None plan.""" + # Arrange + mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan=None) + mock_feature_service.get_features.return_value = mock_features + proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() + proxy._send_to_priority_tenant_queue = Mock() + + # Act + proxy._dispatch() + + # Assert + proxy._send_to_priority_tenant_queue.assert_called_once() + + def test_initialization_with_empty_document_ids(self): + """Test initialization with empty document_ids list.""" + # Arrange + tenant_id = "tenant-123" + dataset_id = "dataset-456" + document_ids = [] + + # Act + proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) + + # Assert + assert proxy._tenant_id == tenant_id + assert proxy._dataset_id == dataset_id + assert proxy._document_ids == document_ids + + def test_initialization_with_single_document_id(self): + """Test initialization with single document_id.""" + # Arrange + tenant_id = "tenant-123" + dataset_id = "dataset-456" + document_ids = ["doc-1"] + + # Act + proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) + + # Assert + assert proxy._tenant_id == tenant_id + assert proxy._dataset_id == dataset_id + assert proxy._document_ids == document_ids diff --git a/api/tests/unit_tests/services/test_rag_pipeline_task_proxy.py b/api/tests/unit_tests/services/test_rag_pipeline_task_proxy.py new file mode 100644 index 0000000000..f5a48b1416 --- /dev/null +++ b/api/tests/unit_tests/services/test_rag_pipeline_task_proxy.py @@ -0,0 +1,483 @@ +import json +from unittest.mock import Mock, patch + +import pytest + +from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity +from core.rag.pipeline.queue import TenantIsolatedTaskQueue +from enums.cloud_plan import CloudPlan +from services.rag_pipeline.rag_pipeline_task_proxy import RagPipelineTaskProxy + + +class RagPipelineTaskProxyTestDataFactory: + """Factory class for creating test data and mock objects for RagPipelineTaskProxy tests.""" + + @staticmethod + def create_mock_features(billing_enabled: bool = False, plan: CloudPlan = CloudPlan.SANDBOX) -> Mock: + """Create mock features with billing configuration.""" + features = Mock() + features.billing = Mock() + features.billing.enabled = billing_enabled + features.billing.subscription = Mock() + features.billing.subscription.plan = plan + return features + + @staticmethod + def create_mock_tenant_queue(has_task_key: bool = False) -> Mock: + """Create mock TenantIsolatedTaskQueue.""" + queue = Mock(spec=TenantIsolatedTaskQueue) + queue.get_task_key.return_value = "task_key" if has_task_key else None + queue.push_tasks = Mock() + queue.set_task_waiting_time = Mock() + return queue + + @staticmethod + def create_rag_pipeline_invoke_entity( + pipeline_id: str = "pipeline-123", + user_id: str = "user-456", + tenant_id: str = "tenant-789", + workflow_id: str = "workflow-101", + streaming: bool = True, + workflow_execution_id: str | None = None, + workflow_thread_pool_id: str | None = None, + ) -> RagPipelineInvokeEntity: + """Create RagPipelineInvokeEntity instance for testing.""" + return RagPipelineInvokeEntity( + pipeline_id=pipeline_id, + application_generate_entity={"key": "value"}, + user_id=user_id, + tenant_id=tenant_id, + workflow_id=workflow_id, + streaming=streaming, + workflow_execution_id=workflow_execution_id, + workflow_thread_pool_id=workflow_thread_pool_id, + ) + + @staticmethod + def create_rag_pipeline_task_proxy( + dataset_tenant_id: str = "tenant-123", + user_id: str = "user-456", + rag_pipeline_invoke_entities: list[RagPipelineInvokeEntity] | None = None, + ) -> RagPipelineTaskProxy: + """Create RagPipelineTaskProxy instance for testing.""" + if rag_pipeline_invoke_entities is None: + rag_pipeline_invoke_entities = [RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity()] + return RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities) + + @staticmethod + def create_mock_upload_file(file_id: str = "file-123") -> Mock: + """Create mock upload file.""" + upload_file = Mock() + upload_file.id = file_id + return upload_file + + +class TestRagPipelineTaskProxy: + """Test cases for RagPipelineTaskProxy class.""" + + def test_initialization(self): + """Test RagPipelineTaskProxy initialization.""" + # Arrange + dataset_tenant_id = "tenant-123" + user_id = "user-456" + rag_pipeline_invoke_entities = [RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity()] + + # Act + proxy = RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities) + + # Assert + assert proxy._dataset_tenant_id == dataset_tenant_id + assert proxy._user_id == user_id + assert proxy._rag_pipeline_invoke_entities == rag_pipeline_invoke_entities + assert isinstance(proxy._tenant_isolated_task_queue, TenantIsolatedTaskQueue) + assert proxy._tenant_isolated_task_queue._tenant_id == dataset_tenant_id + assert proxy._tenant_isolated_task_queue._unique_key == "pipeline" + + def test_initialization_with_empty_entities(self): + """Test initialization with empty rag_pipeline_invoke_entities.""" + # Arrange + dataset_tenant_id = "tenant-123" + user_id = "user-456" + rag_pipeline_invoke_entities = [] + + # Act + proxy = RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities) + + # Assert + assert proxy._dataset_tenant_id == dataset_tenant_id + assert proxy._user_id == user_id + assert proxy._rag_pipeline_invoke_entities == [] + + def test_initialization_with_multiple_entities(self): + """Test initialization with multiple rag_pipeline_invoke_entities.""" + # Arrange + dataset_tenant_id = "tenant-123" + user_id = "user-456" + rag_pipeline_invoke_entities = [ + RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-1"), + RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-2"), + RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-3"), + ] + + # Act + proxy = RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities) + + # Assert + assert len(proxy._rag_pipeline_invoke_entities) == 3 + assert proxy._rag_pipeline_invoke_entities[0].pipeline_id == "pipeline-1" + assert proxy._rag_pipeline_invoke_entities[1].pipeline_id == "pipeline-2" + assert proxy._rag_pipeline_invoke_entities[2].pipeline_id == "pipeline-3" + + @patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService") + def test_features_property(self, mock_feature_service): + """Test cached_property features.""" + # Arrange + mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features() + mock_feature_service.get_features.return_value = mock_features + proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy() + + # Act + features1 = proxy.features + features2 = proxy.features # Second call should use cached property + + # Assert + assert features1 == mock_features + assert features2 == mock_features + assert features1 is features2 # Should be the same instance due to caching + mock_feature_service.get_features.assert_called_once_with("tenant-123") + + @patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService") + @patch("services.rag_pipeline.rag_pipeline_task_proxy.db") + def test_upload_invoke_entities(self, mock_db, mock_file_service_class): + """Test _upload_invoke_entities method.""" + # Arrange + proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy() + mock_file_service = Mock() + mock_file_service_class.return_value = mock_file_service + mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123") + mock_file_service.upload_text.return_value = mock_upload_file + + # Act + result = proxy._upload_invoke_entities() + + # Assert + assert result == "file-123" + mock_file_service_class.assert_called_once_with(mock_db.engine) + + # Verify upload_text was called with correct parameters + mock_file_service.upload_text.assert_called_once() + call_args = mock_file_service.upload_text.call_args + json_text, name, user_id, tenant_id = call_args[0] + + assert name == "rag_pipeline_invoke_entities.json" + assert user_id == "user-456" + assert tenant_id == "tenant-123" + + # Verify JSON content + parsed_json = json.loads(json_text) + assert len(parsed_json) == 1 + assert parsed_json[0]["pipeline_id"] == "pipeline-123" + + @patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService") + @patch("services.rag_pipeline.rag_pipeline_task_proxy.db") + def test_upload_invoke_entities_with_multiple_entities(self, mock_db, mock_file_service_class): + """Test _upload_invoke_entities method with multiple entities.""" + # Arrange + entities = [ + RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-1"), + RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-2"), + ] + proxy = RagPipelineTaskProxy("tenant-123", "user-456", entities) + mock_file_service = Mock() + mock_file_service_class.return_value = mock_file_service + mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-456") + mock_file_service.upload_text.return_value = mock_upload_file + + # Act + result = proxy._upload_invoke_entities() + + # Assert + assert result == "file-456" + + # Verify JSON content contains both entities + call_args = mock_file_service.upload_text.call_args + json_text = call_args[0][0] + parsed_json = json.loads(json_text) + assert len(parsed_json) == 2 + assert parsed_json[0]["pipeline_id"] == "pipeline-1" + assert parsed_json[1]["pipeline_id"] == "pipeline-2" + + @patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task") + def test_send_to_direct_queue(self, mock_task): + """Test _send_to_direct_queue method.""" + # Arrange + proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy() + proxy._tenant_isolated_task_queue = RagPipelineTaskProxyTestDataFactory.create_mock_tenant_queue() + upload_file_id = "file-123" + mock_task.delay = Mock() + + # Act + proxy._send_to_direct_queue(upload_file_id, mock_task) + + # If sent to direct queue, tenant_isolated_task_queue should not be called + proxy._tenant_isolated_task_queue.push_tasks.assert_not_called() + + # Celery should be called directly + mock_task.delay.assert_called_once_with( + rag_pipeline_invoke_entities_file_id=upload_file_id, tenant_id="tenant-123" + ) + + @patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task") + def test_send_to_tenant_queue_with_existing_task_key(self, mock_task): + """Test _send_to_tenant_queue when task key exists.""" + # Arrange + proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy() + proxy._tenant_isolated_task_queue = RagPipelineTaskProxyTestDataFactory.create_mock_tenant_queue( + has_task_key=True + ) + upload_file_id = "file-123" + mock_task.delay = Mock() + + # Act + proxy._send_to_tenant_queue(upload_file_id, mock_task) + + # If task key exists, should push tasks to the queue + proxy._tenant_isolated_task_queue.push_tasks.assert_called_once_with([upload_file_id]) + # Celery should not be called directly + mock_task.delay.assert_not_called() + + @patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task") + def test_send_to_tenant_queue_without_task_key(self, mock_task): + """Test _send_to_tenant_queue when no task key exists.""" + # Arrange + proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy() + proxy._tenant_isolated_task_queue = RagPipelineTaskProxyTestDataFactory.create_mock_tenant_queue( + has_task_key=False + ) + upload_file_id = "file-123" + mock_task.delay = Mock() + + # Act + proxy._send_to_tenant_queue(upload_file_id, mock_task) + + # If no task key, should set task waiting time key first + proxy._tenant_isolated_task_queue.set_task_waiting_time.assert_called_once() + mock_task.delay.assert_called_once_with( + rag_pipeline_invoke_entities_file_id=upload_file_id, tenant_id="tenant-123" + ) + + # The first task should be sent to celery directly, so push tasks should not be called + proxy._tenant_isolated_task_queue.push_tasks.assert_not_called() + + @patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task") + def test_send_to_default_tenant_queue(self, mock_task): + """Test _send_to_default_tenant_queue method.""" + # Arrange + proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy() + proxy._send_to_tenant_queue = Mock() + upload_file_id = "file-123" + + # Act + proxy._send_to_default_tenant_queue(upload_file_id) + + # Assert + proxy._send_to_tenant_queue.assert_called_once_with(upload_file_id, mock_task) + + @patch("services.rag_pipeline.rag_pipeline_task_proxy.priority_rag_pipeline_run_task") + def test_send_to_priority_tenant_queue(self, mock_task): + """Test _send_to_priority_tenant_queue method.""" + # Arrange + proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy() + proxy._send_to_tenant_queue = Mock() + upload_file_id = "file-123" + + # Act + proxy._send_to_priority_tenant_queue(upload_file_id) + + # Assert + proxy._send_to_tenant_queue.assert_called_once_with(upload_file_id, mock_task) + + @patch("services.rag_pipeline.rag_pipeline_task_proxy.priority_rag_pipeline_run_task") + def test_send_to_priority_direct_queue(self, mock_task): + """Test _send_to_priority_direct_queue method.""" + # Arrange + proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy() + proxy._send_to_direct_queue = Mock() + upload_file_id = "file-123" + + # Act + proxy._send_to_priority_direct_queue(upload_file_id) + + # Assert + proxy._send_to_direct_queue.assert_called_once_with(upload_file_id, mock_task) + + @patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService") + @patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService") + @patch("services.rag_pipeline.rag_pipeline_task_proxy.db") + def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_db, mock_file_service_class, mock_feature_service): + """Test _dispatch method when billing is enabled with sandbox plan.""" + # Arrange + mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features( + billing_enabled=True, plan=CloudPlan.SANDBOX + ) + mock_feature_service.get_features.return_value = mock_features + proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy() + proxy._send_to_default_tenant_queue = Mock() + + mock_file_service = Mock() + mock_file_service_class.return_value = mock_file_service + mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123") + mock_file_service.upload_text.return_value = mock_upload_file + + # Act + proxy._dispatch() + + # If billing is enabled with sandbox plan, should send to default tenant queue + proxy._send_to_default_tenant_queue.assert_called_once_with("file-123") + + @patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService") + @patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService") + @patch("services.rag_pipeline.rag_pipeline_task_proxy.db") + def test_dispatch_with_billing_enabled_non_sandbox_plan( + self, mock_db, mock_file_service_class, mock_feature_service + ): + """Test _dispatch method when billing is enabled with non-sandbox plan.""" + # Arrange + mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features( + billing_enabled=True, plan=CloudPlan.TEAM + ) + mock_feature_service.get_features.return_value = mock_features + proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy() + proxy._send_to_priority_tenant_queue = Mock() + + mock_file_service = Mock() + mock_file_service_class.return_value = mock_file_service + mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123") + mock_file_service.upload_text.return_value = mock_upload_file + + # Act + proxy._dispatch() + + # If billing is enabled with non-sandbox plan, should send to priority tenant queue + proxy._send_to_priority_tenant_queue.assert_called_once_with("file-123") + + @patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService") + @patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService") + @patch("services.rag_pipeline.rag_pipeline_task_proxy.db") + def test_dispatch_with_billing_disabled(self, mock_db, mock_file_service_class, mock_feature_service): + """Test _dispatch method when billing is disabled.""" + # Arrange + mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(billing_enabled=False) + mock_feature_service.get_features.return_value = mock_features + proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy() + proxy._send_to_priority_direct_queue = Mock() + + mock_file_service = Mock() + mock_file_service_class.return_value = mock_file_service + mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123") + mock_file_service.upload_text.return_value = mock_upload_file + + # Act + proxy._dispatch() + + # If billing is disabled, for example: self-hosted or enterprise, should send to priority direct queue + proxy._send_to_priority_direct_queue.assert_called_once_with("file-123") + + @patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService") + @patch("services.rag_pipeline.rag_pipeline_task_proxy.db") + def test_dispatch_with_empty_upload_file_id(self, mock_db, mock_file_service_class): + """Test _dispatch method when upload_file_id is empty.""" + # Arrange + proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy() + + mock_file_service = Mock() + mock_file_service_class.return_value = mock_file_service + mock_upload_file = Mock() + mock_upload_file.id = "" # Empty file ID + mock_file_service.upload_text.return_value = mock_upload_file + + # Act & Assert + with pytest.raises(ValueError, match="upload_file_id is empty"): + proxy._dispatch() + + @patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService") + @patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService") + @patch("services.rag_pipeline.rag_pipeline_task_proxy.db") + def test_dispatch_edge_case_empty_plan(self, mock_db, mock_file_service_class, mock_feature_service): + """Test _dispatch method with empty plan string.""" + # Arrange + mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan="") + mock_feature_service.get_features.return_value = mock_features + proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy() + proxy._send_to_priority_tenant_queue = Mock() + + mock_file_service = Mock() + mock_file_service_class.return_value = mock_file_service + mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123") + mock_file_service.upload_text.return_value = mock_upload_file + + # Act + proxy._dispatch() + + # Assert + proxy._send_to_priority_tenant_queue.assert_called_once_with("file-123") + + @patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService") + @patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService") + @patch("services.rag_pipeline.rag_pipeline_task_proxy.db") + def test_dispatch_edge_case_none_plan(self, mock_db, mock_file_service_class, mock_feature_service): + """Test _dispatch method with None plan.""" + # Arrange + mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan=None) + mock_feature_service.get_features.return_value = mock_features + proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy() + proxy._send_to_priority_tenant_queue = Mock() + + mock_file_service = Mock() + mock_file_service_class.return_value = mock_file_service + mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123") + mock_file_service.upload_text.return_value = mock_upload_file + + # Act + proxy._dispatch() + + # Assert + proxy._send_to_priority_tenant_queue.assert_called_once_with("file-123") + + @patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService") + @patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService") + @patch("services.rag_pipeline.rag_pipeline_task_proxy.db") + def test_delay_method(self, mock_db, mock_file_service_class, mock_feature_service): + """Test delay method integration.""" + # Arrange + mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features( + billing_enabled=True, plan=CloudPlan.SANDBOX + ) + mock_feature_service.get_features.return_value = mock_features + proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy() + proxy._dispatch = Mock() + + mock_file_service = Mock() + mock_file_service_class.return_value = mock_file_service + mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123") + mock_file_service.upload_text.return_value = mock_upload_file + + # Act + proxy.delay() + + # Assert + proxy._dispatch.assert_called_once() + + @patch("services.rag_pipeline.rag_pipeline_task_proxy.logger") + def test_delay_method_with_empty_entities(self, mock_logger): + """Test delay method with empty rag_pipeline_invoke_entities.""" + # Arrange + proxy = RagPipelineTaskProxy("tenant-123", "user-456", []) + + # Act + proxy.delay() + + # Assert + mock_logger.warning.assert_called_once_with( + "Received empty rag pipeline invoke entities, no tasks delivered: %s %s", "tenant-123", "user-456" + ) diff --git a/dev/start-worker b/dev/start-worker index 83d7bf0f3c..9cf448c9c6 100755 --- a/dev/start-worker +++ b/dev/start-worker @@ -7,4 +7,4 @@ cd "$SCRIPT_DIR/.." uv --directory api run \ celery -A app.celery worker \ - -P gevent -c 1 --loglevel INFO -Q dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline + -P gevent -c 1 --loglevel INFO -Q dataset,priority_dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline diff --git a/docker/.env.example b/docker/.env.example index c19084ebbf..1ccc11d01b 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -1370,3 +1370,6 @@ ENABLE_CLEAN_MESSAGES=false ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false ENABLE_DATASETS_QUEUE_MONITOR=false ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true + +# Tenant isolated task queue configuration +TENANT_ISOLATED_TASK_CONCURRENCY=1 diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 1ff33a94b5..07d6cd46ab 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -614,6 +614,7 @@ x-shared-env: &shared-api-worker-env ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: ${ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK:-false} ENABLE_DATASETS_QUEUE_MONITOR: ${ENABLE_DATASETS_QUEUE_MONITOR:-false} ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK: ${ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK:-true} + TENANT_ISOLATED_TASK_CONCURRENCY: ${TENANT_ISOLATED_TASK_CONCURRENCY:-1} services: # API service