mirror of https://github.com/langgenius/dify.git
feat: enable tenant isolation on duplicate document indexing tasks (#29080)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
e6d504558a
commit
3cb944f318
|
|
@ -51,7 +51,8 @@ from models.model import UploadFile
|
|||
from models.provider_ids import ModelProviderID
|
||||
from models.source import DataSourceOauthBinding
|
||||
from models.workflow import Workflow
|
||||
from services.document_indexing_task_proxy import DocumentIndexingTaskProxy
|
||||
from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy
|
||||
from services.document_indexing_proxy.duplicate_document_indexing_task_proxy import DuplicateDocumentIndexingTaskProxy
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
ChildChunkUpdateArgs,
|
||||
KnowledgeConfig,
|
||||
|
|
@ -82,7 +83,6 @@ from tasks.delete_segment_from_index_task import delete_segment_from_index_task
|
|||
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
|
||||
from tasks.disable_segments_from_index_task import disable_segments_from_index_task
|
||||
from tasks.document_indexing_update_task import document_indexing_update_task
|
||||
from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task
|
||||
from tasks.enable_segments_to_index_task import enable_segments_to_index_task
|
||||
from tasks.recover_document_indexing_task import recover_document_indexing_task
|
||||
from tasks.remove_document_from_index_task import remove_document_from_index_task
|
||||
|
|
@ -1761,7 +1761,9 @@ class DocumentService:
|
|||
if document_ids:
|
||||
DocumentIndexingTaskProxy(dataset.tenant_id, dataset.id, document_ids).delay()
|
||||
if duplicate_document_ids:
|
||||
duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)
|
||||
DuplicateDocumentIndexingTaskProxy(
|
||||
dataset.tenant_id, dataset.id, duplicate_document_ids
|
||||
).delay()
|
||||
except LockNotOwnedError:
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,11 @@
|
|||
from .base import DocumentTaskProxyBase
|
||||
from .batch_indexing_base import BatchDocumentIndexingProxy
|
||||
from .document_indexing_task_proxy import DocumentIndexingTaskProxy
|
||||
from .duplicate_document_indexing_task_proxy import DuplicateDocumentIndexingTaskProxy
|
||||
|
||||
__all__ = [
|
||||
"BatchDocumentIndexingProxy",
|
||||
"DocumentIndexingTaskProxy",
|
||||
"DocumentTaskProxyBase",
|
||||
"DuplicateDocumentIndexingTaskProxy",
|
||||
]
|
||||
|
|
@ -0,0 +1,111 @@
|
|||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from functools import cached_property
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DocumentTaskProxyBase(ABC):
|
||||
"""
|
||||
Base proxy for all document processing tasks.
|
||||
|
||||
Handles common logic:
|
||||
- Feature/billing checks
|
||||
- Dispatch routing based on plan
|
||||
|
||||
Subclasses must define:
|
||||
- QUEUE_NAME: Redis queue identifier
|
||||
- NORMAL_TASK_FUNC: Task function for normal priority
|
||||
- PRIORITY_TASK_FUNC: Task function for high priority
|
||||
"""
|
||||
|
||||
QUEUE_NAME: ClassVar[str]
|
||||
NORMAL_TASK_FUNC: ClassVar[Callable[..., Any]]
|
||||
PRIORITY_TASK_FUNC: ClassVar[Callable[..., Any]]
|
||||
|
||||
def __init__(self, tenant_id: str, dataset_id: str):
|
||||
"""
|
||||
Initialize with minimal required parameters.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier for billing/features
|
||||
dataset_id: Dataset identifier for logging
|
||||
"""
|
||||
self._tenant_id = tenant_id
|
||||
self._dataset_id = dataset_id
|
||||
|
||||
@cached_property
|
||||
def features(self):
|
||||
return FeatureService.get_features(self._tenant_id)
|
||||
|
||||
@abstractmethod
|
||||
def _send_to_direct_queue(self, task_func: Callable[..., Any]):
|
||||
"""
|
||||
Send task directly to Celery queue without tenant isolation.
|
||||
|
||||
Subclasses implement this to pass task-specific parameters.
|
||||
|
||||
Args:
|
||||
task_func: The Celery task function to call
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _send_to_tenant_queue(self, task_func: Callable[..., Any]):
|
||||
"""
|
||||
Send task to tenant-isolated queue.
|
||||
|
||||
Subclasses implement this to handle queue management.
|
||||
|
||||
Args:
|
||||
task_func: The Celery task function to call
|
||||
"""
|
||||
pass
|
||||
|
||||
def _send_to_default_tenant_queue(self):
|
||||
"""Route to normal priority with tenant isolation."""
|
||||
self._send_to_tenant_queue(self.NORMAL_TASK_FUNC)
|
||||
|
||||
def _send_to_priority_tenant_queue(self):
|
||||
"""Route to priority queue with tenant isolation."""
|
||||
self._send_to_tenant_queue(self.PRIORITY_TASK_FUNC)
|
||||
|
||||
def _send_to_priority_direct_queue(self):
|
||||
"""Route to priority queue without tenant isolation."""
|
||||
self._send_to_direct_queue(self.PRIORITY_TASK_FUNC)
|
||||
|
||||
def _dispatch(self):
|
||||
"""
|
||||
Dispatch task based on billing plan.
|
||||
|
||||
Routing logic:
|
||||
- Sandbox plan → normal queue + tenant isolation
|
||||
- Paid plans → priority queue + tenant isolation
|
||||
- Self-hosted → priority queue, no isolation
|
||||
"""
|
||||
logger.info(
|
||||
"dispatch args: %s - %s - %s",
|
||||
self._tenant_id,
|
||||
self.features.billing.enabled,
|
||||
self.features.billing.subscription.plan,
|
||||
)
|
||||
# dispatch to different indexing queue with tenant isolation when billing enabled
|
||||
if self.features.billing.enabled:
|
||||
if self.features.billing.subscription.plan == CloudPlan.SANDBOX:
|
||||
# dispatch to normal pipeline queue with tenant self sub queue for sandbox plan
|
||||
self._send_to_default_tenant_queue()
|
||||
else:
|
||||
# dispatch to priority pipeline queue with tenant self sub queue for other plans
|
||||
self._send_to_priority_tenant_queue()
|
||||
else:
|
||||
# dispatch to priority queue without tenant isolation for others, e.g.: self-hosted or enterprise
|
||||
self._send_to_priority_direct_queue()
|
||||
|
||||
def delay(self):
|
||||
"""Public API: Queue the task asynchronously."""
|
||||
self._dispatch()
|
||||
|
|
@ -0,0 +1,76 @@
|
|||
import logging
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import asdict
|
||||
from typing import Any
|
||||
|
||||
from core.entities.document_task import DocumentTask
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
|
||||
from .base import DocumentTaskProxyBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BatchDocumentIndexingProxy(DocumentTaskProxyBase):
|
||||
"""
|
||||
Base proxy for batch document indexing tasks (document_ids in plural).
|
||||
|
||||
Adds:
|
||||
- Tenant isolated queue management
|
||||
- Batch document handling
|
||||
"""
|
||||
|
||||
def __init__(self, tenant_id: str, dataset_id: str, document_ids: Sequence[str]):
|
||||
"""
|
||||
Initialize with batch documents.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
dataset_id: Dataset identifier
|
||||
document_ids: List of document IDs to process
|
||||
"""
|
||||
super().__init__(tenant_id, dataset_id)
|
||||
self._document_ids = document_ids
|
||||
self._tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, self.QUEUE_NAME)
|
||||
|
||||
def _send_to_direct_queue(self, task_func: Callable[[str, str, Sequence[str]], Any]):
|
||||
"""
|
||||
Send batch task to direct queue.
|
||||
|
||||
Args:
|
||||
task_func: The Celery task function to call with (tenant_id, dataset_id, document_ids)
|
||||
"""
|
||||
logger.info("tenant %s send documents %s to direct queue", self._tenant_id, self._document_ids)
|
||||
task_func.delay( # type: ignore
|
||||
tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids
|
||||
)
|
||||
|
||||
def _send_to_tenant_queue(self, task_func: Callable[[str, str, Sequence[str]], Any]):
|
||||
"""
|
||||
Send batch task to tenant-isolated queue.
|
||||
|
||||
Args:
|
||||
task_func: The Celery task function to call with (tenant_id, dataset_id, document_ids)
|
||||
"""
|
||||
logger.info(
|
||||
"tenant %s send documents %s to tenant queue %s", self._tenant_id, self._document_ids, self.QUEUE_NAME
|
||||
)
|
||||
if self._tenant_isolated_task_queue.get_task_key():
|
||||
# Add to waiting queue using List operations (lpush)
|
||||
self._tenant_isolated_task_queue.push_tasks(
|
||||
[
|
||||
asdict(
|
||||
DocumentTask(
|
||||
tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
logger.info("tenant %s push tasks: %s - %s", self._tenant_id, self._dataset_id, self._document_ids)
|
||||
else:
|
||||
# Set flag and execute task
|
||||
self._tenant_isolated_task_queue.set_task_waiting_time()
|
||||
task_func.delay( # type: ignore
|
||||
tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids
|
||||
)
|
||||
logger.info("tenant %s init tasks: %s - %s", self._tenant_id, self._dataset_id, self._document_ids)
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
from typing import ClassVar
|
||||
|
||||
from services.document_indexing_proxy.batch_indexing_base import BatchDocumentIndexingProxy
|
||||
from tasks.document_indexing_task import normal_document_indexing_task, priority_document_indexing_task
|
||||
|
||||
|
||||
class DocumentIndexingTaskProxy(BatchDocumentIndexingProxy):
|
||||
"""Proxy for document indexing tasks."""
|
||||
|
||||
QUEUE_NAME: ClassVar[str] = "document_indexing"
|
||||
NORMAL_TASK_FUNC = normal_document_indexing_task
|
||||
PRIORITY_TASK_FUNC = priority_document_indexing_task
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
from typing import ClassVar
|
||||
|
||||
from services.document_indexing_proxy.batch_indexing_base import BatchDocumentIndexingProxy
|
||||
from tasks.duplicate_document_indexing_task import (
|
||||
normal_duplicate_document_indexing_task,
|
||||
priority_duplicate_document_indexing_task,
|
||||
)
|
||||
|
||||
|
||||
class DuplicateDocumentIndexingTaskProxy(BatchDocumentIndexingProxy):
|
||||
"""Proxy for duplicate document indexing tasks."""
|
||||
|
||||
QUEUE_NAME: ClassVar[str] = "duplicate_document_indexing"
|
||||
NORMAL_TASK_FUNC = normal_duplicate_document_indexing_task
|
||||
PRIORITY_TASK_FUNC = priority_duplicate_document_indexing_task
|
||||
|
|
@ -1,83 +0,0 @@
|
|||
import logging
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import asdict
|
||||
from functools import cached_property
|
||||
|
||||
from core.entities.document_task import DocumentTask
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from services.feature_service import FeatureService
|
||||
from tasks.document_indexing_task import normal_document_indexing_task, priority_document_indexing_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DocumentIndexingTaskProxy:
|
||||
def __init__(self, tenant_id: str, dataset_id: str, document_ids: Sequence[str]):
|
||||
self._tenant_id = tenant_id
|
||||
self._dataset_id = dataset_id
|
||||
self._document_ids = document_ids
|
||||
self._tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing")
|
||||
|
||||
@cached_property
|
||||
def features(self):
|
||||
return FeatureService.get_features(self._tenant_id)
|
||||
|
||||
def _send_to_direct_queue(self, task_func: Callable[[str, str, Sequence[str]], None]):
|
||||
logger.info("send dataset %s to direct queue", self._dataset_id)
|
||||
task_func.delay( # type: ignore
|
||||
tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids
|
||||
)
|
||||
|
||||
def _send_to_tenant_queue(self, task_func: Callable[[str, str, Sequence[str]], None]):
|
||||
logger.info("send dataset %s to tenant queue", self._dataset_id)
|
||||
if self._tenant_isolated_task_queue.get_task_key():
|
||||
# Add to waiting queue using List operations (lpush)
|
||||
self._tenant_isolated_task_queue.push_tasks(
|
||||
[
|
||||
asdict(
|
||||
DocumentTask(
|
||||
tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
logger.info("push tasks: %s - %s", self._dataset_id, self._document_ids)
|
||||
else:
|
||||
# Set flag and execute task
|
||||
self._tenant_isolated_task_queue.set_task_waiting_time()
|
||||
task_func.delay( # type: ignore
|
||||
tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids
|
||||
)
|
||||
logger.info("init tasks: %s - %s", self._dataset_id, self._document_ids)
|
||||
|
||||
def _send_to_default_tenant_queue(self):
|
||||
self._send_to_tenant_queue(normal_document_indexing_task)
|
||||
|
||||
def _send_to_priority_tenant_queue(self):
|
||||
self._send_to_tenant_queue(priority_document_indexing_task)
|
||||
|
||||
def _send_to_priority_direct_queue(self):
|
||||
self._send_to_direct_queue(priority_document_indexing_task)
|
||||
|
||||
def _dispatch(self):
|
||||
logger.info(
|
||||
"dispatch args: %s - %s - %s",
|
||||
self._tenant_id,
|
||||
self.features.billing.enabled,
|
||||
self.features.billing.subscription.plan,
|
||||
)
|
||||
# dispatch to different indexing queue with tenant isolation when billing enabled
|
||||
if self.features.billing.enabled:
|
||||
if self.features.billing.subscription.plan == CloudPlan.SANDBOX:
|
||||
# dispatch to normal pipeline queue with tenant self sub queue for sandbox plan
|
||||
self._send_to_default_tenant_queue()
|
||||
else:
|
||||
# dispatch to priority pipeline queue with tenant self sub queue for other plans
|
||||
self._send_to_priority_tenant_queue()
|
||||
else:
|
||||
# dispatch to priority queue without tenant isolation for others, e.g.: self-hosted or enterprise
|
||||
self._send_to_priority_direct_queue()
|
||||
|
||||
def delay(self):
|
||||
self._dispatch()
|
||||
|
|
@ -38,21 +38,24 @@ class RagPipelineTaskProxy:
|
|||
upload_file = FileService(db.engine).upload_text(
|
||||
json_text, self._RAG_PIPELINE_INVOKE_ENTITIES_FILE_NAME, self._user_id, self._dataset_tenant_id
|
||||
)
|
||||
logger.info(
|
||||
"tenant %s upload %d invoke entities", self._dataset_tenant_id, len(self._rag_pipeline_invoke_entities)
|
||||
)
|
||||
return upload_file.id
|
||||
|
||||
def _send_to_direct_queue(self, upload_file_id: str, task_func: Callable[[str, str], None]):
|
||||
logger.info("send file %s to direct queue", upload_file_id)
|
||||
logger.info("tenant %s send file %s to direct queue", self._dataset_tenant_id, upload_file_id)
|
||||
task_func.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=upload_file_id,
|
||||
tenant_id=self._dataset_tenant_id,
|
||||
)
|
||||
|
||||
def _send_to_tenant_queue(self, upload_file_id: str, task_func: Callable[[str, str], None]):
|
||||
logger.info("send file %s to tenant queue", upload_file_id)
|
||||
logger.info("tenant %s send file %s to tenant queue", self._dataset_tenant_id, upload_file_id)
|
||||
if self._tenant_isolated_task_queue.get_task_key():
|
||||
# Add to waiting queue using List operations (lpush)
|
||||
self._tenant_isolated_task_queue.push_tasks([upload_file_id])
|
||||
logger.info("push tasks: %s", upload_file_id)
|
||||
logger.info("tenant %s push tasks: %s", self._dataset_tenant_id, upload_file_id)
|
||||
else:
|
||||
# Set flag and execute task
|
||||
self._tenant_isolated_task_queue.set_task_waiting_time()
|
||||
|
|
@ -60,7 +63,7 @@ class RagPipelineTaskProxy:
|
|||
rag_pipeline_invoke_entities_file_id=upload_file_id,
|
||||
tenant_id=self._dataset_tenant_id,
|
||||
)
|
||||
logger.info("init tasks: %s", upload_file_id)
|
||||
logger.info("tenant %s init tasks: %s", self._dataset_tenant_id, upload_file_id)
|
||||
|
||||
def _send_to_default_tenant_queue(self, upload_file_id: str):
|
||||
self._send_to_tenant_queue(upload_file_id, rag_pipeline_run_task)
|
||||
|
|
|
|||
|
|
@ -114,7 +114,13 @@ def _document_indexing_with_tenant_queue(
|
|||
try:
|
||||
_document_indexing(dataset_id, document_ids)
|
||||
except Exception:
|
||||
logger.exception("Error processing document indexing %s for tenant %s: %s", dataset_id, tenant_id)
|
||||
logger.exception(
|
||||
"Error processing document indexing %s for tenant %s: %s",
|
||||
dataset_id,
|
||||
tenant_id,
|
||||
document_ids,
|
||||
exc_info=True,
|
||||
)
|
||||
finally:
|
||||
tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing")
|
||||
|
||||
|
|
@ -122,7 +128,7 @@ def _document_indexing_with_tenant_queue(
|
|||
# Use rpop to get the next task from the queue (FIFO order)
|
||||
next_tasks = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY)
|
||||
|
||||
logger.info("document indexing tenant isolation queue next tasks: %s", next_tasks)
|
||||
logger.info("document indexing tenant isolation queue %s next tasks: %s", tenant_id, next_tasks)
|
||||
|
||||
if next_tasks:
|
||||
for next_task in next_tasks:
|
||||
|
|
|
|||
|
|
@ -1,13 +1,16 @@
|
|||
import logging
|
||||
import time
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.document_task import DocumentTask
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
|
@ -24,8 +27,55 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
|
|||
:param dataset_id:
|
||||
:param document_ids:
|
||||
|
||||
.. warning:: TO BE DEPRECATED
|
||||
This function will be deprecated and removed in a future version.
|
||||
Use normal_duplicate_document_indexing_task or priority_duplicate_document_indexing_task instead.
|
||||
|
||||
Usage: duplicate_document_indexing_task.delay(dataset_id, document_ids)
|
||||
"""
|
||||
logger.warning("duplicate document indexing task received: %s - %s", dataset_id, document_ids)
|
||||
_duplicate_document_indexing_task(dataset_id, document_ids)
|
||||
|
||||
|
||||
def _duplicate_document_indexing_task_with_tenant_queue(
|
||||
tenant_id: str, dataset_id: str, document_ids: Sequence[str], task_func: Callable[[str, str, Sequence[str]], None]
|
||||
):
|
||||
try:
|
||||
_duplicate_document_indexing_task(dataset_id, document_ids)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Error processing duplicate document indexing %s for tenant %s: %s",
|
||||
dataset_id,
|
||||
tenant_id,
|
||||
document_ids,
|
||||
exc_info=True,
|
||||
)
|
||||
finally:
|
||||
tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "duplicate_document_indexing")
|
||||
|
||||
# Check if there are waiting tasks in the queue
|
||||
# Use rpop to get the next task from the queue (FIFO order)
|
||||
next_tasks = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY)
|
||||
|
||||
logger.info("duplicate document indexing tenant isolation queue %s next tasks: %s", tenant_id, next_tasks)
|
||||
|
||||
if next_tasks:
|
||||
for next_task in next_tasks:
|
||||
document_task = DocumentTask(**next_task)
|
||||
# Process the next waiting task
|
||||
# Keep the flag set to indicate a task is running
|
||||
tenant_isolated_task_queue.set_task_waiting_time()
|
||||
task_func.delay( # type: ignore
|
||||
tenant_id=document_task.tenant_id,
|
||||
dataset_id=document_task.dataset_id,
|
||||
document_ids=document_task.document_ids,
|
||||
)
|
||||
else:
|
||||
# No more waiting tasks, clear the flag
|
||||
tenant_isolated_task_queue.delete_task_key()
|
||||
|
||||
|
||||
def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[str]):
|
||||
documents = []
|
||||
start_at = time.perf_counter()
|
||||
|
||||
|
|
@ -110,3 +160,35 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
|
|||
logger.exception("duplicate_document_indexing_task failed, dataset_id: %s", dataset_id)
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def normal_duplicate_document_indexing_task(tenant_id: str, dataset_id: str, document_ids: Sequence[str]):
|
||||
"""
|
||||
Async process duplicate documents
|
||||
:param tenant_id:
|
||||
:param dataset_id:
|
||||
:param document_ids:
|
||||
|
||||
Usage: normal_duplicate_document_indexing_task.delay(tenant_id, dataset_id, document_ids)
|
||||
"""
|
||||
logger.info("normal duplicate document indexing task received: %s - %s - %s", tenant_id, dataset_id, document_ids)
|
||||
_duplicate_document_indexing_task_with_tenant_queue(
|
||||
tenant_id, dataset_id, document_ids, normal_duplicate_document_indexing_task
|
||||
)
|
||||
|
||||
|
||||
@shared_task(queue="priority_dataset")
|
||||
def priority_duplicate_document_indexing_task(tenant_id: str, dataset_id: str, document_ids: Sequence[str]):
|
||||
"""
|
||||
Async process duplicate documents
|
||||
:param tenant_id:
|
||||
:param dataset_id:
|
||||
:param document_ids:
|
||||
|
||||
Usage: priority_duplicate_document_indexing_task.delay(tenant_id, dataset_id, document_ids)
|
||||
"""
|
||||
logger.info("priority duplicate document indexing task received: %s - %s - %s", tenant_id, dataset_id, document_ids)
|
||||
_duplicate_document_indexing_task_with_tenant_queue(
|
||||
tenant_id, dataset_id, document_ids, priority_duplicate_document_indexing_task
|
||||
)
|
||||
|
|
|
|||
|
|
@ -47,6 +47,8 @@ def priority_rag_pipeline_run_task(
|
|||
)
|
||||
rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content)
|
||||
|
||||
logger.info("tenant %s received %d rag pipeline invoke entities", tenant_id, len(rag_pipeline_invoke_entities))
|
||||
|
||||
# Get Flask app object for thread context
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
|
||||
|
|
@ -66,7 +68,7 @@ def priority_rag_pipeline_run_task(
|
|||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style(
|
||||
f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green"
|
||||
f"tenant_id: {tenant_id}, Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green"
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
|
|
@ -78,7 +80,7 @@ def priority_rag_pipeline_run_task(
|
|||
# Check if there are waiting tasks in the queue
|
||||
# Use rpop to get the next task from the queue (FIFO order)
|
||||
next_file_ids = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY)
|
||||
logger.info("priority rag pipeline tenant isolation queue next files: %s", next_file_ids)
|
||||
logger.info("priority rag pipeline tenant isolation queue %s next files: %s", tenant_id, next_file_ids)
|
||||
|
||||
if next_file_ids:
|
||||
for next_file_id in next_file_ids:
|
||||
|
|
|
|||
|
|
@ -47,6 +47,8 @@ def rag_pipeline_run_task(
|
|||
)
|
||||
rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content)
|
||||
|
||||
logger.info("tenant %s received %d rag pipeline invoke entities", tenant_id, len(rag_pipeline_invoke_entities))
|
||||
|
||||
# Get Flask app object for thread context
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
|
||||
|
|
@ -66,7 +68,7 @@ def rag_pipeline_run_task(
|
|||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style(
|
||||
f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green"
|
||||
f"tenant_id: {tenant_id}, Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green"
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
|
|
@ -78,7 +80,7 @@ def rag_pipeline_run_task(
|
|||
# Check if there are waiting tasks in the queue
|
||||
# Use rpop to get the next task from the queue (FIFO order)
|
||||
next_file_ids = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY)
|
||||
logger.info("rag pipeline tenant isolation queue next files: %s", next_file_ids)
|
||||
logger.info("rag pipeline tenant isolation queue %s next files: %s", tenant_id, next_file_ids)
|
||||
|
||||
if next_file_ids:
|
||||
for next_file_id in next_file_ids:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,763 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from tasks.duplicate_document_indexing_task import (
|
||||
_duplicate_document_indexing_task, # Core function
|
||||
_duplicate_document_indexing_task_with_tenant_queue, # Tenant queue wrapper function
|
||||
duplicate_document_indexing_task, # Deprecated old interface
|
||||
normal_duplicate_document_indexing_task, # New normal task
|
||||
priority_duplicate_document_indexing_task, # New priority task
|
||||
)
|
||||
|
||||
|
||||
class TestDuplicateDocumentIndexingTasks:
|
||||
"""Integration tests for duplicate document indexing tasks using testcontainers.
|
||||
|
||||
This test class covers:
|
||||
- Core _duplicate_document_indexing_task function
|
||||
- Deprecated duplicate_document_indexing_task function
|
||||
- New normal_duplicate_document_indexing_task function
|
||||
- New priority_duplicate_document_indexing_task function
|
||||
- Tenant queue wrapper _duplicate_document_indexing_task_with_tenant_queue function
|
||||
- Document segment cleanup logic
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("tasks.duplicate_document_indexing_task.IndexingRunner") as mock_indexing_runner,
|
||||
patch("tasks.duplicate_document_indexing_task.FeatureService") as mock_feature_service,
|
||||
patch("tasks.duplicate_document_indexing_task.IndexProcessorFactory") as mock_index_processor_factory,
|
||||
):
|
||||
# Setup mock indexing runner
|
||||
mock_runner_instance = MagicMock()
|
||||
mock_indexing_runner.return_value = mock_runner_instance
|
||||
|
||||
# Setup mock feature service
|
||||
mock_features = MagicMock()
|
||||
mock_features.billing.enabled = False
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
|
||||
# Setup mock index processor factory
|
||||
mock_processor = MagicMock()
|
||||
mock_processor.clean = MagicMock()
|
||||
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
|
||||
|
||||
yield {
|
||||
"indexing_runner": mock_indexing_runner,
|
||||
"indexing_runner_instance": mock_runner_instance,
|
||||
"feature_service": mock_feature_service,
|
||||
"features": mock_features,
|
||||
"index_processor_factory": mock_index_processor_factory,
|
||||
"index_processor": mock_processor,
|
||||
}
|
||||
|
||||
def _create_test_dataset_and_documents(
|
||||
self, db_session_with_containers, mock_external_service_dependencies, document_count=3
|
||||
):
|
||||
"""
|
||||
Helper method to create a test dataset and documents for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
document_count: Number of documents to create
|
||||
|
||||
Returns:
|
||||
tuple: (dataset, documents) - Created dataset and document instances
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant
|
||||
account = Account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
|
||||
# Create dataset
|
||||
dataset = Dataset(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type="upload_file",
|
||||
indexing_technique="high_quality",
|
||||
created_by=account.id,
|
||||
)
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
|
||||
# Create documents
|
||||
documents = []
|
||||
for i in range(document_count):
|
||||
document = Document(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=i,
|
||||
data_source_type="upload_file",
|
||||
batch="test_batch",
|
||||
name=fake.file_name(),
|
||||
created_from="upload_file",
|
||||
created_by=account.id,
|
||||
indexing_status="waiting",
|
||||
enabled=True,
|
||||
doc_form="text_model",
|
||||
)
|
||||
db.session.add(document)
|
||||
documents.append(document)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# Refresh dataset to ensure it's properly loaded
|
||||
db.session.refresh(dataset)
|
||||
|
||||
return dataset, documents
|
||||
|
||||
def _create_test_dataset_with_segments(
|
||||
self, db_session_with_containers, mock_external_service_dependencies, document_count=3, segments_per_doc=2
|
||||
):
|
||||
"""
|
||||
Helper method to create a test dataset with documents and segments.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
document_count: Number of documents to create
|
||||
segments_per_doc: Number of segments per document
|
||||
|
||||
Returns:
|
||||
tuple: (dataset, documents, segments) - Created dataset, documents and segments
|
||||
"""
|
||||
dataset, documents = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count
|
||||
)
|
||||
|
||||
fake = Faker()
|
||||
segments = []
|
||||
|
||||
# Create segments for each document
|
||||
for document in documents:
|
||||
for i in range(segments_per_doc):
|
||||
segment = DocumentSegment(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
position=i,
|
||||
index_node_id=f"{document.id}-node-{i}",
|
||||
index_node_hash=fake.sha256(),
|
||||
content=fake.text(max_nb_chars=200),
|
||||
word_count=50,
|
||||
tokens=100,
|
||||
status="completed",
|
||||
enabled=True,
|
||||
indexing_at=fake.date_time_this_year(),
|
||||
created_by=dataset.created_by, # Add required field
|
||||
)
|
||||
db.session.add(segment)
|
||||
segments.append(segment)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# Refresh to ensure all relationships are loaded
|
||||
for document in documents:
|
||||
db.session.refresh(document)
|
||||
|
||||
return dataset, documents, segments
|
||||
|
||||
def _create_test_dataset_with_billing_features(
|
||||
self, db_session_with_containers, mock_external_service_dependencies, billing_enabled=True
|
||||
):
|
||||
"""
|
||||
Helper method to create a test dataset with billing features configured.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
billing_enabled: Whether billing is enabled
|
||||
|
||||
Returns:
|
||||
tuple: (dataset, documents) - Created dataset and document instances
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant
|
||||
account = Account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
|
||||
# Create dataset
|
||||
dataset = Dataset(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type="upload_file",
|
||||
indexing_technique="high_quality",
|
||||
created_by=account.id,
|
||||
)
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
|
||||
# Create documents
|
||||
documents = []
|
||||
for i in range(3):
|
||||
document = Document(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=i,
|
||||
data_source_type="upload_file",
|
||||
batch="test_batch",
|
||||
name=fake.file_name(),
|
||||
created_from="upload_file",
|
||||
created_by=account.id,
|
||||
indexing_status="waiting",
|
||||
enabled=True,
|
||||
doc_form="text_model",
|
||||
)
|
||||
db.session.add(document)
|
||||
documents.append(document)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# Configure billing features
|
||||
mock_external_service_dependencies["features"].billing.enabled = billing_enabled
|
||||
if billing_enabled:
|
||||
mock_external_service_dependencies["features"].billing.subscription.plan = CloudPlan.SANDBOX
|
||||
mock_external_service_dependencies["features"].vector_space.limit = 100
|
||||
mock_external_service_dependencies["features"].vector_space.size = 50
|
||||
|
||||
# Refresh dataset to ensure it's properly loaded
|
||||
db.session.refresh(dataset)
|
||||
|
||||
return dataset, documents
|
||||
|
||||
def test_duplicate_document_indexing_task_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful duplicate document indexing with multiple documents.
|
||||
|
||||
This test verifies:
|
||||
- Proper dataset retrieval from database
|
||||
- Correct document processing and status updates
|
||||
- IndexingRunner integration
|
||||
- Database state updates
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
dataset, documents = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=3
|
||||
)
|
||||
document_ids = [doc.id for doc in documents]
|
||||
|
||||
# Act: Execute the task
|
||||
_duplicate_document_indexing_task(dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
# Verify indexing runner was called correctly
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify documents were updated to parsing status
|
||||
# Re-query documents from database since _duplicate_document_indexing_task uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
# Verify the run method was called with correct documents
|
||||
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
|
||||
assert call_args is not None
|
||||
processed_documents = call_args[0][0] # First argument should be documents list
|
||||
assert len(processed_documents) == 3
|
||||
|
||||
def test_duplicate_document_indexing_task_with_segment_cleanup(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test duplicate document indexing with existing segments that need cleanup.
|
||||
|
||||
This test verifies:
|
||||
- Old segments are identified and cleaned
|
||||
- Index processor clean method is called
|
||||
- Segments are deleted from database
|
||||
- New indexing proceeds after cleanup
|
||||
"""
|
||||
# Arrange: Create test data with existing segments
|
||||
dataset, documents, segments = self._create_test_dataset_with_segments(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=2, segments_per_doc=3
|
||||
)
|
||||
document_ids = [doc.id for doc in documents]
|
||||
|
||||
# Act: Execute the task
|
||||
_duplicate_document_indexing_task(dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify segment cleanup
|
||||
# Verify index processor clean was called for each document with segments
|
||||
assert mock_external_service_dependencies["index_processor"].clean.call_count == len(documents)
|
||||
|
||||
# Verify segments were deleted from database
|
||||
# Re-query segments from database since _duplicate_document_indexing_task uses a different session
|
||||
for segment in segments:
|
||||
deleted_segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment.id).first()
|
||||
assert deleted_segment is None
|
||||
|
||||
# Verify documents were updated to parsing status
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
# Verify indexing runner was called
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
def test_duplicate_document_indexing_task_dataset_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test handling of non-existent dataset.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for missing datasets
|
||||
- Early return without processing
|
||||
- Database session cleanup
|
||||
- No unnecessary indexing runner calls
|
||||
"""
|
||||
# Arrange: Use non-existent dataset ID
|
||||
fake = Faker()
|
||||
non_existent_dataset_id = fake.uuid4()
|
||||
document_ids = [fake.uuid4() for _ in range(3)]
|
||||
|
||||
# Act: Execute the task with non-existent dataset
|
||||
_duplicate_document_indexing_task(non_existent_dataset_id, document_ids)
|
||||
|
||||
# Assert: Verify no processing occurred
|
||||
mock_external_service_dependencies["indexing_runner"].assert_not_called()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called()
|
||||
mock_external_service_dependencies["index_processor"].clean.assert_not_called()
|
||||
|
||||
def test_duplicate_document_indexing_task_document_not_found_in_dataset(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test handling when some documents don't exist in the dataset.
|
||||
|
||||
This test verifies:
|
||||
- Only existing documents are processed
|
||||
- Non-existent documents are ignored
|
||||
- Indexing runner receives only valid documents
|
||||
- Database state updates correctly
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
dataset, documents = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=2
|
||||
)
|
||||
|
||||
# Mix existing and non-existent document IDs
|
||||
fake = Faker()
|
||||
existing_document_ids = [doc.id for doc in documents]
|
||||
non_existent_document_ids = [fake.uuid4() for _ in range(2)]
|
||||
all_document_ids = existing_document_ids + non_existent_document_ids
|
||||
|
||||
# Act: Execute the task with mixed document IDs
|
||||
_duplicate_document_indexing_task(dataset.id, all_document_ids)
|
||||
|
||||
# Assert: Verify only existing documents were processed
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify only existing documents were updated
|
||||
# Re-query documents from database since _duplicate_document_indexing_task uses a different session
|
||||
for doc_id in existing_document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
# Verify the run method was called with only existing documents
|
||||
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
|
||||
assert call_args is not None
|
||||
processed_documents = call_args[0][0] # First argument should be documents list
|
||||
assert len(processed_documents) == 2 # Only existing documents
|
||||
|
||||
def test_duplicate_document_indexing_task_indexing_runner_exception(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test handling of IndexingRunner exceptions.
|
||||
|
||||
This test verifies:
|
||||
- Exceptions from IndexingRunner are properly caught
|
||||
- Task completes without raising exceptions
|
||||
- Database session is properly closed
|
||||
- Error logging occurs
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
dataset, documents = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=2
|
||||
)
|
||||
document_ids = [doc.id for doc in documents]
|
||||
|
||||
# Mock IndexingRunner to raise an exception
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.side_effect = Exception(
|
||||
"Indexing runner failed"
|
||||
)
|
||||
|
||||
# Act: Execute the task
|
||||
_duplicate_document_indexing_task(dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify exception was handled gracefully
|
||||
# The task should complete without raising exceptions
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify documents were still updated to parsing status before the exception
|
||||
# Re-query documents from database since _duplicate_document_indexing_task close the session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
def test_duplicate_document_indexing_task_billing_sandbox_plan_batch_limit(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test billing validation for sandbox plan batch upload limit.
|
||||
|
||||
This test verifies:
|
||||
- Sandbox plan batch upload limit enforcement
|
||||
- Error handling for batch upload limit exceeded
|
||||
- Document status updates to error state
|
||||
- Proper error message recording
|
||||
"""
|
||||
# Arrange: Create test data with billing enabled
|
||||
dataset, documents = self._create_test_dataset_with_billing_features(
|
||||
db_session_with_containers, mock_external_service_dependencies, billing_enabled=True
|
||||
)
|
||||
|
||||
# Configure sandbox plan with batch limit
|
||||
mock_external_service_dependencies["features"].billing.subscription.plan = CloudPlan.SANDBOX
|
||||
|
||||
# Create more documents than sandbox plan allows (limit is 1)
|
||||
fake = Faker()
|
||||
extra_documents = []
|
||||
for i in range(2): # Total will be 5 documents (3 existing + 2 new)
|
||||
document = Document(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=i + 3,
|
||||
data_source_type="upload_file",
|
||||
batch="test_batch",
|
||||
name=fake.file_name(),
|
||||
created_from="upload_file",
|
||||
created_by=dataset.created_by,
|
||||
indexing_status="waiting",
|
||||
enabled=True,
|
||||
doc_form="text_model",
|
||||
)
|
||||
db.session.add(document)
|
||||
extra_documents.append(document)
|
||||
|
||||
db.session.commit()
|
||||
all_documents = documents + extra_documents
|
||||
document_ids = [doc.id for doc in all_documents]
|
||||
|
||||
# Act: Execute the task with too many documents for sandbox plan
|
||||
_duplicate_document_indexing_task(dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify error handling
|
||||
# Re-query documents from database since _duplicate_document_indexing_task uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "error"
|
||||
assert updated_document.error is not None
|
||||
assert "batch upload" in updated_document.error.lower()
|
||||
assert updated_document.stopped_at is not None
|
||||
|
||||
# Verify indexing runner was not called due to early validation error
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called()
|
||||
|
||||
def test_duplicate_document_indexing_task_billing_vector_space_limit_exceeded(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test billing validation for vector space limit.
|
||||
|
||||
This test verifies:
|
||||
- Vector space limit enforcement
|
||||
- Error handling for vector space limit exceeded
|
||||
- Document status updates to error state
|
||||
- Proper error message recording
|
||||
"""
|
||||
# Arrange: Create test data with billing enabled
|
||||
dataset, documents = self._create_test_dataset_with_billing_features(
|
||||
db_session_with_containers, mock_external_service_dependencies, billing_enabled=True
|
||||
)
|
||||
|
||||
# Configure TEAM plan with vector space limit exceeded
|
||||
mock_external_service_dependencies["features"].billing.subscription.plan = CloudPlan.TEAM
|
||||
mock_external_service_dependencies["features"].vector_space.limit = 100
|
||||
mock_external_service_dependencies["features"].vector_space.size = 98 # Almost at limit
|
||||
|
||||
document_ids = [doc.id for doc in documents] # 3 documents will exceed limit
|
||||
|
||||
# Act: Execute the task with documents that will exceed vector space limit
|
||||
_duplicate_document_indexing_task(dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify error handling
|
||||
# Re-query documents from database since _duplicate_document_indexing_task uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "error"
|
||||
assert updated_document.error is not None
|
||||
assert "limit" in updated_document.error.lower()
|
||||
assert updated_document.stopped_at is not None
|
||||
|
||||
# Verify indexing runner was not called due to early validation error
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called()
|
||||
|
||||
def test_duplicate_document_indexing_task_with_empty_document_list(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test handling of empty document list.
|
||||
|
||||
This test verifies:
|
||||
- Empty document list is handled gracefully
|
||||
- No processing occurs
|
||||
- No errors are raised
|
||||
- Database session is properly closed
|
||||
"""
|
||||
# Arrange: Create test dataset
|
||||
dataset, _ = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=0
|
||||
)
|
||||
document_ids = []
|
||||
|
||||
# Act: Execute the task with empty document list
|
||||
_duplicate_document_indexing_task(dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify IndexingRunner was called with empty list
|
||||
# Note: The actual implementation does call run([]) with empty list
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once_with([])
|
||||
|
||||
def test_deprecated_duplicate_document_indexing_task_delegates_to_core(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test that deprecated duplicate_document_indexing_task delegates to core function.
|
||||
|
||||
This test verifies:
|
||||
- Deprecated function calls core _duplicate_document_indexing_task
|
||||
- Proper parameter passing
|
||||
- Backward compatibility
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
dataset, documents = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=2
|
||||
)
|
||||
document_ids = [doc.id for doc in documents]
|
||||
|
||||
# Act: Execute the deprecated task
|
||||
duplicate_document_indexing_task(dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify core function was executed
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Clear session cache to see database updates from task's session
|
||||
db.session.expire_all()
|
||||
|
||||
# Verify documents were processed
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue")
|
||||
def test_normal_duplicate_document_indexing_task_with_tenant_queue(
|
||||
self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test normal_duplicate_document_indexing_task with tenant isolation queue.
|
||||
|
||||
This test verifies:
|
||||
- Task uses tenant isolation queue correctly
|
||||
- Core processing function is called
|
||||
- Queue management (pull tasks, delete key) works properly
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
dataset, documents = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=2
|
||||
)
|
||||
document_ids = [doc.id for doc in documents]
|
||||
|
||||
# Mock tenant isolated queue to return no next tasks
|
||||
mock_queue = MagicMock()
|
||||
mock_queue.pull_tasks.return_value = []
|
||||
mock_queue_class.return_value = mock_queue
|
||||
|
||||
# Act: Execute the normal task
|
||||
normal_duplicate_document_indexing_task(dataset.tenant_id, dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify processing occurred
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify tenant queue was used
|
||||
mock_queue_class.assert_called_with(dataset.tenant_id, "duplicate_document_indexing")
|
||||
mock_queue.pull_tasks.assert_called_once()
|
||||
mock_queue.delete_task_key.assert_called_once()
|
||||
|
||||
# Clear session cache to see database updates from task's session
|
||||
db.session.expire_all()
|
||||
|
||||
# Verify documents were processed
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue")
|
||||
def test_priority_duplicate_document_indexing_task_with_tenant_queue(
|
||||
self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test priority_duplicate_document_indexing_task with tenant isolation queue.
|
||||
|
||||
This test verifies:
|
||||
- Task uses tenant isolation queue correctly
|
||||
- Core processing function is called
|
||||
- Queue management works properly
|
||||
- Same behavior as normal task with different queue assignment
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
dataset, documents = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=2
|
||||
)
|
||||
document_ids = [doc.id for doc in documents]
|
||||
|
||||
# Mock tenant isolated queue to return no next tasks
|
||||
mock_queue = MagicMock()
|
||||
mock_queue.pull_tasks.return_value = []
|
||||
mock_queue_class.return_value = mock_queue
|
||||
|
||||
# Act: Execute the priority task
|
||||
priority_duplicate_document_indexing_task(dataset.tenant_id, dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify processing occurred
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify tenant queue was used
|
||||
mock_queue_class.assert_called_with(dataset.tenant_id, "duplicate_document_indexing")
|
||||
mock_queue.pull_tasks.assert_called_once()
|
||||
mock_queue.delete_task_key.assert_called_once()
|
||||
|
||||
# Clear session cache to see database updates from task's session
|
||||
db.session.expire_all()
|
||||
|
||||
# Verify documents were processed
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue")
|
||||
def test_tenant_queue_wrapper_processes_next_tasks(
|
||||
self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tenant queue wrapper processes next queued tasks.
|
||||
|
||||
This test verifies:
|
||||
- After completing current task, next tasks are pulled from queue
|
||||
- Next tasks are executed correctly
|
||||
- Task waiting time is set for next tasks
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
dataset, documents = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=2
|
||||
)
|
||||
document_ids = [doc.id for doc in documents]
|
||||
|
||||
# Extract values before session detachment
|
||||
tenant_id = dataset.tenant_id
|
||||
dataset_id = dataset.id
|
||||
|
||||
# Mock tenant isolated queue to return next task
|
||||
mock_queue = MagicMock()
|
||||
next_task = {
|
||||
"tenant_id": tenant_id,
|
||||
"dataset_id": dataset_id,
|
||||
"document_ids": document_ids,
|
||||
}
|
||||
mock_queue.pull_tasks.return_value = [next_task]
|
||||
mock_queue_class.return_value = mock_queue
|
||||
|
||||
# Mock the task function to track calls
|
||||
mock_task_func = MagicMock()
|
||||
|
||||
# Act: Execute the wrapper function
|
||||
_duplicate_document_indexing_task_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task_func)
|
||||
|
||||
# Assert: Verify next task was scheduled
|
||||
mock_queue.pull_tasks.assert_called_once()
|
||||
mock_queue.set_task_waiting_time.assert_called_once()
|
||||
mock_task_func.delay.assert_called_once_with(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
document_ids=document_ids,
|
||||
)
|
||||
mock_queue.delete_task_key.assert_not_called()
|
||||
|
|
@ -117,7 +117,7 @@ import pytest
|
|||
from core.entities.document_task import DocumentTask
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from services.document_indexing_task_proxy import DocumentIndexingTaskProxy
|
||||
from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy
|
||||
|
||||
# ============================================================================
|
||||
# Test Data Factory
|
||||
|
|
@ -370,7 +370,7 @@ class TestDocumentIndexingTaskProxy:
|
|||
# Features Property Tests
|
||||
# ========================================================================
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||
def test_features_property(self, mock_feature_service):
|
||||
"""
|
||||
Test cached_property features.
|
||||
|
|
@ -400,7 +400,7 @@ class TestDocumentIndexingTaskProxy:
|
|||
|
||||
mock_feature_service.get_features.assert_called_once_with("tenant-123")
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||
def test_features_property_with_different_tenants(self, mock_feature_service):
|
||||
"""
|
||||
Test features property with different tenant IDs.
|
||||
|
|
@ -438,7 +438,7 @@ class TestDocumentIndexingTaskProxy:
|
|||
# Direct Queue Routing Tests
|
||||
# ========================================================================
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_direct_queue(self, mock_task):
|
||||
"""
|
||||
Test _send_to_direct_queue method.
|
||||
|
|
@ -460,7 +460,7 @@ class TestDocumentIndexingTaskProxy:
|
|||
# Assert
|
||||
mock_task.delay.assert_called_once_with(tenant_id=tenant_id, dataset_id=dataset_id, document_ids=document_ids)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task")
|
||||
def test_send_to_direct_queue_with_priority_task(self, mock_task):
|
||||
"""
|
||||
Test _send_to_direct_queue with priority task function.
|
||||
|
|
@ -481,7 +481,7 @@ class TestDocumentIndexingTaskProxy:
|
|||
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
|
||||
)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_direct_queue_with_single_document(self, mock_task):
|
||||
"""
|
||||
Test _send_to_direct_queue with single document ID.
|
||||
|
|
@ -502,7 +502,7 @@ class TestDocumentIndexingTaskProxy:
|
|||
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1"]
|
||||
)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_direct_queue_with_empty_documents(self, mock_task):
|
||||
"""
|
||||
Test _send_to_direct_queue with empty document_ids list.
|
||||
|
|
@ -525,7 +525,7 @@ class TestDocumentIndexingTaskProxy:
|
|||
# Tenant Queue Routing Tests
|
||||
# ========================================================================
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task):
|
||||
"""
|
||||
Test _send_to_tenant_queue when task key exists.
|
||||
|
|
@ -564,7 +564,7 @@ class TestDocumentIndexingTaskProxy:
|
|||
|
||||
mock_task.delay.assert_not_called()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_tenant_queue_without_task_key(self, mock_task):
|
||||
"""
|
||||
Test _send_to_tenant_queue when no task key exists.
|
||||
|
|
@ -594,7 +594,7 @@ class TestDocumentIndexingTaskProxy:
|
|||
|
||||
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task")
|
||||
def test_send_to_tenant_queue_with_priority_task(self, mock_task):
|
||||
"""
|
||||
Test _send_to_tenant_queue with priority task function.
|
||||
|
|
@ -621,7 +621,7 @@ class TestDocumentIndexingTaskProxy:
|
|||
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
|
||||
)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_tenant_queue_document_task_serialization(self, mock_task):
|
||||
"""
|
||||
Test DocumentTask serialization in _send_to_tenant_queue.
|
||||
|
|
@ -659,7 +659,7 @@ class TestDocumentIndexingTaskProxy:
|
|||
# Queue Type Selection Tests
|
||||
# ========================================================================
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_default_tenant_queue(self, mock_task):
|
||||
"""
|
||||
Test _send_to_default_tenant_queue method.
|
||||
|
|
@ -678,7 +678,7 @@ class TestDocumentIndexingTaskProxy:
|
|||
# Assert
|
||||
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task")
|
||||
def test_send_to_priority_tenant_queue(self, mock_task):
|
||||
"""
|
||||
Test _send_to_priority_tenant_queue method.
|
||||
|
|
@ -697,7 +697,7 @@ class TestDocumentIndexingTaskProxy:
|
|||
# Assert
|
||||
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task")
|
||||
def test_send_to_priority_direct_queue(self, mock_task):
|
||||
"""
|
||||
Test _send_to_priority_direct_queue method.
|
||||
|
|
@ -720,7 +720,7 @@ class TestDocumentIndexingTaskProxy:
|
|||
# Dispatch Logic Tests
|
||||
# ========================================================================
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||
def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service):
|
||||
"""
|
||||
Test _dispatch method when billing is enabled with SANDBOX plan.
|
||||
|
|
@ -745,7 +745,7 @@ class TestDocumentIndexingTaskProxy:
|
|||
# Assert
|
||||
proxy._send_to_default_tenant_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||
def test_dispatch_with_billing_enabled_team_plan(self, mock_feature_service):
|
||||
"""
|
||||
Test _dispatch method when billing is enabled with TEAM plan.
|
||||
|
|
@ -770,7 +770,7 @@ class TestDocumentIndexingTaskProxy:
|
|||
# Assert
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||
def test_dispatch_with_billing_enabled_professional_plan(self, mock_feature_service):
|
||||
"""
|
||||
Test _dispatch method when billing is enabled with PROFESSIONAL plan.
|
||||
|
|
@ -795,7 +795,7 @@ class TestDocumentIndexingTaskProxy:
|
|||
# Assert
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||
def test_dispatch_with_billing_disabled(self, mock_feature_service):
|
||||
"""
|
||||
Test _dispatch method when billing is disabled.
|
||||
|
|
@ -818,7 +818,7 @@ class TestDocumentIndexingTaskProxy:
|
|||
# Assert
|
||||
proxy._send_to_priority_direct_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||
def test_dispatch_edge_case_empty_plan(self, mock_feature_service):
|
||||
"""
|
||||
Test _dispatch method with empty plan string.
|
||||
|
|
@ -842,7 +842,7 @@ class TestDocumentIndexingTaskProxy:
|
|||
# Assert
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||
def test_dispatch_edge_case_none_plan(self, mock_feature_service):
|
||||
"""
|
||||
Test _dispatch method with None plan.
|
||||
|
|
@ -870,7 +870,7 @@ class TestDocumentIndexingTaskProxy:
|
|||
# Delay Method Tests
|
||||
# ========================================================================
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||
def test_delay_method(self, mock_feature_service):
|
||||
"""
|
||||
Test delay method integration.
|
||||
|
|
@ -895,7 +895,7 @@ class TestDocumentIndexingTaskProxy:
|
|||
# Assert
|
||||
proxy._send_to_default_tenant_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||
def test_delay_method_with_team_plan(self, mock_feature_service):
|
||||
"""
|
||||
Test delay method with TEAM plan.
|
||||
|
|
@ -920,7 +920,7 @@ class TestDocumentIndexingTaskProxy:
|
|||
# Assert
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||
def test_delay_method_with_billing_disabled(self, mock_feature_service):
|
||||
"""
|
||||
Test delay method with billing disabled.
|
||||
|
|
@ -1021,7 +1021,7 @@ class TestDocumentIndexingTaskProxy:
|
|||
# Batch Operations Tests
|
||||
# ========================================================================
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_batch_operation_with_multiple_documents(self, mock_task):
|
||||
"""
|
||||
Test batch operation with multiple documents.
|
||||
|
|
@ -1044,7 +1044,7 @@ class TestDocumentIndexingTaskProxy:
|
|||
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=document_ids
|
||||
)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_batch_operation_with_large_batch(self, mock_task):
|
||||
"""
|
||||
Test batch operation with large batch of documents.
|
||||
|
|
@ -1073,7 +1073,7 @@ class TestDocumentIndexingTaskProxy:
|
|||
# Error Handling Tests
|
||||
# ========================================================================
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_direct_queue_task_delay_failure(self, mock_task):
|
||||
"""
|
||||
Test _send_to_direct_queue when task.delay() raises an exception.
|
||||
|
|
@ -1090,7 +1090,7 @@ class TestDocumentIndexingTaskProxy:
|
|||
with pytest.raises(Exception, match="Task delay failed"):
|
||||
proxy._send_to_direct_queue(mock_task)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_tenant_queue_push_tasks_failure(self, mock_task):
|
||||
"""
|
||||
Test _send_to_tenant_queue when push_tasks raises an exception.
|
||||
|
|
@ -1111,7 +1111,7 @@ class TestDocumentIndexingTaskProxy:
|
|||
with pytest.raises(Exception, match="Push tasks failed"):
|
||||
proxy._send_to_tenant_queue(mock_task)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_tenant_queue_set_waiting_time_failure(self, mock_task):
|
||||
"""
|
||||
Test _send_to_tenant_queue when set_task_waiting_time raises an exception.
|
||||
|
|
@ -1132,7 +1132,7 @@ class TestDocumentIndexingTaskProxy:
|
|||
with pytest.raises(Exception, match="Set waiting time failed"):
|
||||
proxy._send_to_tenant_queue(mock_task)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||
def test_dispatch_feature_service_failure(self, mock_feature_service):
|
||||
"""
|
||||
Test _dispatch when FeatureService.get_features raises an exception.
|
||||
|
|
@ -1153,8 +1153,8 @@ class TestDocumentIndexingTaskProxy:
|
|||
# Integration Tests
|
||||
# ========================================================================
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_full_flow_sandbox_plan(self, mock_task, mock_feature_service):
|
||||
"""
|
||||
Test full flow for SANDBOX plan with tenant queue.
|
||||
|
|
@ -1187,8 +1187,8 @@ class TestDocumentIndexingTaskProxy:
|
|||
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
|
||||
)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task")
|
||||
def test_full_flow_team_plan(self, mock_task, mock_feature_service):
|
||||
"""
|
||||
Test full flow for TEAM plan with priority tenant queue.
|
||||
|
|
@ -1221,8 +1221,8 @@ class TestDocumentIndexingTaskProxy:
|
|||
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
|
||||
)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task")
|
||||
def test_full_flow_billing_disabled(self, mock_task, mock_feature_service):
|
||||
"""
|
||||
Test full flow for billing disabled (self-hosted/enterprise).
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from unittest.mock import Mock, patch
|
|||
from core.entities.document_task import DocumentTask
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from services.document_indexing_task_proxy import DocumentIndexingTaskProxy
|
||||
from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy
|
||||
|
||||
|
||||
class DocumentIndexingTaskProxyTestDataFactory:
|
||||
|
|
@ -59,7 +59,7 @@ class TestDocumentIndexingTaskProxy:
|
|||
assert proxy._tenant_isolated_task_queue._tenant_id == tenant_id
|
||||
assert proxy._tenant_isolated_task_queue._unique_key == "document_indexing"
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||
def test_features_property(self, mock_feature_service):
|
||||
"""Test cached_property features."""
|
||||
# Arrange
|
||||
|
|
@ -77,7 +77,7 @@ class TestDocumentIndexingTaskProxy:
|
|||
assert features1 is features2 # Should be the same instance due to caching
|
||||
mock_feature_service.get_features.assert_called_once_with("tenant-123")
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_direct_queue(self, mock_task):
|
||||
"""Test _send_to_direct_queue method."""
|
||||
# Arrange
|
||||
|
|
@ -92,7 +92,7 @@ class TestDocumentIndexingTaskProxy:
|
|||
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
|
||||
)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task):
|
||||
"""Test _send_to_tenant_queue when task key exists."""
|
||||
# Arrange
|
||||
|
|
@ -115,7 +115,7 @@ class TestDocumentIndexingTaskProxy:
|
|||
assert pushed_tasks[0]["document_ids"] == ["doc-1", "doc-2", "doc-3"]
|
||||
mock_task.delay.assert_not_called()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_tenant_queue_without_task_key(self, mock_task):
|
||||
"""Test _send_to_tenant_queue when no task key exists."""
|
||||
# Arrange
|
||||
|
|
@ -135,8 +135,7 @@ class TestDocumentIndexingTaskProxy:
|
|||
)
|
||||
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_default_tenant_queue(self, mock_task):
|
||||
def test_send_to_default_tenant_queue(self):
|
||||
"""Test _send_to_default_tenant_queue method."""
|
||||
# Arrange
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
|
|
@ -146,10 +145,9 @@ class TestDocumentIndexingTaskProxy:
|
|||
proxy._send_to_default_tenant_queue()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
|
||||
proxy._send_to_tenant_queue.assert_called_once_with(proxy.NORMAL_TASK_FUNC)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
|
||||
def test_send_to_priority_tenant_queue(self, mock_task):
|
||||
def test_send_to_priority_tenant_queue(self):
|
||||
"""Test _send_to_priority_tenant_queue method."""
|
||||
# Arrange
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
|
|
@ -159,10 +157,9 @@ class TestDocumentIndexingTaskProxy:
|
|||
proxy._send_to_priority_tenant_queue()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
|
||||
proxy._send_to_tenant_queue.assert_called_once_with(proxy.PRIORITY_TASK_FUNC)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
|
||||
def test_send_to_priority_direct_queue(self, mock_task):
|
||||
def test_send_to_priority_direct_queue(self):
|
||||
"""Test _send_to_priority_direct_queue method."""
|
||||
# Arrange
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
|
|
@ -172,9 +169,9 @@ class TestDocumentIndexingTaskProxy:
|
|||
proxy._send_to_priority_direct_queue()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_direct_queue.assert_called_once_with(mock_task)
|
||||
proxy._send_to_direct_queue.assert_called_once_with(proxy.PRIORITY_TASK_FUNC)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||
def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service):
|
||||
"""Test _dispatch method when billing is enabled with sandbox plan."""
|
||||
# Arrange
|
||||
|
|
@ -191,7 +188,7 @@ class TestDocumentIndexingTaskProxy:
|
|||
# Assert
|
||||
proxy._send_to_default_tenant_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||
def test_dispatch_with_billing_enabled_non_sandbox_plan(self, mock_feature_service):
|
||||
"""Test _dispatch method when billing is enabled with non-sandbox plan."""
|
||||
# Arrange
|
||||
|
|
@ -208,7 +205,7 @@ class TestDocumentIndexingTaskProxy:
|
|||
# If billing enabled with non sandbox plan, should send to priority tenant queue
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||
def test_dispatch_with_billing_disabled(self, mock_feature_service):
|
||||
"""Test _dispatch method when billing is disabled."""
|
||||
# Arrange
|
||||
|
|
@ -223,7 +220,7 @@ class TestDocumentIndexingTaskProxy:
|
|||
# If billing disabled, for example: self-hosted or enterprise, should send to priority direct queue
|
||||
proxy._send_to_priority_direct_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||
def test_delay_method(self, mock_feature_service):
|
||||
"""Test delay method integration."""
|
||||
# Arrange
|
||||
|
|
@ -256,7 +253,7 @@ class TestDocumentIndexingTaskProxy:
|
|||
assert task.dataset_id == dataset_id
|
||||
assert task.document_ids == document_ids
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||
def test_dispatch_edge_case_empty_plan(self, mock_feature_service):
|
||||
"""Test _dispatch method with empty plan string."""
|
||||
# Arrange
|
||||
|
|
@ -271,7 +268,7 @@ class TestDocumentIndexingTaskProxy:
|
|||
# Assert
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||
def test_dispatch_edge_case_none_plan(self, mock_feature_service):
|
||||
"""Test _dispatch method with None plan."""
|
||||
# Arrange
|
||||
|
|
|
|||
|
|
@ -0,0 +1,363 @@
|
|||
from unittest.mock import Mock, patch
|
||||
|
||||
from core.entities.document_task import DocumentTask
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from services.document_indexing_proxy.duplicate_document_indexing_task_proxy import (
|
||||
DuplicateDocumentIndexingTaskProxy,
|
||||
)
|
||||
|
||||
|
||||
class DuplicateDocumentIndexingTaskProxyTestDataFactory:
|
||||
"""Factory class for creating test data and mock objects for DuplicateDocumentIndexingTaskProxy tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_mock_features(billing_enabled: bool = False, plan: CloudPlan = CloudPlan.SANDBOX) -> Mock:
|
||||
"""Create mock features with billing configuration."""
|
||||
features = Mock()
|
||||
features.billing = Mock()
|
||||
features.billing.enabled = billing_enabled
|
||||
features.billing.subscription = Mock()
|
||||
features.billing.subscription.plan = plan
|
||||
return features
|
||||
|
||||
@staticmethod
|
||||
def create_mock_tenant_queue(has_task_key: bool = False) -> Mock:
|
||||
"""Create mock TenantIsolatedTaskQueue."""
|
||||
queue = Mock(spec=TenantIsolatedTaskQueue)
|
||||
queue.get_task_key.return_value = "task_key" if has_task_key else None
|
||||
queue.push_tasks = Mock()
|
||||
queue.set_task_waiting_time = Mock()
|
||||
return queue
|
||||
|
||||
@staticmethod
|
||||
def create_duplicate_document_task_proxy(
|
||||
tenant_id: str = "tenant-123", dataset_id: str = "dataset-456", document_ids: list[str] | None = None
|
||||
) -> DuplicateDocumentIndexingTaskProxy:
|
||||
"""Create DuplicateDocumentIndexingTaskProxy instance for testing."""
|
||||
if document_ids is None:
|
||||
document_ids = ["doc-1", "doc-2", "doc-3"]
|
||||
return DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||
|
||||
|
||||
class TestDuplicateDocumentIndexingTaskProxy:
|
||||
"""Test cases for DuplicateDocumentIndexingTaskProxy class."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test DuplicateDocumentIndexingTaskProxy initialization."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
dataset_id = "dataset-456"
|
||||
document_ids = ["doc-1", "doc-2", "doc-3"]
|
||||
|
||||
# Act
|
||||
proxy = DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
assert proxy._tenant_id == tenant_id
|
||||
assert proxy._dataset_id == dataset_id
|
||||
assert proxy._document_ids == document_ids
|
||||
assert isinstance(proxy._tenant_isolated_task_queue, TenantIsolatedTaskQueue)
|
||||
assert proxy._tenant_isolated_task_queue._tenant_id == tenant_id
|
||||
assert proxy._tenant_isolated_task_queue._unique_key == "duplicate_document_indexing"
|
||||
|
||||
def test_queue_name(self):
|
||||
"""Test QUEUE_NAME class variable."""
|
||||
# Arrange & Act
|
||||
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||
|
||||
# Assert
|
||||
assert proxy.QUEUE_NAME == "duplicate_document_indexing"
|
||||
|
||||
def test_task_functions(self):
|
||||
"""Test NORMAL_TASK_FUNC and PRIORITY_TASK_FUNC class variables."""
|
||||
# Arrange & Act
|
||||
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||
|
||||
# Assert
|
||||
assert proxy.NORMAL_TASK_FUNC.__name__ == "normal_duplicate_document_indexing_task"
|
||||
assert proxy.PRIORITY_TASK_FUNC.__name__ == "priority_duplicate_document_indexing_task"
|
||||
|
||||
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||
def test_features_property(self, mock_feature_service):
|
||||
"""Test cached_property features."""
|
||||
# Arrange
|
||||
mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features()
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||
|
||||
# Act
|
||||
features1 = proxy.features
|
||||
features2 = proxy.features # Second call should use cached property
|
||||
|
||||
# Assert
|
||||
assert features1 == mock_features
|
||||
assert features2 == mock_features
|
||||
assert features1 is features2 # Should be the same instance due to caching
|
||||
mock_feature_service.get_features.assert_called_once_with("tenant-123")
|
||||
|
||||
@patch(
|
||||
"services.document_indexing_proxy.duplicate_document_indexing_task_proxy.normal_duplicate_document_indexing_task"
|
||||
)
|
||||
def test_send_to_direct_queue(self, mock_task):
|
||||
"""Test _send_to_direct_queue method."""
|
||||
# Arrange
|
||||
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||
mock_task.delay = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_direct_queue(mock_task)
|
||||
|
||||
# Assert
|
||||
mock_task.delay.assert_called_once_with(
|
||||
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
|
||||
)
|
||||
|
||||
@patch(
|
||||
"services.document_indexing_proxy.duplicate_document_indexing_task_proxy.normal_duplicate_document_indexing_task"
|
||||
)
|
||||
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task):
|
||||
"""Test _send_to_tenant_queue when task key exists."""
|
||||
# Arrange
|
||||
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||
proxy._tenant_isolated_task_queue = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue(
|
||||
has_task_key=True
|
||||
)
|
||||
mock_task.delay = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_tenant_queue(mock_task)
|
||||
|
||||
# Assert
|
||||
proxy._tenant_isolated_task_queue.push_tasks.assert_called_once()
|
||||
pushed_tasks = proxy._tenant_isolated_task_queue.push_tasks.call_args[0][0]
|
||||
assert len(pushed_tasks) == 1
|
||||
assert isinstance(DocumentTask(**pushed_tasks[0]), DocumentTask)
|
||||
assert pushed_tasks[0]["tenant_id"] == "tenant-123"
|
||||
assert pushed_tasks[0]["dataset_id"] == "dataset-456"
|
||||
assert pushed_tasks[0]["document_ids"] == ["doc-1", "doc-2", "doc-3"]
|
||||
mock_task.delay.assert_not_called()
|
||||
|
||||
@patch(
|
||||
"services.document_indexing_proxy.duplicate_document_indexing_task_proxy.normal_duplicate_document_indexing_task"
|
||||
)
|
||||
def test_send_to_tenant_queue_without_task_key(self, mock_task):
|
||||
"""Test _send_to_tenant_queue when no task key exists."""
|
||||
# Arrange
|
||||
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||
proxy._tenant_isolated_task_queue = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue(
|
||||
has_task_key=False
|
||||
)
|
||||
mock_task.delay = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_tenant_queue(mock_task)
|
||||
|
||||
# Assert
|
||||
proxy._tenant_isolated_task_queue.set_task_waiting_time.assert_called_once()
|
||||
mock_task.delay.assert_called_once_with(
|
||||
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
|
||||
)
|
||||
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
|
||||
|
||||
def test_send_to_default_tenant_queue(self):
|
||||
"""Test _send_to_default_tenant_queue method."""
|
||||
# Arrange
|
||||
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||
proxy._send_to_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_default_tenant_queue()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_tenant_queue.assert_called_once_with(proxy.NORMAL_TASK_FUNC)
|
||||
|
||||
def test_send_to_priority_tenant_queue(self):
|
||||
"""Test _send_to_priority_tenant_queue method."""
|
||||
# Arrange
|
||||
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||
proxy._send_to_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_priority_tenant_queue()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_tenant_queue.assert_called_once_with(proxy.PRIORITY_TASK_FUNC)
|
||||
|
||||
def test_send_to_priority_direct_queue(self):
|
||||
"""Test _send_to_priority_direct_queue method."""
|
||||
# Arrange
|
||||
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||
proxy._send_to_direct_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_priority_direct_queue()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_direct_queue.assert_called_once_with(proxy.PRIORITY_TASK_FUNC)
|
||||
|
||||
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||
def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service):
|
||||
"""Test _dispatch method when billing is enabled with sandbox plan."""
|
||||
# Arrange
|
||||
mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features(
|
||||
billing_enabled=True, plan=CloudPlan.SANDBOX
|
||||
)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||
proxy._send_to_default_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_default_tenant_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||
def test_dispatch_with_billing_enabled_non_sandbox_plan(self, mock_feature_service):
|
||||
"""Test _dispatch method when billing is enabled with non-sandbox plan."""
|
||||
# Arrange
|
||||
mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features(
|
||||
billing_enabled=True, plan=CloudPlan.TEAM
|
||||
)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||
proxy._send_to_priority_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# Assert
|
||||
# If billing enabled with non sandbox plan, should send to priority tenant queue
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||
def test_dispatch_with_billing_disabled(self, mock_feature_service):
|
||||
"""Test _dispatch method when billing is disabled."""
|
||||
# Arrange
|
||||
mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=False)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||
proxy._send_to_priority_direct_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# Assert
|
||||
# If billing disabled, for example: self-hosted or enterprise, should send to priority direct queue
|
||||
proxy._send_to_priority_direct_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||
def test_delay_method(self, mock_feature_service):
|
||||
"""Test delay method integration."""
|
||||
# Arrange
|
||||
mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features(
|
||||
billing_enabled=True, plan=CloudPlan.SANDBOX
|
||||
)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||
proxy._send_to_default_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy.delay()
|
||||
|
||||
# Assert
|
||||
# If billing enabled with sandbox plan, should send to default tenant queue
|
||||
proxy._send_to_default_tenant_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||
def test_dispatch_edge_case_empty_plan(self, mock_feature_service):
|
||||
"""Test _dispatch method with empty plan string."""
|
||||
# Arrange
|
||||
mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features(
|
||||
billing_enabled=True, plan=""
|
||||
)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||
proxy._send_to_priority_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||
def test_dispatch_edge_case_none_plan(self, mock_feature_service):
|
||||
"""Test _dispatch method with None plan."""
|
||||
# Arrange
|
||||
mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features(
|
||||
billing_enabled=True, plan=None
|
||||
)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||
proxy._send_to_priority_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||
|
||||
def test_initialization_with_empty_document_ids(self):
|
||||
"""Test initialization with empty document_ids list."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
dataset_id = "dataset-456"
|
||||
document_ids = []
|
||||
|
||||
# Act
|
||||
proxy = DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
assert proxy._tenant_id == tenant_id
|
||||
assert proxy._dataset_id == dataset_id
|
||||
assert proxy._document_ids == document_ids
|
||||
|
||||
def test_initialization_with_single_document_id(self):
|
||||
"""Test initialization with single document_id."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
dataset_id = "dataset-456"
|
||||
document_ids = ["doc-1"]
|
||||
|
||||
# Act
|
||||
proxy = DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
assert proxy._tenant_id == tenant_id
|
||||
assert proxy._dataset_id == dataset_id
|
||||
assert proxy._document_ids == document_ids
|
||||
|
||||
def test_initialization_with_large_batch(self):
|
||||
"""Test initialization with large batch of document IDs."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
dataset_id = "dataset-456"
|
||||
document_ids = [f"doc-{i}" for i in range(100)]
|
||||
|
||||
# Act
|
||||
proxy = DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
assert proxy._tenant_id == tenant_id
|
||||
assert proxy._dataset_id == dataset_id
|
||||
assert proxy._document_ids == document_ids
|
||||
assert len(proxy._document_ids) == 100
|
||||
|
||||
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||
def test_dispatch_with_professional_plan(self, mock_feature_service):
|
||||
"""Test _dispatch method when billing is enabled with professional plan."""
|
||||
# Arrange
|
||||
mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features(
|
||||
billing_enabled=True, plan=CloudPlan.PROFESSIONAL
|
||||
)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||
proxy._send_to_priority_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||
|
|
@ -19,7 +19,7 @@ from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
|||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset, Document
|
||||
from services.document_indexing_task_proxy import DocumentIndexingTaskProxy
|
||||
from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy
|
||||
from tasks.document_indexing_task import (
|
||||
_document_indexing,
|
||||
_document_indexing_with_tenant_queue,
|
||||
|
|
@ -138,7 +138,9 @@ class TestTaskEnqueuing:
|
|||
with patch.object(DocumentIndexingTaskProxy, "features") as mock_features:
|
||||
mock_features.billing.enabled = False
|
||||
|
||||
with patch("services.document_indexing_task_proxy.priority_document_indexing_task") as mock_task:
|
||||
# Mock the class variable directly
|
||||
mock_task = Mock()
|
||||
with patch.object(DocumentIndexingTaskProxy, "PRIORITY_TASK_FUNC", mock_task):
|
||||
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||
|
||||
# Act
|
||||
|
|
@ -163,7 +165,9 @@ class TestTaskEnqueuing:
|
|||
mock_features.billing.enabled = True
|
||||
mock_features.billing.subscription.plan = CloudPlan.SANDBOX
|
||||
|
||||
with patch("services.document_indexing_task_proxy.normal_document_indexing_task") as mock_task:
|
||||
# Mock the class variable directly
|
||||
mock_task = Mock()
|
||||
with patch.object(DocumentIndexingTaskProxy, "NORMAL_TASK_FUNC", mock_task):
|
||||
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||
|
||||
# Act
|
||||
|
|
@ -187,7 +191,9 @@ class TestTaskEnqueuing:
|
|||
mock_features.billing.enabled = True
|
||||
mock_features.billing.subscription.plan = CloudPlan.PROFESSIONAL
|
||||
|
||||
with patch("services.document_indexing_task_proxy.priority_document_indexing_task") as mock_task:
|
||||
# Mock the class variable directly
|
||||
mock_task = Mock()
|
||||
with patch.object(DocumentIndexingTaskProxy, "PRIORITY_TASK_FUNC", mock_task):
|
||||
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||
|
||||
# Act
|
||||
|
|
@ -211,7 +217,9 @@ class TestTaskEnqueuing:
|
|||
mock_features.billing.enabled = True
|
||||
mock_features.billing.subscription.plan = CloudPlan.PROFESSIONAL
|
||||
|
||||
with patch("services.document_indexing_task_proxy.priority_document_indexing_task") as mock_task:
|
||||
# Mock the class variable directly
|
||||
mock_task = Mock()
|
||||
with patch.object(DocumentIndexingTaskProxy, "PRIORITY_TASK_FUNC", mock_task):
|
||||
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||
|
||||
# Act
|
||||
|
|
@ -1493,7 +1501,9 @@ class TestEdgeCases:
|
|||
mock_features.billing.enabled = True
|
||||
mock_features.billing.subscription.plan = CloudPlan.PROFESSIONAL
|
||||
|
||||
with patch("services.document_indexing_task_proxy.priority_document_indexing_task") as mock_task:
|
||||
# Mock the class variable directly
|
||||
mock_task = Mock()
|
||||
with patch.object(DocumentIndexingTaskProxy, "PRIORITY_TASK_FUNC", mock_task):
|
||||
# Act - Enqueue multiple tasks rapidly
|
||||
for doc_ids in document_ids_list:
|
||||
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, doc_ids)
|
||||
|
|
@ -1898,7 +1908,7 @@ class TestRobustness:
|
|||
- Error is propagated appropriately
|
||||
"""
|
||||
# Arrange
|
||||
with patch("services.document_indexing_task_proxy.FeatureService.get_features") as mock_get_features:
|
||||
with patch("services.document_indexing_proxy.base.FeatureService.get_features") as mock_get_features:
|
||||
# Simulate FeatureService failure
|
||||
mock_get_features.side_effect = Exception("Feature service unavailable")
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,567 @@
|
|||
"""
|
||||
Unit tests for duplicate document indexing tasks.
|
||||
|
||||
This module tests the duplicate document indexing task functionality including:
|
||||
- Task enqueuing to different queues (normal, priority, tenant-isolated)
|
||||
- Batch processing of multiple duplicate documents
|
||||
- Progress tracking through task lifecycle
|
||||
- Error handling and retry mechanisms
|
||||
- Cleanup of old document data before re-indexing
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from tasks.duplicate_document_indexing_task import (
|
||||
_duplicate_document_indexing_task,
|
||||
_duplicate_document_indexing_task_with_tenant_queue,
|
||||
duplicate_document_indexing_task,
|
||||
normal_duplicate_document_indexing_task,
|
||||
priority_duplicate_document_indexing_task,
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
# Fixtures
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_id():
|
||||
"""Generate a unique tenant ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_id():
|
||||
"""Generate a unique dataset ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def document_ids():
|
||||
"""Generate a list of document IDs for testing."""
|
||||
return [str(uuid.uuid4()) for _ in range(3)]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dataset(dataset_id, tenant_id):
|
||||
"""Create a mock Dataset object."""
|
||||
dataset = Mock(spec=Dataset)
|
||||
dataset.id = dataset_id
|
||||
dataset.tenant_id = tenant_id
|
||||
dataset.indexing_technique = "high_quality"
|
||||
dataset.embedding_model_provider = "openai"
|
||||
dataset.embedding_model = "text-embedding-ada-002"
|
||||
return dataset
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_documents(document_ids, dataset_id):
|
||||
"""Create mock Document objects."""
|
||||
documents = []
|
||||
for doc_id in document_ids:
|
||||
doc = Mock(spec=Document)
|
||||
doc.id = doc_id
|
||||
doc.dataset_id = dataset_id
|
||||
doc.indexing_status = "waiting"
|
||||
doc.error = None
|
||||
doc.stopped_at = None
|
||||
doc.processing_started_at = None
|
||||
doc.doc_form = "text_model"
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document_segments(document_ids):
|
||||
"""Create mock DocumentSegment objects."""
|
||||
segments = []
|
||||
for doc_id in document_ids:
|
||||
for i in range(3):
|
||||
segment = Mock(spec=DocumentSegment)
|
||||
segment.id = str(uuid.uuid4())
|
||||
segment.document_id = doc_id
|
||||
segment.index_node_id = f"node-{doc_id}-{i}"
|
||||
segments.append(segment)
|
||||
return segments
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Mock database session."""
|
||||
with patch("tasks.duplicate_document_indexing_task.db.session") as mock_session:
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_session.scalars.return_value = MagicMock()
|
||||
yield mock_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_indexing_runner():
|
||||
"""Mock IndexingRunner."""
|
||||
with patch("tasks.duplicate_document_indexing_task.IndexingRunner") as mock_runner_class:
|
||||
mock_runner = MagicMock(spec=IndexingRunner)
|
||||
mock_runner_class.return_value = mock_runner
|
||||
yield mock_runner
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_feature_service():
|
||||
"""Mock FeatureService."""
|
||||
with patch("tasks.duplicate_document_indexing_task.FeatureService") as mock_service:
|
||||
mock_features = Mock()
|
||||
mock_features.billing = Mock()
|
||||
mock_features.billing.enabled = False
|
||||
mock_features.vector_space = Mock()
|
||||
mock_features.vector_space.size = 0
|
||||
mock_features.vector_space.limit = 1000
|
||||
mock_service.get_features.return_value = mock_features
|
||||
yield mock_service
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_index_processor_factory():
|
||||
"""Mock IndexProcessorFactory."""
|
||||
with patch("tasks.duplicate_document_indexing_task.IndexProcessorFactory") as mock_factory:
|
||||
mock_processor = MagicMock()
|
||||
mock_processor.clean = Mock()
|
||||
mock_factory.return_value.init_index_processor.return_value = mock_processor
|
||||
yield mock_factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tenant_isolated_queue():
|
||||
"""Mock TenantIsolatedTaskQueue."""
|
||||
with patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue") as mock_queue_class:
|
||||
mock_queue = MagicMock(spec=TenantIsolatedTaskQueue)
|
||||
mock_queue.pull_tasks.return_value = []
|
||||
mock_queue.delete_task_key = Mock()
|
||||
mock_queue.set_task_waiting_time = Mock()
|
||||
mock_queue_class.return_value = mock_queue
|
||||
yield mock_queue
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for deprecated duplicate_document_indexing_task
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestDuplicateDocumentIndexingTask:
|
||||
"""Tests for the deprecated duplicate_document_indexing_task function."""
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task")
|
||||
def test_duplicate_document_indexing_task_calls_core_function(self, mock_core_func, dataset_id, document_ids):
|
||||
"""Test that duplicate_document_indexing_task calls the core _duplicate_document_indexing_task function."""
|
||||
# Act
|
||||
duplicate_document_indexing_task(dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
mock_core_func.assert_called_once_with(dataset_id, document_ids)
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task")
|
||||
def test_duplicate_document_indexing_task_with_empty_document_ids(self, mock_core_func, dataset_id):
|
||||
"""Test duplicate_document_indexing_task with empty document_ids list."""
|
||||
# Arrange
|
||||
document_ids = []
|
||||
|
||||
# Act
|
||||
duplicate_document_indexing_task(dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
mock_core_func.assert_called_once_with(dataset_id, document_ids)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for _duplicate_document_indexing_task core function
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestDuplicateDocumentIndexingTaskCore:
|
||||
"""Tests for the _duplicate_document_indexing_task core function."""
|
||||
|
||||
def test_successful_duplicate_document_indexing(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_indexing_runner,
|
||||
mock_feature_service,
|
||||
mock_index_processor_factory,
|
||||
mock_dataset,
|
||||
mock_documents,
|
||||
mock_document_segments,
|
||||
dataset_id,
|
||||
document_ids,
|
||||
):
|
||||
"""Test successful duplicate document indexing flow."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
|
||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||
|
||||
# Act
|
||||
_duplicate_document_indexing_task(dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
# Verify IndexingRunner was called
|
||||
mock_indexing_runner.run.assert_called_once()
|
||||
|
||||
# Verify all documents were set to parsing status
|
||||
for doc in mock_documents:
|
||||
assert doc.indexing_status == "parsing"
|
||||
assert doc.processing_started_at is not None
|
||||
|
||||
# Verify session operations
|
||||
assert mock_db_session.commit.called
|
||||
assert mock_db_session.close.called
|
||||
|
||||
def test_duplicate_document_indexing_dataset_not_found(self, mock_db_session, dataset_id, document_ids):
|
||||
"""Test duplicate document indexing when dataset is not found."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act
|
||||
_duplicate_document_indexing_task(dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
# Should close the session at least once
|
||||
assert mock_db_session.close.called
|
||||
|
||||
def test_duplicate_document_indexing_with_billing_enabled_sandbox_plan(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_feature_service,
|
||||
mock_dataset,
|
||||
dataset_id,
|
||||
document_ids,
|
||||
):
|
||||
"""Test duplicate document indexing with billing enabled and sandbox plan."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_features = mock_feature_service.get_features.return_value
|
||||
mock_features.billing.enabled = True
|
||||
mock_features.billing.subscription.plan = CloudPlan.SANDBOX
|
||||
|
||||
# Act
|
||||
_duplicate_document_indexing_task(dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
# For sandbox plan with multiple documents, should fail
|
||||
mock_db_session.commit.assert_called()
|
||||
|
||||
def test_duplicate_document_indexing_with_billing_limit_exceeded(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_feature_service,
|
||||
mock_dataset,
|
||||
mock_documents,
|
||||
dataset_id,
|
||||
document_ids,
|
||||
):
|
||||
"""Test duplicate document indexing when billing limit is exceeded."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
|
||||
mock_db_session.scalars.return_value.all.return_value = [] # No segments to clean
|
||||
mock_features = mock_feature_service.get_features.return_value
|
||||
mock_features.billing.enabled = True
|
||||
mock_features.billing.subscription.plan = CloudPlan.TEAM
|
||||
mock_features.vector_space.size = 990
|
||||
mock_features.vector_space.limit = 1000
|
||||
|
||||
# Act
|
||||
_duplicate_document_indexing_task(dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
# Should commit the session
|
||||
assert mock_db_session.commit.called
|
||||
# Should close the session
|
||||
assert mock_db_session.close.called
|
||||
|
||||
def test_duplicate_document_indexing_runner_error(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_indexing_runner,
|
||||
mock_feature_service,
|
||||
mock_index_processor_factory,
|
||||
mock_dataset,
|
||||
mock_documents,
|
||||
dataset_id,
|
||||
document_ids,
|
||||
):
|
||||
"""Test duplicate document indexing when IndexingRunner raises an error."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
|
||||
mock_db_session.scalars.return_value.all.return_value = []
|
||||
mock_indexing_runner.run.side_effect = Exception("Indexing error")
|
||||
|
||||
# Act
|
||||
_duplicate_document_indexing_task(dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
# Should close the session even after error
|
||||
mock_db_session.close.assert_called_once()
|
||||
|
||||
def test_duplicate_document_indexing_document_is_paused(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_indexing_runner,
|
||||
mock_feature_service,
|
||||
mock_index_processor_factory,
|
||||
mock_dataset,
|
||||
mock_documents,
|
||||
dataset_id,
|
||||
document_ids,
|
||||
):
|
||||
"""Test duplicate document indexing when document is paused."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
|
||||
mock_db_session.scalars.return_value.all.return_value = []
|
||||
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused")
|
||||
|
||||
# Act
|
||||
_duplicate_document_indexing_task(dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
# Should handle DocumentIsPausedError gracefully
|
||||
mock_db_session.close.assert_called_once()
|
||||
|
||||
def test_duplicate_document_indexing_cleans_old_segments(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_indexing_runner,
|
||||
mock_feature_service,
|
||||
mock_index_processor_factory,
|
||||
mock_dataset,
|
||||
mock_documents,
|
||||
mock_document_segments,
|
||||
dataset_id,
|
||||
document_ids,
|
||||
):
|
||||
"""Test that duplicate document indexing cleans old segments."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
|
||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
||||
|
||||
# Act
|
||||
_duplicate_document_indexing_task(dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
# Verify clean was called for each document
|
||||
assert mock_processor.clean.call_count == len(mock_documents)
|
||||
|
||||
# Verify segments were deleted
|
||||
for segment in mock_document_segments:
|
||||
mock_db_session.delete.assert_any_call(segment)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for tenant queue wrapper function
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestDuplicateDocumentIndexingTaskWithTenantQueue:
|
||||
"""Tests for _duplicate_document_indexing_task_with_tenant_queue function."""
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task")
|
||||
def test_tenant_queue_wrapper_calls_core_function(
|
||||
self,
|
||||
mock_core_func,
|
||||
mock_tenant_isolated_queue,
|
||||
tenant_id,
|
||||
dataset_id,
|
||||
document_ids,
|
||||
):
|
||||
"""Test that tenant queue wrapper calls the core function."""
|
||||
# Arrange
|
||||
mock_task_func = Mock()
|
||||
|
||||
# Act
|
||||
_duplicate_document_indexing_task_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task_func)
|
||||
|
||||
# Assert
|
||||
mock_core_func.assert_called_once_with(dataset_id, document_ids)
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task")
|
||||
def test_tenant_queue_wrapper_deletes_key_when_no_tasks(
|
||||
self,
|
||||
mock_core_func,
|
||||
mock_tenant_isolated_queue,
|
||||
tenant_id,
|
||||
dataset_id,
|
||||
document_ids,
|
||||
):
|
||||
"""Test that tenant queue wrapper deletes task key when no more tasks."""
|
||||
# Arrange
|
||||
mock_task_func = Mock()
|
||||
mock_tenant_isolated_queue.pull_tasks.return_value = []
|
||||
|
||||
# Act
|
||||
_duplicate_document_indexing_task_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task_func)
|
||||
|
||||
# Assert
|
||||
mock_tenant_isolated_queue.delete_task_key.assert_called_once()
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task")
|
||||
def test_tenant_queue_wrapper_processes_next_tasks(
|
||||
self,
|
||||
mock_core_func,
|
||||
mock_tenant_isolated_queue,
|
||||
tenant_id,
|
||||
dataset_id,
|
||||
document_ids,
|
||||
):
|
||||
"""Test that tenant queue wrapper processes next tasks from queue."""
|
||||
# Arrange
|
||||
mock_task_func = Mock()
|
||||
next_task = {
|
||||
"tenant_id": tenant_id,
|
||||
"dataset_id": dataset_id,
|
||||
"document_ids": document_ids,
|
||||
}
|
||||
mock_tenant_isolated_queue.pull_tasks.return_value = [next_task]
|
||||
|
||||
# Act
|
||||
_duplicate_document_indexing_task_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task_func)
|
||||
|
||||
# Assert
|
||||
mock_tenant_isolated_queue.set_task_waiting_time.assert_called_once()
|
||||
mock_task_func.delay.assert_called_once_with(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
document_ids=document_ids,
|
||||
)
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task")
|
||||
def test_tenant_queue_wrapper_handles_core_function_error(
|
||||
self,
|
||||
mock_core_func,
|
||||
mock_tenant_isolated_queue,
|
||||
tenant_id,
|
||||
dataset_id,
|
||||
document_ids,
|
||||
):
|
||||
"""Test that tenant queue wrapper handles errors from core function."""
|
||||
# Arrange
|
||||
mock_task_func = Mock()
|
||||
mock_core_func.side_effect = Exception("Core function error")
|
||||
|
||||
# Act
|
||||
_duplicate_document_indexing_task_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task_func)
|
||||
|
||||
# Assert
|
||||
# Should still check for next tasks even after error
|
||||
mock_tenant_isolated_queue.pull_tasks.assert_called_once()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for normal_duplicate_document_indexing_task
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestNormalDuplicateDocumentIndexingTask:
|
||||
"""Tests for normal_duplicate_document_indexing_task function."""
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue")
|
||||
def test_normal_task_calls_tenant_queue_wrapper(
|
||||
self,
|
||||
mock_wrapper_func,
|
||||
tenant_id,
|
||||
dataset_id,
|
||||
document_ids,
|
||||
):
|
||||
"""Test that normal task calls tenant queue wrapper."""
|
||||
# Act
|
||||
normal_duplicate_document_indexing_task(tenant_id, dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
mock_wrapper_func.assert_called_once_with(
|
||||
tenant_id, dataset_id, document_ids, normal_duplicate_document_indexing_task
|
||||
)
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue")
|
||||
def test_normal_task_with_empty_document_ids(
|
||||
self,
|
||||
mock_wrapper_func,
|
||||
tenant_id,
|
||||
dataset_id,
|
||||
):
|
||||
"""Test normal task with empty document_ids list."""
|
||||
# Arrange
|
||||
document_ids = []
|
||||
|
||||
# Act
|
||||
normal_duplicate_document_indexing_task(tenant_id, dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
mock_wrapper_func.assert_called_once_with(
|
||||
tenant_id, dataset_id, document_ids, normal_duplicate_document_indexing_task
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for priority_duplicate_document_indexing_task
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestPriorityDuplicateDocumentIndexingTask:
|
||||
"""Tests for priority_duplicate_document_indexing_task function."""
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue")
|
||||
def test_priority_task_calls_tenant_queue_wrapper(
|
||||
self,
|
||||
mock_wrapper_func,
|
||||
tenant_id,
|
||||
dataset_id,
|
||||
document_ids,
|
||||
):
|
||||
"""Test that priority task calls tenant queue wrapper."""
|
||||
# Act
|
||||
priority_duplicate_document_indexing_task(tenant_id, dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
mock_wrapper_func.assert_called_once_with(
|
||||
tenant_id, dataset_id, document_ids, priority_duplicate_document_indexing_task
|
||||
)
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue")
|
||||
def test_priority_task_with_single_document(
|
||||
self,
|
||||
mock_wrapper_func,
|
||||
tenant_id,
|
||||
dataset_id,
|
||||
):
|
||||
"""Test priority task with single document."""
|
||||
# Arrange
|
||||
document_ids = ["doc-1"]
|
||||
|
||||
# Act
|
||||
priority_duplicate_document_indexing_task(tenant_id, dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
mock_wrapper_func.assert_called_once_with(
|
||||
tenant_id, dataset_id, document_ids, priority_duplicate_document_indexing_task
|
||||
)
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue")
|
||||
def test_priority_task_with_large_batch(
|
||||
self,
|
||||
mock_wrapper_func,
|
||||
tenant_id,
|
||||
dataset_id,
|
||||
):
|
||||
"""Test priority task with large batch of documents."""
|
||||
# Arrange
|
||||
document_ids = [f"doc-{i}" for i in range(100)]
|
||||
|
||||
# Act
|
||||
priority_duplicate_document_indexing_task(tenant_id, dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
mock_wrapper_func.assert_called_once_with(
|
||||
tenant_id, dataset_id, document_ids, priority_duplicate_document_indexing_task
|
||||
)
|
||||
Loading…
Reference in New Issue