mirror of
https://github.com/langgenius/dify.git
synced 2026-03-20 10:00:02 +08:00
Merge branch 'main' into feat/memory-orchestration-fed
This commit is contained in:
commit
b7311bf1d3
@ -615,5 +615,8 @@ SWAGGER_UI_PATH=/swagger-ui.html
|
||||
# Set to false to export dataset IDs as plain text for easier cross-environment import
|
||||
DSL_EXPORT_ENCRYPT_DATASET_ID=true
|
||||
|
||||
# Tenant isolated task queue configuration
|
||||
TENANT_ISOLATED_TASK_CONCURRENCY=1
|
||||
|
||||
# Maximum number of segments for dataset segments API (0 for unlimited)
|
||||
DATASET_MAX_SEGMENTS_PER_REQUEST=0
|
||||
|
||||
@ -1142,6 +1142,13 @@ class SwaggerUIConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class TenantIsolatedTaskQueueConfig(BaseSettings):
|
||||
TENANT_ISOLATED_TASK_CONCURRENCY: int = Field(
|
||||
description="Number of tasks allowed to be delivered concurrently from isolated queue per tenant",
|
||||
default=1,
|
||||
)
|
||||
|
||||
|
||||
class FeatureConfig(
|
||||
# place the configs in alphabet order
|
||||
AppExecutionConfig,
|
||||
@ -1166,6 +1173,7 @@ class FeatureConfig(
|
||||
RagEtlConfig,
|
||||
RepositoryConfig,
|
||||
SecurityConfig,
|
||||
TenantIsolatedTaskQueueConfig,
|
||||
ToolConfig,
|
||||
UpdateConfig,
|
||||
WorkflowConfig,
|
||||
|
||||
@ -40,20 +40,15 @@ from core.workflow.repositories.draft_variable_repository import DraftVariableSa
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import AppMode
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from services.feature_service import FeatureService
|
||||
from services.file_service import FileService
|
||||
from services.rag_pipeline.rag_pipeline_task_proxy import RagPipelineTaskProxy
|
||||
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
|
||||
from tasks.rag_pipeline.priority_rag_pipeline_run_task import priority_rag_pipeline_run_task
|
||||
from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -249,34 +244,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
)
|
||||
|
||||
if rag_pipeline_invoke_entities:
|
||||
# store the rag_pipeline_invoke_entities to object storage
|
||||
text = [item.model_dump() for item in rag_pipeline_invoke_entities]
|
||||
name = "rag_pipeline_invoke_entities.json"
|
||||
# Convert list to proper JSON string
|
||||
json_text = json.dumps(text)
|
||||
upload_file = FileService(db.engine).upload_text(json_text, name, user.id, dataset.tenant_id)
|
||||
features = FeatureService.get_features(dataset.tenant_id)
|
||||
if features.billing.enabled and features.billing.subscription.plan == CloudPlan.SANDBOX:
|
||||
tenant_pipeline_task_key = f"tenant_pipeline_task:{dataset.tenant_id}"
|
||||
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{dataset.tenant_id}"
|
||||
|
||||
if redis_client.get(tenant_pipeline_task_key):
|
||||
# Add to waiting queue using List operations (lpush)
|
||||
redis_client.lpush(tenant_self_pipeline_task_queue, upload_file.id)
|
||||
else:
|
||||
# Set flag and execute task
|
||||
redis_client.set(tenant_pipeline_task_key, 1, ex=60 * 60)
|
||||
rag_pipeline_run_task.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=upload_file.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
)
|
||||
|
||||
else:
|
||||
priority_rag_pipeline_run_task.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=upload_file.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
)
|
||||
|
||||
RagPipelineTaskProxy(dataset.tenant_id, user.id, rag_pipeline_invoke_entities).delay()
|
||||
# return batch, dataset, documents
|
||||
return {
|
||||
"batch": batch,
|
||||
|
||||
15
api/core/entities/document_task.py
Normal file
15
api/core/entities/document_task.py
Normal file
@ -0,0 +1,15 @@
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocumentTask:
|
||||
"""Document task entity for document indexing operations.
|
||||
|
||||
This class represents a document indexing task that can be queued
|
||||
and processed by the document indexing system.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
dataset_id: str
|
||||
document_ids: Sequence[str]
|
||||
@ -1533,6 +1533,9 @@ class ProviderConfiguration(BaseModel):
|
||||
# Return composite sort key: (model_type value, model position index)
|
||||
return (model.model_type.value, position_index)
|
||||
|
||||
# Deduplicate
|
||||
provider_models = list({(m.model, m.model_type, m.fetch_from): m for m in provider_models}.values())
|
||||
|
||||
# Sort using the composite sort key
|
||||
return sorted(provider_models, key=get_sort_key)
|
||||
|
||||
|
||||
@ -92,7 +92,7 @@ class WeaviateVector(BaseVector):
|
||||
|
||||
# Parse gRPC configuration
|
||||
if config.grpc_endpoint:
|
||||
# Urls without scheme won't be parsed correctly in some python verions,
|
||||
# Urls without scheme won't be parsed correctly in some python versions,
|
||||
# see https://bugs.python.org/issue27657
|
||||
grpc_endpoint_with_scheme = (
|
||||
config.grpc_endpoint if "://" in config.grpc_endpoint else f"grpc://{config.grpc_endpoint}"
|
||||
|
||||
0
api/core/rag/pipeline/__init__.py
Normal file
0
api/core/rag/pipeline/__init__.py
Normal file
79
api/core/rag/pipeline/queue.py
Normal file
79
api/core/rag/pipeline/queue.py
Normal file
@ -0,0 +1,79 @@
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
_DEFAULT_TASK_TTL = 60 * 60 # 1 hour
|
||||
|
||||
|
||||
class TaskWrapper(BaseModel):
|
||||
data: Any
|
||||
|
||||
def serialize(self) -> str:
|
||||
return self.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, serialized_data: str) -> "TaskWrapper":
|
||||
return cls.model_validate_json(serialized_data)
|
||||
|
||||
|
||||
class TenantIsolatedTaskQueue:
|
||||
"""
|
||||
Simple queue for tenant isolated tasks, used for rag related tenant tasks isolation.
|
||||
It uses Redis list to store tasks, and Redis key to store task waiting flag.
|
||||
Support tasks that can be serialized by json.
|
||||
"""
|
||||
|
||||
def __init__(self, tenant_id: str, unique_key: str):
|
||||
self._tenant_id = tenant_id
|
||||
self._unique_key = unique_key
|
||||
self._queue = f"tenant_self_{unique_key}_task_queue:{tenant_id}"
|
||||
self._task_key = f"tenant_{unique_key}_task:{tenant_id}"
|
||||
|
||||
def get_task_key(self):
|
||||
return redis_client.get(self._task_key)
|
||||
|
||||
def set_task_waiting_time(self, ttl: int = _DEFAULT_TASK_TTL):
|
||||
redis_client.setex(self._task_key, ttl, 1)
|
||||
|
||||
def delete_task_key(self):
|
||||
redis_client.delete(self._task_key)
|
||||
|
||||
def push_tasks(self, tasks: Sequence[Any]):
|
||||
serialized_tasks = []
|
||||
for task in tasks:
|
||||
# Store str list directly, maintaining full compatibility for pipeline scenarios
|
||||
if isinstance(task, str):
|
||||
serialized_tasks.append(task)
|
||||
else:
|
||||
# Use TaskWrapper to do JSON serialization for non-string tasks
|
||||
wrapper = TaskWrapper(data=task)
|
||||
serialized_data = wrapper.serialize()
|
||||
serialized_tasks.append(serialized_data)
|
||||
|
||||
redis_client.lpush(self._queue, *serialized_tasks)
|
||||
|
||||
def pull_tasks(self, count: int = 1) -> Sequence[Any]:
|
||||
if count <= 0:
|
||||
return []
|
||||
|
||||
tasks = []
|
||||
for _ in range(count):
|
||||
serialized_task = redis_client.rpop(self._queue)
|
||||
if not serialized_task:
|
||||
break
|
||||
|
||||
if isinstance(serialized_task, bytes):
|
||||
serialized_task = serialized_task.decode("utf-8")
|
||||
|
||||
try:
|
||||
wrapper = TaskWrapper.deserialize(serialized_task)
|
||||
tasks.append(wrapper.data)
|
||||
except (json.JSONDecodeError, ValidationError, TypeError, ValueError):
|
||||
# Fall back to raw string for legacy format or invalid JSON
|
||||
tasks.append(serialized_task)
|
||||
|
||||
return tasks
|
||||
@ -1,16 +1,19 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||
from core.mcp.error import MCPConnectionError
|
||||
from core.mcp.types import CallToolResult, ImageContent, TextContent
|
||||
from core.mcp.types import AudioContent, CallToolResult, ImageContent, TextContent
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.errors import ToolInvokeError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPTool(Tool):
|
||||
def __init__(
|
||||
@ -52,6 +55,11 @@ class MCPTool(Tool):
|
||||
yield from self._process_text_content(content)
|
||||
elif isinstance(content, ImageContent):
|
||||
yield self._process_image_content(content)
|
||||
elif isinstance(content, AudioContent):
|
||||
yield self._process_audio_content(content)
|
||||
else:
|
||||
logger.warning("Unsupported content type=%s", type(content))
|
||||
|
||||
# handle MCP structured output
|
||||
if self.entity.output_schema and result.structuredContent:
|
||||
for k, v in result.structuredContent.items():
|
||||
@ -97,6 +105,10 @@ class MCPTool(Tool):
|
||||
"""Process image content and return a blob message."""
|
||||
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
|
||||
|
||||
def _process_audio_content(self, content: AudioContent) -> ToolInvokeMessage:
|
||||
"""Process audio content and return a blob message."""
|
||||
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
|
||||
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool":
|
||||
return MCPTool(
|
||||
entity=self.entity,
|
||||
|
||||
@ -153,7 +153,11 @@ class VariablePool(BaseModel):
|
||||
return None
|
||||
|
||||
node_id, name = self._selector_to_keys(selector)
|
||||
segment: Segment | None = self.variable_dictionary[node_id].get(name)
|
||||
node_map = self.variable_dictionary.get(node_id)
|
||||
if node_map is None:
|
||||
return None
|
||||
|
||||
segment: Segment | None = node_map.get(name)
|
||||
|
||||
if segment is None:
|
||||
return None
|
||||
|
||||
@ -32,7 +32,7 @@ if [[ "${MODE}" == "worker" ]]; then
|
||||
|
||||
exec celery -A celery_entrypoint.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION \
|
||||
--max-tasks-per-child ${MAX_TASKS_PER_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \
|
||||
-Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline} \
|
||||
-Q ${CELERY_QUEUES:-dataset,priority_dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline} \
|
||||
--prefetch-multiplier=1
|
||||
|
||||
elif [[ "${MODE}" == "beat" ]]; then
|
||||
|
||||
@ -110,7 +110,7 @@ class Account(UserMixin, TypeBase):
|
||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False, onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
role: TenantAccountRole | None = field(default=None, init=False)
|
||||
@ -250,7 +250,9 @@ class Tenant(TypeBase):
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), init=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.current_timestamp(), init=False, onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
def get_accounts(self) -> list[Account]:
|
||||
return list(
|
||||
@ -289,7 +291,7 @@ class TenantAccountJoin(TypeBase):
|
||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False, onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
|
||||
@ -310,7 +312,7 @@ class AccountIntegrate(TypeBase):
|
||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False, onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
|
||||
@ -396,5 +398,5 @@ class TenantPluginAutoUpgradeStrategy(TypeBase):
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False, onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
@ -61,18 +61,20 @@ class Dataset(Base):
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
embedding_model = mapped_column(db.String(255), nullable=True)
|
||||
embedding_model_provider = mapped_column(db.String(255), nullable=True)
|
||||
keyword_number = mapped_column(sa.Integer, nullable=True, server_default=db.text("10"))
|
||||
updated_at = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
embedding_model = mapped_column(sa.String(255), nullable=True)
|
||||
embedding_model_provider = mapped_column(sa.String(255), nullable=True)
|
||||
keyword_number = mapped_column(sa.Integer, nullable=True, server_default=sa.text("10"))
|
||||
collection_binding_id = mapped_column(StringUUID, nullable=True)
|
||||
retrieval_model = mapped_column(JSONB, nullable=True)
|
||||
built_in_field_enabled = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
|
||||
built_in_field_enabled = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
icon_info = mapped_column(JSONB, nullable=True)
|
||||
runtime_mode = mapped_column(db.String(255), nullable=True, server_default=db.text("'general'::character varying"))
|
||||
runtime_mode = mapped_column(sa.String(255), nullable=True, server_default=sa.text("'general'::character varying"))
|
||||
pipeline_id = mapped_column(StringUUID, nullable=True)
|
||||
chunk_structure = mapped_column(db.String(255), nullable=True)
|
||||
enable_api = mapped_column(sa.Boolean, nullable=False, server_default=db.text("true"))
|
||||
chunk_structure = mapped_column(sa.String(255), nullable=True)
|
||||
enable_api = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
|
||||
@property
|
||||
def total_documents(self):
|
||||
@ -399,7 +401,9 @@ class Document(Base):
|
||||
archived_reason = mapped_column(String(255), nullable=True)
|
||||
archived_by = mapped_column(StringUUID, nullable=True)
|
||||
archived_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
doc_type = mapped_column(String(40), nullable=True)
|
||||
doc_metadata = mapped_column(JSONB, nullable=True)
|
||||
doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'::character varying"))
|
||||
@ -716,7 +720,9 @@ class DocumentSegment(Base):
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
error = mapped_column(sa.Text, nullable=True)
|
||||
@ -881,7 +887,7 @@ class ChildChunk(Base):
|
||||
)
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"), onupdate=func.current_timestamp()
|
||||
)
|
||||
indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
@ -1036,8 +1042,8 @@ class TidbAuthBinding(Base):
|
||||
tenant_id = mapped_column(StringUUID, nullable=True)
|
||||
cluster_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
cluster_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
|
||||
status = mapped_column(String(255), nullable=False, server_default=db.text("'CREATING'::character varying"))
|
||||
active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
status = mapped_column(String(255), nullable=False, server_default=sa.text("'CREATING'::character varying"))
|
||||
account: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
password: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
@ -1088,7 +1094,9 @@ class ExternalKnowledgeApis(Base):
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
@ -1141,7 +1149,9 @@ class ExternalKnowledgeBindings(Base):
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
|
||||
class DatasetAutoDisableLog(Base):
|
||||
@ -1197,7 +1207,7 @@ class DatasetMetadata(Base):
|
||||
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"), onupdate=func.current_timestamp()
|
||||
)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
@ -1224,44 +1234,48 @@ class DatasetMetadataBinding(Base):
|
||||
|
||||
class PipelineBuiltInTemplate(Base): # type: ignore[name-defined]
|
||||
__tablename__ = "pipeline_built_in_templates"
|
||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"),)
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"),)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
|
||||
name = mapped_column(db.String(255), nullable=False)
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
|
||||
name = mapped_column(sa.String(255), nullable=False)
|
||||
description = mapped_column(sa.Text, nullable=False)
|
||||
chunk_structure = mapped_column(db.String(255), nullable=False)
|
||||
chunk_structure = mapped_column(sa.String(255), nullable=False)
|
||||
icon = mapped_column(sa.JSON, nullable=False)
|
||||
yaml_content = mapped_column(sa.Text, nullable=False)
|
||||
copyright = mapped_column(db.String(255), nullable=False)
|
||||
privacy_policy = mapped_column(db.String(255), nullable=False)
|
||||
copyright = mapped_column(sa.String(255), nullable=False)
|
||||
privacy_policy = mapped_column(sa.String(255), nullable=False)
|
||||
position = mapped_column(sa.Integer, nullable=False)
|
||||
install_count = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
language = mapped_column(db.String(255), nullable=False)
|
||||
language = mapped_column(sa.String(255), nullable=False)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
|
||||
class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
|
||||
__tablename__ = "pipeline_customized_templates"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="pipeline_customized_template_pkey"),
|
||||
db.Index("pipeline_customized_template_tenant_idx", "tenant_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="pipeline_customized_template_pkey"),
|
||||
sa.Index("pipeline_customized_template_tenant_idx", "tenant_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
name = mapped_column(db.String(255), nullable=False)
|
||||
name = mapped_column(sa.String(255), nullable=False)
|
||||
description = mapped_column(sa.Text, nullable=False)
|
||||
chunk_structure = mapped_column(db.String(255), nullable=False)
|
||||
chunk_structure = mapped_column(sa.String(255), nullable=False)
|
||||
icon = mapped_column(sa.JSON, nullable=False)
|
||||
position = mapped_column(sa.Integer, nullable=False)
|
||||
yaml_content = mapped_column(sa.Text, nullable=False)
|
||||
install_count = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
language = mapped_column(db.String(255), nullable=False)
|
||||
language = mapped_column(sa.String(255), nullable=False)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
@property
|
||||
def created_user_name(self):
|
||||
@ -1273,19 +1287,21 @@ class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
|
||||
|
||||
class Pipeline(Base): # type: ignore[name-defined]
|
||||
__tablename__ = "pipelines"
|
||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_pkey"),)
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="pipeline_pkey"),)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
name = mapped_column(db.String(255), nullable=False)
|
||||
description = mapped_column(sa.Text, nullable=False, server_default=db.text("''::character varying"))
|
||||
name = mapped_column(sa.String(255), nullable=False)
|
||||
description = mapped_column(sa.Text, nullable=False, server_default=sa.text("''::character varying"))
|
||||
workflow_id = mapped_column(StringUUID, nullable=True)
|
||||
is_public = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
|
||||
is_published = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
|
||||
is_public = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
is_published = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
created_by = mapped_column(StringUUID, nullable=True)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
def retrieve_dataset(self, session: Session):
|
||||
return session.query(Dataset).where(Dataset.pipeline_id == self.id).first()
|
||||
@ -1294,16 +1310,16 @@ class Pipeline(Base): # type: ignore[name-defined]
|
||||
class DocumentPipelineExecutionLog(Base):
|
||||
__tablename__ = "document_pipeline_execution_logs"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="document_pipeline_execution_log_pkey"),
|
||||
db.Index("document_pipeline_execution_logs_document_id_idx", "document_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="document_pipeline_execution_log_pkey"),
|
||||
sa.Index("document_pipeline_execution_logs_document_id_idx", "document_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
|
||||
pipeline_id = mapped_column(StringUUID, nullable=False)
|
||||
document_id = mapped_column(StringUUID, nullable=False)
|
||||
datasource_type = mapped_column(db.String(255), nullable=False)
|
||||
datasource_type = mapped_column(sa.String(255), nullable=False)
|
||||
datasource_info = mapped_column(sa.Text, nullable=False)
|
||||
datasource_node_id = mapped_column(db.String(255), nullable=False)
|
||||
datasource_node_id = mapped_column(sa.String(255), nullable=False)
|
||||
input_data = mapped_column(sa.JSON, nullable=False)
|
||||
created_by = mapped_column(StringUUID, nullable=True)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
@ -1311,12 +1327,14 @@ class DocumentPipelineExecutionLog(Base):
|
||||
|
||||
class PipelineRecommendedPlugin(Base):
|
||||
__tablename__ = "pipeline_recommended_plugins"
|
||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_recommended_plugin_pkey"),)
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="pipeline_recommended_plugin_pkey"),)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
|
||||
plugin_id = mapped_column(sa.Text, nullable=False)
|
||||
provider_name = mapped_column(sa.Text, nullable=False)
|
||||
position = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
active = mapped_column(sa.Boolean, nullable=False, default=True)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
@ -95,7 +95,9 @@ class App(Base):
|
||||
created_by = mapped_column(StringUUID, nullable=True)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
use_icon_as_answer_icon: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
|
||||
@property
|
||||
@ -314,7 +316,9 @@ class AppModelConfig(Base):
|
||||
created_by = mapped_column(StringUUID, nullable=True)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
opening_statement = mapped_column(sa.Text)
|
||||
suggested_questions = mapped_column(sa.Text)
|
||||
suggested_questions_after_answer = mapped_column(sa.Text)
|
||||
@ -545,7 +549,9 @@ class RecommendedApp(Base):
|
||||
install_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
language = mapped_column(String(255), nullable=False, server_default=sa.text("'en-US'::character varying"))
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
@ -644,7 +650,9 @@ class Conversation(Base):
|
||||
read_account_id = mapped_column(StringUUID)
|
||||
dialogue_count: Mapped[int] = mapped_column(default=0)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
messages = db.relationship("Message", backref="conversation", lazy="select", passive_deletes="all")
|
||||
message_annotations = db.relationship(
|
||||
@ -948,7 +956,9 @@ class Message(Base):
|
||||
from_end_user_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||
from_account_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
agent_based: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
workflow_run_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||
app_mode: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
@ -1296,7 +1306,9 @@ class MessageFeedback(Base):
|
||||
from_end_user_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||
from_account_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
@property
|
||||
def from_account(self) -> Account | None:
|
||||
@ -1378,7 +1390,9 @@ class MessageAnnotation(Base):
|
||||
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
|
||||
account_id = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
@property
|
||||
def account(self):
|
||||
@ -1443,7 +1457,9 @@ class AppAnnotationSetting(Base):
|
||||
created_user_id = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_user_id = mapped_column(StringUUID, nullable=False)
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
@property
|
||||
def collection_binding_detail(self):
|
||||
@ -1471,7 +1487,9 @@ class OperationLog(Base):
|
||||
content = mapped_column(sa.JSON)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_ip: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
|
||||
class DefaultEndUserSessionID(StrEnum):
|
||||
@ -1510,7 +1528,9 @@ class EndUser(Base, UserMixin):
|
||||
|
||||
session_id: Mapped[str] = mapped_column()
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
|
||||
class AppMCPServer(Base):
|
||||
@ -1530,7 +1550,9 @@ class AppMCPServer(Base):
|
||||
parameters = mapped_column(sa.Text, nullable=False)
|
||||
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def generate_server_code(n: int) -> str:
|
||||
@ -1576,7 +1598,9 @@ class Site(Base):
|
||||
created_by = mapped_column(StringUUID, nullable=True)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
code = mapped_column(String(255))
|
||||
|
||||
@property
|
||||
|
||||
@ -1,62 +1,66 @@
|
||||
from datetime import datetime
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from .base import Base
|
||||
from .engine import db
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class DatasourceOauthParamConfig(Base): # type: ignore[name-defined]
|
||||
__tablename__ = "datasource_oauth_params"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="datasource_oauth_config_pkey"),
|
||||
db.UniqueConstraint("plugin_id", "provider", name="datasource_oauth_config_datasource_id_provider_idx"),
|
||||
sa.PrimaryKeyConstraint("id", name="datasource_oauth_config_pkey"),
|
||||
sa.UniqueConstraint("plugin_id", "provider", name="datasource_oauth_config_datasource_id_provider_idx"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
|
||||
plugin_id: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
|
||||
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
system_credentials: Mapped[dict] = mapped_column(JSONB, nullable=False)
|
||||
|
||||
|
||||
class DatasourceProvider(Base):
|
||||
__tablename__ = "datasource_providers"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="datasource_provider_pkey"),
|
||||
db.UniqueConstraint("tenant_id", "plugin_id", "provider", "name", name="datasource_provider_unique_name"),
|
||||
db.Index("datasource_provider_auth_type_provider_idx", "tenant_id", "plugin_id", "provider"),
|
||||
sa.PrimaryKeyConstraint("id", name="datasource_provider_pkey"),
|
||||
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", "name", name="datasource_provider_unique_name"),
|
||||
sa.Index("datasource_provider_auth_type_provider_idx", "tenant_id", "plugin_id", "provider"),
|
||||
)
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
name: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
plugin_id: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
auth_type: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
auth_type: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
encrypted_credentials: Mapped[dict] = mapped_column(JSONB, nullable=False)
|
||||
avatar_url: Mapped[str] = mapped_column(sa.Text, nullable=True, default="default")
|
||||
is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
|
||||
is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
expires_at: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default="-1")
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, default=datetime.now)
|
||||
updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, default=datetime.now)
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
|
||||
class DatasourceOauthTenantParamConfig(Base):
|
||||
__tablename__ = "datasource_oauth_tenant_params"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="datasource_oauth_tenant_config_pkey"),
|
||||
db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="datasource_oauth_tenant_config_unique"),
|
||||
sa.PrimaryKeyConstraint("id", name="datasource_oauth_tenant_config_pkey"),
|
||||
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", name="datasource_oauth_tenant_config_unique"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
plugin_id: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
client_params: Mapped[dict] = mapped_column(JSONB, nullable=False, default={})
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, default=datetime.now)
|
||||
updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, default=datetime.now)
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
@ -72,7 +72,9 @@ class Provider(Base):
|
||||
quota_used: Mapped[int | None] = mapped_column(sa.BigInteger, default=0)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
@ -135,7 +137,9 @@ class ProviderModel(Base):
|
||||
credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def credential(self):
|
||||
@ -170,7 +174,9 @@ class TenantDefaultModel(Base):
|
||||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
|
||||
class TenantPreferredModelProvider(Base):
|
||||
@ -185,7 +191,9 @@ class TenantPreferredModelProvider(Base):
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
preferred_provider_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
|
||||
class ProviderOrder(Base):
|
||||
@ -212,7 +220,9 @@ class ProviderOrder(Base):
|
||||
pay_failed_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||
refunded_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
|
||||
class ProviderModelSetting(Base):
|
||||
@ -234,7 +244,9 @@ class ProviderModelSetting(Base):
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"))
|
||||
load_balancing_enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
|
||||
class LoadBalancingModelConfig(Base):
|
||||
@ -259,7 +271,9 @@ class LoadBalancingModelConfig(Base):
|
||||
credential_source_type: Mapped[str | None] = mapped_column(String(40), nullable=True)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
|
||||
class ProviderCredential(Base):
|
||||
@ -279,7 +293,9 @@ class ProviderCredential(Base):
|
||||
credential_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
encrypted_config: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
|
||||
class ProviderModelCredential(Base):
|
||||
@ -307,4 +323,6 @@ class ProviderModelCredential(Base):
|
||||
credential_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
encrypted_config: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
@ -140,8 +140,8 @@ class Workflow(Base):
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
default=naive_utc_now(),
|
||||
server_onupdate=func.current_timestamp(),
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
)
|
||||
_environment_variables: Mapped[str] = mapped_column(
|
||||
"environment_variables", sa.Text, nullable=False, server_default="{}"
|
||||
@ -150,7 +150,7 @@ class Workflow(Base):
|
||||
"conversation_variables", sa.Text, nullable=False, server_default="{}"
|
||||
)
|
||||
_rag_pipeline_variables: Mapped[str] = mapped_column(
|
||||
"rag_pipeline_variables", db.Text, nullable=False, server_default="{}"
|
||||
"rag_pipeline_variables", sa.Text, nullable=False, server_default="{}"
|
||||
)
|
||||
|
||||
VERSION_DRAFT = "draft"
|
||||
|
||||
@ -50,6 +50,7 @@ from models.model import UploadFile
|
||||
from models.provider_ids import ModelProviderID
|
||||
from models.source import DataSourceOauthBinding
|
||||
from models.workflow import Workflow
|
||||
from services.document_indexing_task_proxy import DocumentIndexingTaskProxy
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
ChildChunkUpdateArgs,
|
||||
KnowledgeConfig,
|
||||
@ -79,7 +80,6 @@ from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
|
||||
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
|
||||
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
|
||||
from tasks.disable_segments_from_index_task import disable_segments_from_index_task
|
||||
from tasks.document_indexing_task import document_indexing_task
|
||||
from tasks.document_indexing_update_task import document_indexing_update_task
|
||||
from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task
|
||||
from tasks.enable_segments_to_index_task import enable_segments_to_index_task
|
||||
@ -1694,7 +1694,7 @@ class DocumentService:
|
||||
|
||||
# trigger async task
|
||||
if document_ids:
|
||||
document_indexing_task.delay(dataset.id, document_ids)
|
||||
DocumentIndexingTaskProxy(dataset.tenant_id, dataset.id, document_ids).delay()
|
||||
if duplicate_document_ids:
|
||||
duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)
|
||||
|
||||
|
||||
83
api/services/document_indexing_task_proxy.py
Normal file
83
api/services/document_indexing_task_proxy.py
Normal file
@ -0,0 +1,83 @@
|
||||
import logging
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import asdict
|
||||
from functools import cached_property
|
||||
|
||||
from core.entities.document_task import DocumentTask
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from services.feature_service import FeatureService
|
||||
from tasks.document_indexing_task import normal_document_indexing_task, priority_document_indexing_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DocumentIndexingTaskProxy:
|
||||
def __init__(self, tenant_id: str, dataset_id: str, document_ids: Sequence[str]):
|
||||
self._tenant_id = tenant_id
|
||||
self._dataset_id = dataset_id
|
||||
self._document_ids = document_ids
|
||||
self._tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing")
|
||||
|
||||
@cached_property
|
||||
def features(self):
|
||||
return FeatureService.get_features(self._tenant_id)
|
||||
|
||||
def _send_to_direct_queue(self, task_func: Callable[[str, str, Sequence[str]], None]):
|
||||
logger.info("send dataset %s to direct queue", self._dataset_id)
|
||||
task_func.delay( # type: ignore
|
||||
tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids
|
||||
)
|
||||
|
||||
def _send_to_tenant_queue(self, task_func: Callable[[str, str, Sequence[str]], None]):
|
||||
logger.info("send dataset %s to tenant queue", self._dataset_id)
|
||||
if self._tenant_isolated_task_queue.get_task_key():
|
||||
# Add to waiting queue using List operations (lpush)
|
||||
self._tenant_isolated_task_queue.push_tasks(
|
||||
[
|
||||
asdict(
|
||||
DocumentTask(
|
||||
tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
logger.info("push tasks: %s - %s", self._dataset_id, self._document_ids)
|
||||
else:
|
||||
# Set flag and execute task
|
||||
self._tenant_isolated_task_queue.set_task_waiting_time()
|
||||
task_func.delay( # type: ignore
|
||||
tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids
|
||||
)
|
||||
logger.info("init tasks: %s - %s", self._dataset_id, self._document_ids)
|
||||
|
||||
def _send_to_default_tenant_queue(self):
|
||||
self._send_to_tenant_queue(normal_document_indexing_task)
|
||||
|
||||
def _send_to_priority_tenant_queue(self):
|
||||
self._send_to_tenant_queue(priority_document_indexing_task)
|
||||
|
||||
def _send_to_priority_direct_queue(self):
|
||||
self._send_to_direct_queue(priority_document_indexing_task)
|
||||
|
||||
def _dispatch(self):
|
||||
logger.info(
|
||||
"dispatch args: %s - %s - %s",
|
||||
self._tenant_id,
|
||||
self.features.billing.enabled,
|
||||
self.features.billing.subscription.plan,
|
||||
)
|
||||
# dispatch to different indexing queue with tenant isolation when billing enabled
|
||||
if self.features.billing.enabled:
|
||||
if self.features.billing.subscription.plan == CloudPlan.SANDBOX:
|
||||
# dispatch to normal pipeline queue with tenant self sub queue for sandbox plan
|
||||
self._send_to_default_tenant_queue()
|
||||
else:
|
||||
# dispatch to priority pipeline queue with tenant self sub queue for other plans
|
||||
self._send_to_priority_tenant_queue()
|
||||
else:
|
||||
# dispatch to priority queue without tenant isolation for others, e.g.: self-hosted or enterprise
|
||||
self._send_to_priority_direct_queue()
|
||||
|
||||
def delay(self):
|
||||
self._dispatch()
|
||||
106
api/services/rag_pipeline/rag_pipeline_task_proxy.py
Normal file
106
api/services/rag_pipeline/rag_pipeline_task_proxy.py
Normal file
@ -0,0 +1,106 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import cached_property
|
||||
|
||||
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from services.feature_service import FeatureService
|
||||
from services.file_service import FileService
|
||||
from tasks.rag_pipeline.priority_rag_pipeline_run_task import priority_rag_pipeline_run_task
|
||||
from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RagPipelineTaskProxy:
|
||||
# Default uploaded file name for rag pipeline invoke entities
|
||||
_RAG_PIPELINE_INVOKE_ENTITIES_FILE_NAME = "rag_pipeline_invoke_entities.json"
|
||||
|
||||
def __init__(
|
||||
self, dataset_tenant_id: str, user_id: str, rag_pipeline_invoke_entities: Sequence[RagPipelineInvokeEntity]
|
||||
):
|
||||
self._dataset_tenant_id = dataset_tenant_id
|
||||
self._user_id = user_id
|
||||
self._rag_pipeline_invoke_entities = rag_pipeline_invoke_entities
|
||||
self._tenant_isolated_task_queue = TenantIsolatedTaskQueue(dataset_tenant_id, "pipeline")
|
||||
|
||||
@cached_property
|
||||
def features(self):
|
||||
return FeatureService.get_features(self._dataset_tenant_id)
|
||||
|
||||
def _upload_invoke_entities(self) -> str:
|
||||
text = [item.model_dump() for item in self._rag_pipeline_invoke_entities]
|
||||
# Convert list to proper JSON string
|
||||
json_text = json.dumps(text)
|
||||
upload_file = FileService(db.engine).upload_text(
|
||||
json_text, self._RAG_PIPELINE_INVOKE_ENTITIES_FILE_NAME, self._user_id, self._dataset_tenant_id
|
||||
)
|
||||
return upload_file.id
|
||||
|
||||
def _send_to_direct_queue(self, upload_file_id: str, task_func: Callable[[str, str], None]):
|
||||
logger.info("send file %s to direct queue", upload_file_id)
|
||||
task_func.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=upload_file_id,
|
||||
tenant_id=self._dataset_tenant_id,
|
||||
)
|
||||
|
||||
def _send_to_tenant_queue(self, upload_file_id: str, task_func: Callable[[str, str], None]):
|
||||
logger.info("send file %s to tenant queue", upload_file_id)
|
||||
if self._tenant_isolated_task_queue.get_task_key():
|
||||
# Add to waiting queue using List operations (lpush)
|
||||
self._tenant_isolated_task_queue.push_tasks([upload_file_id])
|
||||
logger.info("push tasks: %s", upload_file_id)
|
||||
else:
|
||||
# Set flag and execute task
|
||||
self._tenant_isolated_task_queue.set_task_waiting_time()
|
||||
task_func.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=upload_file_id,
|
||||
tenant_id=self._dataset_tenant_id,
|
||||
)
|
||||
logger.info("init tasks: %s", upload_file_id)
|
||||
|
||||
def _send_to_default_tenant_queue(self, upload_file_id: str):
|
||||
self._send_to_tenant_queue(upload_file_id, rag_pipeline_run_task)
|
||||
|
||||
def _send_to_priority_tenant_queue(self, upload_file_id: str):
|
||||
self._send_to_tenant_queue(upload_file_id, priority_rag_pipeline_run_task)
|
||||
|
||||
def _send_to_priority_direct_queue(self, upload_file_id: str):
|
||||
self._send_to_direct_queue(upload_file_id, priority_rag_pipeline_run_task)
|
||||
|
||||
def _dispatch(self):
|
||||
upload_file_id = self._upload_invoke_entities()
|
||||
if not upload_file_id:
|
||||
raise ValueError("upload_file_id is empty")
|
||||
|
||||
logger.info(
|
||||
"dispatch args: %s - %s - %s",
|
||||
self._dataset_tenant_id,
|
||||
self.features.billing.enabled,
|
||||
self.features.billing.subscription.plan,
|
||||
)
|
||||
|
||||
# dispatch to different pipeline queue with tenant isolation when billing enabled
|
||||
if self.features.billing.enabled:
|
||||
if self.features.billing.subscription.plan == CloudPlan.SANDBOX:
|
||||
# dispatch to normal pipeline queue with tenant isolation for sandbox plan
|
||||
self._send_to_default_tenant_queue(upload_file_id)
|
||||
else:
|
||||
# dispatch to priority pipeline queue with tenant isolation for other plans
|
||||
self._send_to_priority_tenant_queue(upload_file_id)
|
||||
else:
|
||||
# dispatch to priority pipeline queue without tenant isolation for others, e.g.: self-hosted or enterprise
|
||||
self._send_to_priority_direct_queue(upload_file_id)
|
||||
|
||||
def delay(self):
|
||||
if not self._rag_pipeline_invoke_entities:
|
||||
logger.warning(
|
||||
"Received empty rag pipeline invoke entities, no tasks delivered: %s %s",
|
||||
self._dataset_tenant_id,
|
||||
self._user_id,
|
||||
)
|
||||
return
|
||||
self._dispatch()
|
||||
@ -126,7 +126,7 @@ workflow:
|
||||
type: mixed
|
||||
value: '{{#rag.1752491761974.jina_use_sitemap#}}'
|
||||
plugin_id: langgenius/jina_datasource
|
||||
provider_name: jina
|
||||
provider_name: jinareader
|
||||
provider_type: website_crawl
|
||||
selected: false
|
||||
title: Jina Reader
|
||||
|
||||
@ -126,7 +126,7 @@ workflow:
|
||||
type: mixed
|
||||
value: '{{#rag.1752491761974.jina_use_sitemap#}}'
|
||||
plugin_id: langgenius/jina_datasource
|
||||
provider_name: jina
|
||||
provider_name: jinareader
|
||||
provider_type: website_crawl
|
||||
selected: false
|
||||
title: Jina Reader
|
||||
|
||||
@ -419,7 +419,7 @@ workflow:
|
||||
type: mixed
|
||||
value: '{{#rag.1752491761974.jina_use_sitemap#}}'
|
||||
plugin_id: langgenius/jina_datasource
|
||||
provider_name: jina
|
||||
provider_name: jinareader
|
||||
provider_type: website_crawl
|
||||
selected: false
|
||||
title: Jina Reader
|
||||
|
||||
@ -1,11 +1,14 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.document_task import DocumentTask
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@ -22,8 +25,24 @@ def document_indexing_task(dataset_id: str, document_ids: list):
|
||||
:param dataset_id:
|
||||
:param document_ids:
|
||||
|
||||
.. warning:: TO BE DEPRECATED
|
||||
This function will be deprecated and removed in a future version.
|
||||
Use normal_document_indexing_task or priority_document_indexing_task instead.
|
||||
|
||||
Usage: document_indexing_task.delay(dataset_id, document_ids)
|
||||
"""
|
||||
logger.warning("document indexing legacy mode received: %s - %s", dataset_id, document_ids)
|
||||
_document_indexing(dataset_id, document_ids)
|
||||
|
||||
|
||||
def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
||||
"""
|
||||
Process document for tasks
|
||||
:param dataset_id:
|
||||
:param document_ids:
|
||||
|
||||
Usage: _document_indexing(dataset_id, document_ids)
|
||||
"""
|
||||
documents = []
|
||||
start_at = time.perf_counter()
|
||||
|
||||
@ -87,3 +106,63 @@ def document_indexing_task(dataset_id: str, document_ids: list):
|
||||
logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
|
||||
def _document_indexing_with_tenant_queue(
|
||||
tenant_id: str, dataset_id: str, document_ids: Sequence[str], task_func: Callable[[str, str, Sequence[str]], None]
|
||||
):
|
||||
try:
|
||||
_document_indexing(dataset_id, document_ids)
|
||||
except Exception:
|
||||
logger.exception("Error processing document indexing %s for tenant %s: %s", dataset_id, tenant_id)
|
||||
finally:
|
||||
tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing")
|
||||
|
||||
# Check if there are waiting tasks in the queue
|
||||
# Use rpop to get the next task from the queue (FIFO order)
|
||||
next_tasks = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY)
|
||||
|
||||
logger.info("document indexing tenant isolation queue next tasks: %s", next_tasks)
|
||||
|
||||
if next_tasks:
|
||||
for next_task in next_tasks:
|
||||
document_task = DocumentTask(**next_task)
|
||||
# Process the next waiting task
|
||||
# Keep the flag set to indicate a task is running
|
||||
tenant_isolated_task_queue.set_task_waiting_time()
|
||||
task_func.delay( # type: ignore
|
||||
tenant_id=document_task.tenant_id,
|
||||
dataset_id=document_task.dataset_id,
|
||||
document_ids=document_task.document_ids,
|
||||
)
|
||||
else:
|
||||
# No more waiting tasks, clear the flag
|
||||
tenant_isolated_task_queue.delete_task_key()
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def normal_document_indexing_task(tenant_id: str, dataset_id: str, document_ids: Sequence[str]):
|
||||
"""
|
||||
Async process document
|
||||
:param tenant_id:
|
||||
:param dataset_id:
|
||||
:param document_ids:
|
||||
|
||||
Usage: normal_document_indexing_task.delay(tenant_id, dataset_id, document_ids)
|
||||
"""
|
||||
logger.info("normal document indexing task received: %s - %s - %s", tenant_id, dataset_id, document_ids)
|
||||
_document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, normal_document_indexing_task)
|
||||
|
||||
|
||||
@shared_task(queue="priority_dataset")
|
||||
def priority_document_indexing_task(tenant_id: str, dataset_id: str, document_ids: Sequence[str]):
|
||||
"""
|
||||
Priority async process document
|
||||
:param tenant_id:
|
||||
:param dataset_id:
|
||||
:param document_ids:
|
||||
|
||||
Usage: priority_document_indexing_task.delay(tenant_id, dataset_id, document_ids)
|
||||
"""
|
||||
logger.info("priority document indexing task received: %s - %s - %s", tenant_id, dataset_id, document_ids)
|
||||
_document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, priority_document_indexing_task)
|
||||
|
||||
@ -12,8 +12,10 @@ from celery import shared_task # type: ignore
|
||||
from flask import current_app, g
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
|
||||
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from core.repositories.factory import DifyCoreRepositoryFactory
|
||||
from extensions.ext_database import db
|
||||
from models import Account, Tenant
|
||||
@ -22,6 +24,8 @@ from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from services.file_service import FileService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="priority_pipeline")
|
||||
def priority_rag_pipeline_run_task(
|
||||
@ -69,6 +73,27 @@ def priority_rag_pipeline_run_task(
|
||||
logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red"))
|
||||
raise
|
||||
finally:
|
||||
tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "pipeline")
|
||||
|
||||
# Check if there are waiting tasks in the queue
|
||||
# Use rpop to get the next task from the queue (FIFO order)
|
||||
next_file_ids = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY)
|
||||
logger.info("priority rag pipeline tenant isolation queue next files: %s", next_file_ids)
|
||||
|
||||
if next_file_ids:
|
||||
for next_file_id in next_file_ids:
|
||||
# Process the next waiting task
|
||||
# Keep the flag set to indicate a task is running
|
||||
tenant_isolated_task_queue.set_task_waiting_time()
|
||||
priority_rag_pipeline_run_task.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8")
|
||||
if isinstance(next_file_id, bytes)
|
||||
else next_file_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
else:
|
||||
# No more waiting tasks, clear the flag
|
||||
tenant_isolated_task_queue.delete_task_key()
|
||||
file_service = FileService(db.engine)
|
||||
file_service.delete_file(rag_pipeline_invoke_entities_file_id)
|
||||
db.session.close()
|
||||
|
||||
@ -12,17 +12,20 @@ from celery import shared_task # type: ignore
|
||||
from flask import current_app, g
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
|
||||
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from core.repositories.factory import DifyCoreRepositoryFactory
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models import Account, Tenant
|
||||
from models.dataset import Pipeline
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from services.file_service import FileService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="pipeline")
|
||||
def rag_pipeline_run_task(
|
||||
@ -70,26 +73,27 @@ def rag_pipeline_run_task(
|
||||
logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red"))
|
||||
raise
|
||||
finally:
|
||||
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{tenant_id}"
|
||||
tenant_pipeline_task_key = f"tenant_pipeline_task:{tenant_id}"
|
||||
tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "pipeline")
|
||||
|
||||
# Check if there are waiting tasks in the queue
|
||||
# Use rpop to get the next task from the queue (FIFO order)
|
||||
next_file_id = redis_client.rpop(tenant_self_pipeline_task_queue)
|
||||
next_file_ids = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY)
|
||||
logger.info("rag pipeline tenant isolation queue next files: %s", next_file_ids)
|
||||
|
||||
if next_file_id:
|
||||
# Process the next waiting task
|
||||
# Keep the flag set to indicate a task is running
|
||||
redis_client.setex(tenant_pipeline_task_key, 60 * 60, 1)
|
||||
rag_pipeline_run_task.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8")
|
||||
if isinstance(next_file_id, bytes)
|
||||
else next_file_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
if next_file_ids:
|
||||
for next_file_id in next_file_ids:
|
||||
# Process the next waiting task
|
||||
# Keep the flag set to indicate a task is running
|
||||
tenant_isolated_task_queue.set_task_waiting_time()
|
||||
rag_pipeline_run_task.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8")
|
||||
if isinstance(next_file_id, bytes)
|
||||
else next_file_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
else:
|
||||
# No more waiting tasks, clear the flag
|
||||
redis_client.delete(tenant_pipeline_task_key)
|
||||
tenant_isolated_task_queue.delete_task_key()
|
||||
file_service = FileService(db.engine)
|
||||
file_service.delete_file(rag_pipeline_invoke_entities_file_id)
|
||||
db.session.close()
|
||||
|
||||
@ -0,0 +1 @@
|
||||
|
||||
@ -0,0 +1,595 @@
|
||||
"""
|
||||
Integration tests for TenantIsolatedTaskQueue using testcontainers.
|
||||
|
||||
These tests verify the Redis-based task queue functionality with real Redis instances,
|
||||
testing tenant isolation, task serialization, and queue operations in a realistic environment.
|
||||
Includes compatibility tests for migrating from legacy string-only queues.
|
||||
|
||||
All tests use generic naming to avoid coupling to specific business implementations.
|
||||
"""
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.rag.pipeline.queue import TaskWrapper, TenantIsolatedTaskQueue
|
||||
from extensions.ext_redis import redis_client
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestTask:
|
||||
"""Test task data structure for testing complex object serialization."""
|
||||
|
||||
task_id: str
|
||||
tenant_id: str
|
||||
data: dict[str, Any]
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
class TestTenantIsolatedTaskQueueIntegration:
|
||||
"""Integration tests for TenantIsolatedTaskQueue using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def fake(self):
|
||||
"""Faker instance for generating test data."""
|
||||
return Faker()
|
||||
|
||||
@pytest.fixture
|
||||
def test_tenant_and_account(self, db_session_with_containers, fake):
|
||||
"""Create test tenant and account for testing."""
|
||||
# Create account
|
||||
account = Account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return tenant, account
|
||||
|
||||
@pytest.fixture
|
||||
def test_queue(self, test_tenant_and_account):
|
||||
"""Create a generic test queue for testing."""
|
||||
tenant, _ = test_tenant_and_account
|
||||
return TenantIsolatedTaskQueue(tenant.id, "test_queue")
|
||||
|
||||
@pytest.fixture
|
||||
def secondary_queue(self, test_tenant_and_account):
|
||||
"""Create a secondary test queue for testing isolation."""
|
||||
tenant, _ = test_tenant_and_account
|
||||
return TenantIsolatedTaskQueue(tenant.id, "secondary_queue")
|
||||
|
||||
def test_queue_initialization(self, test_tenant_and_account):
|
||||
"""Test queue initialization with correct key generation."""
|
||||
tenant, _ = test_tenant_and_account
|
||||
queue = TenantIsolatedTaskQueue(tenant.id, "test-key")
|
||||
|
||||
assert queue._tenant_id == tenant.id
|
||||
assert queue._unique_key == "test-key"
|
||||
assert queue._queue == f"tenant_self_test-key_task_queue:{tenant.id}"
|
||||
assert queue._task_key == f"tenant_test-key_task:{tenant.id}"
|
||||
|
||||
def test_tenant_isolation(self, test_tenant_and_account, db_session_with_containers, fake):
|
||||
"""Test that different tenants have isolated queues."""
|
||||
tenant1, _ = test_tenant_and_account
|
||||
|
||||
# Create second tenant
|
||||
tenant2 = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db_session_with_containers.add(tenant2)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
queue1 = TenantIsolatedTaskQueue(tenant1.id, "same-key")
|
||||
queue2 = TenantIsolatedTaskQueue(tenant2.id, "same-key")
|
||||
|
||||
assert queue1._queue != queue2._queue
|
||||
assert queue1._task_key != queue2._task_key
|
||||
assert queue1._queue == f"tenant_self_same-key_task_queue:{tenant1.id}"
|
||||
assert queue2._queue == f"tenant_self_same-key_task_queue:{tenant2.id}"
|
||||
|
||||
def test_key_isolation(self, test_tenant_and_account):
|
||||
"""Test that different keys have isolated queues."""
|
||||
tenant, _ = test_tenant_and_account
|
||||
queue1 = TenantIsolatedTaskQueue(tenant.id, "key1")
|
||||
queue2 = TenantIsolatedTaskQueue(tenant.id, "key2")
|
||||
|
||||
assert queue1._queue != queue2._queue
|
||||
assert queue1._task_key != queue2._task_key
|
||||
assert queue1._queue == f"tenant_self_key1_task_queue:{tenant.id}"
|
||||
assert queue2._queue == f"tenant_self_key2_task_queue:{tenant.id}"
|
||||
|
||||
def test_task_key_operations(self, test_queue):
|
||||
"""Test task key operations (get, set, delete)."""
|
||||
# Initially no task key should exist
|
||||
assert test_queue.get_task_key() is None
|
||||
|
||||
# Set task waiting time with default TTL
|
||||
test_queue.set_task_waiting_time()
|
||||
task_key = test_queue.get_task_key()
|
||||
# Redis returns bytes, convert to string for comparison
|
||||
assert task_key in (b"1", "1")
|
||||
|
||||
# Set task waiting time with custom TTL
|
||||
custom_ttl = 30
|
||||
test_queue.set_task_waiting_time(custom_ttl)
|
||||
task_key = test_queue.get_task_key()
|
||||
assert task_key in (b"1", "1")
|
||||
|
||||
# Delete task key
|
||||
test_queue.delete_task_key()
|
||||
assert test_queue.get_task_key() is None
|
||||
|
||||
def test_push_and_pull_string_tasks(self, test_queue):
|
||||
"""Test pushing and pulling string tasks."""
|
||||
tasks = ["task1", "task2", "task3"]
|
||||
|
||||
# Push tasks
|
||||
test_queue.push_tasks(tasks)
|
||||
|
||||
# Pull tasks (FIFO order)
|
||||
pulled_tasks = test_queue.pull_tasks(3)
|
||||
|
||||
# Should get tasks in FIFO order (lpush + rpop = FIFO)
|
||||
assert pulled_tasks == ["task1", "task2", "task3"]
|
||||
|
||||
def test_push_and_pull_multiple_tasks(self, test_queue):
|
||||
"""Test pushing and pulling multiple tasks at once."""
|
||||
tasks = ["task1", "task2", "task3", "task4", "task5"]
|
||||
|
||||
# Push tasks
|
||||
test_queue.push_tasks(tasks)
|
||||
|
||||
# Pull multiple tasks
|
||||
pulled_tasks = test_queue.pull_tasks(3)
|
||||
assert len(pulled_tasks) == 3
|
||||
assert pulled_tasks == ["task1", "task2", "task3"]
|
||||
|
||||
# Pull remaining tasks
|
||||
remaining_tasks = test_queue.pull_tasks(5)
|
||||
assert len(remaining_tasks) == 2
|
||||
assert remaining_tasks == ["task4", "task5"]
|
||||
|
||||
def test_push_and_pull_complex_objects(self, test_queue, fake):
|
||||
"""Test pushing and pulling complex object tasks."""
|
||||
# Create complex task objects as dictionaries (not dataclass instances)
|
||||
tasks = [
|
||||
{
|
||||
"task_id": str(uuid4()),
|
||||
"tenant_id": test_queue._tenant_id,
|
||||
"data": {
|
||||
"file_id": str(uuid4()),
|
||||
"content": fake.text(),
|
||||
"metadata": {"size": fake.random_int(1000, 10000)},
|
||||
},
|
||||
"metadata": {"created_at": fake.iso8601(), "tags": fake.words(3)},
|
||||
},
|
||||
{
|
||||
"task_id": str(uuid4()),
|
||||
"tenant_id": test_queue._tenant_id,
|
||||
"data": {
|
||||
"file_id": str(uuid4()),
|
||||
"content": "测试中文内容",
|
||||
"metadata": {"size": fake.random_int(1000, 10000)},
|
||||
},
|
||||
"metadata": {"created_at": fake.iso8601(), "tags": ["中文", "测试", "emoji🚀"]},
|
||||
},
|
||||
]
|
||||
|
||||
# Push complex tasks
|
||||
test_queue.push_tasks(tasks)
|
||||
|
||||
# Pull tasks
|
||||
pulled_tasks = test_queue.pull_tasks(2)
|
||||
assert len(pulled_tasks) == 2
|
||||
|
||||
# Verify deserialized tasks match original (FIFO order)
|
||||
for i, pulled_task in enumerate(pulled_tasks):
|
||||
original_task = tasks[i] # FIFO order
|
||||
assert isinstance(pulled_task, dict)
|
||||
assert pulled_task["task_id"] == original_task["task_id"]
|
||||
assert pulled_task["tenant_id"] == original_task["tenant_id"]
|
||||
assert pulled_task["data"] == original_task["data"]
|
||||
assert pulled_task["metadata"] == original_task["metadata"]
|
||||
|
||||
def test_mixed_task_types(self, test_queue, fake):
|
||||
"""Test pushing and pulling mixed string and object tasks."""
|
||||
string_task = "simple_string_task"
|
||||
object_task = {
|
||||
"task_id": str(uuid4()),
|
||||
"dataset_id": str(uuid4()),
|
||||
"document_ids": [str(uuid4()) for _ in range(3)],
|
||||
}
|
||||
|
||||
tasks = [string_task, object_task, "another_string"]
|
||||
|
||||
# Push mixed tasks
|
||||
test_queue.push_tasks(tasks)
|
||||
|
||||
# Pull all tasks
|
||||
pulled_tasks = test_queue.pull_tasks(3)
|
||||
assert len(pulled_tasks) == 3
|
||||
|
||||
# Verify types and content
|
||||
assert pulled_tasks[0] == string_task
|
||||
assert isinstance(pulled_tasks[1], dict)
|
||||
assert pulled_tasks[1] == object_task
|
||||
assert pulled_tasks[2] == "another_string"
|
||||
|
||||
def test_empty_queue_operations(self, test_queue):
|
||||
"""Test operations on empty queue."""
|
||||
# Pull from empty queue
|
||||
tasks = test_queue.pull_tasks(5)
|
||||
assert tasks == []
|
||||
|
||||
# Pull zero or negative count
|
||||
assert test_queue.pull_tasks(0) == []
|
||||
assert test_queue.pull_tasks(-1) == []
|
||||
|
||||
def test_task_ttl_expiration(self, test_queue):
|
||||
"""Test task key TTL expiration."""
|
||||
# Set task with short TTL
|
||||
short_ttl = 2
|
||||
test_queue.set_task_waiting_time(short_ttl)
|
||||
|
||||
# Verify task key exists
|
||||
assert test_queue.get_task_key() == b"1" or test_queue.get_task_key() == "1"
|
||||
|
||||
# Wait for TTL to expire
|
||||
time.sleep(short_ttl + 1)
|
||||
|
||||
# Verify task key has expired
|
||||
assert test_queue.get_task_key() is None
|
||||
|
||||
def test_large_task_batch(self, test_queue, fake):
|
||||
"""Test handling large batches of tasks."""
|
||||
# Create large batch of tasks
|
||||
large_batch = []
|
||||
for i in range(100):
|
||||
task = {
|
||||
"task_id": str(uuid4()),
|
||||
"index": i,
|
||||
"data": fake.text(max_nb_chars=100),
|
||||
"metadata": {"batch_id": str(uuid4())},
|
||||
}
|
||||
large_batch.append(task)
|
||||
|
||||
# Push large batch
|
||||
test_queue.push_tasks(large_batch)
|
||||
|
||||
# Pull all tasks
|
||||
pulled_tasks = test_queue.pull_tasks(100)
|
||||
assert len(pulled_tasks) == 100
|
||||
|
||||
# Verify all tasks were retrieved correctly (FIFO order)
|
||||
for i, task in enumerate(pulled_tasks):
|
||||
assert isinstance(task, dict)
|
||||
assert task["index"] == i # FIFO order
|
||||
|
||||
def test_queue_operations_isolation(self, test_tenant_and_account, fake):
|
||||
"""Test concurrent operations on different queues."""
|
||||
tenant, _ = test_tenant_and_account
|
||||
|
||||
# Create multiple queues for the same tenant
|
||||
queue1 = TenantIsolatedTaskQueue(tenant.id, "queue1")
|
||||
queue2 = TenantIsolatedTaskQueue(tenant.id, "queue2")
|
||||
|
||||
# Push tasks to different queues
|
||||
queue1.push_tasks(["task1_queue1", "task2_queue1"])
|
||||
queue2.push_tasks(["task1_queue2", "task2_queue2"])
|
||||
|
||||
# Verify queues are isolated
|
||||
tasks1 = queue1.pull_tasks(2)
|
||||
tasks2 = queue2.pull_tasks(2)
|
||||
|
||||
assert tasks1 == ["task1_queue1", "task2_queue1"]
|
||||
assert tasks2 == ["task1_queue2", "task2_queue2"]
|
||||
assert tasks1 != tasks2
|
||||
|
||||
def test_task_wrapper_serialization_roundtrip(self, test_queue, fake):
|
||||
"""Test TaskWrapper serialization and deserialization roundtrip."""
|
||||
# Create complex nested data
|
||||
complex_data = {
|
||||
"id": str(uuid4()),
|
||||
"nested": {"deep": {"value": "test", "numbers": [1, 2, 3, 4, 5], "unicode": "测试中文", "emoji": "🚀"}},
|
||||
"metadata": {"created_at": fake.iso8601(), "tags": ["tag1", "tag2", "tag3"]},
|
||||
}
|
||||
|
||||
# Create wrapper and serialize
|
||||
wrapper = TaskWrapper(data=complex_data)
|
||||
serialized = wrapper.serialize()
|
||||
|
||||
# Verify serialization
|
||||
assert isinstance(serialized, str)
|
||||
assert "测试中文" in serialized
|
||||
assert "🚀" in serialized
|
||||
|
||||
# Deserialize and verify
|
||||
deserialized_wrapper = TaskWrapper.deserialize(serialized)
|
||||
assert deserialized_wrapper.data == complex_data
|
||||
|
||||
def test_error_handling_invalid_json(self, test_queue):
|
||||
"""Test error handling for invalid JSON in wrapped tasks."""
|
||||
# Manually create invalid JSON task (not a valid TaskWrapper JSON)
|
||||
invalid_json_task = "invalid json data"
|
||||
|
||||
# Push invalid task directly to Redis
|
||||
redis_client.lpush(test_queue._queue, invalid_json_task)
|
||||
|
||||
# Pull task - should fall back to string since it's not valid JSON
|
||||
task = test_queue.pull_tasks(1)
|
||||
assert task[0] == invalid_json_task
|
||||
|
||||
def test_real_world_batch_processing_scenario(self, test_queue, fake):
|
||||
"""Test realistic batch processing scenario."""
|
||||
# Simulate batch processing tasks
|
||||
batch_tasks = []
|
||||
for i in range(3):
|
||||
task = {
|
||||
"file_id": str(uuid4()),
|
||||
"tenant_id": test_queue._tenant_id,
|
||||
"user_id": str(uuid4()),
|
||||
"processing_config": {
|
||||
"model": fake.random_element(["model_a", "model_b", "model_c"]),
|
||||
"temperature": fake.random.uniform(0.1, 1.0),
|
||||
"max_tokens": fake.random_int(1000, 4000),
|
||||
},
|
||||
"metadata": {
|
||||
"source": fake.random_element(["upload", "api", "webhook"]),
|
||||
"priority": fake.random_element(["low", "normal", "high"]),
|
||||
},
|
||||
}
|
||||
batch_tasks.append(task)
|
||||
|
||||
# Push tasks
|
||||
test_queue.push_tasks(batch_tasks)
|
||||
|
||||
# Process tasks in batches
|
||||
batch_size = 2
|
||||
processed_tasks = []
|
||||
|
||||
while True:
|
||||
batch = test_queue.pull_tasks(batch_size)
|
||||
if not batch:
|
||||
break
|
||||
|
||||
processed_tasks.extend(batch)
|
||||
|
||||
# Verify all tasks were processed
|
||||
assert len(processed_tasks) == 3
|
||||
|
||||
# Verify task structure
|
||||
for task in processed_tasks:
|
||||
assert isinstance(task, dict)
|
||||
assert "file_id" in task
|
||||
assert "tenant_id" in task
|
||||
assert "processing_config" in task
|
||||
assert "metadata" in task
|
||||
assert task["tenant_id"] == test_queue._tenant_id
|
||||
|
||||
|
||||
class TestTenantIsolatedTaskQueueCompatibility:
|
||||
"""Compatibility tests for migrating from legacy string-only queues."""
|
||||
|
||||
@pytest.fixture
|
||||
def fake(self):
|
||||
"""Faker instance for generating test data."""
|
||||
return Faker()
|
||||
|
||||
@pytest.fixture
|
||||
def test_tenant_and_account(self, db_session_with_containers, fake):
|
||||
"""Create test tenant and account for testing."""
|
||||
# Create account
|
||||
account = Account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return tenant, account
|
||||
|
||||
def test_legacy_string_queue_compatibility(self, test_tenant_and_account, fake):
|
||||
"""
|
||||
Test compatibility with legacy queues containing only string data.
|
||||
|
||||
This simulates the scenario where Redis queues already contain string data
|
||||
from the old architecture, and we need to ensure the new code can read them.
|
||||
"""
|
||||
tenant, _ = test_tenant_and_account
|
||||
queue = TenantIsolatedTaskQueue(tenant.id, "legacy_queue")
|
||||
|
||||
# Simulate legacy string data in Redis queue (using old format)
|
||||
legacy_strings = ["legacy_task_1", "legacy_task_2", "legacy_task_3", "legacy_task_4", "legacy_task_5"]
|
||||
|
||||
# Manually push legacy strings directly to Redis (simulating old system)
|
||||
for legacy_string in legacy_strings:
|
||||
redis_client.lpush(queue._queue, legacy_string)
|
||||
|
||||
# Verify new code can read legacy string data
|
||||
pulled_tasks = queue.pull_tasks(5)
|
||||
assert len(pulled_tasks) == 5
|
||||
|
||||
# Verify all tasks are strings (not wrapped)
|
||||
for task in pulled_tasks:
|
||||
assert isinstance(task, str)
|
||||
assert task.startswith("legacy_task_")
|
||||
|
||||
# Verify order (FIFO from Redis list)
|
||||
expected_order = ["legacy_task_1", "legacy_task_2", "legacy_task_3", "legacy_task_4", "legacy_task_5"]
|
||||
assert pulled_tasks == expected_order
|
||||
|
||||
def test_legacy_queue_migration_scenario(self, test_tenant_and_account, fake):
|
||||
"""
|
||||
Test complete migration scenario from legacy to new system.
|
||||
|
||||
This simulates the real-world scenario where:
|
||||
1. Legacy system has string data in Redis
|
||||
2. New system starts processing the same queue
|
||||
3. Both legacy and new tasks coexist during migration
|
||||
4. New system can handle both formats seamlessly
|
||||
"""
|
||||
tenant, _ = test_tenant_and_account
|
||||
queue = TenantIsolatedTaskQueue(tenant.id, "migration_queue")
|
||||
|
||||
# Phase 1: Legacy system has data
|
||||
legacy_tasks = [f"legacy_resource_{i}" for i in range(1, 6)]
|
||||
redis_client.lpush(queue._queue, *legacy_tasks)
|
||||
|
||||
# Phase 2: New system starts processing legacy data
|
||||
processed_legacy = []
|
||||
while True:
|
||||
tasks = queue.pull_tasks(1)
|
||||
if not tasks:
|
||||
break
|
||||
processed_legacy.extend(tasks)
|
||||
|
||||
# Verify legacy data was processed correctly
|
||||
assert len(processed_legacy) == 5
|
||||
for task in processed_legacy:
|
||||
assert isinstance(task, str)
|
||||
assert task.startswith("legacy_resource_")
|
||||
|
||||
# Phase 3: New system adds new tasks (mixed types)
|
||||
new_string_tasks = ["new_resource_1", "new_resource_2"]
|
||||
new_object_tasks = [
|
||||
{
|
||||
"resource_id": str(uuid4()),
|
||||
"tenant_id": tenant.id,
|
||||
"processing_type": "new_system",
|
||||
"metadata": {"version": "2.0", "features": ["ai", "ml"]},
|
||||
},
|
||||
{
|
||||
"resource_id": str(uuid4()),
|
||||
"tenant_id": tenant.id,
|
||||
"processing_type": "new_system",
|
||||
"metadata": {"version": "2.0", "features": ["ai", "ml"]},
|
||||
},
|
||||
]
|
||||
|
||||
# Push new tasks using new system
|
||||
queue.push_tasks(new_string_tasks)
|
||||
queue.push_tasks(new_object_tasks)
|
||||
|
||||
# Phase 4: Process all new tasks
|
||||
processed_new = []
|
||||
while True:
|
||||
tasks = queue.pull_tasks(1)
|
||||
if not tasks:
|
||||
break
|
||||
processed_new.extend(tasks)
|
||||
|
||||
# Verify new tasks were processed correctly
|
||||
assert len(processed_new) == 4
|
||||
|
||||
string_tasks = [task for task in processed_new if isinstance(task, str)]
|
||||
object_tasks = [task for task in processed_new if isinstance(task, dict)]
|
||||
|
||||
assert len(string_tasks) == 2
|
||||
assert len(object_tasks) == 2
|
||||
|
||||
# Verify string tasks
|
||||
for task in string_tasks:
|
||||
assert task.startswith("new_resource_")
|
||||
|
||||
# Verify object tasks
|
||||
for task in object_tasks:
|
||||
assert isinstance(task, dict)
|
||||
assert "resource_id" in task
|
||||
assert "tenant_id" in task
|
||||
assert task["tenant_id"] == tenant.id
|
||||
assert task["processing_type"] == "new_system"
|
||||
|
||||
def test_legacy_queue_error_recovery(self, test_tenant_and_account, fake):
|
||||
"""
|
||||
Test error recovery when legacy queue contains malformed data.
|
||||
|
||||
This ensures the new system can gracefully handle corrupted or
|
||||
malformed legacy data without crashing.
|
||||
"""
|
||||
tenant, _ = test_tenant_and_account
|
||||
queue = TenantIsolatedTaskQueue(tenant.id, "error_recovery_queue")
|
||||
|
||||
# Create mix of valid and malformed legacy data
|
||||
mixed_legacy_data = [
|
||||
"valid_legacy_task_1",
|
||||
"valid_legacy_task_2",
|
||||
"malformed_data_string", # This should be treated as string
|
||||
"valid_legacy_task_3",
|
||||
"invalid_json_not_taskwrapper_format", # This should fall back to string (not valid TaskWrapper JSON)
|
||||
"valid_legacy_task_4",
|
||||
]
|
||||
|
||||
# Manually push mixed data directly to Redis
|
||||
redis_client.lpush(queue._queue, *mixed_legacy_data)
|
||||
|
||||
# Process all tasks
|
||||
processed_tasks = []
|
||||
while True:
|
||||
tasks = queue.pull_tasks(1)
|
||||
if not tasks:
|
||||
break
|
||||
processed_tasks.extend(tasks)
|
||||
|
||||
# Verify all tasks were processed (no crashes)
|
||||
assert len(processed_tasks) == 6
|
||||
|
||||
# Verify all tasks are strings (malformed data falls back to string)
|
||||
for task in processed_tasks:
|
||||
assert isinstance(task, str)
|
||||
|
||||
# Verify valid tasks are preserved
|
||||
valid_tasks = [task for task in processed_tasks if task.startswith("valid_legacy_task_")]
|
||||
assert len(valid_tasks) == 4
|
||||
|
||||
# Verify malformed data is handled gracefully
|
||||
malformed_tasks = [task for task in processed_tasks if not task.startswith("valid_legacy_task_")]
|
||||
assert len(malformed_tasks) == 2
|
||||
assert "malformed_data_string" in malformed_tasks
|
||||
assert "invalid_json_not_taskwrapper_format" in malformed_tasks
|
||||
@ -1,17 +1,33 @@
|
||||
from dataclasses import asdict
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.entities.document_task import DocumentTask
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document
|
||||
from tasks.document_indexing_task import document_indexing_task
|
||||
from tasks.document_indexing_task import (
|
||||
_document_indexing, # Core function
|
||||
_document_indexing_with_tenant_queue, # Tenant queue wrapper function
|
||||
document_indexing_task, # Deprecated old interface
|
||||
normal_document_indexing_task, # New normal task
|
||||
priority_document_indexing_task, # New priority task
|
||||
)
|
||||
|
||||
|
||||
class TestDocumentIndexingTask:
|
||||
"""Integration tests for document_indexing_task using testcontainers."""
|
||||
class TestDocumentIndexingTasks:
|
||||
"""Integration tests for document indexing tasks using testcontainers.
|
||||
|
||||
This test class covers:
|
||||
- Core _document_indexing function
|
||||
- Deprecated document_indexing_task function
|
||||
- New normal_document_indexing_task function
|
||||
- New priority_document_indexing_task function
|
||||
- Tenant queue wrapper _document_indexing_with_tenant_queue function
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
@ -224,7 +240,7 @@ class TestDocumentIndexingTask:
|
||||
document_ids = [doc.id for doc in documents]
|
||||
|
||||
# Act: Execute the task
|
||||
document_indexing_task(dataset.id, document_ids)
|
||||
_document_indexing(dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
# Verify indexing runner was called correctly
|
||||
@ -232,10 +248,11 @@ class TestDocumentIndexingTask:
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify documents were updated to parsing status
|
||||
for document in documents:
|
||||
db.session.refresh(document)
|
||||
assert document.indexing_status == "parsing"
|
||||
assert document.processing_started_at is not None
|
||||
# Re-query documents from database since _document_indexing uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
# Verify the run method was called with correct documents
|
||||
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
|
||||
@ -261,7 +278,7 @@ class TestDocumentIndexingTask:
|
||||
document_ids = [fake.uuid4() for _ in range(3)]
|
||||
|
||||
# Act: Execute the task with non-existent dataset
|
||||
document_indexing_task(non_existent_dataset_id, document_ids)
|
||||
_document_indexing(non_existent_dataset_id, document_ids)
|
||||
|
||||
# Assert: Verify no processing occurred
|
||||
mock_external_service_dependencies["indexing_runner"].assert_not_called()
|
||||
@ -291,17 +308,18 @@ class TestDocumentIndexingTask:
|
||||
all_document_ids = existing_document_ids + non_existent_document_ids
|
||||
|
||||
# Act: Execute the task with mixed document IDs
|
||||
document_indexing_task(dataset.id, all_document_ids)
|
||||
_document_indexing(dataset.id, all_document_ids)
|
||||
|
||||
# Assert: Verify only existing documents were processed
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify only existing documents were updated
|
||||
for document in documents:
|
||||
db.session.refresh(document)
|
||||
assert document.indexing_status == "parsing"
|
||||
assert document.processing_started_at is not None
|
||||
# Re-query documents from database since _document_indexing uses a different session
|
||||
for doc_id in existing_document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
# Verify the run method was called with only existing documents
|
||||
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
|
||||
@ -333,7 +351,7 @@ class TestDocumentIndexingTask:
|
||||
)
|
||||
|
||||
# Act: Execute the task
|
||||
document_indexing_task(dataset.id, document_ids)
|
||||
_document_indexing(dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify exception was handled gracefully
|
||||
# The task should complete without raising exceptions
|
||||
@ -341,10 +359,11 @@ class TestDocumentIndexingTask:
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify documents were still updated to parsing status before the exception
|
||||
for document in documents:
|
||||
db.session.refresh(document)
|
||||
assert document.indexing_status == "parsing"
|
||||
assert document.processing_started_at is not None
|
||||
# Re-query documents from database since _document_indexing close the session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
def test_document_indexing_task_mixed_document_states(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
@ -407,17 +426,18 @@ class TestDocumentIndexingTask:
|
||||
document_ids = [doc.id for doc in all_documents]
|
||||
|
||||
# Act: Execute the task with mixed document states
|
||||
document_indexing_task(dataset.id, document_ids)
|
||||
_document_indexing(dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify processing
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify all documents were updated to parsing status
|
||||
for document in all_documents:
|
||||
db.session.refresh(document)
|
||||
assert document.indexing_status == "parsing"
|
||||
assert document.processing_started_at is not None
|
||||
# Re-query documents from database since _document_indexing uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
# Verify the run method was called with all documents
|
||||
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
|
||||
@ -470,15 +490,16 @@ class TestDocumentIndexingTask:
|
||||
document_ids = [doc.id for doc in all_documents]
|
||||
|
||||
# Act: Execute the task with too many documents for sandbox plan
|
||||
document_indexing_task(dataset.id, document_ids)
|
||||
_document_indexing(dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify error handling
|
||||
for document in all_documents:
|
||||
db.session.refresh(document)
|
||||
assert document.indexing_status == "error"
|
||||
assert document.error is not None
|
||||
assert "batch upload" in document.error
|
||||
assert document.stopped_at is not None
|
||||
# Re-query documents from database since _document_indexing uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "error"
|
||||
assert updated_document.error is not None
|
||||
assert "batch upload" in updated_document.error
|
||||
assert updated_document.stopped_at is not None
|
||||
|
||||
# Verify no indexing runner was called
|
||||
mock_external_service_dependencies["indexing_runner"].assert_not_called()
|
||||
@ -503,17 +524,18 @@ class TestDocumentIndexingTask:
|
||||
document_ids = [doc.id for doc in documents]
|
||||
|
||||
# Act: Execute the task with billing disabled
|
||||
document_indexing_task(dataset.id, document_ids)
|
||||
_document_indexing(dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify successful processing
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify documents were updated to parsing status
|
||||
for document in documents:
|
||||
db.session.refresh(document)
|
||||
assert document.indexing_status == "parsing"
|
||||
assert document.processing_started_at is not None
|
||||
# Re-query documents from database since _document_indexing uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
def test_document_indexing_task_document_is_paused_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
@ -541,7 +563,7 @@ class TestDocumentIndexingTask:
|
||||
)
|
||||
|
||||
# Act: Execute the task
|
||||
document_indexing_task(dataset.id, document_ids)
|
||||
_document_indexing(dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify exception was handled gracefully
|
||||
# The task should complete without raising exceptions
|
||||
@ -549,7 +571,317 @@ class TestDocumentIndexingTask:
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify documents were still updated to parsing status before the exception
|
||||
for document in documents:
|
||||
db.session.refresh(document)
|
||||
assert document.indexing_status == "parsing"
|
||||
assert document.processing_started_at is not None
|
||||
# Re-query documents from database since _document_indexing uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
# ==================== NEW TESTS FOR REFACTORED FUNCTIONS ====================
|
||||
def test_old_document_indexing_task_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test document_indexing_task basic functionality.
|
||||
|
||||
This test verifies:
|
||||
- Task function calls the wrapper correctly
|
||||
- Basic parameter passing works
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
dataset, documents = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=1
|
||||
)
|
||||
document_ids = [doc.id for doc in documents]
|
||||
|
||||
# Act: Execute the deprecated task (it only takes 2 parameters)
|
||||
document_indexing_task(dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify processing occurred (core logic is tested in _document_indexing tests)
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
def test_normal_document_indexing_task_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test normal_document_indexing_task basic functionality.
|
||||
|
||||
This test verifies:
|
||||
- Task function calls the wrapper correctly
|
||||
- Basic parameter passing works
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
dataset, documents = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=1
|
||||
)
|
||||
document_ids = [doc.id for doc in documents]
|
||||
tenant_id = dataset.tenant_id
|
||||
|
||||
# Act: Execute the new normal task
|
||||
normal_document_indexing_task(tenant_id, dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify processing occurred (core logic is tested in _document_indexing tests)
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
def test_priority_document_indexing_task_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test priority_document_indexing_task basic functionality.
|
||||
|
||||
This test verifies:
|
||||
- Task function calls the wrapper correctly
|
||||
- Basic parameter passing works
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
dataset, documents = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=1
|
||||
)
|
||||
document_ids = [doc.id for doc in documents]
|
||||
tenant_id = dataset.tenant_id
|
||||
|
||||
# Act: Execute the new priority task
|
||||
priority_document_indexing_task(tenant_id, dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify processing occurred (core logic is tested in _document_indexing tests)
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
def test_document_indexing_with_tenant_queue_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test _document_indexing_with_tenant_queue function with no waiting tasks.
|
||||
|
||||
This test verifies:
|
||||
- Core indexing logic execution (same as _document_indexing)
|
||||
- Tenant queue cleanup when no waiting tasks
|
||||
- Task function parameter passing
|
||||
- Queue management after processing
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
dataset, documents = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=2
|
||||
)
|
||||
document_ids = [doc.id for doc in documents]
|
||||
tenant_id = dataset.tenant_id
|
||||
|
||||
# Mock the task function
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_task_func = MagicMock()
|
||||
|
||||
# Act: Execute the wrapper function
|
||||
_document_indexing_with_tenant_queue(tenant_id, dataset.id, document_ids, mock_task_func)
|
||||
|
||||
# Assert: Verify core processing occurred (same as _document_indexing)
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify documents were updated (same as _document_indexing)
|
||||
# Re-query documents from database since _document_indexing uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
# Verify the run method was called with correct documents
|
||||
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
|
||||
assert call_args is not None
|
||||
processed_documents = call_args[0][0]
|
||||
assert len(processed_documents) == 2
|
||||
|
||||
# Verify task function was not called (no waiting tasks)
|
||||
mock_task_func.delay.assert_not_called()
|
||||
|
||||
def test_document_indexing_with_tenant_queue_with_waiting_tasks(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test _document_indexing_with_tenant_queue function with waiting tasks in queue using real Redis.
|
||||
|
||||
This test verifies:
|
||||
- Core indexing logic execution
|
||||
- Real Redis-based tenant queue processing of waiting tasks
|
||||
- Task function calls for waiting tasks
|
||||
- Queue management with multiple tasks using actual Redis operations
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
dataset, documents = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=1
|
||||
)
|
||||
document_ids = [doc.id for doc in documents]
|
||||
tenant_id = dataset.tenant_id
|
||||
dataset_id = dataset.id
|
||||
|
||||
# Mock the task function
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_task_func = MagicMock()
|
||||
|
||||
# Use real Redis for TenantIsolatedTaskQueue
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
|
||||
# Create real queue instance
|
||||
queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing")
|
||||
|
||||
# Add waiting tasks to the real Redis queue
|
||||
waiting_tasks = [
|
||||
DocumentTask(tenant_id=tenant_id, dataset_id=dataset.id, document_ids=["waiting-doc-1"]),
|
||||
DocumentTask(tenant_id=tenant_id, dataset_id=dataset.id, document_ids=["waiting-doc-2"]),
|
||||
]
|
||||
# Convert DocumentTask objects to dictionaries for serialization
|
||||
waiting_task_dicts = [asdict(task) for task in waiting_tasks]
|
||||
queue.push_tasks(waiting_task_dicts)
|
||||
|
||||
# Act: Execute the wrapper function
|
||||
_document_indexing_with_tenant_queue(tenant_id, dataset.id, document_ids, mock_task_func)
|
||||
|
||||
# Assert: Verify core processing occurred
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify task function was called for each waiting task
|
||||
assert mock_task_func.delay.call_count == 1
|
||||
|
||||
# Verify correct parameters for each call
|
||||
calls = mock_task_func.delay.call_args_list
|
||||
assert calls[0][1] == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]}
|
||||
|
||||
# Verify queue is empty after processing (tasks were pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10) # Pull more than we added
|
||||
assert len(remaining_tasks) == 1
|
||||
|
||||
def test_document_indexing_with_tenant_queue_error_handling(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error handling in _document_indexing_with_tenant_queue using real Redis.
|
||||
|
||||
This test verifies:
|
||||
- Exception handling during core processing
|
||||
- Tenant queue cleanup even on errors using real Redis
|
||||
- Proper error logging
|
||||
- Function completes without raising exceptions
|
||||
- Queue management continues despite core processing errors
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
dataset, documents = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=1
|
||||
)
|
||||
document_ids = [doc.id for doc in documents]
|
||||
tenant_id = dataset.tenant_id
|
||||
dataset_id = dataset.id
|
||||
|
||||
# Mock IndexingRunner to raise an exception
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.side_effect = Exception("Test error")
|
||||
|
||||
# Mock the task function
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_task_func = MagicMock()
|
||||
|
||||
# Use real Redis for TenantIsolatedTaskQueue
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
|
||||
# Create real queue instance
|
||||
queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing")
|
||||
|
||||
# Add waiting task to the real Redis queue
|
||||
waiting_task = DocumentTask(tenant_id=tenant_id, dataset_id=dataset.id, document_ids=["waiting-doc-1"])
|
||||
queue.push_tasks([asdict(waiting_task)])
|
||||
|
||||
# Act: Execute the wrapper function
|
||||
_document_indexing_with_tenant_queue(tenant_id, dataset.id, document_ids, mock_task_func)
|
||||
|
||||
# Assert: Verify error was handled gracefully
|
||||
# The function should not raise exceptions
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify documents were still updated to parsing status before the exception
|
||||
# Re-query documents from database since _document_indexing uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
# Verify waiting task was still processed despite core processing error
|
||||
mock_task_func.delay.assert_called_once()
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call = mock_task_func.delay.call_args
|
||||
assert call[1] == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]}
|
||||
|
||||
# Verify queue is empty after processing (task was pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10)
|
||||
assert len(remaining_tasks) == 0
|
||||
|
||||
def test_document_indexing_with_tenant_queue_tenant_isolation(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tenant isolation in _document_indexing_with_tenant_queue using real Redis.
|
||||
|
||||
This test verifies:
|
||||
- Different tenants have isolated queues
|
||||
- Tasks from one tenant don't affect another tenant's queue
|
||||
- Queue operations are properly scoped to tenant
|
||||
"""
|
||||
# Arrange: Create test data for two different tenants
|
||||
dataset1, documents1 = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=1
|
||||
)
|
||||
dataset2, documents2 = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=1
|
||||
)
|
||||
|
||||
tenant1_id = dataset1.tenant_id
|
||||
tenant2_id = dataset2.tenant_id
|
||||
dataset1_id = dataset1.id
|
||||
dataset2_id = dataset2.id
|
||||
document_ids1 = [doc.id for doc in documents1]
|
||||
document_ids2 = [doc.id for doc in documents2]
|
||||
|
||||
# Mock the task function
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_task_func = MagicMock()
|
||||
|
||||
# Use real Redis for TenantIsolatedTaskQueue
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
|
||||
# Create queue instances for both tenants
|
||||
queue1 = TenantIsolatedTaskQueue(tenant1_id, "document_indexing")
|
||||
queue2 = TenantIsolatedTaskQueue(tenant2_id, "document_indexing")
|
||||
|
||||
# Add waiting tasks to both queues
|
||||
waiting_task1 = DocumentTask(tenant_id=tenant1_id, dataset_id=dataset1.id, document_ids=["tenant1-doc-1"])
|
||||
waiting_task2 = DocumentTask(tenant_id=tenant2_id, dataset_id=dataset2.id, document_ids=["tenant2-doc-1"])
|
||||
|
||||
queue1.push_tasks([asdict(waiting_task1)])
|
||||
queue2.push_tasks([asdict(waiting_task2)])
|
||||
|
||||
# Act: Execute the wrapper function for tenant1 only
|
||||
_document_indexing_with_tenant_queue(tenant1_id, dataset1.id, document_ids1, mock_task_func)
|
||||
|
||||
# Assert: Verify core processing occurred for tenant1
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify only tenant1's waiting task was processed
|
||||
mock_task_func.delay.assert_called_once()
|
||||
call = mock_task_func.delay.call_args
|
||||
assert call[1] == {"tenant_id": tenant1_id, "dataset_id": dataset1_id, "document_ids": ["tenant1-doc-1"]}
|
||||
|
||||
# Verify tenant1's queue is empty
|
||||
remaining_tasks1 = queue1.pull_tasks(count=10)
|
||||
assert len(remaining_tasks1) == 0
|
||||
|
||||
# Verify tenant2's queue still has its task (isolation)
|
||||
remaining_tasks2 = queue2.pull_tasks(count=10)
|
||||
assert len(remaining_tasks2) == 1
|
||||
|
||||
# Verify queue keys are different
|
||||
assert queue1._queue != queue2._queue
|
||||
assert queue1._task_key != queue2._task_key
|
||||
|
||||
@ -0,0 +1,936 @@
|
||||
import json
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
|
||||
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from extensions.ext_database import db
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Pipeline
|
||||
from models.workflow import Workflow
|
||||
from tasks.rag_pipeline.priority_rag_pipeline_run_task import (
|
||||
priority_rag_pipeline_run_task,
|
||||
run_single_rag_pipeline_task,
|
||||
)
|
||||
from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task
|
||||
|
||||
|
||||
class TestRagPipelineRunTasks:
|
||||
"""Integration tests for RAG pipeline run tasks using testcontainers.
|
||||
|
||||
This test class covers:
|
||||
- priority_rag_pipeline_run_task function
|
||||
- rag_pipeline_run_task function
|
||||
- run_single_rag_pipeline_task function
|
||||
- Real Redis-based TenantIsolatedTaskQueue operations
|
||||
- PipelineGenerator._generate method mocking and parameter validation
|
||||
- File operations and cleanup
|
||||
- Error handling and queue management
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pipeline_generator(self):
|
||||
"""Mock PipelineGenerator._generate method."""
|
||||
with patch("core.app.apps.pipeline.pipeline_generator.PipelineGenerator._generate") as mock_generate:
|
||||
# Mock the _generate method to return a simple response
|
||||
mock_generate.return_value = {"answer": "Test response", "metadata": {"test": "data"}}
|
||||
yield mock_generate
|
||||
|
||||
@pytest.fixture
|
||||
def mock_file_service(self):
|
||||
"""Mock FileService for file operations."""
|
||||
with (
|
||||
patch("services.file_service.FileService.get_file_content") as mock_get_content,
|
||||
patch("services.file_service.FileService.delete_file") as mock_delete_file,
|
||||
):
|
||||
yield {
|
||||
"get_content": mock_get_content,
|
||||
"delete_file": mock_delete_file,
|
||||
}
|
||||
|
||||
def _create_test_pipeline_and_workflow(self, db_session_with_containers):
|
||||
"""
|
||||
Helper method to create test pipeline and workflow for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
|
||||
Returns:
|
||||
tuple: (account, tenant, pipeline, workflow) - Created entities
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant
|
||||
account = Account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
|
||||
# Create workflow
|
||||
workflow = Workflow(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
app_id=str(uuid.uuid4()),
|
||||
type="workflow",
|
||||
version="draft",
|
||||
graph="{}",
|
||||
features="{}",
|
||||
marked_name=fake.company(),
|
||||
marked_comment=fake.text(max_nb_chars=100),
|
||||
created_by=account.id,
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
rag_pipeline_variables=[],
|
||||
)
|
||||
db.session.add(workflow)
|
||||
db.session.commit()
|
||||
|
||||
# Create pipeline
|
||||
pipeline = Pipeline(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
workflow_id=workflow.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
created_by=account.id,
|
||||
)
|
||||
db.session.add(pipeline)
|
||||
db.session.commit()
|
||||
|
||||
# Refresh entities to ensure they're properly loaded
|
||||
db.session.refresh(account)
|
||||
db.session.refresh(tenant)
|
||||
db.session.refresh(workflow)
|
||||
db.session.refresh(pipeline)
|
||||
|
||||
return account, tenant, pipeline, workflow
|
||||
|
||||
def _create_rag_pipeline_invoke_entities(self, account, tenant, pipeline, workflow, count=2):
|
||||
"""
|
||||
Helper method to create RAG pipeline invoke entities for testing.
|
||||
|
||||
Args:
|
||||
account: Account instance
|
||||
tenant: Tenant instance
|
||||
pipeline: Pipeline instance
|
||||
workflow: Workflow instance
|
||||
count: Number of entities to create
|
||||
|
||||
Returns:
|
||||
list: List of RagPipelineInvokeEntity instances
|
||||
"""
|
||||
fake = Faker()
|
||||
entities = []
|
||||
|
||||
for i in range(count):
|
||||
# Create application generate entity
|
||||
app_config = {
|
||||
"app_id": str(uuid.uuid4()),
|
||||
"app_name": fake.company(),
|
||||
"mode": "workflow",
|
||||
"workflow_id": workflow.id,
|
||||
"tenant_id": tenant.id,
|
||||
"app_mode": "workflow",
|
||||
}
|
||||
|
||||
application_generate_entity = {
|
||||
"task_id": str(uuid.uuid4()),
|
||||
"app_config": app_config,
|
||||
"inputs": {"query": f"Test query {i}"},
|
||||
"files": [],
|
||||
"user_id": account.id,
|
||||
"stream": False,
|
||||
"invoke_from": "published",
|
||||
"workflow_execution_id": str(uuid.uuid4()),
|
||||
"pipeline_config": {
|
||||
"app_id": str(uuid.uuid4()),
|
||||
"app_name": fake.company(),
|
||||
"mode": "workflow",
|
||||
"workflow_id": workflow.id,
|
||||
"tenant_id": tenant.id,
|
||||
"app_mode": "workflow",
|
||||
},
|
||||
"datasource_type": "upload_file",
|
||||
"datasource_info": {},
|
||||
"dataset_id": str(uuid.uuid4()),
|
||||
"batch": "test_batch",
|
||||
}
|
||||
|
||||
entity = RagPipelineInvokeEntity(
|
||||
pipeline_id=pipeline.id,
|
||||
application_generate_entity=application_generate_entity,
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
workflow_id=workflow.id,
|
||||
streaming=False,
|
||||
workflow_execution_id=str(uuid.uuid4()),
|
||||
workflow_thread_pool_id=str(uuid.uuid4()),
|
||||
)
|
||||
entities.append(entity)
|
||||
|
||||
return entities
|
||||
|
||||
def _create_file_content_for_entities(self, entities):
|
||||
"""
|
||||
Helper method to create file content for RAG pipeline invoke entities.
|
||||
|
||||
Args:
|
||||
entities: List of RagPipelineInvokeEntity instances
|
||||
|
||||
Returns:
|
||||
str: JSON string containing serialized entities
|
||||
"""
|
||||
entities_data = [entity.model_dump() for entity in entities]
|
||||
return json.dumps(entities_data)
|
||||
|
||||
def test_priority_rag_pipeline_run_task_success(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test successful priority RAG pipeline run task execution.
|
||||
|
||||
This test verifies:
|
||||
- Task execution with multiple RAG pipeline invoke entities
|
||||
- File content retrieval and parsing
|
||||
- PipelineGenerator._generate method calls with correct parameters
|
||||
- Thread pool execution
|
||||
- File cleanup after execution
|
||||
- Queue management with no waiting tasks
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=2)
|
||||
file_content = self._create_file_content_for_entities(entities)
|
||||
|
||||
# Mock file service
|
||||
file_id = str(uuid.uuid4())
|
||||
mock_file_service["get_content"].return_value = file_content
|
||||
|
||||
# Act: Execute the priority task
|
||||
priority_rag_pipeline_run_task(file_id, tenant.id)
|
||||
|
||||
# Assert: Verify expected outcomes
|
||||
# Verify file operations
|
||||
mock_file_service["get_content"].assert_called_once_with(file_id)
|
||||
mock_file_service["delete_file"].assert_called_once_with(file_id)
|
||||
|
||||
# Verify PipelineGenerator._generate was called for each entity
|
||||
assert mock_pipeline_generator.call_count == 2
|
||||
|
||||
# Verify call parameters for each entity
|
||||
calls = mock_pipeline_generator.call_args_list
|
||||
for call in calls:
|
||||
call_kwargs = call[1] # Get keyword arguments
|
||||
assert call_kwargs["pipeline"].id == pipeline.id
|
||||
assert call_kwargs["workflow_id"] == workflow.id
|
||||
assert call_kwargs["user"].id == account.id
|
||||
assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED
|
||||
assert call_kwargs["streaming"] == False
|
||||
assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity)
|
||||
|
||||
def test_rag_pipeline_run_task_success(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test successful regular RAG pipeline run task execution.
|
||||
|
||||
This test verifies:
|
||||
- Task execution with multiple RAG pipeline invoke entities
|
||||
- File content retrieval and parsing
|
||||
- PipelineGenerator._generate method calls with correct parameters
|
||||
- Thread pool execution
|
||||
- File cleanup after execution
|
||||
- Queue management with no waiting tasks
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=3)
|
||||
file_content = self._create_file_content_for_entities(entities)
|
||||
|
||||
# Mock file service
|
||||
file_id = str(uuid.uuid4())
|
||||
mock_file_service["get_content"].return_value = file_content
|
||||
|
||||
# Act: Execute the regular task
|
||||
rag_pipeline_run_task(file_id, tenant.id)
|
||||
|
||||
# Assert: Verify expected outcomes
|
||||
# Verify file operations
|
||||
mock_file_service["get_content"].assert_called_once_with(file_id)
|
||||
mock_file_service["delete_file"].assert_called_once_with(file_id)
|
||||
|
||||
# Verify PipelineGenerator._generate was called for each entity
|
||||
assert mock_pipeline_generator.call_count == 3
|
||||
|
||||
# Verify call parameters for each entity
|
||||
calls = mock_pipeline_generator.call_args_list
|
||||
for call in calls:
|
||||
call_kwargs = call[1] # Get keyword arguments
|
||||
assert call_kwargs["pipeline"].id == pipeline.id
|
||||
assert call_kwargs["workflow_id"] == workflow.id
|
||||
assert call_kwargs["user"].id == account.id
|
||||
assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED
|
||||
assert call_kwargs["streaming"] == False
|
||||
assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity)
|
||||
|
||||
def test_priority_rag_pipeline_run_task_with_waiting_tasks(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test priority RAG pipeline run task with waiting tasks in queue using real Redis.
|
||||
|
||||
This test verifies:
|
||||
- Core task execution
|
||||
- Real Redis-based tenant queue processing of waiting tasks
|
||||
- Task function calls for waiting tasks
|
||||
- Queue management with multiple tasks using actual Redis operations
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1)
|
||||
file_content = self._create_file_content_for_entities(entities)
|
||||
|
||||
# Mock file service
|
||||
file_id = str(uuid.uuid4())
|
||||
mock_file_service["get_content"].return_value = file_content
|
||||
|
||||
# Use real Redis for TenantIsolatedTaskQueue
|
||||
queue = TenantIsolatedTaskQueue(tenant.id, "pipeline")
|
||||
|
||||
# Add waiting tasks to the real Redis queue
|
||||
waiting_file_ids = [str(uuid.uuid4()) for _ in range(2)]
|
||||
queue.push_tasks(waiting_file_ids)
|
||||
|
||||
# Mock the task function calls
|
||||
with patch(
|
||||
"tasks.rag_pipeline.priority_rag_pipeline_run_task.priority_rag_pipeline_run_task.delay"
|
||||
) as mock_delay:
|
||||
# Act: Execute the priority task
|
||||
priority_rag_pipeline_run_task(file_id, tenant.id)
|
||||
|
||||
# Assert: Verify core processing occurred
|
||||
mock_file_service["get_content"].assert_called_once_with(file_id)
|
||||
mock_file_service["delete_file"].assert_called_once_with(file_id)
|
||||
assert mock_pipeline_generator.call_count == 1
|
||||
|
||||
# Verify waiting tasks were processed, pull 1 task a time by default
|
||||
assert mock_delay.call_count == 1
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_ids[0]
|
||||
assert call_kwargs.get("tenant_id") == tenant.id
|
||||
|
||||
# Verify queue still has remaining tasks (only 1 was pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10)
|
||||
assert len(remaining_tasks) == 1 # 2 original - 1 pulled = 1 remaining
|
||||
|
||||
def test_rag_pipeline_run_task_legacy_compatibility(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test regular RAG pipeline run task with legacy Redis queue format for backward compatibility.
|
||||
|
||||
This test simulates the scenario where:
|
||||
- Old code writes file IDs directly to Redis list using lpush
|
||||
- New worker processes these legacy queue entries
|
||||
- Ensures backward compatibility during deployment transition
|
||||
|
||||
Legacy format: redis_client.lpush(tenant_self_pipeline_task_queue, upload_file.id)
|
||||
New format: TenantIsolatedTaskQueue.push_tasks([file_id])
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1)
|
||||
file_content = self._create_file_content_for_entities(entities)
|
||||
|
||||
# Mock file service
|
||||
file_id = str(uuid.uuid4())
|
||||
mock_file_service["get_content"].return_value = file_content
|
||||
|
||||
# Simulate legacy Redis queue format - direct file IDs in Redis list
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
# Legacy queue key format (old code)
|
||||
legacy_queue_key = f"tenant_self_pipeline_task_queue:{tenant.id}"
|
||||
legacy_task_key = f"tenant_pipeline_task:{tenant.id}"
|
||||
|
||||
# Add legacy format data to Redis (simulating old code behavior)
|
||||
legacy_file_ids = [str(uuid.uuid4()) for _ in range(3)]
|
||||
for file_id_legacy in legacy_file_ids:
|
||||
redis_client.lpush(legacy_queue_key, file_id_legacy)
|
||||
|
||||
# Set the task key to indicate there are waiting tasks (legacy behavior)
|
||||
redis_client.set(legacy_task_key, 1, ex=60 * 60)
|
||||
|
||||
# Mock the task function calls
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
|
||||
# Act: Execute the priority task with new code but legacy queue data
|
||||
rag_pipeline_run_task(file_id, tenant.id)
|
||||
|
||||
# Assert: Verify core processing occurred
|
||||
mock_file_service["get_content"].assert_called_once_with(file_id)
|
||||
mock_file_service["delete_file"].assert_called_once_with(file_id)
|
||||
assert mock_pipeline_generator.call_count == 1
|
||||
|
||||
# Verify waiting tasks were processed, pull 1 task a time by default
|
||||
assert mock_delay.call_count == 1
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == legacy_file_ids[0]
|
||||
assert call_kwargs.get("tenant_id") == tenant.id
|
||||
|
||||
# Verify that new code can process legacy queue entries
|
||||
# The new TenantIsolatedTaskQueue should be able to read from the legacy format
|
||||
queue = TenantIsolatedTaskQueue(tenant.id, "pipeline")
|
||||
|
||||
# Verify queue still has remaining tasks (only 1 was pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10)
|
||||
assert len(remaining_tasks) == 2 # 3 original - 1 pulled = 2 remaining
|
||||
|
||||
# Cleanup: Remove legacy test data
|
||||
redis_client.delete(legacy_queue_key)
|
||||
redis_client.delete(legacy_task_key)
|
||||
|
||||
def test_rag_pipeline_run_task_with_waiting_tasks(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test regular RAG pipeline run task with waiting tasks in queue using real Redis.
|
||||
|
||||
This test verifies:
|
||||
- Core task execution
|
||||
- Real Redis-based tenant queue processing of waiting tasks
|
||||
- Task function calls for waiting tasks
|
||||
- Queue management with multiple tasks using actual Redis operations
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1)
|
||||
file_content = self._create_file_content_for_entities(entities)
|
||||
|
||||
# Mock file service
|
||||
file_id = str(uuid.uuid4())
|
||||
mock_file_service["get_content"].return_value = file_content
|
||||
|
||||
# Use real Redis for TenantIsolatedTaskQueue
|
||||
queue = TenantIsolatedTaskQueue(tenant.id, "pipeline")
|
||||
|
||||
# Add waiting tasks to the real Redis queue
|
||||
waiting_file_ids = [str(uuid.uuid4()) for _ in range(3)]
|
||||
queue.push_tasks(waiting_file_ids)
|
||||
|
||||
# Mock the task function calls
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
|
||||
# Act: Execute the regular task
|
||||
rag_pipeline_run_task(file_id, tenant.id)
|
||||
|
||||
# Assert: Verify core processing occurred
|
||||
mock_file_service["get_content"].assert_called_once_with(file_id)
|
||||
mock_file_service["delete_file"].assert_called_once_with(file_id)
|
||||
assert mock_pipeline_generator.call_count == 1
|
||||
|
||||
# Verify waiting tasks were processed, pull 1 task a time by default
|
||||
assert mock_delay.call_count == 1
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_ids[0]
|
||||
assert call_kwargs.get("tenant_id") == tenant.id
|
||||
|
||||
# Verify queue still has remaining tasks (only 1 was pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10)
|
||||
assert len(remaining_tasks) == 2 # 3 original - 1 pulled = 2 remaining
|
||||
|
||||
def test_priority_rag_pipeline_run_task_error_handling(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test error handling in priority RAG pipeline run task using real Redis.
|
||||
|
||||
This test verifies:
|
||||
- Exception handling during core processing
|
||||
- Tenant queue cleanup even on errors using real Redis
|
||||
- Proper error logging
|
||||
- Function completes without raising exceptions
|
||||
- Queue management continues despite core processing errors
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1)
|
||||
file_content = self._create_file_content_for_entities(entities)
|
||||
|
||||
# Mock file service
|
||||
file_id = str(uuid.uuid4())
|
||||
mock_file_service["get_content"].return_value = file_content
|
||||
|
||||
# Mock PipelineGenerator to raise an exception
|
||||
mock_pipeline_generator.side_effect = Exception("Pipeline generation failed")
|
||||
|
||||
# Use real Redis for TenantIsolatedTaskQueue
|
||||
queue = TenantIsolatedTaskQueue(tenant.id, "pipeline")
|
||||
|
||||
# Add waiting task to the real Redis queue
|
||||
waiting_file_id = str(uuid.uuid4())
|
||||
queue.push_tasks([waiting_file_id])
|
||||
|
||||
# Mock the task function calls
|
||||
with patch(
|
||||
"tasks.rag_pipeline.priority_rag_pipeline_run_task.priority_rag_pipeline_run_task.delay"
|
||||
) as mock_delay:
|
||||
# Act: Execute the priority task (should not raise exception)
|
||||
priority_rag_pipeline_run_task(file_id, tenant.id)
|
||||
|
||||
# Assert: Verify error was handled gracefully
|
||||
# The function should not raise exceptions
|
||||
mock_file_service["get_content"].assert_called_once_with(file_id)
|
||||
mock_file_service["delete_file"].assert_called_once_with(file_id)
|
||||
assert mock_pipeline_generator.call_count == 1
|
||||
|
||||
# Verify waiting task was still processed despite core processing error
|
||||
mock_delay.assert_called_once()
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
|
||||
assert call_kwargs.get("tenant_id") == tenant.id
|
||||
|
||||
# Verify queue is empty after processing (task was pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10)
|
||||
assert len(remaining_tasks) == 0
|
||||
|
||||
def test_rag_pipeline_run_task_error_handling(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test error handling in regular RAG pipeline run task using real Redis.
|
||||
|
||||
This test verifies:
|
||||
- Exception handling during core processing
|
||||
- Tenant queue cleanup even on errors using real Redis
|
||||
- Proper error logging
|
||||
- Function completes without raising exceptions
|
||||
- Queue management continues despite core processing errors
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1)
|
||||
file_content = self._create_file_content_for_entities(entities)
|
||||
|
||||
# Mock file service
|
||||
file_id = str(uuid.uuid4())
|
||||
mock_file_service["get_content"].return_value = file_content
|
||||
|
||||
# Mock PipelineGenerator to raise an exception
|
||||
mock_pipeline_generator.side_effect = Exception("Pipeline generation failed")
|
||||
|
||||
# Use real Redis for TenantIsolatedTaskQueue
|
||||
queue = TenantIsolatedTaskQueue(tenant.id, "pipeline")
|
||||
|
||||
# Add waiting task to the real Redis queue
|
||||
waiting_file_id = str(uuid.uuid4())
|
||||
queue.push_tasks([waiting_file_id])
|
||||
|
||||
# Mock the task function calls
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
|
||||
# Act: Execute the regular task (should not raise exception)
|
||||
rag_pipeline_run_task(file_id, tenant.id)
|
||||
|
||||
# Assert: Verify error was handled gracefully
|
||||
# The function should not raise exceptions
|
||||
mock_file_service["get_content"].assert_called_once_with(file_id)
|
||||
mock_file_service["delete_file"].assert_called_once_with(file_id)
|
||||
assert mock_pipeline_generator.call_count == 1
|
||||
|
||||
# Verify waiting task was still processed despite core processing error
|
||||
mock_delay.assert_called_once()
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
|
||||
assert call_kwargs.get("tenant_id") == tenant.id
|
||||
|
||||
# Verify queue is empty after processing (task was pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10)
|
||||
assert len(remaining_tasks) == 0
|
||||
|
||||
def test_priority_rag_pipeline_run_task_tenant_isolation(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test tenant isolation in priority RAG pipeline run task using real Redis.
|
||||
|
||||
This test verifies:
|
||||
- Different tenants have isolated queues
|
||||
- Tasks from one tenant don't affect another tenant's queue
|
||||
- Queue operations are properly scoped to tenant
|
||||
"""
|
||||
# Arrange: Create test data for two different tenants
|
||||
account1, tenant1, pipeline1, workflow1 = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
account2, tenant2, pipeline2, workflow2 = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
|
||||
entities1 = self._create_rag_pipeline_invoke_entities(account1, tenant1, pipeline1, workflow1, count=1)
|
||||
entities2 = self._create_rag_pipeline_invoke_entities(account2, tenant2, pipeline2, workflow2, count=1)
|
||||
|
||||
file_content1 = self._create_file_content_for_entities(entities1)
|
||||
file_content2 = self._create_file_content_for_entities(entities2)
|
||||
|
||||
# Mock file service
|
||||
file_id1 = str(uuid.uuid4())
|
||||
file_id2 = str(uuid.uuid4())
|
||||
mock_file_service["get_content"].side_effect = [file_content1, file_content2]
|
||||
|
||||
# Use real Redis for TenantIsolatedTaskQueue
|
||||
queue1 = TenantIsolatedTaskQueue(tenant1.id, "pipeline")
|
||||
queue2 = TenantIsolatedTaskQueue(tenant2.id, "pipeline")
|
||||
|
||||
# Add waiting tasks to both queues
|
||||
waiting_file_id1 = str(uuid.uuid4())
|
||||
waiting_file_id2 = str(uuid.uuid4())
|
||||
|
||||
queue1.push_tasks([waiting_file_id1])
|
||||
queue2.push_tasks([waiting_file_id2])
|
||||
|
||||
# Mock the task function calls
|
||||
with patch(
|
||||
"tasks.rag_pipeline.priority_rag_pipeline_run_task.priority_rag_pipeline_run_task.delay"
|
||||
) as mock_delay:
|
||||
# Act: Execute the priority task for tenant1 only
|
||||
priority_rag_pipeline_run_task(file_id1, tenant1.id)
|
||||
|
||||
# Assert: Verify core processing occurred for tenant1
|
||||
assert mock_file_service["get_content"].call_count == 1
|
||||
assert mock_file_service["delete_file"].call_count == 1
|
||||
assert mock_pipeline_generator.call_count == 1
|
||||
|
||||
# Verify only tenant1's waiting task was processed
|
||||
mock_delay.assert_called_once()
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id1
|
||||
assert call_kwargs.get("tenant_id") == tenant1.id
|
||||
|
||||
# Verify tenant1's queue is empty
|
||||
remaining_tasks1 = queue1.pull_tasks(count=10)
|
||||
assert len(remaining_tasks1) == 0
|
||||
|
||||
# Verify tenant2's queue still has its task (isolation)
|
||||
remaining_tasks2 = queue2.pull_tasks(count=10)
|
||||
assert len(remaining_tasks2) == 1
|
||||
|
||||
# Verify queue keys are different
|
||||
assert queue1._queue != queue2._queue
|
||||
assert queue1._task_key != queue2._task_key
|
||||
|
||||
def test_rag_pipeline_run_task_tenant_isolation(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test tenant isolation in regular RAG pipeline run task using real Redis.
|
||||
|
||||
This test verifies:
|
||||
- Different tenants have isolated queues
|
||||
- Tasks from one tenant don't affect another tenant's queue
|
||||
- Queue operations are properly scoped to tenant
|
||||
"""
|
||||
# Arrange: Create test data for two different tenants
|
||||
account1, tenant1, pipeline1, workflow1 = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
account2, tenant2, pipeline2, workflow2 = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
|
||||
entities1 = self._create_rag_pipeline_invoke_entities(account1, tenant1, pipeline1, workflow1, count=1)
|
||||
entities2 = self._create_rag_pipeline_invoke_entities(account2, tenant2, pipeline2, workflow2, count=1)
|
||||
|
||||
file_content1 = self._create_file_content_for_entities(entities1)
|
||||
file_content2 = self._create_file_content_for_entities(entities2)
|
||||
|
||||
# Mock file service
|
||||
file_id1 = str(uuid.uuid4())
|
||||
file_id2 = str(uuid.uuid4())
|
||||
mock_file_service["get_content"].side_effect = [file_content1, file_content2]
|
||||
|
||||
# Use real Redis for TenantIsolatedTaskQueue
|
||||
queue1 = TenantIsolatedTaskQueue(tenant1.id, "pipeline")
|
||||
queue2 = TenantIsolatedTaskQueue(tenant2.id, "pipeline")
|
||||
|
||||
# Add waiting tasks to both queues
|
||||
waiting_file_id1 = str(uuid.uuid4())
|
||||
waiting_file_id2 = str(uuid.uuid4())
|
||||
|
||||
queue1.push_tasks([waiting_file_id1])
|
||||
queue2.push_tasks([waiting_file_id2])
|
||||
|
||||
# Mock the task function calls
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
|
||||
# Act: Execute the regular task for tenant1 only
|
||||
rag_pipeline_run_task(file_id1, tenant1.id)
|
||||
|
||||
# Assert: Verify core processing occurred for tenant1
|
||||
assert mock_file_service["get_content"].call_count == 1
|
||||
assert mock_file_service["delete_file"].call_count == 1
|
||||
assert mock_pipeline_generator.call_count == 1
|
||||
|
||||
# Verify only tenant1's waiting task was processed
|
||||
mock_delay.assert_called_once()
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id1
|
||||
assert call_kwargs.get("tenant_id") == tenant1.id
|
||||
|
||||
# Verify tenant1's queue is empty
|
||||
remaining_tasks1 = queue1.pull_tasks(count=10)
|
||||
assert len(remaining_tasks1) == 0
|
||||
|
||||
# Verify tenant2's queue still has its task (isolation)
|
||||
remaining_tasks2 = queue2.pull_tasks(count=10)
|
||||
assert len(remaining_tasks2) == 1
|
||||
|
||||
# Verify queue keys are different
|
||||
assert queue1._queue != queue2._queue
|
||||
assert queue1._task_key != queue2._task_key
|
||||
|
||||
def test_run_single_rag_pipeline_task_success(
|
||||
self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers
|
||||
):
|
||||
"""
|
||||
Test successful run_single_rag_pipeline_task execution.
|
||||
|
||||
This test verifies:
|
||||
- Single RAG pipeline task execution within Flask app context
|
||||
- Entity validation and database queries
|
||||
- PipelineGenerator._generate method call with correct parameters
|
||||
- Proper Flask context handling
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1)
|
||||
entity_data = entities[0].model_dump()
|
||||
|
||||
# Act: Execute the single task
|
||||
with flask_app_with_containers.app_context():
|
||||
run_single_rag_pipeline_task(entity_data, flask_app_with_containers)
|
||||
|
||||
# Assert: Verify expected outcomes
|
||||
# Verify PipelineGenerator._generate was called
|
||||
assert mock_pipeline_generator.call_count == 1
|
||||
|
||||
# Verify call parameters
|
||||
call = mock_pipeline_generator.call_args
|
||||
call_kwargs = call[1] # Get keyword arguments
|
||||
assert call_kwargs["pipeline"].id == pipeline.id
|
||||
assert call_kwargs["workflow_id"] == workflow.id
|
||||
assert call_kwargs["user"].id == account.id
|
||||
assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED
|
||||
assert call_kwargs["streaming"] == False
|
||||
assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity)
|
||||
|
||||
def test_run_single_rag_pipeline_task_entity_validation_error(
|
||||
self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers
|
||||
):
|
||||
"""
|
||||
Test run_single_rag_pipeline_task with invalid entity data.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for invalid entity data
|
||||
- Exception logging
|
||||
- Function raises ValueError for missing entities
|
||||
"""
|
||||
# Arrange: Create entity data with valid UUIDs but non-existent entities
|
||||
fake = Faker()
|
||||
invalid_entity_data = {
|
||||
"pipeline_id": str(uuid.uuid4()),
|
||||
"application_generate_entity": {
|
||||
"app_config": {
|
||||
"app_id": str(uuid.uuid4()),
|
||||
"app_name": "Test App",
|
||||
"mode": "workflow",
|
||||
"workflow_id": str(uuid.uuid4()),
|
||||
},
|
||||
"inputs": {"query": "Test query"},
|
||||
"query": "Test query",
|
||||
"response_mode": "blocking",
|
||||
"user": str(uuid.uuid4()),
|
||||
"files": [],
|
||||
"conversation_id": str(uuid.uuid4()),
|
||||
},
|
||||
"user_id": str(uuid.uuid4()),
|
||||
"tenant_id": str(uuid.uuid4()),
|
||||
"workflow_id": str(uuid.uuid4()),
|
||||
"streaming": False,
|
||||
"workflow_execution_id": str(uuid.uuid4()),
|
||||
"workflow_thread_pool_id": str(uuid.uuid4()),
|
||||
}
|
||||
|
||||
# Act & Assert: Execute the single task with non-existent entities (should raise ValueError)
|
||||
with flask_app_with_containers.app_context():
|
||||
with pytest.raises(ValueError, match="Account .* not found"):
|
||||
run_single_rag_pipeline_task(invalid_entity_data, flask_app_with_containers)
|
||||
|
||||
# Assert: Pipeline generator should not be called
|
||||
mock_pipeline_generator.assert_not_called()
|
||||
|
||||
def test_run_single_rag_pipeline_task_database_entity_not_found(
|
||||
self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers
|
||||
):
|
||||
"""
|
||||
Test run_single_rag_pipeline_task with non-existent database entities.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for missing database entities
|
||||
- Exception logging
|
||||
- Function raises ValueError for missing entities
|
||||
"""
|
||||
# Arrange: Create test data with non-existent IDs
|
||||
fake = Faker()
|
||||
entity_data = {
|
||||
"pipeline_id": str(uuid.uuid4()),
|
||||
"application_generate_entity": {
|
||||
"app_config": {
|
||||
"app_id": str(uuid.uuid4()),
|
||||
"app_name": "Test App",
|
||||
"mode": "workflow",
|
||||
"workflow_id": str(uuid.uuid4()),
|
||||
},
|
||||
"inputs": {"query": "Test query"},
|
||||
"query": "Test query",
|
||||
"response_mode": "blocking",
|
||||
"user": str(uuid.uuid4()),
|
||||
"files": [],
|
||||
"conversation_id": str(uuid.uuid4()),
|
||||
},
|
||||
"user_id": str(uuid.uuid4()),
|
||||
"tenant_id": str(uuid.uuid4()),
|
||||
"workflow_id": str(uuid.uuid4()),
|
||||
"streaming": False,
|
||||
"workflow_execution_id": str(uuid.uuid4()),
|
||||
"workflow_thread_pool_id": str(uuid.uuid4()),
|
||||
}
|
||||
|
||||
# Act & Assert: Execute the single task with non-existent entities (should raise ValueError)
|
||||
with flask_app_with_containers.app_context():
|
||||
with pytest.raises(ValueError, match="Account .* not found"):
|
||||
run_single_rag_pipeline_task(entity_data, flask_app_with_containers)
|
||||
|
||||
# Assert: Pipeline generator should not be called
|
||||
mock_pipeline_generator.assert_not_called()
|
||||
|
||||
def test_priority_rag_pipeline_run_task_file_not_found(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test priority RAG pipeline run task with non-existent file.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for missing files
|
||||
- Exception logging
|
||||
- Function raises Exception for file errors
|
||||
- Queue management continues despite file errors
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
|
||||
# Mock file service to raise exception
|
||||
file_id = str(uuid.uuid4())
|
||||
mock_file_service["get_content"].side_effect = Exception("File not found")
|
||||
|
||||
# Use real Redis for TenantIsolatedTaskQueue
|
||||
queue = TenantIsolatedTaskQueue(tenant.id, "pipeline")
|
||||
|
||||
# Add waiting task to the real Redis queue
|
||||
waiting_file_id = str(uuid.uuid4())
|
||||
queue.push_tasks([waiting_file_id])
|
||||
|
||||
# Mock the task function calls
|
||||
with patch(
|
||||
"tasks.rag_pipeline.priority_rag_pipeline_run_task.priority_rag_pipeline_run_task.delay"
|
||||
) as mock_delay:
|
||||
# Act & Assert: Execute the priority task (should raise Exception)
|
||||
with pytest.raises(Exception, match="File not found"):
|
||||
priority_rag_pipeline_run_task(file_id, tenant.id)
|
||||
|
||||
# Assert: Verify error was handled gracefully
|
||||
mock_file_service["get_content"].assert_called_once_with(file_id)
|
||||
mock_pipeline_generator.assert_not_called()
|
||||
|
||||
# Verify waiting task was still processed despite file error
|
||||
mock_delay.assert_called_once()
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
|
||||
assert call_kwargs.get("tenant_id") == tenant.id
|
||||
|
||||
# Verify queue is empty after processing (task was pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10)
|
||||
assert len(remaining_tasks) == 0
|
||||
|
||||
def test_rag_pipeline_run_task_file_not_found(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test regular RAG pipeline run task with non-existent file.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for missing files
|
||||
- Exception logging
|
||||
- Function raises Exception for file errors
|
||||
- Queue management continues despite file errors
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
|
||||
# Mock file service to raise exception
|
||||
file_id = str(uuid.uuid4())
|
||||
mock_file_service["get_content"].side_effect = Exception("File not found")
|
||||
|
||||
# Use real Redis for TenantIsolatedTaskQueue
|
||||
queue = TenantIsolatedTaskQueue(tenant.id, "pipeline")
|
||||
|
||||
# Add waiting task to the real Redis queue
|
||||
waiting_file_id = str(uuid.uuid4())
|
||||
queue.push_tasks([waiting_file_id])
|
||||
|
||||
# Mock the task function calls
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
|
||||
# Act & Assert: Execute the regular task (should raise Exception)
|
||||
with pytest.raises(Exception, match="File not found"):
|
||||
rag_pipeline_run_task(file_id, tenant.id)
|
||||
|
||||
# Assert: Verify error was handled gracefully
|
||||
mock_file_service["get_content"].assert_called_once_with(file_id)
|
||||
mock_pipeline_generator.assert_not_called()
|
||||
|
||||
# Verify waiting task was still processed despite file error
|
||||
mock_delay.assert_called_once()
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
|
||||
assert call_kwargs.get("tenant_id") == tenant.id
|
||||
|
||||
# Verify queue is empty after processing (task was pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10)
|
||||
assert len(remaining_tasks) == 0
|
||||
301
api/tests/unit_tests/core/rag/pipeline/test_queue.py
Normal file
301
api/tests/unit_tests/core/rag/pipeline/test_queue.py
Normal file
@ -0,0 +1,301 @@
|
||||
"""
|
||||
Unit tests for TenantIsolatedTaskQueue.
|
||||
|
||||
These tests verify the Redis-based task queue functionality for tenant-specific
|
||||
task management with proper serialization and deserialization.
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.rag.pipeline.queue import TaskWrapper, TenantIsolatedTaskQueue
|
||||
|
||||
|
||||
class TestTaskWrapper:
|
||||
"""Test cases for TaskWrapper serialization/deserialization."""
|
||||
|
||||
def test_serialize_simple_data(self):
|
||||
"""Test serialization of simple data types."""
|
||||
data = {"key": "value", "number": 42, "list": [1, 2, 3]}
|
||||
wrapper = TaskWrapper(data=data)
|
||||
|
||||
serialized = wrapper.serialize()
|
||||
assert isinstance(serialized, str)
|
||||
|
||||
# Verify it's valid JSON
|
||||
parsed = json.loads(serialized)
|
||||
assert parsed["data"] == data
|
||||
|
||||
def test_serialize_complex_data(self):
|
||||
"""Test serialization of complex nested data."""
|
||||
data = {
|
||||
"nested": {"deep": {"value": "test", "numbers": [1, 2, 3, 4, 5]}},
|
||||
"unicode": "测试中文",
|
||||
"special_chars": "!@#$%^&*()",
|
||||
}
|
||||
wrapper = TaskWrapper(data=data)
|
||||
|
||||
serialized = wrapper.serialize()
|
||||
parsed = json.loads(serialized)
|
||||
assert parsed["data"] == data
|
||||
|
||||
def test_deserialize_valid_data(self):
|
||||
"""Test deserialization of valid JSON data."""
|
||||
original_data = {"key": "value", "number": 42}
|
||||
# Serialize using TaskWrapper to get the correct format
|
||||
wrapper = TaskWrapper(data=original_data)
|
||||
serialized = wrapper.serialize()
|
||||
|
||||
wrapper = TaskWrapper.deserialize(serialized)
|
||||
assert wrapper.data == original_data
|
||||
|
||||
def test_deserialize_invalid_json(self):
|
||||
"""Test deserialization handles invalid JSON gracefully."""
|
||||
invalid_json = "{invalid json}"
|
||||
|
||||
# Pydantic will raise ValidationError for invalid JSON
|
||||
with pytest.raises(ValidationError):
|
||||
TaskWrapper.deserialize(invalid_json)
|
||||
|
||||
def test_serialize_ensure_ascii_false(self):
|
||||
"""Test that serialization preserves Unicode characters."""
|
||||
data = {"chinese": "中文测试", "emoji": "🚀"}
|
||||
wrapper = TaskWrapper(data=data)
|
||||
|
||||
serialized = wrapper.serialize()
|
||||
assert "中文测试" in serialized
|
||||
assert "🚀" in serialized
|
||||
|
||||
|
||||
class TestTenantIsolatedTaskQueue:
|
||||
"""Test cases for TenantIsolatedTaskQueue functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis_client(self):
|
||||
"""Mock Redis client for testing."""
|
||||
mock_redis = MagicMock()
|
||||
return mock_redis
|
||||
|
||||
@pytest.fixture
|
||||
def sample_queue(self, mock_redis_client):
|
||||
"""Create a sample TenantIsolatedTaskQueue instance."""
|
||||
return TenantIsolatedTaskQueue("tenant-123", "test-key")
|
||||
|
||||
def test_initialization(self, sample_queue):
|
||||
"""Test queue initialization with correct key generation."""
|
||||
assert sample_queue._tenant_id == "tenant-123"
|
||||
assert sample_queue._unique_key == "test-key"
|
||||
assert sample_queue._queue == "tenant_self_test-key_task_queue:tenant-123"
|
||||
assert sample_queue._task_key == "tenant_test-key_task:tenant-123"
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_get_task_key_exists(self, mock_redis, sample_queue):
|
||||
"""Test getting task key when it exists."""
|
||||
mock_redis.get.return_value = "1"
|
||||
|
||||
result = sample_queue.get_task_key()
|
||||
|
||||
assert result == "1"
|
||||
mock_redis.get.assert_called_once_with("tenant_test-key_task:tenant-123")
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_get_task_key_not_exists(self, mock_redis, sample_queue):
|
||||
"""Test getting task key when it doesn't exist."""
|
||||
mock_redis.get.return_value = None
|
||||
|
||||
result = sample_queue.get_task_key()
|
||||
|
||||
assert result is None
|
||||
mock_redis.get.assert_called_once_with("tenant_test-key_task:tenant-123")
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_set_task_waiting_time_default_ttl(self, mock_redis, sample_queue):
|
||||
"""Test setting task waiting flag with default TTL."""
|
||||
sample_queue.set_task_waiting_time()
|
||||
|
||||
mock_redis.setex.assert_called_once_with(
|
||||
"tenant_test-key_task:tenant-123",
|
||||
3600, # DEFAULT_TASK_TTL
|
||||
1,
|
||||
)
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_set_task_waiting_time_custom_ttl(self, mock_redis, sample_queue):
|
||||
"""Test setting task waiting flag with custom TTL."""
|
||||
custom_ttl = 1800
|
||||
sample_queue.set_task_waiting_time(custom_ttl)
|
||||
|
||||
mock_redis.setex.assert_called_once_with("tenant_test-key_task:tenant-123", custom_ttl, 1)
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_delete_task_key(self, mock_redis, sample_queue):
|
||||
"""Test deleting task key."""
|
||||
sample_queue.delete_task_key()
|
||||
|
||||
mock_redis.delete.assert_called_once_with("tenant_test-key_task:tenant-123")
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_push_tasks_string_list(self, mock_redis, sample_queue):
|
||||
"""Test pushing string tasks directly."""
|
||||
tasks = ["task1", "task2", "task3"]
|
||||
|
||||
sample_queue.push_tasks(tasks)
|
||||
|
||||
mock_redis.lpush.assert_called_once_with(
|
||||
"tenant_self_test-key_task_queue:tenant-123", "task1", "task2", "task3"
|
||||
)
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_push_tasks_mixed_types(self, mock_redis, sample_queue):
|
||||
"""Test pushing mixed string and object tasks."""
|
||||
tasks = ["string_task", {"object_task": "data", "id": 123}, "another_string"]
|
||||
|
||||
sample_queue.push_tasks(tasks)
|
||||
|
||||
# Verify lpush was called
|
||||
mock_redis.lpush.assert_called_once()
|
||||
call_args = mock_redis.lpush.call_args
|
||||
|
||||
# Check queue name
|
||||
assert call_args[0][0] == "tenant_self_test-key_task_queue:tenant-123"
|
||||
|
||||
# Check serialized tasks
|
||||
serialized_tasks = call_args[0][1:]
|
||||
assert len(serialized_tasks) == 3
|
||||
assert serialized_tasks[0] == "string_task"
|
||||
assert serialized_tasks[2] == "another_string"
|
||||
|
||||
# Check object task is serialized as TaskWrapper JSON (without prefix)
|
||||
# It should be a valid JSON string that can be deserialized by TaskWrapper
|
||||
wrapper = TaskWrapper.deserialize(serialized_tasks[1])
|
||||
assert wrapper.data == {"object_task": "data", "id": 123}
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_push_tasks_empty_list(self, mock_redis, sample_queue):
|
||||
"""Test pushing empty task list."""
|
||||
sample_queue.push_tasks([])
|
||||
|
||||
mock_redis.lpush.assert_called_once_with("tenant_self_test-key_task_queue:tenant-123")
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_pull_tasks_default_count(self, mock_redis, sample_queue):
|
||||
"""Test pulling tasks with default count (1)."""
|
||||
mock_redis.rpop.side_effect = ["task1", None]
|
||||
|
||||
result = sample_queue.pull_tasks()
|
||||
|
||||
assert result == ["task1"]
|
||||
assert mock_redis.rpop.call_count == 1
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_pull_tasks_custom_count(self, mock_redis, sample_queue):
|
||||
"""Test pulling tasks with custom count."""
|
||||
# First test: pull 3 tasks
|
||||
mock_redis.rpop.side_effect = ["task1", "task2", "task3", None]
|
||||
|
||||
result = sample_queue.pull_tasks(3)
|
||||
|
||||
assert result == ["task1", "task2", "task3"]
|
||||
assert mock_redis.rpop.call_count == 3
|
||||
|
||||
# Reset mock for second test
|
||||
mock_redis.reset_mock()
|
||||
mock_redis.rpop.side_effect = ["task1", "task2", None]
|
||||
|
||||
result = sample_queue.pull_tasks(3)
|
||||
|
||||
assert result == ["task1", "task2"]
|
||||
assert mock_redis.rpop.call_count == 3
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_pull_tasks_zero_count(self, mock_redis, sample_queue):
|
||||
"""Test pulling tasks with zero count returns empty list."""
|
||||
result = sample_queue.pull_tasks(0)
|
||||
|
||||
assert result == []
|
||||
mock_redis.rpop.assert_not_called()
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_pull_tasks_negative_count(self, mock_redis, sample_queue):
|
||||
"""Test pulling tasks with negative count returns empty list."""
|
||||
result = sample_queue.pull_tasks(-1)
|
||||
|
||||
assert result == []
|
||||
mock_redis.rpop.assert_not_called()
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_pull_tasks_with_wrapped_objects(self, mock_redis, sample_queue):
|
||||
"""Test pulling tasks that include wrapped objects."""
|
||||
# Create a wrapped task
|
||||
task_data = {"task_id": 123, "data": "test"}
|
||||
wrapper = TaskWrapper(data=task_data)
|
||||
wrapped_task = wrapper.serialize()
|
||||
|
||||
mock_redis.rpop.side_effect = [
|
||||
"string_task",
|
||||
wrapped_task.encode("utf-8"), # Simulate bytes from Redis
|
||||
None,
|
||||
]
|
||||
|
||||
result = sample_queue.pull_tasks(2)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0] == "string_task"
|
||||
assert result[1] == {"task_id": 123, "data": "test"}
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_pull_tasks_with_invalid_wrapped_data(self, mock_redis, sample_queue):
|
||||
"""Test pulling tasks with invalid JSON falls back to string."""
|
||||
# Invalid JSON string that cannot be deserialized
|
||||
invalid_json = "invalid json data"
|
||||
mock_redis.rpop.side_effect = [invalid_json, None]
|
||||
|
||||
result = sample_queue.pull_tasks(1)
|
||||
|
||||
assert result == [invalid_json]
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_pull_tasks_bytes_decoding(self, mock_redis, sample_queue):
|
||||
"""Test pulling tasks handles bytes from Redis correctly."""
|
||||
mock_redis.rpop.side_effect = [
|
||||
b"task1", # bytes
|
||||
"task2", # string
|
||||
None,
|
||||
]
|
||||
|
||||
result = sample_queue.pull_tasks(2)
|
||||
|
||||
assert result == ["task1", "task2"]
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_complex_object_serialization_roundtrip(self, mock_redis, sample_queue):
|
||||
"""Test complex object serialization and deserialization roundtrip."""
|
||||
complex_task = {
|
||||
"id": uuid4().hex,
|
||||
"data": {"nested": {"deep": [1, 2, 3], "unicode": "测试中文", "special": "!@#$%^&*()"}},
|
||||
"metadata": {"created_at": "2024-01-01T00:00:00Z", "tags": ["tag1", "tag2", "tag3"]},
|
||||
}
|
||||
|
||||
# Push the complex task
|
||||
sample_queue.push_tasks([complex_task])
|
||||
|
||||
# Verify it was serialized as TaskWrapper JSON
|
||||
call_args = mock_redis.lpush.call_args
|
||||
wrapped_task = call_args[0][1]
|
||||
# Verify it's a valid TaskWrapper JSON (starts with {"data":)
|
||||
assert wrapped_task.startswith('{"data":')
|
||||
|
||||
# Verify it can be deserialized
|
||||
wrapper = TaskWrapper.deserialize(wrapped_task)
|
||||
assert wrapper.data == complex_task
|
||||
|
||||
# Simulate pulling it back
|
||||
mock_redis.rpop.return_value = wrapped_task
|
||||
result = sample_queue.pull_tasks(1)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0] == complex_task
|
||||
@ -111,3 +111,26 @@ class TestVariablePoolGetAndNestedAttribute:
|
||||
assert segment_false is not None
|
||||
assert isinstance(segment_false, BooleanSegment)
|
||||
assert segment_false.value is False
|
||||
|
||||
|
||||
class TestVariablePoolGetNotModifyVariableDictionary:
|
||||
_NODE_ID = "start"
|
||||
_VAR_NAME = "name"
|
||||
|
||||
def test_convert_to_template_should_not_introduce_extra_keys(self):
|
||||
pool = VariablePool.empty()
|
||||
pool.add([self._NODE_ID, self._VAR_NAME], 0)
|
||||
pool.convert_template("The start.name is {{#start.name#}}")
|
||||
assert "The start" not in pool.variable_dictionary
|
||||
|
||||
def test_get_should_not_modify_variable_dictionary(self):
|
||||
pool = VariablePool.empty()
|
||||
pool.get([self._NODE_ID, self._VAR_NAME])
|
||||
assert len(pool.variable_dictionary) == 1 # only contains `sys` node id
|
||||
assert "start" not in pool.variable_dictionary
|
||||
|
||||
pool = VariablePool.empty()
|
||||
pool.add([self._NODE_ID, self._VAR_NAME], "Joe")
|
||||
pool.get([self._NODE_ID, "count"])
|
||||
start_subdict = pool.variable_dictionary[self._NODE_ID]
|
||||
assert "count" not in start_subdict
|
||||
|
||||
@ -0,0 +1,317 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from core.entities.document_task import DocumentTask
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from services.document_indexing_task_proxy import DocumentIndexingTaskProxy
|
||||
|
||||
|
||||
class DocumentIndexingTaskProxyTestDataFactory:
|
||||
"""Factory class for creating test data and mock objects for DocumentIndexingTaskProxy tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_mock_features(billing_enabled: bool = False, plan: CloudPlan = CloudPlan.SANDBOX) -> Mock:
|
||||
"""Create mock features with billing configuration."""
|
||||
features = Mock()
|
||||
features.billing = Mock()
|
||||
features.billing.enabled = billing_enabled
|
||||
features.billing.subscription = Mock()
|
||||
features.billing.subscription.plan = plan
|
||||
return features
|
||||
|
||||
@staticmethod
|
||||
def create_mock_tenant_queue(has_task_key: bool = False) -> Mock:
|
||||
"""Create mock TenantIsolatedTaskQueue."""
|
||||
queue = Mock(spec=TenantIsolatedTaskQueue)
|
||||
queue.get_task_key.return_value = "task_key" if has_task_key else None
|
||||
queue.push_tasks = Mock()
|
||||
queue.set_task_waiting_time = Mock()
|
||||
return queue
|
||||
|
||||
@staticmethod
|
||||
def create_document_task_proxy(
|
||||
tenant_id: str = "tenant-123", dataset_id: str = "dataset-456", document_ids: list[str] | None = None
|
||||
) -> DocumentIndexingTaskProxy:
|
||||
"""Create DocumentIndexingTaskProxy instance for testing."""
|
||||
if document_ids is None:
|
||||
document_ids = ["doc-1", "doc-2", "doc-3"]
|
||||
return DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||
|
||||
|
||||
class TestDocumentIndexingTaskProxy:
|
||||
"""Test cases for DocumentIndexingTaskProxy class."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test DocumentIndexingTaskProxy initialization."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
dataset_id = "dataset-456"
|
||||
document_ids = ["doc-1", "doc-2", "doc-3"]
|
||||
|
||||
# Act
|
||||
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
assert proxy._tenant_id == tenant_id
|
||||
assert proxy._dataset_id == dataset_id
|
||||
assert proxy._document_ids == document_ids
|
||||
assert isinstance(proxy._tenant_isolated_task_queue, TenantIsolatedTaskQueue)
|
||||
assert proxy._tenant_isolated_task_queue._tenant_id == tenant_id
|
||||
assert proxy._tenant_isolated_task_queue._unique_key == "document_indexing"
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
def test_features_property(self, mock_feature_service):
|
||||
"""Test cached_property features."""
|
||||
# Arrange
|
||||
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features()
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
|
||||
# Act
|
||||
features1 = proxy.features
|
||||
features2 = proxy.features # Second call should use cached property
|
||||
|
||||
# Assert
|
||||
assert features1 == mock_features
|
||||
assert features2 == mock_features
|
||||
assert features1 is features2 # Should be the same instance due to caching
|
||||
mock_feature_service.get_features.assert_called_once_with("tenant-123")
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_direct_queue(self, mock_task):
|
||||
"""Test _send_to_direct_queue method."""
|
||||
# Arrange
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
mock_task.delay = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_direct_queue(mock_task)
|
||||
|
||||
# Assert
|
||||
mock_task.delay.assert_called_once_with(
|
||||
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
|
||||
)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task):
|
||||
"""Test _send_to_tenant_queue when task key exists."""
|
||||
# Arrange
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._tenant_isolated_task_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue(
|
||||
has_task_key=True
|
||||
)
|
||||
mock_task.delay = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_tenant_queue(mock_task)
|
||||
|
||||
# Assert
|
||||
proxy._tenant_isolated_task_queue.push_tasks.assert_called_once()
|
||||
pushed_tasks = proxy._tenant_isolated_task_queue.push_tasks.call_args[0][0]
|
||||
assert len(pushed_tasks) == 1
|
||||
assert isinstance(DocumentTask(**pushed_tasks[0]), DocumentTask)
|
||||
assert pushed_tasks[0]["tenant_id"] == "tenant-123"
|
||||
assert pushed_tasks[0]["dataset_id"] == "dataset-456"
|
||||
assert pushed_tasks[0]["document_ids"] == ["doc-1", "doc-2", "doc-3"]
|
||||
mock_task.delay.assert_not_called()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_tenant_queue_without_task_key(self, mock_task):
|
||||
"""Test _send_to_tenant_queue when no task key exists."""
|
||||
# Arrange
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._tenant_isolated_task_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue(
|
||||
has_task_key=False
|
||||
)
|
||||
mock_task.delay = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_tenant_queue(mock_task)
|
||||
|
||||
# Assert
|
||||
proxy._tenant_isolated_task_queue.set_task_waiting_time.assert_called_once()
|
||||
mock_task.delay.assert_called_once_with(
|
||||
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
|
||||
)
|
||||
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_default_tenant_queue(self, mock_task):
|
||||
"""Test _send_to_default_tenant_queue method."""
|
||||
# Arrange
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._send_to_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_default_tenant_queue()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
|
||||
def test_send_to_priority_tenant_queue(self, mock_task):
|
||||
"""Test _send_to_priority_tenant_queue method."""
|
||||
# Arrange
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._send_to_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_priority_tenant_queue()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
|
||||
def test_send_to_priority_direct_queue(self, mock_task):
|
||||
"""Test _send_to_priority_direct_queue method."""
|
||||
# Arrange
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._send_to_direct_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_priority_direct_queue()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_direct_queue.assert_called_once_with(mock_task)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service):
|
||||
"""Test _dispatch method when billing is enabled with sandbox plan."""
|
||||
# Arrange
|
||||
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(
|
||||
billing_enabled=True, plan=CloudPlan.SANDBOX
|
||||
)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._send_to_default_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_default_tenant_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
def test_dispatch_with_billing_enabled_non_sandbox_plan(self, mock_feature_service):
|
||||
"""Test _dispatch method when billing is enabled with non-sandbox plan."""
|
||||
# Arrange
|
||||
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(
|
||||
billing_enabled=True, plan=CloudPlan.TEAM
|
||||
)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._send_to_priority_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# If billing enabled with non sandbox plan, should send to priority tenant queue
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
def test_dispatch_with_billing_disabled(self, mock_feature_service):
|
||||
"""Test _dispatch method when billing is disabled."""
|
||||
# Arrange
|
||||
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=False)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._send_to_priority_direct_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# If billing disabled, for example: self-hosted or enterprise, should send to priority direct queue
|
||||
proxy._send_to_priority_direct_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
def test_delay_method(self, mock_feature_service):
|
||||
"""Test delay method integration."""
|
||||
# Arrange
|
||||
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(
|
||||
billing_enabled=True, plan=CloudPlan.SANDBOX
|
||||
)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._send_to_default_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy.delay()
|
||||
|
||||
# Assert
|
||||
# If billing enabled with sandbox plan, should send to default tenant queue
|
||||
proxy._send_to_default_tenant_queue.assert_called_once()
|
||||
|
||||
def test_document_task_dataclass(self):
|
||||
"""Test DocumentTask dataclass."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
dataset_id = "dataset-456"
|
||||
document_ids = ["doc-1", "doc-2"]
|
||||
|
||||
# Act
|
||||
task = DocumentTask(tenant_id=tenant_id, dataset_id=dataset_id, document_ids=document_ids)
|
||||
|
||||
# Assert
|
||||
assert task.tenant_id == tenant_id
|
||||
assert task.dataset_id == dataset_id
|
||||
assert task.document_ids == document_ids
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
def test_dispatch_edge_case_empty_plan(self, mock_feature_service):
|
||||
"""Test _dispatch method with empty plan string."""
|
||||
# Arrange
|
||||
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan="")
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._send_to_priority_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
def test_dispatch_edge_case_none_plan(self, mock_feature_service):
|
||||
"""Test _dispatch method with None plan."""
|
||||
# Arrange
|
||||
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan=None)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._send_to_priority_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||
|
||||
def test_initialization_with_empty_document_ids(self):
|
||||
"""Test initialization with empty document_ids list."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
dataset_id = "dataset-456"
|
||||
document_ids = []
|
||||
|
||||
# Act
|
||||
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
assert proxy._tenant_id == tenant_id
|
||||
assert proxy._dataset_id == dataset_id
|
||||
assert proxy._document_ids == document_ids
|
||||
|
||||
def test_initialization_with_single_document_id(self):
|
||||
"""Test initialization with single document_id."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
dataset_id = "dataset-456"
|
||||
document_ids = ["doc-1"]
|
||||
|
||||
# Act
|
||||
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
assert proxy._tenant_id == tenant_id
|
||||
assert proxy._dataset_id == dataset_id
|
||||
assert proxy._document_ids == document_ids
|
||||
483
api/tests/unit_tests/services/test_rag_pipeline_task_proxy.py
Normal file
483
api/tests/unit_tests/services/test_rag_pipeline_task_proxy.py
Normal file
@ -0,0 +1,483 @@
|
||||
import json
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from services.rag_pipeline.rag_pipeline_task_proxy import RagPipelineTaskProxy
|
||||
|
||||
|
||||
class RagPipelineTaskProxyTestDataFactory:
|
||||
"""Factory class for creating test data and mock objects for RagPipelineTaskProxy tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_mock_features(billing_enabled: bool = False, plan: CloudPlan = CloudPlan.SANDBOX) -> Mock:
|
||||
"""Create mock features with billing configuration."""
|
||||
features = Mock()
|
||||
features.billing = Mock()
|
||||
features.billing.enabled = billing_enabled
|
||||
features.billing.subscription = Mock()
|
||||
features.billing.subscription.plan = plan
|
||||
return features
|
||||
|
||||
@staticmethod
|
||||
def create_mock_tenant_queue(has_task_key: bool = False) -> Mock:
|
||||
"""Create mock TenantIsolatedTaskQueue."""
|
||||
queue = Mock(spec=TenantIsolatedTaskQueue)
|
||||
queue.get_task_key.return_value = "task_key" if has_task_key else None
|
||||
queue.push_tasks = Mock()
|
||||
queue.set_task_waiting_time = Mock()
|
||||
return queue
|
||||
|
||||
@staticmethod
|
||||
def create_rag_pipeline_invoke_entity(
|
||||
pipeline_id: str = "pipeline-123",
|
||||
user_id: str = "user-456",
|
||||
tenant_id: str = "tenant-789",
|
||||
workflow_id: str = "workflow-101",
|
||||
streaming: bool = True,
|
||||
workflow_execution_id: str | None = None,
|
||||
workflow_thread_pool_id: str | None = None,
|
||||
) -> RagPipelineInvokeEntity:
|
||||
"""Create RagPipelineInvokeEntity instance for testing."""
|
||||
return RagPipelineInvokeEntity(
|
||||
pipeline_id=pipeline_id,
|
||||
application_generate_entity={"key": "value"},
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
workflow_id=workflow_id,
|
||||
streaming=streaming,
|
||||
workflow_execution_id=workflow_execution_id,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_rag_pipeline_task_proxy(
|
||||
dataset_tenant_id: str = "tenant-123",
|
||||
user_id: str = "user-456",
|
||||
rag_pipeline_invoke_entities: list[RagPipelineInvokeEntity] | None = None,
|
||||
) -> RagPipelineTaskProxy:
|
||||
"""Create RagPipelineTaskProxy instance for testing."""
|
||||
if rag_pipeline_invoke_entities is None:
|
||||
rag_pipeline_invoke_entities = [RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity()]
|
||||
return RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities)
|
||||
|
||||
@staticmethod
|
||||
def create_mock_upload_file(file_id: str = "file-123") -> Mock:
|
||||
"""Create mock upload file."""
|
||||
upload_file = Mock()
|
||||
upload_file.id = file_id
|
||||
return upload_file
|
||||
|
||||
|
||||
class TestRagPipelineTaskProxy:
|
||||
"""Test cases for RagPipelineTaskProxy class."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test RagPipelineTaskProxy initialization."""
|
||||
# Arrange
|
||||
dataset_tenant_id = "tenant-123"
|
||||
user_id = "user-456"
|
||||
rag_pipeline_invoke_entities = [RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity()]
|
||||
|
||||
# Act
|
||||
proxy = RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities)
|
||||
|
||||
# Assert
|
||||
assert proxy._dataset_tenant_id == dataset_tenant_id
|
||||
assert proxy._user_id == user_id
|
||||
assert proxy._rag_pipeline_invoke_entities == rag_pipeline_invoke_entities
|
||||
assert isinstance(proxy._tenant_isolated_task_queue, TenantIsolatedTaskQueue)
|
||||
assert proxy._tenant_isolated_task_queue._tenant_id == dataset_tenant_id
|
||||
assert proxy._tenant_isolated_task_queue._unique_key == "pipeline"
|
||||
|
||||
def test_initialization_with_empty_entities(self):
|
||||
"""Test initialization with empty rag_pipeline_invoke_entities."""
|
||||
# Arrange
|
||||
dataset_tenant_id = "tenant-123"
|
||||
user_id = "user-456"
|
||||
rag_pipeline_invoke_entities = []
|
||||
|
||||
# Act
|
||||
proxy = RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities)
|
||||
|
||||
# Assert
|
||||
assert proxy._dataset_tenant_id == dataset_tenant_id
|
||||
assert proxy._user_id == user_id
|
||||
assert proxy._rag_pipeline_invoke_entities == []
|
||||
|
||||
def test_initialization_with_multiple_entities(self):
|
||||
"""Test initialization with multiple rag_pipeline_invoke_entities."""
|
||||
# Arrange
|
||||
dataset_tenant_id = "tenant-123"
|
||||
user_id = "user-456"
|
||||
rag_pipeline_invoke_entities = [
|
||||
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-1"),
|
||||
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-2"),
|
||||
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-3"),
|
||||
]
|
||||
|
||||
# Act
|
||||
proxy = RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities)
|
||||
|
||||
# Assert
|
||||
assert len(proxy._rag_pipeline_invoke_entities) == 3
|
||||
assert proxy._rag_pipeline_invoke_entities[0].pipeline_id == "pipeline-1"
|
||||
assert proxy._rag_pipeline_invoke_entities[1].pipeline_id == "pipeline-2"
|
||||
assert proxy._rag_pipeline_invoke_entities[2].pipeline_id == "pipeline-3"
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
|
||||
def test_features_property(self, mock_feature_service):
|
||||
"""Test cached_property features."""
|
||||
# Arrange
|
||||
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features()
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
|
||||
# Act
|
||||
features1 = proxy.features
|
||||
features2 = proxy.features # Second call should use cached property
|
||||
|
||||
# Assert
|
||||
assert features1 == mock_features
|
||||
assert features2 == mock_features
|
||||
assert features1 is features2 # Should be the same instance due to caching
|
||||
mock_feature_service.get_features.assert_called_once_with("tenant-123")
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_upload_invoke_entities(self, mock_db, mock_file_service_class):
|
||||
"""Test _upload_invoke_entities method."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
mock_file_service = Mock()
|
||||
mock_file_service_class.return_value = mock_file_service
|
||||
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
|
||||
mock_file_service.upload_text.return_value = mock_upload_file
|
||||
|
||||
# Act
|
||||
result = proxy._upload_invoke_entities()
|
||||
|
||||
# Assert
|
||||
assert result == "file-123"
|
||||
mock_file_service_class.assert_called_once_with(mock_db.engine)
|
||||
|
||||
# Verify upload_text was called with correct parameters
|
||||
mock_file_service.upload_text.assert_called_once()
|
||||
call_args = mock_file_service.upload_text.call_args
|
||||
json_text, name, user_id, tenant_id = call_args[0]
|
||||
|
||||
assert name == "rag_pipeline_invoke_entities.json"
|
||||
assert user_id == "user-456"
|
||||
assert tenant_id == "tenant-123"
|
||||
|
||||
# Verify JSON content
|
||||
parsed_json = json.loads(json_text)
|
||||
assert len(parsed_json) == 1
|
||||
assert parsed_json[0]["pipeline_id"] == "pipeline-123"
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_upload_invoke_entities_with_multiple_entities(self, mock_db, mock_file_service_class):
|
||||
"""Test _upload_invoke_entities method with multiple entities."""
|
||||
# Arrange
|
||||
entities = [
|
||||
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-1"),
|
||||
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-2"),
|
||||
]
|
||||
proxy = RagPipelineTaskProxy("tenant-123", "user-456", entities)
|
||||
mock_file_service = Mock()
|
||||
mock_file_service_class.return_value = mock_file_service
|
||||
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-456")
|
||||
mock_file_service.upload_text.return_value = mock_upload_file
|
||||
|
||||
# Act
|
||||
result = proxy._upload_invoke_entities()
|
||||
|
||||
# Assert
|
||||
assert result == "file-456"
|
||||
|
||||
# Verify JSON content contains both entities
|
||||
call_args = mock_file_service.upload_text.call_args
|
||||
json_text = call_args[0][0]
|
||||
parsed_json = json.loads(json_text)
|
||||
assert len(parsed_json) == 2
|
||||
assert parsed_json[0]["pipeline_id"] == "pipeline-1"
|
||||
assert parsed_json[1]["pipeline_id"] == "pipeline-2"
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task")
|
||||
def test_send_to_direct_queue(self, mock_task):
|
||||
"""Test _send_to_direct_queue method."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._tenant_isolated_task_queue = RagPipelineTaskProxyTestDataFactory.create_mock_tenant_queue()
|
||||
upload_file_id = "file-123"
|
||||
mock_task.delay = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_direct_queue(upload_file_id, mock_task)
|
||||
|
||||
# If sent to direct queue, tenant_isolated_task_queue should not be called
|
||||
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
|
||||
|
||||
# Celery should be called directly
|
||||
mock_task.delay.assert_called_once_with(
|
||||
rag_pipeline_invoke_entities_file_id=upload_file_id, tenant_id="tenant-123"
|
||||
)
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task")
|
||||
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task):
|
||||
"""Test _send_to_tenant_queue when task key exists."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._tenant_isolated_task_queue = RagPipelineTaskProxyTestDataFactory.create_mock_tenant_queue(
|
||||
has_task_key=True
|
||||
)
|
||||
upload_file_id = "file-123"
|
||||
mock_task.delay = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_tenant_queue(upload_file_id, mock_task)
|
||||
|
||||
# If task key exists, should push tasks to the queue
|
||||
proxy._tenant_isolated_task_queue.push_tasks.assert_called_once_with([upload_file_id])
|
||||
# Celery should not be called directly
|
||||
mock_task.delay.assert_not_called()
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task")
|
||||
def test_send_to_tenant_queue_without_task_key(self, mock_task):
|
||||
"""Test _send_to_tenant_queue when no task key exists."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._tenant_isolated_task_queue = RagPipelineTaskProxyTestDataFactory.create_mock_tenant_queue(
|
||||
has_task_key=False
|
||||
)
|
||||
upload_file_id = "file-123"
|
||||
mock_task.delay = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_tenant_queue(upload_file_id, mock_task)
|
||||
|
||||
# If no task key, should set task waiting time key first
|
||||
proxy._tenant_isolated_task_queue.set_task_waiting_time.assert_called_once()
|
||||
mock_task.delay.assert_called_once_with(
|
||||
rag_pipeline_invoke_entities_file_id=upload_file_id, tenant_id="tenant-123"
|
||||
)
|
||||
|
||||
# The first task should be sent to celery directly, so push tasks should not be called
|
||||
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task")
|
||||
def test_send_to_default_tenant_queue(self, mock_task):
|
||||
"""Test _send_to_default_tenant_queue method."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._send_to_tenant_queue = Mock()
|
||||
upload_file_id = "file-123"
|
||||
|
||||
# Act
|
||||
proxy._send_to_default_tenant_queue(upload_file_id)
|
||||
|
||||
# Assert
|
||||
proxy._send_to_tenant_queue.assert_called_once_with(upload_file_id, mock_task)
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.priority_rag_pipeline_run_task")
|
||||
def test_send_to_priority_tenant_queue(self, mock_task):
|
||||
"""Test _send_to_priority_tenant_queue method."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._send_to_tenant_queue = Mock()
|
||||
upload_file_id = "file-123"
|
||||
|
||||
# Act
|
||||
proxy._send_to_priority_tenant_queue(upload_file_id)
|
||||
|
||||
# Assert
|
||||
proxy._send_to_tenant_queue.assert_called_once_with(upload_file_id, mock_task)
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.priority_rag_pipeline_run_task")
|
||||
def test_send_to_priority_direct_queue(self, mock_task):
|
||||
"""Test _send_to_priority_direct_queue method."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._send_to_direct_queue = Mock()
|
||||
upload_file_id = "file-123"
|
||||
|
||||
# Act
|
||||
proxy._send_to_priority_direct_queue(upload_file_id)
|
||||
|
||||
# Assert
|
||||
proxy._send_to_direct_queue.assert_called_once_with(upload_file_id, mock_task)
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_db, mock_file_service_class, mock_feature_service):
|
||||
"""Test _dispatch method when billing is enabled with sandbox plan."""
|
||||
# Arrange
|
||||
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(
|
||||
billing_enabled=True, plan=CloudPlan.SANDBOX
|
||||
)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._send_to_default_tenant_queue = Mock()
|
||||
|
||||
mock_file_service = Mock()
|
||||
mock_file_service_class.return_value = mock_file_service
|
||||
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
|
||||
mock_file_service.upload_text.return_value = mock_upload_file
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# If billing is enabled with sandbox plan, should send to default tenant queue
|
||||
proxy._send_to_default_tenant_queue.assert_called_once_with("file-123")
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_dispatch_with_billing_enabled_non_sandbox_plan(
|
||||
self, mock_db, mock_file_service_class, mock_feature_service
|
||||
):
|
||||
"""Test _dispatch method when billing is enabled with non-sandbox plan."""
|
||||
# Arrange
|
||||
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(
|
||||
billing_enabled=True, plan=CloudPlan.TEAM
|
||||
)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._send_to_priority_tenant_queue = Mock()
|
||||
|
||||
mock_file_service = Mock()
|
||||
mock_file_service_class.return_value = mock_file_service
|
||||
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
|
||||
mock_file_service.upload_text.return_value = mock_upload_file
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# If billing is enabled with non-sandbox plan, should send to priority tenant queue
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once_with("file-123")
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_dispatch_with_billing_disabled(self, mock_db, mock_file_service_class, mock_feature_service):
|
||||
"""Test _dispatch method when billing is disabled."""
|
||||
# Arrange
|
||||
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(billing_enabled=False)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._send_to_priority_direct_queue = Mock()
|
||||
|
||||
mock_file_service = Mock()
|
||||
mock_file_service_class.return_value = mock_file_service
|
||||
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
|
||||
mock_file_service.upload_text.return_value = mock_upload_file
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# If billing is disabled, for example: self-hosted or enterprise, should send to priority direct queue
|
||||
proxy._send_to_priority_direct_queue.assert_called_once_with("file-123")
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_dispatch_with_empty_upload_file_id(self, mock_db, mock_file_service_class):
|
||||
"""Test _dispatch method when upload_file_id is empty."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
|
||||
mock_file_service = Mock()
|
||||
mock_file_service_class.return_value = mock_file_service
|
||||
mock_upload_file = Mock()
|
||||
mock_upload_file.id = "" # Empty file ID
|
||||
mock_file_service.upload_text.return_value = mock_upload_file
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="upload_file_id is empty"):
|
||||
proxy._dispatch()
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_dispatch_edge_case_empty_plan(self, mock_db, mock_file_service_class, mock_feature_service):
|
||||
"""Test _dispatch method with empty plan string."""
|
||||
# Arrange
|
||||
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan="")
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._send_to_priority_tenant_queue = Mock()
|
||||
|
||||
mock_file_service = Mock()
|
||||
mock_file_service_class.return_value = mock_file_service
|
||||
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
|
||||
mock_file_service.upload_text.return_value = mock_upload_file
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once_with("file-123")
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_dispatch_edge_case_none_plan(self, mock_db, mock_file_service_class, mock_feature_service):
|
||||
"""Test _dispatch method with None plan."""
|
||||
# Arrange
|
||||
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan=None)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._send_to_priority_tenant_queue = Mock()
|
||||
|
||||
mock_file_service = Mock()
|
||||
mock_file_service_class.return_value = mock_file_service
|
||||
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
|
||||
mock_file_service.upload_text.return_value = mock_upload_file
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once_with("file-123")
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_delay_method(self, mock_db, mock_file_service_class, mock_feature_service):
|
||||
"""Test delay method integration."""
|
||||
# Arrange
|
||||
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(
|
||||
billing_enabled=True, plan=CloudPlan.SANDBOX
|
||||
)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._dispatch = Mock()
|
||||
|
||||
mock_file_service = Mock()
|
||||
mock_file_service_class.return_value = mock_file_service
|
||||
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
|
||||
mock_file_service.upload_text.return_value = mock_upload_file
|
||||
|
||||
# Act
|
||||
proxy.delay()
|
||||
|
||||
# Assert
|
||||
proxy._dispatch.assert_called_once()
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.logger")
|
||||
def test_delay_method_with_empty_entities(self, mock_logger):
|
||||
"""Test delay method with empty rag_pipeline_invoke_entities."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxy("tenant-123", "user-456", [])
|
||||
|
||||
# Act
|
||||
proxy.delay()
|
||||
|
||||
# Assert
|
||||
mock_logger.warning.assert_called_once_with(
|
||||
"Received empty rag pipeline invoke entities, no tasks delivered: %s %s", "tenant-123", "user-456"
|
||||
)
|
||||
4597
api/uv.lock
generated
4597
api/uv.lock
generated
File diff suppressed because it is too large
Load Diff
@ -7,4 +7,4 @@ cd "$SCRIPT_DIR/.."
|
||||
|
||||
uv --directory api run \
|
||||
celery -A app.celery worker \
|
||||
-P gevent -c 1 --loglevel INFO -Q dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline
|
||||
-P gevent -c 1 --loglevel INFO -Q dataset,priority_dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline
|
||||
|
||||
@ -1370,3 +1370,6 @@ ENABLE_CLEAN_MESSAGES=false
|
||||
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false
|
||||
ENABLE_DATASETS_QUEUE_MONITOR=false
|
||||
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true
|
||||
|
||||
# Tenant isolated task queue configuration
|
||||
TENANT_ISOLATED_TASK_CONCURRENCY=1
|
||||
|
||||
@ -614,6 +614,7 @@ x-shared-env: &shared-api-worker-env
|
||||
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: ${ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK:-false}
|
||||
ENABLE_DATASETS_QUEUE_MONITOR: ${ENABLE_DATASETS_QUEUE_MONITOR:-false}
|
||||
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK: ${ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK:-true}
|
||||
TENANT_ISOLATED_TASK_CONCURRENCY: ${TENANT_ISOLATED_TASK_CONCURRENCY:-1}
|
||||
|
||||
services:
|
||||
# API service
|
||||
|
||||
@ -74,7 +74,8 @@ Chat applications support session persistence, allowing previous chat history to
|
||||
If set to `false`, can achieve async title generation by calling the conversation rename API and setting `auto_generate` to `true`.
|
||||
</Property>
|
||||
<Property name='workflow_id' type='string' key='workflow_id'>
|
||||
(Optional) Workflow ID to specify a specific version, if not provided, uses the default published version.
|
||||
(Optional) Workflow ID to specify a specific version, if not provided, uses the default published version.<br/>
|
||||
How to obtain: In the version history interface, click the copy icon on the right side of each version entry to copy the complete workflow ID.
|
||||
</Property>
|
||||
<Property name='trace_id' type='string' key='trace_id'>
|
||||
(Optional) Trace ID. Used for integration with existing business trace components to achieve end-to-end distributed tracing. If not provided, the system will automatically generate a trace_id. Supports the following three ways to pass, in order of priority:<br/>
|
||||
|
||||
@ -74,7 +74,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
||||
`false`に設定すると、会話のリネームAPIを呼び出し、`auto_generate`を`true`に設定することで非同期タイトル生成を実現できます。
|
||||
</Property>
|
||||
<Property name='workflow_id' type='string' key='workflow_id'>
|
||||
(オプション)ワークフローID、特定のバージョンを指定するために使用、提供されない場合はデフォルトの公開バージョンを使用。
|
||||
(オプション)ワークフローID、特定のバージョンを指定するために使用、提供されない場合はデフォルトの公開バージョンを使用。<br/>
|
||||
取得方法:バージョン履歴インターフェースで、各バージョンエントリの右側にあるコピーアイコンをクリックすると、完全なワークフローIDをコピーできます。
|
||||
</Property>
|
||||
<Property name='trace_id' type='string' key='trace_id'>
|
||||
(オプション)トレースID。既存の業務システムのトレースコンポーネントと連携し、エンドツーエンドの分散トレーシングを実現するために使用します。指定がない場合、システムが自動的に trace_id を生成します。以下の3つの方法で渡すことができ、優先順位は次のとおりです:<br/>
|
||||
|
||||
@ -72,7 +72,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
|
||||
(选填)自动生成标题,默认 `true`。 若设置为 `false`,则可通过调用会话重命名接口并设置 `auto_generate` 为 `true` 实现异步生成标题。
|
||||
</Property>
|
||||
<Property name='workflow_id' type='string' key='workflow_id'>
|
||||
(选填)工作流ID,用于指定特定版本,如果不提供则使用默认的已发布版本。
|
||||
(选填)工作流ID,用于指定特定版本,如果不提供则使用默认的已发布版本。<br/>
|
||||
获取方式:在版本历史界面,点击每个版本条目右侧的复制图标即可复制完整的工作流 ID。
|
||||
</Property>
|
||||
<Property name='trace_id' type='string' key='trace_id'>
|
||||
(选填)链路追踪ID。适用于与业务系统已有的trace组件打通,实现端到端分布式追踪等场景。如果未指定,系统会自动生成<code>trace_id</code>。支持以下三种方式传递,具体优先级依次为:<br/>
|
||||
|
||||
@ -344,7 +344,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
||||
### パス
|
||||
- `workflow_id` (string) 必須 特定バージョンのワークフローを指定するためのワークフローID
|
||||
|
||||
取得方法:バージョン履歴で特定バージョンのワークフローIDを照会できます。
|
||||
取得方法:バージョン履歴インターフェースで、各バージョンエントリの右側にあるコピーアイコンをクリックすると、完全なワークフローIDをコピーできます。
|
||||
|
||||
### リクエストボディ
|
||||
- `inputs` (object) 必須
|
||||
|
||||
@ -334,7 +334,7 @@ Workflow 应用无会话支持,适合用于翻译/文章写作/总结 AI 等
|
||||
### Path
|
||||
- `workflow_id` (string) Required 工作流ID,用于指定特定版本的工作流
|
||||
|
||||
获取方式:可以在版本历史中查询特定版本的工作流ID。
|
||||
获取方式:在版本历史界面,点击每个版本条目右侧的复制图标即可复制完整的工作流 ID。
|
||||
|
||||
### Request Body
|
||||
- `inputs` (object) Required
|
||||
|
||||
@ -86,7 +86,7 @@ const ModelList: FC<ModelListProps> = ({
|
||||
{
|
||||
models.map(model => (
|
||||
<ModelListItem
|
||||
key={`${model.model}-${model.fetch_from}`}
|
||||
key={`${model.model}-${model.model_type}-${model.fetch_from}`}
|
||||
{...{
|
||||
model,
|
||||
provider,
|
||||
|
||||
@ -856,6 +856,18 @@
|
||||
color: var(--color-prettylights-syntax-comment);
|
||||
}
|
||||
|
||||
.markdown-body .katex {
|
||||
/* Allow long inline formulas to wrap instead of overflowing */
|
||||
white-space: normal !important;
|
||||
overflow-wrap: break-word; /* better cross-browser support */
|
||||
word-break: break-word; /* non-standard fallback for older WebKit/Blink */
|
||||
}
|
||||
|
||||
.markdown-body .katex-display {
|
||||
/* Fallback for very long display equations */
|
||||
overflow-x: auto;
|
||||
}
|
||||
|
||||
.markdown-body .pl-c1,
|
||||
.markdown-body .pl-s .pl-v {
|
||||
color: var(--color-prettylights-syntax-constant);
|
||||
|
||||
@ -10,6 +10,7 @@ import MaintenanceNotice from '@/app/components/header/maintenance-notice'
|
||||
import { noop } from 'lodash-es'
|
||||
import { setZendeskConversationFields } from '@/app/components/base/zendesk/utils'
|
||||
import { ZENDESK_FIELD_IDS } from '@/config'
|
||||
import { useGlobalPublicStore } from './global-public-context'
|
||||
|
||||
export type AppContextValue = {
|
||||
userProfile: UserProfileResponse
|
||||
@ -77,6 +78,7 @@ export type AppContextProviderProps = {
|
||||
}
|
||||
|
||||
export const AppContextProvider: FC<AppContextProviderProps> = ({ children }) => {
|
||||
const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
|
||||
const { data: userProfileResponse, mutate: mutateUserProfile, error: userProfileError } = useSWR({ url: '/account/profile', params: {} }, fetchUserProfile)
|
||||
const { data: currentWorkspaceResponse, mutate: mutateCurrentWorkspace, isLoading: isLoadingCurrentWorkspace } = useSWR({ url: '/workspaces/current', params: {} }, fetchCurrentWorkspace)
|
||||
|
||||
@ -92,10 +94,12 @@ export const AppContextProvider: FC<AppContextProviderProps> = ({ children }) =>
|
||||
try {
|
||||
const result = await userProfileResponse.json()
|
||||
setUserProfile(result)
|
||||
const current_version = userProfileResponse.headers.get('x-version')
|
||||
const current_env = process.env.NODE_ENV === 'development' ? 'DEVELOPMENT' : userProfileResponse.headers.get('x-env')
|
||||
const versionData = await fetchLangGeniusVersion({ url: '/version', params: { current_version } })
|
||||
setLangGeniusVersionInfo({ ...versionData, current_version, latest_version: versionData.version, current_env })
|
||||
if (!systemFeatures.branding.enabled) {
|
||||
const current_version = userProfileResponse.headers.get('x-version')
|
||||
const current_env = process.env.NODE_ENV === 'development' ? 'DEVELOPMENT' : userProfileResponse.headers.get('x-env')
|
||||
const versionData = await fetchLangGeniusVersion({ url: '/version', params: { current_version } })
|
||||
setLangGeniusVersionInfo({ ...versionData, current_version, latest_version: versionData.version, current_env })
|
||||
}
|
||||
}
|
||||
catch (error) {
|
||||
console.error('Failed to update user profile:', error)
|
||||
|
||||
@ -45,5 +45,118 @@ describe('get-icon', () => {
|
||||
const result = getIconFromMarketPlace(pluginId)
|
||||
expect(result).toBe(`${MARKETPLACE_API_PREFIX}/plugins/${pluginId}/icon`)
|
||||
})
|
||||
|
||||
/**
|
||||
* Security tests: Path traversal attempts
|
||||
* These tests document current behavior and potential security concerns
|
||||
* Note: Current implementation does not sanitize path traversal sequences
|
||||
*/
|
||||
test('handles path traversal attempts', () => {
|
||||
const pluginId = '../../../etc/passwd'
|
||||
const result = getIconFromMarketPlace(pluginId)
|
||||
// Current implementation includes path traversal sequences in URL
|
||||
// This is a potential security concern that should be addressed
|
||||
expect(result).toContain('../')
|
||||
expect(result).toContain(pluginId)
|
||||
})
|
||||
|
||||
test('handles multiple path traversal attempts', () => {
|
||||
const pluginId = '../../../../etc/passwd'
|
||||
const result = getIconFromMarketPlace(pluginId)
|
||||
// Current implementation includes path traversal sequences in URL
|
||||
expect(result).toContain('../')
|
||||
expect(result).toContain(pluginId)
|
||||
})
|
||||
|
||||
test('passes through URL-encoded path traversal sequences', () => {
|
||||
const pluginId = '..%2F..%2Fetc%2Fpasswd'
|
||||
const result = getIconFromMarketPlace(pluginId)
|
||||
expect(result).toContain(pluginId)
|
||||
})
|
||||
|
||||
/**
|
||||
* Security tests: Null and undefined handling
|
||||
* These tests document current behavior with invalid input types
|
||||
* Note: Current implementation converts null/undefined to strings instead of throwing
|
||||
*/
|
||||
test('handles null plugin ID', () => {
|
||||
// Current implementation converts null to string "null"
|
||||
const result = getIconFromMarketPlace(null as any)
|
||||
expect(result).toContain('null')
|
||||
// This is a potential issue - should validate input type
|
||||
})
|
||||
|
||||
test('handles undefined plugin ID', () => {
|
||||
// Current implementation converts undefined to string "undefined"
|
||||
const result = getIconFromMarketPlace(undefined as any)
|
||||
expect(result).toContain('undefined')
|
||||
// This is a potential issue - should validate input type
|
||||
})
|
||||
|
||||
/**
|
||||
* Security tests: URL-sensitive characters
|
||||
* These tests verify that URL-sensitive characters are handled appropriately
|
||||
*/
|
||||
test('does not encode URL-sensitive characters', () => {
|
||||
const pluginId = 'plugin/with?special=chars#hash'
|
||||
const result = getIconFromMarketPlace(pluginId)
|
||||
// Note: Current implementation doesn't encode, but test documents the behavior
|
||||
expect(result).toContain(pluginId)
|
||||
expect(result).toContain('?')
|
||||
expect(result).toContain('#')
|
||||
expect(result).toContain('=')
|
||||
})
|
||||
|
||||
test('handles URL characters like & and %', () => {
|
||||
const pluginId = 'plugin&with%encoding'
|
||||
const result = getIconFromMarketPlace(pluginId)
|
||||
expect(result).toContain(pluginId)
|
||||
})
|
||||
|
||||
/**
|
||||
* Edge case tests: Extreme inputs
|
||||
* These tests verify behavior with unusual but valid inputs
|
||||
*/
|
||||
test('handles very long plugin ID', () => {
|
||||
const pluginId = 'a'.repeat(10000)
|
||||
const result = getIconFromMarketPlace(pluginId)
|
||||
expect(result).toContain(pluginId)
|
||||
expect(result.length).toBeGreaterThan(10000)
|
||||
})
|
||||
|
||||
test('handles Unicode characters', () => {
|
||||
const pluginId = '插件-🚀-测试-日本語'
|
||||
const result = getIconFromMarketPlace(pluginId)
|
||||
expect(result).toContain(pluginId)
|
||||
})
|
||||
|
||||
test('handles control characters', () => {
|
||||
const pluginId = 'plugin\nwith\ttabs\r\nand\0null'
|
||||
const result = getIconFromMarketPlace(pluginId)
|
||||
expect(result).toContain(pluginId)
|
||||
})
|
||||
|
||||
/**
|
||||
* Security tests: XSS attempts
|
||||
* These tests verify that XSS attempts are handled appropriately
|
||||
*/
|
||||
test('handles XSS attempts with script tags', () => {
|
||||
const pluginId = '<script>alert("xss")</script>'
|
||||
const result = getIconFromMarketPlace(pluginId)
|
||||
expect(result).toContain(pluginId)
|
||||
// Note: Current implementation doesn't sanitize, but test documents the behavior
|
||||
})
|
||||
|
||||
test('handles XSS attempts with event handlers', () => {
|
||||
const pluginId = 'plugin"onerror="alert(1)"'
|
||||
const result = getIconFromMarketPlace(pluginId)
|
||||
expect(result).toContain(pluginId)
|
||||
})
|
||||
|
||||
test('handles XSS attempts with encoded script tags', () => {
|
||||
const pluginId = '%3Cscript%3Ealert%28%22xss%22%29%3C%2Fscript%3E'
|
||||
const result = getIconFromMarketPlace(pluginId)
|
||||
expect(result).toContain(pluginId)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@ -87,7 +87,8 @@ describe('time', () => {
|
||||
test('works with timestamps', () => {
|
||||
const date = 1705276800000 // 2024-01-15 00:00:00 UTC
|
||||
const result = formatTime({ date, dateFormat: 'YYYY-MM-DD' })
|
||||
expect(result).toContain('2024-01-1') // Account for timezone differences
|
||||
// Account for timezone differences: UTC-5 to UTC+8 can result in 2024-01-14 or 2024-01-15
|
||||
expect(result).toMatch(/^2024-01-(14|15)$/)
|
||||
})
|
||||
|
||||
test('handles ISO 8601 format', () => {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user