mirror of https://github.com/langgenius/dify.git
Merge remote-tracking branch 'origin/main' into feat/trigger
# Conflicts: # api/docker/entrypoint.sh # api/uv.lock # dev/start-worker # docker/.env.example # docker/docker-compose.yaml # web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chart-view.tsx # web/app/components/base/date-and-time-picker/date-picker/index.tsx # web/app/components/base/date-and-time-picker/types.ts
This commit is contained in:
commit
a94e650ffd
|
|
@ -117,7 +117,7 @@ All of Dify's offerings come with corresponding APIs, so you could effortlessly
|
|||
Use our [documentation](https://docs.dify.ai) for further references and more in-depth instructions.
|
||||
|
||||
- **Dify for enterprise / organizations<br/>**
|
||||
We provide additional enterprise-centric features. [Log your questions for us through this chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) or [send us an email](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) to discuss enterprise needs. <br/>
|
||||
We provide additional enterprise-centric features. [Send us an email](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) to discuss your enterprise needs. <br/>
|
||||
|
||||
> For startups and small businesses using AWS, check out [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one click. It's an affordable AMI offering with the option to create apps with custom logo and branding.
|
||||
|
||||
|
|
|
|||
|
|
@ -627,5 +627,8 @@ SWAGGER_UI_PATH=/swagger-ui.html
|
|||
# Set to false to export dataset IDs as plain text for easier cross-environment import
|
||||
DSL_EXPORT_ENCRYPT_DATASET_ID=true
|
||||
|
||||
# Tenant isolated task queue configuration
|
||||
TENANT_ISOLATED_TASK_CONCURRENCY=1
|
||||
|
||||
# Maximum number of segments for dataset segments API (0 for unlimited)
|
||||
DATASET_MAX_SEGMENTS_PER_REQUEST=0
|
||||
|
|
|
|||
|
|
@ -1209,6 +1209,13 @@ class SwaggerUIConfig(BaseSettings):
|
|||
)
|
||||
|
||||
|
||||
class TenantIsolatedTaskQueueConfig(BaseSettings):
|
||||
TENANT_ISOLATED_TASK_CONCURRENCY: int = Field(
|
||||
description="Number of tasks allowed to be delivered concurrently from isolated queue per tenant",
|
||||
default=1,
|
||||
)
|
||||
|
||||
|
||||
class FeatureConfig(
|
||||
# place the configs in alphabet order
|
||||
AppExecutionConfig,
|
||||
|
|
@ -1235,6 +1242,7 @@ class FeatureConfig(
|
|||
RagEtlConfig,
|
||||
RepositoryConfig,
|
||||
SecurityConfig,
|
||||
TenantIsolatedTaskQueueConfig,
|
||||
ToolConfig,
|
||||
UpdateConfig,
|
||||
WorkflowConfig,
|
||||
|
|
|
|||
|
|
@ -40,20 +40,15 @@ from core.workflow.repositories.draft_variable_repository import DraftVariableSa
|
|||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import AppMode
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from services.feature_service import FeatureService
|
||||
from services.file_service import FileService
|
||||
from services.rag_pipeline.rag_pipeline_task_proxy import RagPipelineTaskProxy
|
||||
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
|
||||
from tasks.rag_pipeline.priority_rag_pipeline_run_task import priority_rag_pipeline_run_task
|
||||
from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -249,34 +244,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
)
|
||||
|
||||
if rag_pipeline_invoke_entities:
|
||||
# store the rag_pipeline_invoke_entities to object storage
|
||||
text = [item.model_dump() for item in rag_pipeline_invoke_entities]
|
||||
name = "rag_pipeline_invoke_entities.json"
|
||||
# Convert list to proper JSON string
|
||||
json_text = json.dumps(text)
|
||||
upload_file = FileService(db.engine).upload_text(json_text, name, user.id, dataset.tenant_id)
|
||||
features = FeatureService.get_features(dataset.tenant_id)
|
||||
if features.billing.enabled and features.billing.subscription.plan == CloudPlan.SANDBOX:
|
||||
tenant_pipeline_task_key = f"tenant_pipeline_task:{dataset.tenant_id}"
|
||||
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{dataset.tenant_id}"
|
||||
|
||||
if redis_client.get(tenant_pipeline_task_key):
|
||||
# Add to waiting queue using List operations (lpush)
|
||||
redis_client.lpush(tenant_self_pipeline_task_queue, upload_file.id)
|
||||
else:
|
||||
# Set flag and execute task
|
||||
redis_client.set(tenant_pipeline_task_key, 1, ex=60 * 60)
|
||||
rag_pipeline_run_task.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=upload_file.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
)
|
||||
|
||||
else:
|
||||
priority_rag_pipeline_run_task.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=upload_file.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
)
|
||||
|
||||
RagPipelineTaskProxy(dataset.tenant_id, user.id, rag_pipeline_invoke_entities).delay()
|
||||
# return batch, dataset, documents
|
||||
return {
|
||||
"batch": batch,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,15 @@
|
|||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocumentTask:
|
||||
"""Document task entity for document indexing operations.
|
||||
|
||||
This class represents a document indexing task that can be queued
|
||||
and processed by the document indexing system.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
dataset_id: str
|
||||
document_ids: Sequence[str]
|
||||
|
|
@ -1533,6 +1533,9 @@ class ProviderConfiguration(BaseModel):
|
|||
# Return composite sort key: (model_type value, model position index)
|
||||
return (model.model_type.value, position_index)
|
||||
|
||||
# Deduplicate
|
||||
provider_models = list({(m.model, m.model_type, m.fetch_from): m for m in provider_models}.values())
|
||||
|
||||
# Sort using the composite sort key
|
||||
return sorted(provider_models, key=get_sort_key)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,21 +1,22 @@
|
|||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import traceback
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Union, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes
|
||||
from opentelemetry import trace
|
||||
from openinference.semconv.trace import OpenInferenceMimeTypeValues, OpenInferenceSpanKindValues, SpanAttributes
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GrpcOTLPSpanExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HttpOTLPSpanExporter
|
||||
from opentelemetry.sdk import trace as trace_sdk
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
||||
from opentelemetry.sdk.trace.id_generator import RandomIdGenerator
|
||||
from opentelemetry.trace import SpanContext, TraceFlags, TraceState
|
||||
from sqlalchemy import select
|
||||
from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes
|
||||
from opentelemetry.trace import Span, Status, StatusCode, set_span_in_context, use_span
|
||||
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
|
||||
from opentelemetry.util.types import AttributeValue
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig
|
||||
|
|
@ -30,9 +31,10 @@ from core.ops.entities.trace_entity import (
|
|||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser, MessageFile
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -99,22 +101,45 @@ def datetime_to_nanos(dt: datetime | None) -> int:
|
|||
return int(dt.timestamp() * 1_000_000_000)
|
||||
|
||||
|
||||
def string_to_trace_id128(string: str | None) -> int:
|
||||
"""
|
||||
Convert any input string into a stable 128-bit integer trace ID.
|
||||
def error_to_string(error: Exception | str | None) -> str:
|
||||
"""Convert an error to a string with traceback information."""
|
||||
error_message = "Empty Stack Trace"
|
||||
if error:
|
||||
if isinstance(error, Exception):
|
||||
string_stacktrace = "".join(traceback.format_exception(error))
|
||||
error_message = f"{error.__class__.__name__}: {error}\n\n{string_stacktrace}"
|
||||
else:
|
||||
error_message = str(error)
|
||||
return error_message
|
||||
|
||||
This uses SHA-256 hashing and takes the first 16 bytes (128 bits) of the digest.
|
||||
It's suitable for generating consistent, unique identifiers from strings.
|
||||
"""
|
||||
if string is None:
|
||||
string = ""
|
||||
hash_object = hashlib.sha256(string.encode())
|
||||
|
||||
# Take the first 16 bytes (128 bits) of the hash digest
|
||||
digest = hash_object.digest()[:16]
|
||||
def set_span_status(current_span: Span, error: Exception | str | None = None):
|
||||
"""Set the status of the current span based on the presence of an error."""
|
||||
if error:
|
||||
error_string = error_to_string(error)
|
||||
current_span.set_status(Status(StatusCode.ERROR, error_string))
|
||||
|
||||
# Convert to a 128-bit integer
|
||||
return int.from_bytes(digest, byteorder="big")
|
||||
if isinstance(error, Exception):
|
||||
current_span.record_exception(error)
|
||||
else:
|
||||
exception_type = error.__class__.__name__
|
||||
exception_message = str(error)
|
||||
if not exception_message:
|
||||
exception_message = repr(error)
|
||||
attributes: dict[str, AttributeValue] = {
|
||||
OTELSpanAttributes.EXCEPTION_TYPE: exception_type,
|
||||
OTELSpanAttributes.EXCEPTION_MESSAGE: exception_message,
|
||||
OTELSpanAttributes.EXCEPTION_ESCAPED: False,
|
||||
OTELSpanAttributes.EXCEPTION_STACKTRACE: error_string,
|
||||
}
|
||||
current_span.add_event(name="exception", attributes=attributes)
|
||||
else:
|
||||
current_span.set_status(Status(StatusCode.OK))
|
||||
|
||||
|
||||
def safe_json_dumps(obj: Any) -> str:
|
||||
"""A convenience wrapper around `json.dumps` that ensures that any object can be safely encoded."""
|
||||
return json.dumps(obj, default=str, ensure_ascii=False)
|
||||
|
||||
|
||||
class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
|
|
@ -131,9 +156,12 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
self.tracer, self.processor = setup_tracer(arize_phoenix_config)
|
||||
self.project = arize_phoenix_config.project
|
||||
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
self.propagator = TraceContextTextMapPropagator()
|
||||
self.dify_trace_ids: set[str] = set()
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
logger.info("[Arize/Phoenix] Trace: %s", trace_info)
|
||||
logger.info("[Arize/Phoenix] Trace Entity Info: %s", trace_info)
|
||||
logger.info("[Arize/Phoenix] Trace Entity Type: %s", type(trace_info))
|
||||
try:
|
||||
if isinstance(trace_info, WorkflowTraceInfo):
|
||||
self.workflow_trace(trace_info)
|
||||
|
|
@ -151,7 +179,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
self.generate_name_trace(trace_info)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Arize/Phoenix] Error in the trace: %s", str(e), exc_info=True)
|
||||
logger.error("[Arize/Phoenix] Trace Entity Error: %s", str(e), exc_info=True)
|
||||
raise
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
|
|
@ -166,15 +194,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
}
|
||||
workflow_metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = string_to_trace_id128(trace_info.trace_id or trace_info.workflow_run_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
root_span_context = self.propagator.extract(carrier=self.carrier)
|
||||
|
||||
workflow_span = self.tracer.start_span(
|
||||
name=TraceTaskName.WORKFLOW_TRACE.value,
|
||||
|
|
@ -186,31 +208,58 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
context=root_span_context,
|
||||
)
|
||||
|
||||
# Through workflow_run_id, get all_nodes_execution using repository
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
|
||||
# Find the app's creator account
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
|
||||
service_account = self.get_service_account_with_tenant(app_id)
|
||||
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=service_account,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
# Get all executions for this workflow run
|
||||
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
|
||||
workflow_run_id=trace_info.workflow_run_id
|
||||
)
|
||||
|
||||
try:
|
||||
# Process workflow nodes
|
||||
for node_execution in self._get_workflow_nodes(trace_info.workflow_run_id):
|
||||
for node_execution in workflow_node_executions:
|
||||
tenant_id = trace_info.tenant_id # Use from trace_info instead
|
||||
app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
|
||||
inputs_value = node_execution.inputs or {}
|
||||
outputs_value = node_execution.outputs or {}
|
||||
|
||||
created_at = node_execution.created_at or datetime.now()
|
||||
elapsed_time = node_execution.elapsed_time
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
|
||||
process_data = node_execution.process_data or {}
|
||||
execution_metadata = node_execution.metadata or {}
|
||||
node_metadata = {str(k): v for k, v in execution_metadata.items()}
|
||||
|
||||
node_metadata = {
|
||||
"node_id": node_execution.id,
|
||||
"node_type": node_execution.node_type,
|
||||
"node_status": node_execution.status,
|
||||
"tenant_id": node_execution.tenant_id,
|
||||
"app_id": node_execution.app_id,
|
||||
"app_name": node_execution.title,
|
||||
"status": node_execution.status,
|
||||
"level": "ERROR" if node_execution.status != "succeeded" else "DEFAULT",
|
||||
}
|
||||
|
||||
if node_execution.execution_metadata:
|
||||
node_metadata.update(json.loads(node_execution.execution_metadata))
|
||||
node_metadata.update(
|
||||
{
|
||||
"node_id": node_execution.id,
|
||||
"node_type": node_execution.node_type,
|
||||
"node_status": node_execution.status,
|
||||
"tenant_id": tenant_id,
|
||||
"app_id": app_id,
|
||||
"app_name": node_execution.title,
|
||||
"status": node_execution.status,
|
||||
"level": "ERROR" if node_execution.status == "failed" else "DEFAULT",
|
||||
}
|
||||
)
|
||||
|
||||
# Determine the correct span kind based on node type
|
||||
span_kind = OpenInferenceSpanKindValues.CHAIN
|
||||
|
|
@ -223,8 +272,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
if model:
|
||||
node_metadata["ls_model_name"] = model
|
||||
|
||||
outputs = json.loads(node_execution.outputs).get("usage", {}) if "outputs" in node_execution else {}
|
||||
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
|
||||
usage_data = (
|
||||
process_data.get("usage", {}) if "usage" in process_data else outputs_value.get("usage", {})
|
||||
)
|
||||
if usage_data:
|
||||
node_metadata["total_tokens"] = usage_data.get("total_tokens", 0)
|
||||
node_metadata["prompt_tokens"] = usage_data.get("prompt_tokens", 0)
|
||||
|
|
@ -236,17 +286,20 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
else:
|
||||
span_kind = OpenInferenceSpanKindValues.CHAIN
|
||||
|
||||
workflow_span_context = set_span_in_context(workflow_span)
|
||||
node_span = self.tracer.start_span(
|
||||
name=node_execution.node_type,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: node_execution.inputs or "{}",
|
||||
SpanAttributes.OUTPUT_VALUE: node_execution.outputs or "{}",
|
||||
SpanAttributes.INPUT_VALUE: safe_json_dumps(inputs_value),
|
||||
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(outputs_value),
|
||||
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind.value,
|
||||
SpanAttributes.METADATA: json.dumps(node_metadata, ensure_ascii=False),
|
||||
SpanAttributes.METADATA: safe_json_dumps(node_metadata),
|
||||
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
|
||||
},
|
||||
start_time=datetime_to_nanos(created_at),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
context=workflow_span_context,
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
@ -260,11 +313,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
llm_attributes[SpanAttributes.LLM_PROVIDER] = provider
|
||||
if model:
|
||||
llm_attributes[SpanAttributes.LLM_MODEL_NAME] = model
|
||||
outputs = (
|
||||
json.loads(node_execution.outputs).get("usage", {}) if "outputs" in node_execution else {}
|
||||
)
|
||||
usage_data = (
|
||||
process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
|
||||
process_data.get("usage", {}) if "usage" in process_data else outputs_value.get("usage", {})
|
||||
)
|
||||
if usage_data:
|
||||
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = usage_data.get("total_tokens", 0)
|
||||
|
|
@ -275,8 +325,16 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
llm_attributes.update(self._construct_llm_attributes(process_data.get("prompts", [])))
|
||||
node_span.set_attributes(llm_attributes)
|
||||
finally:
|
||||
if node_execution.status == "failed":
|
||||
set_span_status(node_span, node_execution.error)
|
||||
else:
|
||||
set_span_status(node_span)
|
||||
node_span.end(end_time=datetime_to_nanos(finished_at))
|
||||
finally:
|
||||
if trace_info.error:
|
||||
set_span_status(workflow_span, trace_info.error)
|
||||
else:
|
||||
set_span_status(workflow_span)
|
||||
workflow_span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def message_trace(self, trace_info: MessageTraceInfo):
|
||||
|
|
@ -322,34 +380,18 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
|
||||
}
|
||||
|
||||
trace_id = string_to_trace_id128(trace_info.trace_id or trace_info.message_id)
|
||||
message_span_id = RandomIdGenerator().generate_span_id()
|
||||
span_context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=message_span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
root_span_context = self.propagator.extract(carrier=self.carrier)
|
||||
|
||||
message_span = self.tracer.start_span(
|
||||
name=TraceTaskName.MESSAGE_TRACE.value,
|
||||
attributes=attributes,
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
|
||||
context=root_span_context,
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.error:
|
||||
message_span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.error,
|
||||
},
|
||||
)
|
||||
|
||||
# Convert outputs to string based on type
|
||||
if isinstance(trace_info.outputs, dict | list):
|
||||
outputs_str = json.dumps(trace_info.outputs, ensure_ascii=False)
|
||||
|
|
@ -383,26 +425,26 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
if model_params := metadata_dict.get("model_parameters"):
|
||||
llm_attributes[SpanAttributes.LLM_INVOCATION_PARAMETERS] = json.dumps(model_params)
|
||||
|
||||
message_span_context = set_span_in_context(message_span)
|
||||
llm_span = self.tracer.start_span(
|
||||
name="llm",
|
||||
attributes=llm_attributes,
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
|
||||
context=message_span_context,
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.error:
|
||||
llm_span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.error,
|
||||
},
|
||||
)
|
||||
if trace_info.message_data.error:
|
||||
set_span_status(llm_span, trace_info.message_data.error)
|
||||
else:
|
||||
set_span_status(llm_span)
|
||||
finally:
|
||||
llm_span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
finally:
|
||||
if trace_info.error:
|
||||
set_span_status(message_span, trace_info.error)
|
||||
else:
|
||||
set_span_status(message_span)
|
||||
message_span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def moderation_trace(self, trace_info: ModerationTraceInfo):
|
||||
|
|
@ -418,15 +460,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = string_to_trace_id128(trace_info.message_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
root_span_context = self.propagator.extract(carrier=self.carrier)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.MODERATION_TRACE.value,
|
||||
|
|
@ -445,19 +481,14 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
context=root_span_context,
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.message_data.error:
|
||||
span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.message_data.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.message_data.error,
|
||||
},
|
||||
)
|
||||
set_span_status(span, trace_info.message_data.error)
|
||||
else:
|
||||
set_span_status(span)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
|
|
@ -480,15 +511,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = string_to_trace_id128(trace_info.message_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
root_span_context = self.propagator.extract(carrier=self.carrier)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
|
||||
|
|
@ -499,19 +524,14 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
|
||||
},
|
||||
start_time=datetime_to_nanos(start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
context=root_span_context,
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.error:
|
||||
span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.error,
|
||||
},
|
||||
)
|
||||
set_span_status(span, trace_info.error)
|
||||
else:
|
||||
set_span_status(span)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(end_time))
|
||||
|
||||
|
|
@ -533,15 +553,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = string_to_trace_id128(trace_info.message_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
root_span_context = self.propagator.extract(carrier=self.carrier)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
|
||||
|
|
@ -554,19 +568,14 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
"end_time": end_time.isoformat() if end_time else "",
|
||||
},
|
||||
start_time=datetime_to_nanos(start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
context=root_span_context,
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.message_data.error:
|
||||
span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.message_data.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.message_data.error,
|
||||
},
|
||||
)
|
||||
set_span_status(span, trace_info.message_data.error)
|
||||
else:
|
||||
set_span_status(span)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(end_time))
|
||||
|
||||
|
|
@ -580,20 +589,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
"tool_config": json.dumps(trace_info.tool_config, ensure_ascii=False),
|
||||
}
|
||||
|
||||
trace_id = string_to_trace_id128(trace_info.message_id)
|
||||
tool_span_id = RandomIdGenerator().generate_span_id()
|
||||
logger.info("[Arize/Phoenix] Creating tool trace with trace_id: %s, span_id: %s", trace_id, tool_span_id)
|
||||
|
||||
# Create span context with the same trace_id as the parent
|
||||
# todo: Create with the appropriate parent span context, so that the tool span is
|
||||
# a child of the appropriate span (e.g. message span)
|
||||
span_context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=tool_span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
root_span_context = self.propagator.extract(carrier=self.carrier)
|
||||
|
||||
tool_params_str = (
|
||||
json.dumps(trace_info.tool_parameters, ensure_ascii=False)
|
||||
|
|
@ -612,19 +610,14 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
SpanAttributes.TOOL_PARAMETERS: tool_params_str,
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
|
||||
context=root_span_context,
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.error:
|
||||
span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.error,
|
||||
},
|
||||
)
|
||||
set_span_status(span, trace_info.error)
|
||||
else:
|
||||
set_span_status(span)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
|
|
@ -641,15 +634,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = string_to_trace_id128(trace_info.message_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.conversation_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
root_span_context = self.propagator.extract(carrier=self.carrier)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.GENERATE_NAME_TRACE.value,
|
||||
|
|
@ -663,22 +650,34 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
"end_time": trace_info.end_time.isoformat() if trace_info.end_time else "",
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
context=root_span_context,
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.message_data.error:
|
||||
span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.message_data.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.message_data.error,
|
||||
},
|
||||
)
|
||||
set_span_status(span, trace_info.message_data.error)
|
||||
else:
|
||||
set_span_status(span)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def ensure_root_span(self, dify_trace_id: str | None):
|
||||
"""Ensure a unique root span exists for the given Dify trace ID."""
|
||||
if str(dify_trace_id) not in self.dify_trace_ids:
|
||||
self.carrier: dict[str, str] = {}
|
||||
|
||||
root_span = self.tracer.start_span(name="Dify")
|
||||
root_span.set_attribute(SpanAttributes.OPENINFERENCE_SPAN_KIND, OpenInferenceSpanKindValues.CHAIN.value)
|
||||
root_span.set_attribute("dify_project_name", str(self.project))
|
||||
root_span.set_attribute("dify_trace_id", str(dify_trace_id))
|
||||
|
||||
with use_span(root_span, end_on_exit=False):
|
||||
self.propagator.inject(carrier=self.carrier)
|
||||
|
||||
set_span_status(root_span)
|
||||
root_span.end()
|
||||
self.dify_trace_ids.add(str(dify_trace_id))
|
||||
|
||||
def api_check(self):
|
||||
try:
|
||||
with self.tracer.start_span("api_check") as span:
|
||||
|
|
@ -698,26 +697,6 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
logger.info("[Arize/Phoenix] Get run url failed: %s", str(e), exc_info=True)
|
||||
raise ValueError(f"[Arize/Phoenix] Get run url failed: {str(e)}")
|
||||
|
||||
def _get_workflow_nodes(self, workflow_run_id: str):
|
||||
"""Helper method to get workflow nodes"""
|
||||
workflow_nodes = db.session.scalars(
|
||||
select(
|
||||
WorkflowNodeExecutionModel.id,
|
||||
WorkflowNodeExecutionModel.tenant_id,
|
||||
WorkflowNodeExecutionModel.app_id,
|
||||
WorkflowNodeExecutionModel.title,
|
||||
WorkflowNodeExecutionModel.node_type,
|
||||
WorkflowNodeExecutionModel.status,
|
||||
WorkflowNodeExecutionModel.inputs,
|
||||
WorkflowNodeExecutionModel.outputs,
|
||||
WorkflowNodeExecutionModel.created_at,
|
||||
WorkflowNodeExecutionModel.elapsed_time,
|
||||
WorkflowNodeExecutionModel.process_data,
|
||||
WorkflowNodeExecutionModel.execution_metadata,
|
||||
).where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
|
||||
).all()
|
||||
return workflow_nodes
|
||||
|
||||
def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]:
|
||||
"""Helper method to construct LLM attributes with passed prompts."""
|
||||
attributes = {}
|
||||
|
|
|
|||
|
|
@ -147,7 +147,8 @@ class ElasticSearchVector(BaseVector):
|
|||
|
||||
def _get_version(self) -> str:
|
||||
info = self._client.info()
|
||||
return cast(str, info["version"]["number"])
|
||||
# remove any suffix like "-SNAPSHOT" from the version string
|
||||
return cast(str, info["version"]["number"]).split("-")[0]
|
||||
|
||||
def _check_version(self):
|
||||
if parse_version(self._version) < parse_version("8.0.0"):
|
||||
|
|
|
|||
|
|
@ -92,7 +92,7 @@ class WeaviateVector(BaseVector):
|
|||
|
||||
# Parse gRPC configuration
|
||||
if config.grpc_endpoint:
|
||||
# Urls without scheme won't be parsed correctly in some python verions,
|
||||
# Urls without scheme won't be parsed correctly in some python versions,
|
||||
# see https://bugs.python.org/issue27657
|
||||
grpc_endpoint_with_scheme = (
|
||||
config.grpc_endpoint if "://" in config.grpc_endpoint else f"grpc://{config.grpc_endpoint}"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,79 @@
|
|||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
_DEFAULT_TASK_TTL = 60 * 60 # 1 hour
|
||||
|
||||
|
||||
class TaskWrapper(BaseModel):
|
||||
data: Any
|
||||
|
||||
def serialize(self) -> str:
|
||||
return self.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, serialized_data: str) -> "TaskWrapper":
|
||||
return cls.model_validate_json(serialized_data)
|
||||
|
||||
|
||||
class TenantIsolatedTaskQueue:
|
||||
"""
|
||||
Simple queue for tenant isolated tasks, used for rag related tenant tasks isolation.
|
||||
It uses Redis list to store tasks, and Redis key to store task waiting flag.
|
||||
Support tasks that can be serialized by json.
|
||||
"""
|
||||
|
||||
def __init__(self, tenant_id: str, unique_key: str):
|
||||
self._tenant_id = tenant_id
|
||||
self._unique_key = unique_key
|
||||
self._queue = f"tenant_self_{unique_key}_task_queue:{tenant_id}"
|
||||
self._task_key = f"tenant_{unique_key}_task:{tenant_id}"
|
||||
|
||||
def get_task_key(self):
|
||||
return redis_client.get(self._task_key)
|
||||
|
||||
def set_task_waiting_time(self, ttl: int = _DEFAULT_TASK_TTL):
|
||||
redis_client.setex(self._task_key, ttl, 1)
|
||||
|
||||
def delete_task_key(self):
|
||||
redis_client.delete(self._task_key)
|
||||
|
||||
def push_tasks(self, tasks: Sequence[Any]):
|
||||
serialized_tasks = []
|
||||
for task in tasks:
|
||||
# Store str list directly, maintaining full compatibility for pipeline scenarios
|
||||
if isinstance(task, str):
|
||||
serialized_tasks.append(task)
|
||||
else:
|
||||
# Use TaskWrapper to do JSON serialization for non-string tasks
|
||||
wrapper = TaskWrapper(data=task)
|
||||
serialized_data = wrapper.serialize()
|
||||
serialized_tasks.append(serialized_data)
|
||||
|
||||
redis_client.lpush(self._queue, *serialized_tasks)
|
||||
|
||||
def pull_tasks(self, count: int = 1) -> Sequence[Any]:
|
||||
if count <= 0:
|
||||
return []
|
||||
|
||||
tasks = []
|
||||
for _ in range(count):
|
||||
serialized_task = redis_client.rpop(self._queue)
|
||||
if not serialized_task:
|
||||
break
|
||||
|
||||
if isinstance(serialized_task, bytes):
|
||||
serialized_task = serialized_task.decode("utf-8")
|
||||
|
||||
try:
|
||||
wrapper = TaskWrapper.deserialize(serialized_task)
|
||||
tasks.append(wrapper.data)
|
||||
except (json.JSONDecodeError, ValidationError, TypeError, ValueError):
|
||||
# Fall back to raw string for legacy format or invalid JSON
|
||||
tasks.append(serialized_task)
|
||||
|
||||
return tasks
|
||||
|
|
@ -1,16 +1,19 @@
|
|||
import base64
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||
from core.mcp.error import MCPConnectionError
|
||||
from core.mcp.types import CallToolResult, ImageContent, TextContent
|
||||
from core.mcp.types import AudioContent, CallToolResult, ImageContent, TextContent
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.errors import ToolInvokeError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPTool(Tool):
|
||||
def __init__(
|
||||
|
|
@ -52,6 +55,11 @@ class MCPTool(Tool):
|
|||
yield from self._process_text_content(content)
|
||||
elif isinstance(content, ImageContent):
|
||||
yield self._process_image_content(content)
|
||||
elif isinstance(content, AudioContent):
|
||||
yield self._process_audio_content(content)
|
||||
else:
|
||||
logger.warning("Unsupported content type=%s", type(content))
|
||||
|
||||
# handle MCP structured output
|
||||
if self.entity.output_schema and result.structuredContent:
|
||||
for k, v in result.structuredContent.items():
|
||||
|
|
@ -97,6 +105,10 @@ class MCPTool(Tool):
|
|||
"""Process image content and return a blob message."""
|
||||
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
|
||||
|
||||
def _process_audio_content(self, content: AudioContent) -> ToolInvokeMessage:
|
||||
"""Process audio content and return a blob message."""
|
||||
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
|
||||
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool":
|
||||
return MCPTool(
|
||||
entity=self.entity,
|
||||
|
|
|
|||
|
|
@ -153,7 +153,11 @@ class VariablePool(BaseModel):
|
|||
return None
|
||||
|
||||
node_id, name = self._selector_to_keys(selector)
|
||||
segment: Segment | None = self.variable_dictionary[node_id].get(name)
|
||||
node_map = self.variable_dictionary.get(node_id)
|
||||
if node_map is None:
|
||||
return None
|
||||
|
||||
segment: Segment | None = node_map.get(name)
|
||||
|
||||
if segment is None:
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -0,0 +1,134 @@
|
|||
"""
|
||||
Broadcast channel for Pub/Sub messaging.
|
||||
"""
|
||||
|
||||
import types
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Iterator
|
||||
from contextlib import AbstractContextManager
|
||||
from typing import Protocol, Self
|
||||
|
||||
|
||||
class Subscription(AbstractContextManager["Subscription"], Protocol):
|
||||
"""A subscription to a topic that provides an iterator over received messages.
|
||||
The subscription can be used as a context manager and will automatically
|
||||
close when exiting the context.
|
||||
|
||||
Note: `Subscription` instances are not thread-safe. Each thread should create its own
|
||||
subscription.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __iter__(self) -> Iterator[bytes]:
|
||||
"""`__iter__` returns an iterator used to consume the message from this subscription.
|
||||
|
||||
If the caller did not enter the context, `__iter__` may lazily perform the setup before
|
||||
yielding messages; otherwise `__enter__` handles it.”
|
||||
|
||||
If the subscription is closed, then the returned iterator exits without
|
||||
raising any error.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
"""close closes the subscription, releases any resources associated with it."""
|
||||
...
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
"""`__enter__` does the setup logic of the subscription (if any), and return itself."""
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: types.TracebackType | None,
|
||||
) -> bool | None:
|
||||
self.close()
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def receive(self, timeout: float | None = 0.1) -> bytes | None:
|
||||
"""Receive the next message from the broadcast channel.
|
||||
|
||||
If `timeout` is specified, this method returns `None` if no message is
|
||||
received within the given period. If `timeout` is `None`, the call blocks
|
||||
until a message is received.
|
||||
|
||||
Calling receive with `timeout=None` is highly discouraged, as it is impossible to
|
||||
cancel a blocking subscription.
|
||||
|
||||
:param timeout: timeout for receive message, in seconds.
|
||||
|
||||
Returns:
|
||||
bytes: The received message as a byte string, or
|
||||
None: If the timeout expires before a message is received.
|
||||
|
||||
Raises:
|
||||
SubscriptionClosed: If the subscription has already been closed.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class Producer(Protocol):
|
||||
"""Producer is an interface for message publishing. It is already bound to a specific topic.
|
||||
|
||||
`Producer` implementations must be thread-safe and support concurrent use by multiple threads.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def publish(self, payload: bytes) -> None:
|
||||
"""Publish a message to the bounded topic."""
|
||||
...
|
||||
|
||||
|
||||
class Subscriber(Protocol):
|
||||
"""Subscriber is an interface for subscription creation. It is already bound to a specific topic.
|
||||
|
||||
`Subscriber` implementations must be thread-safe and support concurrent use by multiple threads.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def subscribe(self) -> Subscription:
|
||||
pass
|
||||
|
||||
|
||||
class Topic(Producer, Subscriber, Protocol):
|
||||
"""A named channel for publishing and subscribing to messages.
|
||||
|
||||
Topics provide both read and write access. For restricted access,
|
||||
use as_producer() for write-only view or as_subscriber() for read-only view.
|
||||
|
||||
`Topic` implementations must be thread-safe and support concurrent use by multiple threads.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def as_producer(self) -> Producer:
|
||||
"""as_producer creates a write-only view for this topic."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def as_subscriber(self) -> Subscriber:
|
||||
"""as_subscriber create a read-only view for this topic."""
|
||||
...
|
||||
|
||||
|
||||
class BroadcastChannel(Protocol):
|
||||
"""A broadcasting channel is a channel supporting broadcasting semantics.
|
||||
|
||||
Each channel is identified by a topic, different topics are isolated and do not affect each other.
|
||||
|
||||
There can be multiple subscriptions to a specific topic. When a publisher publishes a message to
|
||||
a specific topic, all subscription should receive the published message.
|
||||
|
||||
There are no restriction for the persistence of messages. Once a subscription is created, it
|
||||
should receive all subsequent messages published.
|
||||
|
||||
`BroadcastChannel` implementations must be thread-safe and support concurrent use by multiple threads.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def topic(self, topic: str) -> "Topic":
|
||||
"""topic returns a `Topic` instance for the given topic name."""
|
||||
...
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
class BroadcastChannelError(Exception):
|
||||
"""`BroadcastChannelError` is the base class for all exceptions related
|
||||
to `BroadcastChannel`."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SubscriptionClosedError(BroadcastChannelError):
|
||||
"""SubscriptionClosedError means that the subscription has been closed and
|
||||
methods for consuming messages should not be called."""
|
||||
|
||||
pass
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
from .channel import BroadcastChannel
|
||||
|
||||
__all__ = ["BroadcastChannel"]
|
||||
|
|
@ -0,0 +1,200 @@
|
|||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import types
|
||||
from collections.abc import Generator, Iterator
|
||||
from typing import Self
|
||||
|
||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||
from redis import Redis
|
||||
from redis.client import PubSub
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BroadcastChannel:
|
||||
"""
|
||||
Redis Pub/Sub based broadcast channel implementation.
|
||||
|
||||
Provides "at most once" delivery semantics for messages published to channels.
|
||||
Uses Redis PUBLISH/SUBSCRIBE commands for real-time message delivery.
|
||||
|
||||
The `redis_client` used to construct BroadcastChannel should have `decode_responses` set to `False`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: Redis,
|
||||
):
|
||||
self._client = redis_client
|
||||
|
||||
def topic(self, topic: str) -> "Topic":
|
||||
return Topic(self._client, topic)
|
||||
|
||||
|
||||
class Topic:
|
||||
def __init__(self, redis_client: Redis, topic: str):
|
||||
self._client = redis_client
|
||||
self._topic = topic
|
||||
|
||||
def as_producer(self) -> Producer:
|
||||
return self
|
||||
|
||||
def publish(self, payload: bytes) -> None:
|
||||
self._client.publish(self._topic, payload)
|
||||
|
||||
def as_subscriber(self) -> Subscriber:
|
||||
return self
|
||||
|
||||
def subscribe(self) -> Subscription:
|
||||
return _RedisSubscription(
|
||||
pubsub=self._client.pubsub(),
|
||||
topic=self._topic,
|
||||
)
|
||||
|
||||
|
||||
class _RedisSubscription(Subscription):
|
||||
def __init__(
|
||||
self,
|
||||
pubsub: PubSub,
|
||||
topic: str,
|
||||
):
|
||||
# The _pubsub is None only if the subscription is closed.
|
||||
self._pubsub: PubSub | None = pubsub
|
||||
self._topic = topic
|
||||
self._closed = threading.Event()
|
||||
self._queue: queue.Queue[bytes] = queue.Queue(maxsize=1024)
|
||||
self._dropped_count = 0
|
||||
self._listener_thread: threading.Thread | None = None
|
||||
self._start_lock = threading.Lock()
|
||||
self._started = False
|
||||
|
||||
def _start_if_needed(self) -> None:
|
||||
with self._start_lock:
|
||||
if self._started:
|
||||
return
|
||||
if self._closed.is_set():
|
||||
raise SubscriptionClosedError("The Redis subscription is closed")
|
||||
if self._pubsub is None:
|
||||
raise SubscriptionClosedError("The Redis subscription has been cleaned up")
|
||||
|
||||
self._pubsub.subscribe(self._topic)
|
||||
_logger.debug("Subscribed to channel %s", self._topic)
|
||||
|
||||
self._listener_thread = threading.Thread(
|
||||
target=self._listen,
|
||||
name=f"redis-broadcast-{self._topic}",
|
||||
daemon=True,
|
||||
)
|
||||
self._listener_thread.start()
|
||||
self._started = True
|
||||
|
||||
def _listen(self) -> None:
|
||||
pubsub = self._pubsub
|
||||
assert pubsub is not None, "PubSub should not be None while starting listening."
|
||||
while not self._closed.is_set():
|
||||
raw_message = pubsub.get_message(ignore_subscribe_messages=True, timeout=0.1)
|
||||
|
||||
if raw_message is None:
|
||||
continue
|
||||
|
||||
if raw_message.get("type") != "message":
|
||||
continue
|
||||
|
||||
channel_field = raw_message.get("channel")
|
||||
if isinstance(channel_field, bytes):
|
||||
channel_name = channel_field.decode("utf-8")
|
||||
elif isinstance(channel_field, str):
|
||||
channel_name = channel_field
|
||||
else:
|
||||
channel_name = str(channel_field)
|
||||
|
||||
if channel_name != self._topic:
|
||||
_logger.warning("Ignoring message from unexpected channel %s", channel_name)
|
||||
continue
|
||||
|
||||
payload_bytes: bytes | None = raw_message.get("data")
|
||||
if not isinstance(payload_bytes, bytes):
|
||||
_logger.error("Received invalid data from channel %s, type=%s", self._topic, type(payload_bytes))
|
||||
continue
|
||||
|
||||
self._enqueue_message(payload_bytes)
|
||||
|
||||
_logger.debug("Listener thread stopped for channel %s", self._topic)
|
||||
pubsub.unsubscribe(self._topic)
|
||||
pubsub.close()
|
||||
_logger.debug("PubSub closed for topic %s", self._topic)
|
||||
self._pubsub = None
|
||||
|
||||
def _enqueue_message(self, payload: bytes) -> None:
|
||||
while not self._closed.is_set():
|
||||
try:
|
||||
self._queue.put_nowait(payload)
|
||||
return
|
||||
except queue.Full:
|
||||
try:
|
||||
self._queue.get_nowait()
|
||||
self._dropped_count += 1
|
||||
_logger.debug(
|
||||
"Dropped message from Redis subscription, topic=%s, total_dropped=%d",
|
||||
self._topic,
|
||||
self._dropped_count,
|
||||
)
|
||||
except queue.Empty:
|
||||
continue
|
||||
return
|
||||
|
||||
def _message_iterator(self) -> Generator[bytes, None, None]:
|
||||
while not self._closed.is_set():
|
||||
try:
|
||||
item = self._queue.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
yield item
|
||||
|
||||
def __iter__(self) -> Iterator[bytes]:
|
||||
if self._closed.is_set():
|
||||
raise SubscriptionClosedError("The Redis subscription is closed")
|
||||
self._start_if_needed()
|
||||
return iter(self._message_iterator())
|
||||
|
||||
def receive(self, timeout: float | None = None) -> bytes | None:
|
||||
if self._closed.is_set():
|
||||
raise SubscriptionClosedError("The Redis subscription is closed")
|
||||
self._start_if_needed()
|
||||
|
||||
try:
|
||||
item = self._queue.get(timeout=timeout)
|
||||
except queue.Empty:
|
||||
return None
|
||||
|
||||
return item
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
self._start_if_needed()
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: types.TracebackType | None,
|
||||
) -> bool | None:
|
||||
self.close()
|
||||
return None
|
||||
|
||||
def close(self) -> None:
|
||||
if self._closed.is_set():
|
||||
return
|
||||
|
||||
self._closed.set()
|
||||
# NOTE: PubSub is not thread-safe. More specifically, the `PubSub.close` method and the `PubSub.get_message`
|
||||
# method should NOT be called concurrently.
|
||||
#
|
||||
# Due to the restriction above, the PubSub cleanup logic happens inside the consumer thread.
|
||||
listener = self._listener_thread
|
||||
if listener is not None:
|
||||
listener.join(timeout=1.0)
|
||||
self._listener_thread = None
|
||||
|
|
@ -110,7 +110,7 @@ class Account(UserMixin, TypeBase):
|
|||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False, onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
role: TenantAccountRole | None = field(default=None, init=False)
|
||||
|
|
@ -250,7 +250,9 @@ class Tenant(TypeBase):
|
|||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), init=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.current_timestamp(), init=False, onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
def get_accounts(self) -> list[Account]:
|
||||
return list(
|
||||
|
|
@ -289,7 +291,7 @@ class TenantAccountJoin(TypeBase):
|
|||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False, onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -310,7 +312,7 @@ class AccountIntegrate(TypeBase):
|
|||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False, onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -396,5 +398,5 @@ class TenantPluginAutoUpgradeStrategy(TypeBase):
|
|||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False, onupdate=func.current_timestamp()
|
||||
)
|
||||
|
|
|
|||
|
|
@ -61,18 +61,20 @@ class Dataset(Base):
|
|||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
embedding_model = mapped_column(db.String(255), nullable=True)
|
||||
embedding_model_provider = mapped_column(db.String(255), nullable=True)
|
||||
keyword_number = mapped_column(sa.Integer, nullable=True, server_default=db.text("10"))
|
||||
updated_at = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
embedding_model = mapped_column(sa.String(255), nullable=True)
|
||||
embedding_model_provider = mapped_column(sa.String(255), nullable=True)
|
||||
keyword_number = mapped_column(sa.Integer, nullable=True, server_default=sa.text("10"))
|
||||
collection_binding_id = mapped_column(StringUUID, nullable=True)
|
||||
retrieval_model = mapped_column(JSONB, nullable=True)
|
||||
built_in_field_enabled = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
|
||||
built_in_field_enabled = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
icon_info = mapped_column(JSONB, nullable=True)
|
||||
runtime_mode = mapped_column(db.String(255), nullable=True, server_default=db.text("'general'::character varying"))
|
||||
runtime_mode = mapped_column(sa.String(255), nullable=True, server_default=sa.text("'general'::character varying"))
|
||||
pipeline_id = mapped_column(StringUUID, nullable=True)
|
||||
chunk_structure = mapped_column(db.String(255), nullable=True)
|
||||
enable_api = mapped_column(sa.Boolean, nullable=False, server_default=db.text("true"))
|
||||
chunk_structure = mapped_column(sa.String(255), nullable=True)
|
||||
enable_api = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
|
||||
@property
|
||||
def total_documents(self):
|
||||
|
|
@ -399,7 +401,9 @@ class Document(Base):
|
|||
archived_reason = mapped_column(String(255), nullable=True)
|
||||
archived_by = mapped_column(StringUUID, nullable=True)
|
||||
archived_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
doc_type = mapped_column(String(40), nullable=True)
|
||||
doc_metadata = mapped_column(JSONB, nullable=True)
|
||||
doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'::character varying"))
|
||||
|
|
@ -716,7 +720,9 @@ class DocumentSegment(Base):
|
|||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
error = mapped_column(sa.Text, nullable=True)
|
||||
|
|
@ -881,7 +887,7 @@ class ChildChunk(Base):
|
|||
)
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"), onupdate=func.current_timestamp()
|
||||
)
|
||||
indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
|
|
@ -1036,8 +1042,8 @@ class TidbAuthBinding(Base):
|
|||
tenant_id = mapped_column(StringUUID, nullable=True)
|
||||
cluster_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
cluster_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
|
||||
status = mapped_column(String(255), nullable=False, server_default=db.text("'CREATING'::character varying"))
|
||||
active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
status = mapped_column(String(255), nullable=False, server_default=sa.text("'CREATING'::character varying"))
|
||||
account: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
password: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
|
@ -1088,7 +1094,9 @@ class ExternalKnowledgeApis(Base):
|
|||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
|
|
@ -1141,7 +1149,9 @@ class ExternalKnowledgeBindings(Base):
|
|||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
|
||||
class DatasetAutoDisableLog(Base):
|
||||
|
|
@ -1197,7 +1207,7 @@ class DatasetMetadata(Base):
|
|||
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"), onupdate=func.current_timestamp()
|
||||
)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
|
|
@ -1224,44 +1234,48 @@ class DatasetMetadataBinding(Base):
|
|||
|
||||
class PipelineBuiltInTemplate(Base): # type: ignore[name-defined]
|
||||
__tablename__ = "pipeline_built_in_templates"
|
||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"),)
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"),)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
|
||||
name = mapped_column(db.String(255), nullable=False)
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
|
||||
name = mapped_column(sa.String(255), nullable=False)
|
||||
description = mapped_column(sa.Text, nullable=False)
|
||||
chunk_structure = mapped_column(db.String(255), nullable=False)
|
||||
chunk_structure = mapped_column(sa.String(255), nullable=False)
|
||||
icon = mapped_column(sa.JSON, nullable=False)
|
||||
yaml_content = mapped_column(sa.Text, nullable=False)
|
||||
copyright = mapped_column(db.String(255), nullable=False)
|
||||
privacy_policy = mapped_column(db.String(255), nullable=False)
|
||||
copyright = mapped_column(sa.String(255), nullable=False)
|
||||
privacy_policy = mapped_column(sa.String(255), nullable=False)
|
||||
position = mapped_column(sa.Integer, nullable=False)
|
||||
install_count = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
language = mapped_column(db.String(255), nullable=False)
|
||||
language = mapped_column(sa.String(255), nullable=False)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
|
||||
class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
|
||||
__tablename__ = "pipeline_customized_templates"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="pipeline_customized_template_pkey"),
|
||||
db.Index("pipeline_customized_template_tenant_idx", "tenant_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="pipeline_customized_template_pkey"),
|
||||
sa.Index("pipeline_customized_template_tenant_idx", "tenant_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
name = mapped_column(db.String(255), nullable=False)
|
||||
name = mapped_column(sa.String(255), nullable=False)
|
||||
description = mapped_column(sa.Text, nullable=False)
|
||||
chunk_structure = mapped_column(db.String(255), nullable=False)
|
||||
chunk_structure = mapped_column(sa.String(255), nullable=False)
|
||||
icon = mapped_column(sa.JSON, nullable=False)
|
||||
position = mapped_column(sa.Integer, nullable=False)
|
||||
yaml_content = mapped_column(sa.Text, nullable=False)
|
||||
install_count = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
language = mapped_column(db.String(255), nullable=False)
|
||||
language = mapped_column(sa.String(255), nullable=False)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
@property
|
||||
def created_user_name(self):
|
||||
|
|
@ -1273,19 +1287,21 @@ class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
|
|||
|
||||
class Pipeline(Base): # type: ignore[name-defined]
|
||||
__tablename__ = "pipelines"
|
||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_pkey"),)
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="pipeline_pkey"),)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
name = mapped_column(db.String(255), nullable=False)
|
||||
description = mapped_column(sa.Text, nullable=False, server_default=db.text("''::character varying"))
|
||||
name = mapped_column(sa.String(255), nullable=False)
|
||||
description = mapped_column(sa.Text, nullable=False, server_default=sa.text("''::character varying"))
|
||||
workflow_id = mapped_column(StringUUID, nullable=True)
|
||||
is_public = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
|
||||
is_published = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
|
||||
is_public = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
is_published = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
created_by = mapped_column(StringUUID, nullable=True)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
def retrieve_dataset(self, session: Session):
|
||||
return session.query(Dataset).where(Dataset.pipeline_id == self.id).first()
|
||||
|
|
@ -1294,16 +1310,16 @@ class Pipeline(Base): # type: ignore[name-defined]
|
|||
class DocumentPipelineExecutionLog(Base):
|
||||
__tablename__ = "document_pipeline_execution_logs"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="document_pipeline_execution_log_pkey"),
|
||||
db.Index("document_pipeline_execution_logs_document_id_idx", "document_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="document_pipeline_execution_log_pkey"),
|
||||
sa.Index("document_pipeline_execution_logs_document_id_idx", "document_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
|
||||
pipeline_id = mapped_column(StringUUID, nullable=False)
|
||||
document_id = mapped_column(StringUUID, nullable=False)
|
||||
datasource_type = mapped_column(db.String(255), nullable=False)
|
||||
datasource_type = mapped_column(sa.String(255), nullable=False)
|
||||
datasource_info = mapped_column(sa.Text, nullable=False)
|
||||
datasource_node_id = mapped_column(db.String(255), nullable=False)
|
||||
datasource_node_id = mapped_column(sa.String(255), nullable=False)
|
||||
input_data = mapped_column(sa.JSON, nullable=False)
|
||||
created_by = mapped_column(StringUUID, nullable=True)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
|
@ -1311,12 +1327,14 @@ class DocumentPipelineExecutionLog(Base):
|
|||
|
||||
class PipelineRecommendedPlugin(Base):
|
||||
__tablename__ = "pipeline_recommended_plugins"
|
||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_recommended_plugin_pkey"),)
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="pipeline_recommended_plugin_pkey"),)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
|
||||
plugin_id = mapped_column(sa.Text, nullable=False)
|
||||
provider_name = mapped_column(sa.Text, nullable=False)
|
||||
position = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
active = mapped_column(sa.Boolean, nullable=False, default=True)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
|
|
|||
|
|
@ -95,7 +95,9 @@ class App(Base):
|
|||
created_by = mapped_column(StringUUID, nullable=True)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
use_icon_as_answer_icon: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
|
||||
@property
|
||||
|
|
@ -314,7 +316,9 @@ class AppModelConfig(Base):
|
|||
created_by = mapped_column(StringUUID, nullable=True)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
opening_statement = mapped_column(sa.Text)
|
||||
suggested_questions = mapped_column(sa.Text)
|
||||
suggested_questions_after_answer = mapped_column(sa.Text)
|
||||
|
|
@ -545,7 +549,9 @@ class RecommendedApp(Base):
|
|||
install_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
language = mapped_column(String(255), nullable=False, server_default=sa.text("'en-US'::character varying"))
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
|
|
@ -644,7 +650,9 @@ class Conversation(Base):
|
|||
read_account_id = mapped_column(StringUUID)
|
||||
dialogue_count: Mapped[int] = mapped_column(default=0)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
messages = db.relationship("Message", backref="conversation", lazy="select", passive_deletes="all")
|
||||
message_annotations = db.relationship(
|
||||
|
|
@ -948,7 +956,9 @@ class Message(Base):
|
|||
from_end_user_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||
from_account_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
agent_based: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
workflow_run_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||
app_mode: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
|
|
@ -1296,7 +1306,9 @@ class MessageFeedback(Base):
|
|||
from_end_user_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||
from_account_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
@property
|
||||
def from_account(self) -> Account | None:
|
||||
|
|
@ -1378,7 +1390,9 @@ class MessageAnnotation(Base):
|
|||
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
|
||||
account_id = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
@property
|
||||
def account(self):
|
||||
|
|
@ -1443,7 +1457,9 @@ class AppAnnotationSetting(Base):
|
|||
created_user_id = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_user_id = mapped_column(StringUUID, nullable=False)
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
@property
|
||||
def collection_binding_detail(self):
|
||||
|
|
@ -1471,7 +1487,9 @@ class OperationLog(Base):
|
|||
content = mapped_column(sa.JSON)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_ip: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
|
||||
class DefaultEndUserSessionID(StrEnum):
|
||||
|
|
@ -1510,7 +1528,9 @@ class EndUser(Base, UserMixin):
|
|||
|
||||
session_id: Mapped[str] = mapped_column()
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
|
||||
class AppMCPServer(Base):
|
||||
|
|
@ -1530,7 +1550,9 @@ class AppMCPServer(Base):
|
|||
parameters = mapped_column(sa.Text, nullable=False)
|
||||
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def generate_server_code(n: int) -> str:
|
||||
|
|
@ -1576,7 +1598,9 @@ class Site(Base):
|
|||
created_by = mapped_column(StringUUID, nullable=True)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
code = mapped_column(String(255))
|
||||
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -1,62 +1,66 @@
|
|||
from datetime import datetime
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from .base import Base
|
||||
from .engine import db
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class DatasourceOauthParamConfig(Base): # type: ignore[name-defined]
|
||||
__tablename__ = "datasource_oauth_params"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="datasource_oauth_config_pkey"),
|
||||
db.UniqueConstraint("plugin_id", "provider", name="datasource_oauth_config_datasource_id_provider_idx"),
|
||||
sa.PrimaryKeyConstraint("id", name="datasource_oauth_config_pkey"),
|
||||
sa.UniqueConstraint("plugin_id", "provider", name="datasource_oauth_config_datasource_id_provider_idx"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
|
||||
plugin_id: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
|
||||
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
system_credentials: Mapped[dict] = mapped_column(JSONB, nullable=False)
|
||||
|
||||
|
||||
class DatasourceProvider(Base):
|
||||
__tablename__ = "datasource_providers"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="datasource_provider_pkey"),
|
||||
db.UniqueConstraint("tenant_id", "plugin_id", "provider", "name", name="datasource_provider_unique_name"),
|
||||
db.Index("datasource_provider_auth_type_provider_idx", "tenant_id", "plugin_id", "provider"),
|
||||
sa.PrimaryKeyConstraint("id", name="datasource_provider_pkey"),
|
||||
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", "name", name="datasource_provider_unique_name"),
|
||||
sa.Index("datasource_provider_auth_type_provider_idx", "tenant_id", "plugin_id", "provider"),
|
||||
)
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
name: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
plugin_id: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
auth_type: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
auth_type: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
encrypted_credentials: Mapped[dict] = mapped_column(JSONB, nullable=False)
|
||||
avatar_url: Mapped[str] = mapped_column(sa.Text, nullable=True, default="default")
|
||||
is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
|
||||
is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
expires_at: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default="-1")
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, default=datetime.now)
|
||||
updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, default=datetime.now)
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
|
||||
class DatasourceOauthTenantParamConfig(Base):
|
||||
__tablename__ = "datasource_oauth_tenant_params"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="datasource_oauth_tenant_config_pkey"),
|
||||
db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="datasource_oauth_tenant_config_unique"),
|
||||
sa.PrimaryKeyConstraint("id", name="datasource_oauth_tenant_config_pkey"),
|
||||
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", name="datasource_oauth_tenant_config_unique"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
plugin_id: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
client_params: Mapped[dict] = mapped_column(JSONB, nullable=False, default={})
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, default=datetime.now)
|
||||
updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, default=datetime.now)
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
|
|
|||
|
|
@ -72,7 +72,9 @@ class Provider(Base):
|
|||
quota_used: Mapped[int | None] = mapped_column(sa.BigInteger, default=0)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
|
|
@ -135,7 +137,9 @@ class ProviderModel(Base):
|
|||
credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def credential(self):
|
||||
|
|
@ -170,7 +174,9 @@ class TenantDefaultModel(Base):
|
|||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
|
||||
class TenantPreferredModelProvider(Base):
|
||||
|
|
@ -185,7 +191,9 @@ class TenantPreferredModelProvider(Base):
|
|||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
preferred_provider_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
|
||||
class ProviderOrder(Base):
|
||||
|
|
@ -212,7 +220,9 @@ class ProviderOrder(Base):
|
|||
pay_failed_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||
refunded_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
|
||||
class ProviderModelSetting(Base):
|
||||
|
|
@ -234,7 +244,9 @@ class ProviderModelSetting(Base):
|
|||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"))
|
||||
load_balancing_enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
|
||||
class LoadBalancingModelConfig(Base):
|
||||
|
|
@ -259,7 +271,9 @@ class LoadBalancingModelConfig(Base):
|
|||
credential_source_type: Mapped[str | None] = mapped_column(String(40), nullable=True)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
|
||||
class ProviderCredential(Base):
|
||||
|
|
@ -279,7 +293,9 @@ class ProviderCredential(Base):
|
|||
credential_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
encrypted_config: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
|
||||
class ProviderModelCredential(Base):
|
||||
|
|
@ -307,4 +323,6 @@ class ProviderModelCredential(Base):
|
|||
credential_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
encrypted_config: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
|
|
|||
|
|
@ -140,8 +140,9 @@ class Workflow(Base):
|
|||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
default=naive_utc_now(),
|
||||
server_onupdate=func.current_timestamp(),
|
||||
default=func.current_timestamp(),
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
)
|
||||
_environment_variables: Mapped[str] = mapped_column(
|
||||
"environment_variables", sa.Text, nullable=False, server_default="{}"
|
||||
|
|
@ -150,7 +151,7 @@ class Workflow(Base):
|
|||
"conversation_variables", sa.Text, nullable=False, server_default="{}"
|
||||
)
|
||||
_rag_pipeline_variables: Mapped[str] = mapped_column(
|
||||
"rag_pipeline_variables", db.Text, nullable=False, server_default="{}"
|
||||
"rag_pipeline_variables", sa.Text, nullable=False, server_default="{}"
|
||||
)
|
||||
|
||||
VERSION_DRAFT = "draft"
|
||||
|
|
|
|||
|
|
@ -50,6 +50,7 @@ from models.model import UploadFile
|
|||
from models.provider_ids import ModelProviderID
|
||||
from models.source import DataSourceOauthBinding
|
||||
from models.workflow import Workflow
|
||||
from services.document_indexing_task_proxy import DocumentIndexingTaskProxy
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
ChildChunkUpdateArgs,
|
||||
KnowledgeConfig,
|
||||
|
|
@ -79,7 +80,6 @@ from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
|
|||
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
|
||||
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
|
||||
from tasks.disable_segments_from_index_task import disable_segments_from_index_task
|
||||
from tasks.document_indexing_task import document_indexing_task
|
||||
from tasks.document_indexing_update_task import document_indexing_update_task
|
||||
from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task
|
||||
from tasks.enable_segments_to_index_task import enable_segments_to_index_task
|
||||
|
|
@ -1694,7 +1694,7 @@ class DocumentService:
|
|||
|
||||
# trigger async task
|
||||
if document_ids:
|
||||
document_indexing_task.delay(dataset.id, document_ids)
|
||||
DocumentIndexingTaskProxy(dataset.tenant_id, dataset.id, document_ids).delay()
|
||||
if duplicate_document_ids:
|
||||
duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,83 @@
|
|||
import logging
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import asdict
|
||||
from functools import cached_property
|
||||
|
||||
from core.entities.document_task import DocumentTask
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from services.feature_service import FeatureService
|
||||
from tasks.document_indexing_task import normal_document_indexing_task, priority_document_indexing_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DocumentIndexingTaskProxy:
|
||||
def __init__(self, tenant_id: str, dataset_id: str, document_ids: Sequence[str]):
|
||||
self._tenant_id = tenant_id
|
||||
self._dataset_id = dataset_id
|
||||
self._document_ids = document_ids
|
||||
self._tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing")
|
||||
|
||||
@cached_property
|
||||
def features(self):
|
||||
return FeatureService.get_features(self._tenant_id)
|
||||
|
||||
def _send_to_direct_queue(self, task_func: Callable[[str, str, Sequence[str]], None]):
|
||||
logger.info("send dataset %s to direct queue", self._dataset_id)
|
||||
task_func.delay( # type: ignore
|
||||
tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids
|
||||
)
|
||||
|
||||
def _send_to_tenant_queue(self, task_func: Callable[[str, str, Sequence[str]], None]):
|
||||
logger.info("send dataset %s to tenant queue", self._dataset_id)
|
||||
if self._tenant_isolated_task_queue.get_task_key():
|
||||
# Add to waiting queue using List operations (lpush)
|
||||
self._tenant_isolated_task_queue.push_tasks(
|
||||
[
|
||||
asdict(
|
||||
DocumentTask(
|
||||
tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
logger.info("push tasks: %s - %s", self._dataset_id, self._document_ids)
|
||||
else:
|
||||
# Set flag and execute task
|
||||
self._tenant_isolated_task_queue.set_task_waiting_time()
|
||||
task_func.delay( # type: ignore
|
||||
tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids
|
||||
)
|
||||
logger.info("init tasks: %s - %s", self._dataset_id, self._document_ids)
|
||||
|
||||
def _send_to_default_tenant_queue(self):
|
||||
self._send_to_tenant_queue(normal_document_indexing_task)
|
||||
|
||||
def _send_to_priority_tenant_queue(self):
|
||||
self._send_to_tenant_queue(priority_document_indexing_task)
|
||||
|
||||
def _send_to_priority_direct_queue(self):
|
||||
self._send_to_direct_queue(priority_document_indexing_task)
|
||||
|
||||
def _dispatch(self):
|
||||
logger.info(
|
||||
"dispatch args: %s - %s - %s",
|
||||
self._tenant_id,
|
||||
self.features.billing.enabled,
|
||||
self.features.billing.subscription.plan,
|
||||
)
|
||||
# dispatch to different indexing queue with tenant isolation when billing enabled
|
||||
if self.features.billing.enabled:
|
||||
if self.features.billing.subscription.plan == CloudPlan.SANDBOX:
|
||||
# dispatch to normal pipeline queue with tenant self sub queue for sandbox plan
|
||||
self._send_to_default_tenant_queue()
|
||||
else:
|
||||
# dispatch to priority pipeline queue with tenant self sub queue for other plans
|
||||
self._send_to_priority_tenant_queue()
|
||||
else:
|
||||
# dispatch to priority queue without tenant isolation for others, e.g.: self-hosted or enterprise
|
||||
self._send_to_priority_direct_queue()
|
||||
|
||||
def delay(self):
|
||||
self._dispatch()
|
||||
|
|
@ -0,0 +1,106 @@
|
|||
import json
|
||||
import logging
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import cached_property
|
||||
|
||||
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from services.feature_service import FeatureService
|
||||
from services.file_service import FileService
|
||||
from tasks.rag_pipeline.priority_rag_pipeline_run_task import priority_rag_pipeline_run_task
|
||||
from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RagPipelineTaskProxy:
|
||||
# Default uploaded file name for rag pipeline invoke entities
|
||||
_RAG_PIPELINE_INVOKE_ENTITIES_FILE_NAME = "rag_pipeline_invoke_entities.json"
|
||||
|
||||
def __init__(
|
||||
self, dataset_tenant_id: str, user_id: str, rag_pipeline_invoke_entities: Sequence[RagPipelineInvokeEntity]
|
||||
):
|
||||
self._dataset_tenant_id = dataset_tenant_id
|
||||
self._user_id = user_id
|
||||
self._rag_pipeline_invoke_entities = rag_pipeline_invoke_entities
|
||||
self._tenant_isolated_task_queue = TenantIsolatedTaskQueue(dataset_tenant_id, "pipeline")
|
||||
|
||||
@cached_property
|
||||
def features(self):
|
||||
return FeatureService.get_features(self._dataset_tenant_id)
|
||||
|
||||
def _upload_invoke_entities(self) -> str:
|
||||
text = [item.model_dump() for item in self._rag_pipeline_invoke_entities]
|
||||
# Convert list to proper JSON string
|
||||
json_text = json.dumps(text)
|
||||
upload_file = FileService(db.engine).upload_text(
|
||||
json_text, self._RAG_PIPELINE_INVOKE_ENTITIES_FILE_NAME, self._user_id, self._dataset_tenant_id
|
||||
)
|
||||
return upload_file.id
|
||||
|
||||
def _send_to_direct_queue(self, upload_file_id: str, task_func: Callable[[str, str], None]):
|
||||
logger.info("send file %s to direct queue", upload_file_id)
|
||||
task_func.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=upload_file_id,
|
||||
tenant_id=self._dataset_tenant_id,
|
||||
)
|
||||
|
||||
def _send_to_tenant_queue(self, upload_file_id: str, task_func: Callable[[str, str], None]):
|
||||
logger.info("send file %s to tenant queue", upload_file_id)
|
||||
if self._tenant_isolated_task_queue.get_task_key():
|
||||
# Add to waiting queue using List operations (lpush)
|
||||
self._tenant_isolated_task_queue.push_tasks([upload_file_id])
|
||||
logger.info("push tasks: %s", upload_file_id)
|
||||
else:
|
||||
# Set flag and execute task
|
||||
self._tenant_isolated_task_queue.set_task_waiting_time()
|
||||
task_func.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=upload_file_id,
|
||||
tenant_id=self._dataset_tenant_id,
|
||||
)
|
||||
logger.info("init tasks: %s", upload_file_id)
|
||||
|
||||
def _send_to_default_tenant_queue(self, upload_file_id: str):
|
||||
self._send_to_tenant_queue(upload_file_id, rag_pipeline_run_task)
|
||||
|
||||
def _send_to_priority_tenant_queue(self, upload_file_id: str):
|
||||
self._send_to_tenant_queue(upload_file_id, priority_rag_pipeline_run_task)
|
||||
|
||||
def _send_to_priority_direct_queue(self, upload_file_id: str):
|
||||
self._send_to_direct_queue(upload_file_id, priority_rag_pipeline_run_task)
|
||||
|
||||
def _dispatch(self):
|
||||
upload_file_id = self._upload_invoke_entities()
|
||||
if not upload_file_id:
|
||||
raise ValueError("upload_file_id is empty")
|
||||
|
||||
logger.info(
|
||||
"dispatch args: %s - %s - %s",
|
||||
self._dataset_tenant_id,
|
||||
self.features.billing.enabled,
|
||||
self.features.billing.subscription.plan,
|
||||
)
|
||||
|
||||
# dispatch to different pipeline queue with tenant isolation when billing enabled
|
||||
if self.features.billing.enabled:
|
||||
if self.features.billing.subscription.plan == CloudPlan.SANDBOX:
|
||||
# dispatch to normal pipeline queue with tenant isolation for sandbox plan
|
||||
self._send_to_default_tenant_queue(upload_file_id)
|
||||
else:
|
||||
# dispatch to priority pipeline queue with tenant isolation for other plans
|
||||
self._send_to_priority_tenant_queue(upload_file_id)
|
||||
else:
|
||||
# dispatch to priority pipeline queue without tenant isolation for others, e.g.: self-hosted or enterprise
|
||||
self._send_to_priority_direct_queue(upload_file_id)
|
||||
|
||||
def delay(self):
|
||||
if not self._rag_pipeline_invoke_entities:
|
||||
logger.warning(
|
||||
"Received empty rag pipeline invoke entities, no tasks delivered: %s %s",
|
||||
self._dataset_tenant_id,
|
||||
self._user_id,
|
||||
)
|
||||
return
|
||||
self._dispatch()
|
||||
|
|
@ -126,7 +126,7 @@ workflow:
|
|||
type: mixed
|
||||
value: '{{#rag.1752491761974.jina_use_sitemap#}}'
|
||||
plugin_id: langgenius/jina_datasource
|
||||
provider_name: jina
|
||||
provider_name: jinareader
|
||||
provider_type: website_crawl
|
||||
selected: false
|
||||
title: Jina Reader
|
||||
|
|
|
|||
|
|
@ -126,7 +126,7 @@ workflow:
|
|||
type: mixed
|
||||
value: '{{#rag.1752491761974.jina_use_sitemap#}}'
|
||||
plugin_id: langgenius/jina_datasource
|
||||
provider_name: jina
|
||||
provider_name: jinareader
|
||||
provider_type: website_crawl
|
||||
selected: false
|
||||
title: Jina Reader
|
||||
|
|
|
|||
|
|
@ -419,7 +419,7 @@ workflow:
|
|||
type: mixed
|
||||
value: '{{#rag.1752491761974.jina_use_sitemap#}}'
|
||||
plugin_id: langgenius/jina_datasource
|
||||
provider_name: jina
|
||||
provider_name: jinareader
|
||||
provider_type: website_crawl
|
||||
selected: false
|
||||
title: Jina Reader
|
||||
|
|
|
|||
|
|
@ -1,11 +1,14 @@
|
|||
import logging
|
||||
import time
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.document_task import DocumentTask
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
|
@ -22,8 +25,24 @@ def document_indexing_task(dataset_id: str, document_ids: list):
|
|||
:param dataset_id:
|
||||
:param document_ids:
|
||||
|
||||
.. warning:: TO BE DEPRECATED
|
||||
This function will be deprecated and removed in a future version.
|
||||
Use normal_document_indexing_task or priority_document_indexing_task instead.
|
||||
|
||||
Usage: document_indexing_task.delay(dataset_id, document_ids)
|
||||
"""
|
||||
logger.warning("document indexing legacy mode received: %s - %s", dataset_id, document_ids)
|
||||
_document_indexing(dataset_id, document_ids)
|
||||
|
||||
|
||||
def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
||||
"""
|
||||
Process document for tasks
|
||||
:param dataset_id:
|
||||
:param document_ids:
|
||||
|
||||
Usage: _document_indexing(dataset_id, document_ids)
|
||||
"""
|
||||
documents = []
|
||||
start_at = time.perf_counter()
|
||||
|
||||
|
|
@ -87,3 +106,63 @@ def document_indexing_task(dataset_id: str, document_ids: list):
|
|||
logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
|
||||
def _document_indexing_with_tenant_queue(
|
||||
tenant_id: str, dataset_id: str, document_ids: Sequence[str], task_func: Callable[[str, str, Sequence[str]], None]
|
||||
):
|
||||
try:
|
||||
_document_indexing(dataset_id, document_ids)
|
||||
except Exception:
|
||||
logger.exception("Error processing document indexing %s for tenant %s: %s", dataset_id, tenant_id)
|
||||
finally:
|
||||
tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing")
|
||||
|
||||
# Check if there are waiting tasks in the queue
|
||||
# Use rpop to get the next task from the queue (FIFO order)
|
||||
next_tasks = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY)
|
||||
|
||||
logger.info("document indexing tenant isolation queue next tasks: %s", next_tasks)
|
||||
|
||||
if next_tasks:
|
||||
for next_task in next_tasks:
|
||||
document_task = DocumentTask(**next_task)
|
||||
# Process the next waiting task
|
||||
# Keep the flag set to indicate a task is running
|
||||
tenant_isolated_task_queue.set_task_waiting_time()
|
||||
task_func.delay( # type: ignore
|
||||
tenant_id=document_task.tenant_id,
|
||||
dataset_id=document_task.dataset_id,
|
||||
document_ids=document_task.document_ids,
|
||||
)
|
||||
else:
|
||||
# No more waiting tasks, clear the flag
|
||||
tenant_isolated_task_queue.delete_task_key()
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def normal_document_indexing_task(tenant_id: str, dataset_id: str, document_ids: Sequence[str]):
|
||||
"""
|
||||
Async process document
|
||||
:param tenant_id:
|
||||
:param dataset_id:
|
||||
:param document_ids:
|
||||
|
||||
Usage: normal_document_indexing_task.delay(tenant_id, dataset_id, document_ids)
|
||||
"""
|
||||
logger.info("normal document indexing task received: %s - %s - %s", tenant_id, dataset_id, document_ids)
|
||||
_document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, normal_document_indexing_task)
|
||||
|
||||
|
||||
@shared_task(queue="priority_dataset")
|
||||
def priority_document_indexing_task(tenant_id: str, dataset_id: str, document_ids: Sequence[str]):
|
||||
"""
|
||||
Priority async process document
|
||||
:param tenant_id:
|
||||
:param dataset_id:
|
||||
:param document_ids:
|
||||
|
||||
Usage: priority_document_indexing_task.delay(tenant_id, dataset_id, document_ids)
|
||||
"""
|
||||
logger.info("priority document indexing task received: %s - %s - %s", tenant_id, dataset_id, document_ids)
|
||||
_document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, priority_document_indexing_task)
|
||||
|
|
|
|||
|
|
@ -12,8 +12,10 @@ from celery import shared_task # type: ignore
|
|||
from flask import current_app, g
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
|
||||
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from core.repositories.factory import DifyCoreRepositoryFactory
|
||||
from extensions.ext_database import db
|
||||
from models import Account, Tenant
|
||||
|
|
@ -22,6 +24,8 @@ from models.enums import WorkflowRunTriggeredFrom
|
|||
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from services.file_service import FileService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="priority_pipeline")
|
||||
def priority_rag_pipeline_run_task(
|
||||
|
|
@ -69,6 +73,27 @@ def priority_rag_pipeline_run_task(
|
|||
logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red"))
|
||||
raise
|
||||
finally:
|
||||
tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "pipeline")
|
||||
|
||||
# Check if there are waiting tasks in the queue
|
||||
# Use rpop to get the next task from the queue (FIFO order)
|
||||
next_file_ids = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY)
|
||||
logger.info("priority rag pipeline tenant isolation queue next files: %s", next_file_ids)
|
||||
|
||||
if next_file_ids:
|
||||
for next_file_id in next_file_ids:
|
||||
# Process the next waiting task
|
||||
# Keep the flag set to indicate a task is running
|
||||
tenant_isolated_task_queue.set_task_waiting_time()
|
||||
priority_rag_pipeline_run_task.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8")
|
||||
if isinstance(next_file_id, bytes)
|
||||
else next_file_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
else:
|
||||
# No more waiting tasks, clear the flag
|
||||
tenant_isolated_task_queue.delete_task_key()
|
||||
file_service = FileService(db.engine)
|
||||
file_service.delete_file(rag_pipeline_invoke_entities_file_id)
|
||||
db.session.close()
|
||||
|
|
|
|||
|
|
@ -12,17 +12,20 @@ from celery import shared_task # type: ignore
|
|||
from flask import current_app, g
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
|
||||
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from core.repositories.factory import DifyCoreRepositoryFactory
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models import Account, Tenant
|
||||
from models.dataset import Pipeline
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from services.file_service import FileService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="pipeline")
|
||||
def rag_pipeline_run_task(
|
||||
|
|
@ -70,26 +73,27 @@ def rag_pipeline_run_task(
|
|||
logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red"))
|
||||
raise
|
||||
finally:
|
||||
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{tenant_id}"
|
||||
tenant_pipeline_task_key = f"tenant_pipeline_task:{tenant_id}"
|
||||
tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "pipeline")
|
||||
|
||||
# Check if there are waiting tasks in the queue
|
||||
# Use rpop to get the next task from the queue (FIFO order)
|
||||
next_file_id = redis_client.rpop(tenant_self_pipeline_task_queue)
|
||||
next_file_ids = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY)
|
||||
logger.info("rag pipeline tenant isolation queue next files: %s", next_file_ids)
|
||||
|
||||
if next_file_id:
|
||||
# Process the next waiting task
|
||||
# Keep the flag set to indicate a task is running
|
||||
redis_client.setex(tenant_pipeline_task_key, 60 * 60, 1)
|
||||
rag_pipeline_run_task.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8")
|
||||
if isinstance(next_file_id, bytes)
|
||||
else next_file_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
if next_file_ids:
|
||||
for next_file_id in next_file_ids:
|
||||
# Process the next waiting task
|
||||
# Keep the flag set to indicate a task is running
|
||||
tenant_isolated_task_queue.set_task_waiting_time()
|
||||
rag_pipeline_run_task.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8")
|
||||
if isinstance(next_file_id, bytes)
|
||||
else next_file_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
else:
|
||||
# No more waiting tasks, clear the flag
|
||||
redis_client.delete(tenant_pipeline_task_key)
|
||||
tenant_isolated_task_queue.delete_task_key()
|
||||
file_service = FileService(db.engine)
|
||||
file_service.delete_file(rag_pipeline_invoke_entities_file_id)
|
||||
db.session.close()
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
|
||||
|
|
@ -0,0 +1,595 @@
|
|||
"""
|
||||
Integration tests for TenantIsolatedTaskQueue using testcontainers.
|
||||
|
||||
These tests verify the Redis-based task queue functionality with real Redis instances,
|
||||
testing tenant isolation, task serialization, and queue operations in a realistic environment.
|
||||
Includes compatibility tests for migrating from legacy string-only queues.
|
||||
|
||||
All tests use generic naming to avoid coupling to specific business implementations.
|
||||
"""
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.rag.pipeline.queue import TaskWrapper, TenantIsolatedTaskQueue
|
||||
from extensions.ext_redis import redis_client
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestTask:
|
||||
"""Test task data structure for testing complex object serialization."""
|
||||
|
||||
task_id: str
|
||||
tenant_id: str
|
||||
data: dict[str, Any]
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
class TestTenantIsolatedTaskQueueIntegration:
|
||||
"""Integration tests for TenantIsolatedTaskQueue using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def fake(self):
|
||||
"""Faker instance for generating test data."""
|
||||
return Faker()
|
||||
|
||||
@pytest.fixture
|
||||
def test_tenant_and_account(self, db_session_with_containers, fake):
|
||||
"""Create test tenant and account for testing."""
|
||||
# Create account
|
||||
account = Account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return tenant, account
|
||||
|
||||
@pytest.fixture
|
||||
def test_queue(self, test_tenant_and_account):
|
||||
"""Create a generic test queue for testing."""
|
||||
tenant, _ = test_tenant_and_account
|
||||
return TenantIsolatedTaskQueue(tenant.id, "test_queue")
|
||||
|
||||
@pytest.fixture
|
||||
def secondary_queue(self, test_tenant_and_account):
|
||||
"""Create a secondary test queue for testing isolation."""
|
||||
tenant, _ = test_tenant_and_account
|
||||
return TenantIsolatedTaskQueue(tenant.id, "secondary_queue")
|
||||
|
||||
def test_queue_initialization(self, test_tenant_and_account):
|
||||
"""Test queue initialization with correct key generation."""
|
||||
tenant, _ = test_tenant_and_account
|
||||
queue = TenantIsolatedTaskQueue(tenant.id, "test-key")
|
||||
|
||||
assert queue._tenant_id == tenant.id
|
||||
assert queue._unique_key == "test-key"
|
||||
assert queue._queue == f"tenant_self_test-key_task_queue:{tenant.id}"
|
||||
assert queue._task_key == f"tenant_test-key_task:{tenant.id}"
|
||||
|
||||
def test_tenant_isolation(self, test_tenant_and_account, db_session_with_containers, fake):
|
||||
"""Test that different tenants have isolated queues."""
|
||||
tenant1, _ = test_tenant_and_account
|
||||
|
||||
# Create second tenant
|
||||
tenant2 = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db_session_with_containers.add(tenant2)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
queue1 = TenantIsolatedTaskQueue(tenant1.id, "same-key")
|
||||
queue2 = TenantIsolatedTaskQueue(tenant2.id, "same-key")
|
||||
|
||||
assert queue1._queue != queue2._queue
|
||||
assert queue1._task_key != queue2._task_key
|
||||
assert queue1._queue == f"tenant_self_same-key_task_queue:{tenant1.id}"
|
||||
assert queue2._queue == f"tenant_self_same-key_task_queue:{tenant2.id}"
|
||||
|
||||
def test_key_isolation(self, test_tenant_and_account):
|
||||
"""Test that different keys have isolated queues."""
|
||||
tenant, _ = test_tenant_and_account
|
||||
queue1 = TenantIsolatedTaskQueue(tenant.id, "key1")
|
||||
queue2 = TenantIsolatedTaskQueue(tenant.id, "key2")
|
||||
|
||||
assert queue1._queue != queue2._queue
|
||||
assert queue1._task_key != queue2._task_key
|
||||
assert queue1._queue == f"tenant_self_key1_task_queue:{tenant.id}"
|
||||
assert queue2._queue == f"tenant_self_key2_task_queue:{tenant.id}"
|
||||
|
||||
def test_task_key_operations(self, test_queue):
|
||||
"""Test task key operations (get, set, delete)."""
|
||||
# Initially no task key should exist
|
||||
assert test_queue.get_task_key() is None
|
||||
|
||||
# Set task waiting time with default TTL
|
||||
test_queue.set_task_waiting_time()
|
||||
task_key = test_queue.get_task_key()
|
||||
# Redis returns bytes, convert to string for comparison
|
||||
assert task_key in (b"1", "1")
|
||||
|
||||
# Set task waiting time with custom TTL
|
||||
custom_ttl = 30
|
||||
test_queue.set_task_waiting_time(custom_ttl)
|
||||
task_key = test_queue.get_task_key()
|
||||
assert task_key in (b"1", "1")
|
||||
|
||||
# Delete task key
|
||||
test_queue.delete_task_key()
|
||||
assert test_queue.get_task_key() is None
|
||||
|
||||
def test_push_and_pull_string_tasks(self, test_queue):
|
||||
"""Test pushing and pulling string tasks."""
|
||||
tasks = ["task1", "task2", "task3"]
|
||||
|
||||
# Push tasks
|
||||
test_queue.push_tasks(tasks)
|
||||
|
||||
# Pull tasks (FIFO order)
|
||||
pulled_tasks = test_queue.pull_tasks(3)
|
||||
|
||||
# Should get tasks in FIFO order (lpush + rpop = FIFO)
|
||||
assert pulled_tasks == ["task1", "task2", "task3"]
|
||||
|
||||
def test_push_and_pull_multiple_tasks(self, test_queue):
|
||||
"""Test pushing and pulling multiple tasks at once."""
|
||||
tasks = ["task1", "task2", "task3", "task4", "task5"]
|
||||
|
||||
# Push tasks
|
||||
test_queue.push_tasks(tasks)
|
||||
|
||||
# Pull multiple tasks
|
||||
pulled_tasks = test_queue.pull_tasks(3)
|
||||
assert len(pulled_tasks) == 3
|
||||
assert pulled_tasks == ["task1", "task2", "task3"]
|
||||
|
||||
# Pull remaining tasks
|
||||
remaining_tasks = test_queue.pull_tasks(5)
|
||||
assert len(remaining_tasks) == 2
|
||||
assert remaining_tasks == ["task4", "task5"]
|
||||
|
||||
def test_push_and_pull_complex_objects(self, test_queue, fake):
|
||||
"""Test pushing and pulling complex object tasks."""
|
||||
# Create complex task objects as dictionaries (not dataclass instances)
|
||||
tasks = [
|
||||
{
|
||||
"task_id": str(uuid4()),
|
||||
"tenant_id": test_queue._tenant_id,
|
||||
"data": {
|
||||
"file_id": str(uuid4()),
|
||||
"content": fake.text(),
|
||||
"metadata": {"size": fake.random_int(1000, 10000)},
|
||||
},
|
||||
"metadata": {"created_at": fake.iso8601(), "tags": fake.words(3)},
|
||||
},
|
||||
{
|
||||
"task_id": str(uuid4()),
|
||||
"tenant_id": test_queue._tenant_id,
|
||||
"data": {
|
||||
"file_id": str(uuid4()),
|
||||
"content": "测试中文内容",
|
||||
"metadata": {"size": fake.random_int(1000, 10000)},
|
||||
},
|
||||
"metadata": {"created_at": fake.iso8601(), "tags": ["中文", "测试", "emoji🚀"]},
|
||||
},
|
||||
]
|
||||
|
||||
# Push complex tasks
|
||||
test_queue.push_tasks(tasks)
|
||||
|
||||
# Pull tasks
|
||||
pulled_tasks = test_queue.pull_tasks(2)
|
||||
assert len(pulled_tasks) == 2
|
||||
|
||||
# Verify deserialized tasks match original (FIFO order)
|
||||
for i, pulled_task in enumerate(pulled_tasks):
|
||||
original_task = tasks[i] # FIFO order
|
||||
assert isinstance(pulled_task, dict)
|
||||
assert pulled_task["task_id"] == original_task["task_id"]
|
||||
assert pulled_task["tenant_id"] == original_task["tenant_id"]
|
||||
assert pulled_task["data"] == original_task["data"]
|
||||
assert pulled_task["metadata"] == original_task["metadata"]
|
||||
|
||||
def test_mixed_task_types(self, test_queue, fake):
|
||||
"""Test pushing and pulling mixed string and object tasks."""
|
||||
string_task = "simple_string_task"
|
||||
object_task = {
|
||||
"task_id": str(uuid4()),
|
||||
"dataset_id": str(uuid4()),
|
||||
"document_ids": [str(uuid4()) for _ in range(3)],
|
||||
}
|
||||
|
||||
tasks = [string_task, object_task, "another_string"]
|
||||
|
||||
# Push mixed tasks
|
||||
test_queue.push_tasks(tasks)
|
||||
|
||||
# Pull all tasks
|
||||
pulled_tasks = test_queue.pull_tasks(3)
|
||||
assert len(pulled_tasks) == 3
|
||||
|
||||
# Verify types and content
|
||||
assert pulled_tasks[0] == string_task
|
||||
assert isinstance(pulled_tasks[1], dict)
|
||||
assert pulled_tasks[1] == object_task
|
||||
assert pulled_tasks[2] == "another_string"
|
||||
|
||||
def test_empty_queue_operations(self, test_queue):
|
||||
"""Test operations on empty queue."""
|
||||
# Pull from empty queue
|
||||
tasks = test_queue.pull_tasks(5)
|
||||
assert tasks == []
|
||||
|
||||
# Pull zero or negative count
|
||||
assert test_queue.pull_tasks(0) == []
|
||||
assert test_queue.pull_tasks(-1) == []
|
||||
|
||||
def test_task_ttl_expiration(self, test_queue):
|
||||
"""Test task key TTL expiration."""
|
||||
# Set task with short TTL
|
||||
short_ttl = 2
|
||||
test_queue.set_task_waiting_time(short_ttl)
|
||||
|
||||
# Verify task key exists
|
||||
assert test_queue.get_task_key() == b"1" or test_queue.get_task_key() == "1"
|
||||
|
||||
# Wait for TTL to expire
|
||||
time.sleep(short_ttl + 1)
|
||||
|
||||
# Verify task key has expired
|
||||
assert test_queue.get_task_key() is None
|
||||
|
||||
def test_large_task_batch(self, test_queue, fake):
|
||||
"""Test handling large batches of tasks."""
|
||||
# Create large batch of tasks
|
||||
large_batch = []
|
||||
for i in range(100):
|
||||
task = {
|
||||
"task_id": str(uuid4()),
|
||||
"index": i,
|
||||
"data": fake.text(max_nb_chars=100),
|
||||
"metadata": {"batch_id": str(uuid4())},
|
||||
}
|
||||
large_batch.append(task)
|
||||
|
||||
# Push large batch
|
||||
test_queue.push_tasks(large_batch)
|
||||
|
||||
# Pull all tasks
|
||||
pulled_tasks = test_queue.pull_tasks(100)
|
||||
assert len(pulled_tasks) == 100
|
||||
|
||||
# Verify all tasks were retrieved correctly (FIFO order)
|
||||
for i, task in enumerate(pulled_tasks):
|
||||
assert isinstance(task, dict)
|
||||
assert task["index"] == i # FIFO order
|
||||
|
||||
def test_queue_operations_isolation(self, test_tenant_and_account, fake):
|
||||
"""Test concurrent operations on different queues."""
|
||||
tenant, _ = test_tenant_and_account
|
||||
|
||||
# Create multiple queues for the same tenant
|
||||
queue1 = TenantIsolatedTaskQueue(tenant.id, "queue1")
|
||||
queue2 = TenantIsolatedTaskQueue(tenant.id, "queue2")
|
||||
|
||||
# Push tasks to different queues
|
||||
queue1.push_tasks(["task1_queue1", "task2_queue1"])
|
||||
queue2.push_tasks(["task1_queue2", "task2_queue2"])
|
||||
|
||||
# Verify queues are isolated
|
||||
tasks1 = queue1.pull_tasks(2)
|
||||
tasks2 = queue2.pull_tasks(2)
|
||||
|
||||
assert tasks1 == ["task1_queue1", "task2_queue1"]
|
||||
assert tasks2 == ["task1_queue2", "task2_queue2"]
|
||||
assert tasks1 != tasks2
|
||||
|
||||
def test_task_wrapper_serialization_roundtrip(self, test_queue, fake):
|
||||
"""Test TaskWrapper serialization and deserialization roundtrip."""
|
||||
# Create complex nested data
|
||||
complex_data = {
|
||||
"id": str(uuid4()),
|
||||
"nested": {"deep": {"value": "test", "numbers": [1, 2, 3, 4, 5], "unicode": "测试中文", "emoji": "🚀"}},
|
||||
"metadata": {"created_at": fake.iso8601(), "tags": ["tag1", "tag2", "tag3"]},
|
||||
}
|
||||
|
||||
# Create wrapper and serialize
|
||||
wrapper = TaskWrapper(data=complex_data)
|
||||
serialized = wrapper.serialize()
|
||||
|
||||
# Verify serialization
|
||||
assert isinstance(serialized, str)
|
||||
assert "测试中文" in serialized
|
||||
assert "🚀" in serialized
|
||||
|
||||
# Deserialize and verify
|
||||
deserialized_wrapper = TaskWrapper.deserialize(serialized)
|
||||
assert deserialized_wrapper.data == complex_data
|
||||
|
||||
def test_error_handling_invalid_json(self, test_queue):
|
||||
"""Test error handling for invalid JSON in wrapped tasks."""
|
||||
# Manually create invalid JSON task (not a valid TaskWrapper JSON)
|
||||
invalid_json_task = "invalid json data"
|
||||
|
||||
# Push invalid task directly to Redis
|
||||
redis_client.lpush(test_queue._queue, invalid_json_task)
|
||||
|
||||
# Pull task - should fall back to string since it's not valid JSON
|
||||
task = test_queue.pull_tasks(1)
|
||||
assert task[0] == invalid_json_task
|
||||
|
||||
def test_real_world_batch_processing_scenario(self, test_queue, fake):
|
||||
"""Test realistic batch processing scenario."""
|
||||
# Simulate batch processing tasks
|
||||
batch_tasks = []
|
||||
for i in range(3):
|
||||
task = {
|
||||
"file_id": str(uuid4()),
|
||||
"tenant_id": test_queue._tenant_id,
|
||||
"user_id": str(uuid4()),
|
||||
"processing_config": {
|
||||
"model": fake.random_element(["model_a", "model_b", "model_c"]),
|
||||
"temperature": fake.random.uniform(0.1, 1.0),
|
||||
"max_tokens": fake.random_int(1000, 4000),
|
||||
},
|
||||
"metadata": {
|
||||
"source": fake.random_element(["upload", "api", "webhook"]),
|
||||
"priority": fake.random_element(["low", "normal", "high"]),
|
||||
},
|
||||
}
|
||||
batch_tasks.append(task)
|
||||
|
||||
# Push tasks
|
||||
test_queue.push_tasks(batch_tasks)
|
||||
|
||||
# Process tasks in batches
|
||||
batch_size = 2
|
||||
processed_tasks = []
|
||||
|
||||
while True:
|
||||
batch = test_queue.pull_tasks(batch_size)
|
||||
if not batch:
|
||||
break
|
||||
|
||||
processed_tasks.extend(batch)
|
||||
|
||||
# Verify all tasks were processed
|
||||
assert len(processed_tasks) == 3
|
||||
|
||||
# Verify task structure
|
||||
for task in processed_tasks:
|
||||
assert isinstance(task, dict)
|
||||
assert "file_id" in task
|
||||
assert "tenant_id" in task
|
||||
assert "processing_config" in task
|
||||
assert "metadata" in task
|
||||
assert task["tenant_id"] == test_queue._tenant_id
|
||||
|
||||
|
||||
class TestTenantIsolatedTaskQueueCompatibility:
|
||||
"""Compatibility tests for migrating from legacy string-only queues."""
|
||||
|
||||
@pytest.fixture
|
||||
def fake(self):
|
||||
"""Faker instance for generating test data."""
|
||||
return Faker()
|
||||
|
||||
@pytest.fixture
|
||||
def test_tenant_and_account(self, db_session_with_containers, fake):
|
||||
"""Create test tenant and account for testing."""
|
||||
# Create account
|
||||
account = Account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return tenant, account
|
||||
|
||||
def test_legacy_string_queue_compatibility(self, test_tenant_and_account, fake):
|
||||
"""
|
||||
Test compatibility with legacy queues containing only string data.
|
||||
|
||||
This simulates the scenario where Redis queues already contain string data
|
||||
from the old architecture, and we need to ensure the new code can read them.
|
||||
"""
|
||||
tenant, _ = test_tenant_and_account
|
||||
queue = TenantIsolatedTaskQueue(tenant.id, "legacy_queue")
|
||||
|
||||
# Simulate legacy string data in Redis queue (using old format)
|
||||
legacy_strings = ["legacy_task_1", "legacy_task_2", "legacy_task_3", "legacy_task_4", "legacy_task_5"]
|
||||
|
||||
# Manually push legacy strings directly to Redis (simulating old system)
|
||||
for legacy_string in legacy_strings:
|
||||
redis_client.lpush(queue._queue, legacy_string)
|
||||
|
||||
# Verify new code can read legacy string data
|
||||
pulled_tasks = queue.pull_tasks(5)
|
||||
assert len(pulled_tasks) == 5
|
||||
|
||||
# Verify all tasks are strings (not wrapped)
|
||||
for task in pulled_tasks:
|
||||
assert isinstance(task, str)
|
||||
assert task.startswith("legacy_task_")
|
||||
|
||||
# Verify order (FIFO from Redis list)
|
||||
expected_order = ["legacy_task_1", "legacy_task_2", "legacy_task_3", "legacy_task_4", "legacy_task_5"]
|
||||
assert pulled_tasks == expected_order
|
||||
|
||||
def test_legacy_queue_migration_scenario(self, test_tenant_and_account, fake):
|
||||
"""
|
||||
Test complete migration scenario from legacy to new system.
|
||||
|
||||
This simulates the real-world scenario where:
|
||||
1. Legacy system has string data in Redis
|
||||
2. New system starts processing the same queue
|
||||
3. Both legacy and new tasks coexist during migration
|
||||
4. New system can handle both formats seamlessly
|
||||
"""
|
||||
tenant, _ = test_tenant_and_account
|
||||
queue = TenantIsolatedTaskQueue(tenant.id, "migration_queue")
|
||||
|
||||
# Phase 1: Legacy system has data
|
||||
legacy_tasks = [f"legacy_resource_{i}" for i in range(1, 6)]
|
||||
redis_client.lpush(queue._queue, *legacy_tasks)
|
||||
|
||||
# Phase 2: New system starts processing legacy data
|
||||
processed_legacy = []
|
||||
while True:
|
||||
tasks = queue.pull_tasks(1)
|
||||
if not tasks:
|
||||
break
|
||||
processed_legacy.extend(tasks)
|
||||
|
||||
# Verify legacy data was processed correctly
|
||||
assert len(processed_legacy) == 5
|
||||
for task in processed_legacy:
|
||||
assert isinstance(task, str)
|
||||
assert task.startswith("legacy_resource_")
|
||||
|
||||
# Phase 3: New system adds new tasks (mixed types)
|
||||
new_string_tasks = ["new_resource_1", "new_resource_2"]
|
||||
new_object_tasks = [
|
||||
{
|
||||
"resource_id": str(uuid4()),
|
||||
"tenant_id": tenant.id,
|
||||
"processing_type": "new_system",
|
||||
"metadata": {"version": "2.0", "features": ["ai", "ml"]},
|
||||
},
|
||||
{
|
||||
"resource_id": str(uuid4()),
|
||||
"tenant_id": tenant.id,
|
||||
"processing_type": "new_system",
|
||||
"metadata": {"version": "2.0", "features": ["ai", "ml"]},
|
||||
},
|
||||
]
|
||||
|
||||
# Push new tasks using new system
|
||||
queue.push_tasks(new_string_tasks)
|
||||
queue.push_tasks(new_object_tasks)
|
||||
|
||||
# Phase 4: Process all new tasks
|
||||
processed_new = []
|
||||
while True:
|
||||
tasks = queue.pull_tasks(1)
|
||||
if not tasks:
|
||||
break
|
||||
processed_new.extend(tasks)
|
||||
|
||||
# Verify new tasks were processed correctly
|
||||
assert len(processed_new) == 4
|
||||
|
||||
string_tasks = [task for task in processed_new if isinstance(task, str)]
|
||||
object_tasks = [task for task in processed_new if isinstance(task, dict)]
|
||||
|
||||
assert len(string_tasks) == 2
|
||||
assert len(object_tasks) == 2
|
||||
|
||||
# Verify string tasks
|
||||
for task in string_tasks:
|
||||
assert task.startswith("new_resource_")
|
||||
|
||||
# Verify object tasks
|
||||
for task in object_tasks:
|
||||
assert isinstance(task, dict)
|
||||
assert "resource_id" in task
|
||||
assert "tenant_id" in task
|
||||
assert task["tenant_id"] == tenant.id
|
||||
assert task["processing_type"] == "new_system"
|
||||
|
||||
def test_legacy_queue_error_recovery(self, test_tenant_and_account, fake):
|
||||
"""
|
||||
Test error recovery when legacy queue contains malformed data.
|
||||
|
||||
This ensures the new system can gracefully handle corrupted or
|
||||
malformed legacy data without crashing.
|
||||
"""
|
||||
tenant, _ = test_tenant_and_account
|
||||
queue = TenantIsolatedTaskQueue(tenant.id, "error_recovery_queue")
|
||||
|
||||
# Create mix of valid and malformed legacy data
|
||||
mixed_legacy_data = [
|
||||
"valid_legacy_task_1",
|
||||
"valid_legacy_task_2",
|
||||
"malformed_data_string", # This should be treated as string
|
||||
"valid_legacy_task_3",
|
||||
"invalid_json_not_taskwrapper_format", # This should fall back to string (not valid TaskWrapper JSON)
|
||||
"valid_legacy_task_4",
|
||||
]
|
||||
|
||||
# Manually push mixed data directly to Redis
|
||||
redis_client.lpush(queue._queue, *mixed_legacy_data)
|
||||
|
||||
# Process all tasks
|
||||
processed_tasks = []
|
||||
while True:
|
||||
tasks = queue.pull_tasks(1)
|
||||
if not tasks:
|
||||
break
|
||||
processed_tasks.extend(tasks)
|
||||
|
||||
# Verify all tasks were processed (no crashes)
|
||||
assert len(processed_tasks) == 6
|
||||
|
||||
# Verify all tasks are strings (malformed data falls back to string)
|
||||
for task in processed_tasks:
|
||||
assert isinstance(task, str)
|
||||
|
||||
# Verify valid tasks are preserved
|
||||
valid_tasks = [task for task in processed_tasks if task.startswith("valid_legacy_task_")]
|
||||
assert len(valid_tasks) == 4
|
||||
|
||||
# Verify malformed data is handled gracefully
|
||||
malformed_tasks = [task for task in processed_tasks if not task.startswith("valid_legacy_task_")]
|
||||
assert len(malformed_tasks) == 2
|
||||
assert "malformed_data_string" in malformed_tasks
|
||||
assert "invalid_json_not_taskwrapper_format" in malformed_tasks
|
||||
|
|
@ -0,0 +1,311 @@
|
|||
"""
|
||||
Integration tests for Redis broadcast channel implementation using TestContainers.
|
||||
|
||||
This test suite covers real Redis interactions including:
|
||||
- Multiple producer/consumer scenarios
|
||||
- Network failure scenarios
|
||||
- Performance under load
|
||||
- Real-world usage patterns
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Iterator
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import pytest
|
||||
import redis
|
||||
from testcontainers.redis import RedisContainer
|
||||
|
||||
from libs.broadcast_channel.channel import BroadcastChannel, Subscription, Topic
|
||||
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||
from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel
|
||||
|
||||
|
||||
class TestRedisBroadcastChannelIntegration:
|
||||
"""Integration tests for Redis broadcast channel with real Redis instance."""
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def redis_container(self) -> Iterator[RedisContainer]:
|
||||
"""Create a Redis container for integration testing."""
|
||||
with RedisContainer(image="redis:6-alpine") as container:
|
||||
yield container
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def redis_client(self, redis_container: RedisContainer) -> redis.Redis:
|
||||
"""Create a Redis client connected to the test container."""
|
||||
host = redis_container.get_container_host_ip()
|
||||
port = redis_container.get_exposed_port(6379)
|
||||
return redis.Redis(host=host, port=port, decode_responses=False)
|
||||
|
||||
@pytest.fixture
|
||||
def broadcast_channel(self, redis_client: redis.Redis) -> BroadcastChannel:
|
||||
"""Create a BroadcastChannel instance with real Redis client."""
|
||||
return RedisBroadcastChannel(redis_client)
|
||||
|
||||
@classmethod
|
||||
def _get_test_topic_name(cls):
|
||||
return f"test_topic_{uuid.uuid4()}"
|
||||
|
||||
# ==================== Basic Functionality Tests ===================='
|
||||
|
||||
def test_close_an_active_subscription_should_stop_iteration(self, broadcast_channel):
|
||||
topic_name = self._get_test_topic_name()
|
||||
topic = broadcast_channel.topic(topic_name)
|
||||
subscription = topic.subscribe()
|
||||
consuming_event = threading.Event()
|
||||
|
||||
def consume():
|
||||
msgs = []
|
||||
consuming_event.set()
|
||||
for msg in subscription:
|
||||
msgs.append(msg)
|
||||
return msgs
|
||||
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
producer_future = executor.submit(consume)
|
||||
consuming_event.wait()
|
||||
subscription.close()
|
||||
msgs = producer_future.result(timeout=1)
|
||||
assert msgs == []
|
||||
|
||||
def test_end_to_end_messaging(self, broadcast_channel: BroadcastChannel):
|
||||
"""Test complete end-to-end messaging flow."""
|
||||
topic_name = "test-topic"
|
||||
message = b"hello world"
|
||||
|
||||
# Create producer and subscriber
|
||||
topic = broadcast_channel.topic(topic_name)
|
||||
producer = topic.as_producer()
|
||||
subscription = topic.subscribe()
|
||||
|
||||
# Publish and receive message
|
||||
|
||||
def producer_thread():
|
||||
time.sleep(0.1) # Small delay to ensure subscriber is ready
|
||||
producer.publish(message)
|
||||
time.sleep(0.1)
|
||||
subscription.close()
|
||||
|
||||
def consumer_thread() -> list[bytes]:
|
||||
received_messages = []
|
||||
for msg in subscription:
|
||||
received_messages.append(msg)
|
||||
return received_messages
|
||||
|
||||
# Run producer and consumer
|
||||
with ThreadPoolExecutor(max_workers=2) as executor:
|
||||
producer_future = executor.submit(producer_thread)
|
||||
consumer_future = executor.submit(consumer_thread)
|
||||
|
||||
# Wait for completion
|
||||
producer_future.result(timeout=5.0)
|
||||
received_messages = consumer_future.result(timeout=5.0)
|
||||
|
||||
assert len(received_messages) == 1
|
||||
assert received_messages[0] == message
|
||||
|
||||
def test_multiple_subscribers_same_topic(self, broadcast_channel: BroadcastChannel):
|
||||
"""Test message broadcasting to multiple subscribers."""
|
||||
topic_name = "broadcast-topic"
|
||||
message = b"broadcast message"
|
||||
subscriber_count = 5
|
||||
|
||||
# Create producer and multiple subscribers
|
||||
topic = broadcast_channel.topic(topic_name)
|
||||
producer = topic.as_producer()
|
||||
subscriptions = [topic.subscribe() for _ in range(subscriber_count)]
|
||||
|
||||
def producer_thread():
|
||||
time.sleep(0.2) # Allow all subscribers to connect
|
||||
producer.publish(message)
|
||||
time.sleep(0.2)
|
||||
for sub in subscriptions:
|
||||
sub.close()
|
||||
|
||||
def consumer_thread(subscription: Subscription) -> list[bytes]:
|
||||
received_msgs = []
|
||||
while True:
|
||||
try:
|
||||
msg = subscription.receive(0.1)
|
||||
except SubscriptionClosedError:
|
||||
break
|
||||
if msg is None:
|
||||
continue
|
||||
received_msgs.append(msg)
|
||||
if len(received_msgs) >= 1:
|
||||
break
|
||||
return received_msgs
|
||||
|
||||
# Run producer and consumers
|
||||
with ThreadPoolExecutor(max_workers=subscriber_count + 1) as executor:
|
||||
producer_future = executor.submit(producer_thread)
|
||||
consumer_futures = [executor.submit(consumer_thread, subscription) for subscription in subscriptions]
|
||||
|
||||
# Wait for completion
|
||||
producer_future.result(timeout=10.0)
|
||||
msgs_by_consumers = []
|
||||
for future in as_completed(consumer_futures, timeout=10.0):
|
||||
msgs_by_consumers.append(future.result())
|
||||
|
||||
# Close all subscriptions
|
||||
for subscription in subscriptions:
|
||||
subscription.close()
|
||||
|
||||
# Verify all subscribers received the message
|
||||
for msgs in msgs_by_consumers:
|
||||
assert len(msgs) == 1
|
||||
assert msgs[0] == message
|
||||
|
||||
def test_topic_isolation(self, broadcast_channel: BroadcastChannel):
|
||||
"""Test that different topics are isolated from each other."""
|
||||
topic1_name = "topic1"
|
||||
topic2_name = "topic2"
|
||||
message1 = b"message for topic1"
|
||||
message2 = b"message for topic2"
|
||||
|
||||
# Create producers and subscribers for different topics
|
||||
topic1 = broadcast_channel.topic(topic1_name)
|
||||
topic2 = broadcast_channel.topic(topic2_name)
|
||||
|
||||
def producer_thread():
|
||||
time.sleep(0.1)
|
||||
topic1.publish(message1)
|
||||
topic2.publish(message2)
|
||||
|
||||
def consumer_by_thread(topic: Topic) -> list[bytes]:
|
||||
subscription = topic.subscribe()
|
||||
received = []
|
||||
with subscription:
|
||||
for msg in subscription:
|
||||
received.append(msg)
|
||||
if len(received) >= 1:
|
||||
break
|
||||
return received
|
||||
|
||||
# Run all threads
|
||||
with ThreadPoolExecutor(max_workers=3) as executor:
|
||||
producer_future = executor.submit(producer_thread)
|
||||
consumer1_future = executor.submit(consumer_by_thread, topic1)
|
||||
consumer2_future = executor.submit(consumer_by_thread, topic2)
|
||||
|
||||
# Wait for completion
|
||||
producer_future.result(timeout=5.0)
|
||||
received_by_topic1 = consumer1_future.result(timeout=5.0)
|
||||
received_by_topic2 = consumer2_future.result(timeout=5.0)
|
||||
|
||||
# Verify topic isolation
|
||||
assert len(received_by_topic1) == 1
|
||||
assert len(received_by_topic2) == 1
|
||||
assert received_by_topic1[0] == message1
|
||||
assert received_by_topic2[0] == message2
|
||||
|
||||
# ==================== Performance Tests ====================
|
||||
|
||||
def test_concurrent_producers(self, broadcast_channel: BroadcastChannel):
|
||||
"""Test multiple producers publishing to the same topic."""
|
||||
topic_name = "concurrent-producers-topic"
|
||||
producer_count = 5
|
||||
messages_per_producer = 5
|
||||
|
||||
topic = broadcast_channel.topic(topic_name)
|
||||
subscription = topic.subscribe()
|
||||
|
||||
expected_total = producer_count * messages_per_producer
|
||||
consumer_ready = threading.Event()
|
||||
|
||||
def producer_thread(producer_idx: int) -> set[bytes]:
|
||||
producer = topic.as_producer()
|
||||
produced = set()
|
||||
for i in range(messages_per_producer):
|
||||
message = f"producer_{producer_idx}_msg_{i}".encode()
|
||||
produced.add(message)
|
||||
producer.publish(message)
|
||||
time.sleep(0.001) # Small delay to avoid overwhelming
|
||||
return produced
|
||||
|
||||
def consumer_thread() -> set[bytes]:
|
||||
received_msgs: set[bytes] = set()
|
||||
with subscription:
|
||||
consumer_ready.set()
|
||||
while True:
|
||||
try:
|
||||
msg = subscription.receive(timeout=0.1)
|
||||
except SubscriptionClosedError:
|
||||
break
|
||||
if msg is None:
|
||||
if len(received_msgs) >= expected_total:
|
||||
break
|
||||
else:
|
||||
continue
|
||||
|
||||
received_msgs.add(msg)
|
||||
return received_msgs
|
||||
|
||||
# Run producers and consumer
|
||||
with ThreadPoolExecutor(max_workers=producer_count + 1) as executor:
|
||||
consumer_future = executor.submit(consumer_thread)
|
||||
consumer_ready.wait()
|
||||
producer_futures = [executor.submit(producer_thread, i) for i in range(producer_count)]
|
||||
|
||||
sent_msgs: set[bytes] = set()
|
||||
# Wait for completion
|
||||
for future in as_completed(producer_futures, timeout=30.0):
|
||||
sent_msgs.update(future.result())
|
||||
|
||||
subscription.close()
|
||||
consumer_received_msgs = consumer_future.result(timeout=30.0)
|
||||
|
||||
# Verify message content
|
||||
assert sent_msgs == consumer_received_msgs
|
||||
|
||||
# ==================== Resource Management Tests ====================
|
||||
|
||||
def test_subscription_cleanup(self, broadcast_channel: BroadcastChannel, redis_client: redis.Redis):
|
||||
"""Test proper cleanup of subscription resources."""
|
||||
topic_name = "cleanup-test-topic"
|
||||
|
||||
# Create multiple subscriptions
|
||||
topic = broadcast_channel.topic(topic_name)
|
||||
|
||||
def _consume(sub: Subscription):
|
||||
for i in sub:
|
||||
pass
|
||||
|
||||
subscriptions = []
|
||||
for i in range(5):
|
||||
subscription = topic.subscribe()
|
||||
subscriptions.append(subscription)
|
||||
|
||||
# Start all subscriptions
|
||||
thread = threading.Thread(target=_consume, args=(subscription,))
|
||||
thread.start()
|
||||
time.sleep(0.01)
|
||||
|
||||
# Verify subscriptions are active
|
||||
pubsub_info = redis_client.pubsub_numsub(topic_name)
|
||||
# pubsub_numsub returns list of tuples, find our topic
|
||||
topic_subscribers = 0
|
||||
for channel, count in pubsub_info:
|
||||
# the channel name returned by redis is bytes.
|
||||
if channel == topic_name.encode():
|
||||
topic_subscribers = count
|
||||
break
|
||||
assert topic_subscribers >= 5
|
||||
|
||||
# Close all subscriptions
|
||||
for subscription in subscriptions:
|
||||
subscription.close()
|
||||
|
||||
# Wait a bit for cleanup
|
||||
time.sleep(1)
|
||||
|
||||
# Verify subscriptions are cleaned up
|
||||
pubsub_info_after = redis_client.pubsub_numsub(topic_name)
|
||||
topic_subscribers_after = 0
|
||||
for channel, count in pubsub_info_after:
|
||||
if channel == topic_name.encode():
|
||||
topic_subscribers_after = count
|
||||
break
|
||||
assert topic_subscribers_after == 0
|
||||
|
|
@ -1,17 +1,33 @@
|
|||
from dataclasses import asdict
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.entities.document_task import DocumentTask
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document
|
||||
from tasks.document_indexing_task import document_indexing_task
|
||||
from tasks.document_indexing_task import (
|
||||
_document_indexing, # Core function
|
||||
_document_indexing_with_tenant_queue, # Tenant queue wrapper function
|
||||
document_indexing_task, # Deprecated old interface
|
||||
normal_document_indexing_task, # New normal task
|
||||
priority_document_indexing_task, # New priority task
|
||||
)
|
||||
|
||||
|
||||
class TestDocumentIndexingTask:
|
||||
"""Integration tests for document_indexing_task using testcontainers."""
|
||||
class TestDocumentIndexingTasks:
|
||||
"""Integration tests for document indexing tasks using testcontainers.
|
||||
|
||||
This test class covers:
|
||||
- Core _document_indexing function
|
||||
- Deprecated document_indexing_task function
|
||||
- New normal_document_indexing_task function
|
||||
- New priority_document_indexing_task function
|
||||
- Tenant queue wrapper _document_indexing_with_tenant_queue function
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
|
|
@ -224,7 +240,7 @@ class TestDocumentIndexingTask:
|
|||
document_ids = [doc.id for doc in documents]
|
||||
|
||||
# Act: Execute the task
|
||||
document_indexing_task(dataset.id, document_ids)
|
||||
_document_indexing(dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
# Verify indexing runner was called correctly
|
||||
|
|
@ -232,10 +248,11 @@ class TestDocumentIndexingTask:
|
|||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify documents were updated to parsing status
|
||||
for document in documents:
|
||||
db.session.refresh(document)
|
||||
assert document.indexing_status == "parsing"
|
||||
assert document.processing_started_at is not None
|
||||
# Re-query documents from database since _document_indexing uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
# Verify the run method was called with correct documents
|
||||
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
|
||||
|
|
@ -261,7 +278,7 @@ class TestDocumentIndexingTask:
|
|||
document_ids = [fake.uuid4() for _ in range(3)]
|
||||
|
||||
# Act: Execute the task with non-existent dataset
|
||||
document_indexing_task(non_existent_dataset_id, document_ids)
|
||||
_document_indexing(non_existent_dataset_id, document_ids)
|
||||
|
||||
# Assert: Verify no processing occurred
|
||||
mock_external_service_dependencies["indexing_runner"].assert_not_called()
|
||||
|
|
@ -291,17 +308,18 @@ class TestDocumentIndexingTask:
|
|||
all_document_ids = existing_document_ids + non_existent_document_ids
|
||||
|
||||
# Act: Execute the task with mixed document IDs
|
||||
document_indexing_task(dataset.id, all_document_ids)
|
||||
_document_indexing(dataset.id, all_document_ids)
|
||||
|
||||
# Assert: Verify only existing documents were processed
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify only existing documents were updated
|
||||
for document in documents:
|
||||
db.session.refresh(document)
|
||||
assert document.indexing_status == "parsing"
|
||||
assert document.processing_started_at is not None
|
||||
# Re-query documents from database since _document_indexing uses a different session
|
||||
for doc_id in existing_document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
# Verify the run method was called with only existing documents
|
||||
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
|
||||
|
|
@ -333,7 +351,7 @@ class TestDocumentIndexingTask:
|
|||
)
|
||||
|
||||
# Act: Execute the task
|
||||
document_indexing_task(dataset.id, document_ids)
|
||||
_document_indexing(dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify exception was handled gracefully
|
||||
# The task should complete without raising exceptions
|
||||
|
|
@ -341,10 +359,11 @@ class TestDocumentIndexingTask:
|
|||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify documents were still updated to parsing status before the exception
|
||||
for document in documents:
|
||||
db.session.refresh(document)
|
||||
assert document.indexing_status == "parsing"
|
||||
assert document.processing_started_at is not None
|
||||
# Re-query documents from database since _document_indexing close the session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
def test_document_indexing_task_mixed_document_states(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
|
|
@ -407,17 +426,18 @@ class TestDocumentIndexingTask:
|
|||
document_ids = [doc.id for doc in all_documents]
|
||||
|
||||
# Act: Execute the task with mixed document states
|
||||
document_indexing_task(dataset.id, document_ids)
|
||||
_document_indexing(dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify processing
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify all documents were updated to parsing status
|
||||
for document in all_documents:
|
||||
db.session.refresh(document)
|
||||
assert document.indexing_status == "parsing"
|
||||
assert document.processing_started_at is not None
|
||||
# Re-query documents from database since _document_indexing uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
# Verify the run method was called with all documents
|
||||
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
|
||||
|
|
@ -470,15 +490,16 @@ class TestDocumentIndexingTask:
|
|||
document_ids = [doc.id for doc in all_documents]
|
||||
|
||||
# Act: Execute the task with too many documents for sandbox plan
|
||||
document_indexing_task(dataset.id, document_ids)
|
||||
_document_indexing(dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify error handling
|
||||
for document in all_documents:
|
||||
db.session.refresh(document)
|
||||
assert document.indexing_status == "error"
|
||||
assert document.error is not None
|
||||
assert "batch upload" in document.error
|
||||
assert document.stopped_at is not None
|
||||
# Re-query documents from database since _document_indexing uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "error"
|
||||
assert updated_document.error is not None
|
||||
assert "batch upload" in updated_document.error
|
||||
assert updated_document.stopped_at is not None
|
||||
|
||||
# Verify no indexing runner was called
|
||||
mock_external_service_dependencies["indexing_runner"].assert_not_called()
|
||||
|
|
@ -503,17 +524,18 @@ class TestDocumentIndexingTask:
|
|||
document_ids = [doc.id for doc in documents]
|
||||
|
||||
# Act: Execute the task with billing disabled
|
||||
document_indexing_task(dataset.id, document_ids)
|
||||
_document_indexing(dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify successful processing
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify documents were updated to parsing status
|
||||
for document in documents:
|
||||
db.session.refresh(document)
|
||||
assert document.indexing_status == "parsing"
|
||||
assert document.processing_started_at is not None
|
||||
# Re-query documents from database since _document_indexing uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
def test_document_indexing_task_document_is_paused_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
|
|
@ -541,7 +563,7 @@ class TestDocumentIndexingTask:
|
|||
)
|
||||
|
||||
# Act: Execute the task
|
||||
document_indexing_task(dataset.id, document_ids)
|
||||
_document_indexing(dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify exception was handled gracefully
|
||||
# The task should complete without raising exceptions
|
||||
|
|
@ -549,7 +571,317 @@ class TestDocumentIndexingTask:
|
|||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify documents were still updated to parsing status before the exception
|
||||
for document in documents:
|
||||
db.session.refresh(document)
|
||||
assert document.indexing_status == "parsing"
|
||||
assert document.processing_started_at is not None
|
||||
# Re-query documents from database since _document_indexing uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
# ==================== NEW TESTS FOR REFACTORED FUNCTIONS ====================
|
||||
def test_old_document_indexing_task_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test document_indexing_task basic functionality.
|
||||
|
||||
This test verifies:
|
||||
- Task function calls the wrapper correctly
|
||||
- Basic parameter passing works
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
dataset, documents = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=1
|
||||
)
|
||||
document_ids = [doc.id for doc in documents]
|
||||
|
||||
# Act: Execute the deprecated task (it only takes 2 parameters)
|
||||
document_indexing_task(dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify processing occurred (core logic is tested in _document_indexing tests)
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
def test_normal_document_indexing_task_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test normal_document_indexing_task basic functionality.
|
||||
|
||||
This test verifies:
|
||||
- Task function calls the wrapper correctly
|
||||
- Basic parameter passing works
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
dataset, documents = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=1
|
||||
)
|
||||
document_ids = [doc.id for doc in documents]
|
||||
tenant_id = dataset.tenant_id
|
||||
|
||||
# Act: Execute the new normal task
|
||||
normal_document_indexing_task(tenant_id, dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify processing occurred (core logic is tested in _document_indexing tests)
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
def test_priority_document_indexing_task_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test priority_document_indexing_task basic functionality.
|
||||
|
||||
This test verifies:
|
||||
- Task function calls the wrapper correctly
|
||||
- Basic parameter passing works
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
dataset, documents = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=1
|
||||
)
|
||||
document_ids = [doc.id for doc in documents]
|
||||
tenant_id = dataset.tenant_id
|
||||
|
||||
# Act: Execute the new priority task
|
||||
priority_document_indexing_task(tenant_id, dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify processing occurred (core logic is tested in _document_indexing tests)
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
def test_document_indexing_with_tenant_queue_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test _document_indexing_with_tenant_queue function with no waiting tasks.
|
||||
|
||||
This test verifies:
|
||||
- Core indexing logic execution (same as _document_indexing)
|
||||
- Tenant queue cleanup when no waiting tasks
|
||||
- Task function parameter passing
|
||||
- Queue management after processing
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
dataset, documents = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=2
|
||||
)
|
||||
document_ids = [doc.id for doc in documents]
|
||||
tenant_id = dataset.tenant_id
|
||||
|
||||
# Mock the task function
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_task_func = MagicMock()
|
||||
|
||||
# Act: Execute the wrapper function
|
||||
_document_indexing_with_tenant_queue(tenant_id, dataset.id, document_ids, mock_task_func)
|
||||
|
||||
# Assert: Verify core processing occurred (same as _document_indexing)
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify documents were updated (same as _document_indexing)
|
||||
# Re-query documents from database since _document_indexing uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
# Verify the run method was called with correct documents
|
||||
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
|
||||
assert call_args is not None
|
||||
processed_documents = call_args[0][0]
|
||||
assert len(processed_documents) == 2
|
||||
|
||||
# Verify task function was not called (no waiting tasks)
|
||||
mock_task_func.delay.assert_not_called()
|
||||
|
||||
def test_document_indexing_with_tenant_queue_with_waiting_tasks(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test _document_indexing_with_tenant_queue function with waiting tasks in queue using real Redis.
|
||||
|
||||
This test verifies:
|
||||
- Core indexing logic execution
|
||||
- Real Redis-based tenant queue processing of waiting tasks
|
||||
- Task function calls for waiting tasks
|
||||
- Queue management with multiple tasks using actual Redis operations
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
dataset, documents = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=1
|
||||
)
|
||||
document_ids = [doc.id for doc in documents]
|
||||
tenant_id = dataset.tenant_id
|
||||
dataset_id = dataset.id
|
||||
|
||||
# Mock the task function
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_task_func = MagicMock()
|
||||
|
||||
# Use real Redis for TenantIsolatedTaskQueue
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
|
||||
# Create real queue instance
|
||||
queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing")
|
||||
|
||||
# Add waiting tasks to the real Redis queue
|
||||
waiting_tasks = [
|
||||
DocumentTask(tenant_id=tenant_id, dataset_id=dataset.id, document_ids=["waiting-doc-1"]),
|
||||
DocumentTask(tenant_id=tenant_id, dataset_id=dataset.id, document_ids=["waiting-doc-2"]),
|
||||
]
|
||||
# Convert DocumentTask objects to dictionaries for serialization
|
||||
waiting_task_dicts = [asdict(task) for task in waiting_tasks]
|
||||
queue.push_tasks(waiting_task_dicts)
|
||||
|
||||
# Act: Execute the wrapper function
|
||||
_document_indexing_with_tenant_queue(tenant_id, dataset.id, document_ids, mock_task_func)
|
||||
|
||||
# Assert: Verify core processing occurred
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify task function was called for each waiting task
|
||||
assert mock_task_func.delay.call_count == 1
|
||||
|
||||
# Verify correct parameters for each call
|
||||
calls = mock_task_func.delay.call_args_list
|
||||
assert calls[0][1] == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]}
|
||||
|
||||
# Verify queue is empty after processing (tasks were pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10) # Pull more than we added
|
||||
assert len(remaining_tasks) == 1
|
||||
|
||||
def test_document_indexing_with_tenant_queue_error_handling(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error handling in _document_indexing_with_tenant_queue using real Redis.
|
||||
|
||||
This test verifies:
|
||||
- Exception handling during core processing
|
||||
- Tenant queue cleanup even on errors using real Redis
|
||||
- Proper error logging
|
||||
- Function completes without raising exceptions
|
||||
- Queue management continues despite core processing errors
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
dataset, documents = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=1
|
||||
)
|
||||
document_ids = [doc.id for doc in documents]
|
||||
tenant_id = dataset.tenant_id
|
||||
dataset_id = dataset.id
|
||||
|
||||
# Mock IndexingRunner to raise an exception
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.side_effect = Exception("Test error")
|
||||
|
||||
# Mock the task function
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_task_func = MagicMock()
|
||||
|
||||
# Use real Redis for TenantIsolatedTaskQueue
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
|
||||
# Create real queue instance
|
||||
queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing")
|
||||
|
||||
# Add waiting task to the real Redis queue
|
||||
waiting_task = DocumentTask(tenant_id=tenant_id, dataset_id=dataset.id, document_ids=["waiting-doc-1"])
|
||||
queue.push_tasks([asdict(waiting_task)])
|
||||
|
||||
# Act: Execute the wrapper function
|
||||
_document_indexing_with_tenant_queue(tenant_id, dataset.id, document_ids, mock_task_func)
|
||||
|
||||
# Assert: Verify error was handled gracefully
|
||||
# The function should not raise exceptions
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify documents were still updated to parsing status before the exception
|
||||
# Re-query documents from database since _document_indexing uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
# Verify waiting task was still processed despite core processing error
|
||||
mock_task_func.delay.assert_called_once()
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call = mock_task_func.delay.call_args
|
||||
assert call[1] == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]}
|
||||
|
||||
# Verify queue is empty after processing (task was pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10)
|
||||
assert len(remaining_tasks) == 0
|
||||
|
||||
def test_document_indexing_with_tenant_queue_tenant_isolation(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tenant isolation in _document_indexing_with_tenant_queue using real Redis.
|
||||
|
||||
This test verifies:
|
||||
- Different tenants have isolated queues
|
||||
- Tasks from one tenant don't affect another tenant's queue
|
||||
- Queue operations are properly scoped to tenant
|
||||
"""
|
||||
# Arrange: Create test data for two different tenants
|
||||
dataset1, documents1 = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=1
|
||||
)
|
||||
dataset2, documents2 = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=1
|
||||
)
|
||||
|
||||
tenant1_id = dataset1.tenant_id
|
||||
tenant2_id = dataset2.tenant_id
|
||||
dataset1_id = dataset1.id
|
||||
dataset2_id = dataset2.id
|
||||
document_ids1 = [doc.id for doc in documents1]
|
||||
document_ids2 = [doc.id for doc in documents2]
|
||||
|
||||
# Mock the task function
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_task_func = MagicMock()
|
||||
|
||||
# Use real Redis for TenantIsolatedTaskQueue
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
|
||||
# Create queue instances for both tenants
|
||||
queue1 = TenantIsolatedTaskQueue(tenant1_id, "document_indexing")
|
||||
queue2 = TenantIsolatedTaskQueue(tenant2_id, "document_indexing")
|
||||
|
||||
# Add waiting tasks to both queues
|
||||
waiting_task1 = DocumentTask(tenant_id=tenant1_id, dataset_id=dataset1.id, document_ids=["tenant1-doc-1"])
|
||||
waiting_task2 = DocumentTask(tenant_id=tenant2_id, dataset_id=dataset2.id, document_ids=["tenant2-doc-1"])
|
||||
|
||||
queue1.push_tasks([asdict(waiting_task1)])
|
||||
queue2.push_tasks([asdict(waiting_task2)])
|
||||
|
||||
# Act: Execute the wrapper function for tenant1 only
|
||||
_document_indexing_with_tenant_queue(tenant1_id, dataset1.id, document_ids1, mock_task_func)
|
||||
|
||||
# Assert: Verify core processing occurred for tenant1
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify only tenant1's waiting task was processed
|
||||
mock_task_func.delay.assert_called_once()
|
||||
call = mock_task_func.delay.call_args
|
||||
assert call[1] == {"tenant_id": tenant1_id, "dataset_id": dataset1_id, "document_ids": ["tenant1-doc-1"]}
|
||||
|
||||
# Verify tenant1's queue is empty
|
||||
remaining_tasks1 = queue1.pull_tasks(count=10)
|
||||
assert len(remaining_tasks1) == 0
|
||||
|
||||
# Verify tenant2's queue still has its task (isolation)
|
||||
remaining_tasks2 = queue2.pull_tasks(count=10)
|
||||
assert len(remaining_tasks2) == 1
|
||||
|
||||
# Verify queue keys are different
|
||||
assert queue1._queue != queue2._queue
|
||||
assert queue1._task_key != queue2._task_key
|
||||
|
|
|
|||
|
|
@ -0,0 +1,936 @@
|
|||
import json
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
|
||||
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from extensions.ext_database import db
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Pipeline
|
||||
from models.workflow import Workflow
|
||||
from tasks.rag_pipeline.priority_rag_pipeline_run_task import (
|
||||
priority_rag_pipeline_run_task,
|
||||
run_single_rag_pipeline_task,
|
||||
)
|
||||
from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task
|
||||
|
||||
|
||||
class TestRagPipelineRunTasks:
|
||||
"""Integration tests for RAG pipeline run tasks using testcontainers.
|
||||
|
||||
This test class covers:
|
||||
- priority_rag_pipeline_run_task function
|
||||
- rag_pipeline_run_task function
|
||||
- run_single_rag_pipeline_task function
|
||||
- Real Redis-based TenantIsolatedTaskQueue operations
|
||||
- PipelineGenerator._generate method mocking and parameter validation
|
||||
- File operations and cleanup
|
||||
- Error handling and queue management
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pipeline_generator(self):
|
||||
"""Mock PipelineGenerator._generate method."""
|
||||
with patch("core.app.apps.pipeline.pipeline_generator.PipelineGenerator._generate") as mock_generate:
|
||||
# Mock the _generate method to return a simple response
|
||||
mock_generate.return_value = {"answer": "Test response", "metadata": {"test": "data"}}
|
||||
yield mock_generate
|
||||
|
||||
@pytest.fixture
|
||||
def mock_file_service(self):
|
||||
"""Mock FileService for file operations."""
|
||||
with (
|
||||
patch("services.file_service.FileService.get_file_content") as mock_get_content,
|
||||
patch("services.file_service.FileService.delete_file") as mock_delete_file,
|
||||
):
|
||||
yield {
|
||||
"get_content": mock_get_content,
|
||||
"delete_file": mock_delete_file,
|
||||
}
|
||||
|
||||
def _create_test_pipeline_and_workflow(self, db_session_with_containers):
|
||||
"""
|
||||
Helper method to create test pipeline and workflow for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
|
||||
Returns:
|
||||
tuple: (account, tenant, pipeline, workflow) - Created entities
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant
|
||||
account = Account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
|
||||
# Create workflow
|
||||
workflow = Workflow(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
app_id=str(uuid.uuid4()),
|
||||
type="workflow",
|
||||
version="draft",
|
||||
graph="{}",
|
||||
features="{}",
|
||||
marked_name=fake.company(),
|
||||
marked_comment=fake.text(max_nb_chars=100),
|
||||
created_by=account.id,
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
rag_pipeline_variables=[],
|
||||
)
|
||||
db.session.add(workflow)
|
||||
db.session.commit()
|
||||
|
||||
# Create pipeline
|
||||
pipeline = Pipeline(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
workflow_id=workflow.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
created_by=account.id,
|
||||
)
|
||||
db.session.add(pipeline)
|
||||
db.session.commit()
|
||||
|
||||
# Refresh entities to ensure they're properly loaded
|
||||
db.session.refresh(account)
|
||||
db.session.refresh(tenant)
|
||||
db.session.refresh(workflow)
|
||||
db.session.refresh(pipeline)
|
||||
|
||||
return account, tenant, pipeline, workflow
|
||||
|
||||
def _create_rag_pipeline_invoke_entities(self, account, tenant, pipeline, workflow, count=2):
|
||||
"""
|
||||
Helper method to create RAG pipeline invoke entities for testing.
|
||||
|
||||
Args:
|
||||
account: Account instance
|
||||
tenant: Tenant instance
|
||||
pipeline: Pipeline instance
|
||||
workflow: Workflow instance
|
||||
count: Number of entities to create
|
||||
|
||||
Returns:
|
||||
list: List of RagPipelineInvokeEntity instances
|
||||
"""
|
||||
fake = Faker()
|
||||
entities = []
|
||||
|
||||
for i in range(count):
|
||||
# Create application generate entity
|
||||
app_config = {
|
||||
"app_id": str(uuid.uuid4()),
|
||||
"app_name": fake.company(),
|
||||
"mode": "workflow",
|
||||
"workflow_id": workflow.id,
|
||||
"tenant_id": tenant.id,
|
||||
"app_mode": "workflow",
|
||||
}
|
||||
|
||||
application_generate_entity = {
|
||||
"task_id": str(uuid.uuid4()),
|
||||
"app_config": app_config,
|
||||
"inputs": {"query": f"Test query {i}"},
|
||||
"files": [],
|
||||
"user_id": account.id,
|
||||
"stream": False,
|
||||
"invoke_from": "published",
|
||||
"workflow_execution_id": str(uuid.uuid4()),
|
||||
"pipeline_config": {
|
||||
"app_id": str(uuid.uuid4()),
|
||||
"app_name": fake.company(),
|
||||
"mode": "workflow",
|
||||
"workflow_id": workflow.id,
|
||||
"tenant_id": tenant.id,
|
||||
"app_mode": "workflow",
|
||||
},
|
||||
"datasource_type": "upload_file",
|
||||
"datasource_info": {},
|
||||
"dataset_id": str(uuid.uuid4()),
|
||||
"batch": "test_batch",
|
||||
}
|
||||
|
||||
entity = RagPipelineInvokeEntity(
|
||||
pipeline_id=pipeline.id,
|
||||
application_generate_entity=application_generate_entity,
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
workflow_id=workflow.id,
|
||||
streaming=False,
|
||||
workflow_execution_id=str(uuid.uuid4()),
|
||||
workflow_thread_pool_id=str(uuid.uuid4()),
|
||||
)
|
||||
entities.append(entity)
|
||||
|
||||
return entities
|
||||
|
||||
def _create_file_content_for_entities(self, entities):
|
||||
"""
|
||||
Helper method to create file content for RAG pipeline invoke entities.
|
||||
|
||||
Args:
|
||||
entities: List of RagPipelineInvokeEntity instances
|
||||
|
||||
Returns:
|
||||
str: JSON string containing serialized entities
|
||||
"""
|
||||
entities_data = [entity.model_dump() for entity in entities]
|
||||
return json.dumps(entities_data)
|
||||
|
||||
def test_priority_rag_pipeline_run_task_success(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test successful priority RAG pipeline run task execution.
|
||||
|
||||
This test verifies:
|
||||
- Task execution with multiple RAG pipeline invoke entities
|
||||
- File content retrieval and parsing
|
||||
- PipelineGenerator._generate method calls with correct parameters
|
||||
- Thread pool execution
|
||||
- File cleanup after execution
|
||||
- Queue management with no waiting tasks
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=2)
|
||||
file_content = self._create_file_content_for_entities(entities)
|
||||
|
||||
# Mock file service
|
||||
file_id = str(uuid.uuid4())
|
||||
mock_file_service["get_content"].return_value = file_content
|
||||
|
||||
# Act: Execute the priority task
|
||||
priority_rag_pipeline_run_task(file_id, tenant.id)
|
||||
|
||||
# Assert: Verify expected outcomes
|
||||
# Verify file operations
|
||||
mock_file_service["get_content"].assert_called_once_with(file_id)
|
||||
mock_file_service["delete_file"].assert_called_once_with(file_id)
|
||||
|
||||
# Verify PipelineGenerator._generate was called for each entity
|
||||
assert mock_pipeline_generator.call_count == 2
|
||||
|
||||
# Verify call parameters for each entity
|
||||
calls = mock_pipeline_generator.call_args_list
|
||||
for call in calls:
|
||||
call_kwargs = call[1] # Get keyword arguments
|
||||
assert call_kwargs["pipeline"].id == pipeline.id
|
||||
assert call_kwargs["workflow_id"] == workflow.id
|
||||
assert call_kwargs["user"].id == account.id
|
||||
assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED
|
||||
assert call_kwargs["streaming"] == False
|
||||
assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity)
|
||||
|
||||
def test_rag_pipeline_run_task_success(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test successful regular RAG pipeline run task execution.
|
||||
|
||||
This test verifies:
|
||||
- Task execution with multiple RAG pipeline invoke entities
|
||||
- File content retrieval and parsing
|
||||
- PipelineGenerator._generate method calls with correct parameters
|
||||
- Thread pool execution
|
||||
- File cleanup after execution
|
||||
- Queue management with no waiting tasks
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=3)
|
||||
file_content = self._create_file_content_for_entities(entities)
|
||||
|
||||
# Mock file service
|
||||
file_id = str(uuid.uuid4())
|
||||
mock_file_service["get_content"].return_value = file_content
|
||||
|
||||
# Act: Execute the regular task
|
||||
rag_pipeline_run_task(file_id, tenant.id)
|
||||
|
||||
# Assert: Verify expected outcomes
|
||||
# Verify file operations
|
||||
mock_file_service["get_content"].assert_called_once_with(file_id)
|
||||
mock_file_service["delete_file"].assert_called_once_with(file_id)
|
||||
|
||||
# Verify PipelineGenerator._generate was called for each entity
|
||||
assert mock_pipeline_generator.call_count == 3
|
||||
|
||||
# Verify call parameters for each entity
|
||||
calls = mock_pipeline_generator.call_args_list
|
||||
for call in calls:
|
||||
call_kwargs = call[1] # Get keyword arguments
|
||||
assert call_kwargs["pipeline"].id == pipeline.id
|
||||
assert call_kwargs["workflow_id"] == workflow.id
|
||||
assert call_kwargs["user"].id == account.id
|
||||
assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED
|
||||
assert call_kwargs["streaming"] == False
|
||||
assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity)
|
||||
|
||||
def test_priority_rag_pipeline_run_task_with_waiting_tasks(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test priority RAG pipeline run task with waiting tasks in queue using real Redis.
|
||||
|
||||
This test verifies:
|
||||
- Core task execution
|
||||
- Real Redis-based tenant queue processing of waiting tasks
|
||||
- Task function calls for waiting tasks
|
||||
- Queue management with multiple tasks using actual Redis operations
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1)
|
||||
file_content = self._create_file_content_for_entities(entities)
|
||||
|
||||
# Mock file service
|
||||
file_id = str(uuid.uuid4())
|
||||
mock_file_service["get_content"].return_value = file_content
|
||||
|
||||
# Use real Redis for TenantIsolatedTaskQueue
|
||||
queue = TenantIsolatedTaskQueue(tenant.id, "pipeline")
|
||||
|
||||
# Add waiting tasks to the real Redis queue
|
||||
waiting_file_ids = [str(uuid.uuid4()) for _ in range(2)]
|
||||
queue.push_tasks(waiting_file_ids)
|
||||
|
||||
# Mock the task function calls
|
||||
with patch(
|
||||
"tasks.rag_pipeline.priority_rag_pipeline_run_task.priority_rag_pipeline_run_task.delay"
|
||||
) as mock_delay:
|
||||
# Act: Execute the priority task
|
||||
priority_rag_pipeline_run_task(file_id, tenant.id)
|
||||
|
||||
# Assert: Verify core processing occurred
|
||||
mock_file_service["get_content"].assert_called_once_with(file_id)
|
||||
mock_file_service["delete_file"].assert_called_once_with(file_id)
|
||||
assert mock_pipeline_generator.call_count == 1
|
||||
|
||||
# Verify waiting tasks were processed, pull 1 task a time by default
|
||||
assert mock_delay.call_count == 1
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_ids[0]
|
||||
assert call_kwargs.get("tenant_id") == tenant.id
|
||||
|
||||
# Verify queue still has remaining tasks (only 1 was pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10)
|
||||
assert len(remaining_tasks) == 1 # 2 original - 1 pulled = 1 remaining
|
||||
|
||||
def test_rag_pipeline_run_task_legacy_compatibility(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test regular RAG pipeline run task with legacy Redis queue format for backward compatibility.
|
||||
|
||||
This test simulates the scenario where:
|
||||
- Old code writes file IDs directly to Redis list using lpush
|
||||
- New worker processes these legacy queue entries
|
||||
- Ensures backward compatibility during deployment transition
|
||||
|
||||
Legacy format: redis_client.lpush(tenant_self_pipeline_task_queue, upload_file.id)
|
||||
New format: TenantIsolatedTaskQueue.push_tasks([file_id])
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1)
|
||||
file_content = self._create_file_content_for_entities(entities)
|
||||
|
||||
# Mock file service
|
||||
file_id = str(uuid.uuid4())
|
||||
mock_file_service["get_content"].return_value = file_content
|
||||
|
||||
# Simulate legacy Redis queue format - direct file IDs in Redis list
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
# Legacy queue key format (old code)
|
||||
legacy_queue_key = f"tenant_self_pipeline_task_queue:{tenant.id}"
|
||||
legacy_task_key = f"tenant_pipeline_task:{tenant.id}"
|
||||
|
||||
# Add legacy format data to Redis (simulating old code behavior)
|
||||
legacy_file_ids = [str(uuid.uuid4()) for _ in range(3)]
|
||||
for file_id_legacy in legacy_file_ids:
|
||||
redis_client.lpush(legacy_queue_key, file_id_legacy)
|
||||
|
||||
# Set the task key to indicate there are waiting tasks (legacy behavior)
|
||||
redis_client.set(legacy_task_key, 1, ex=60 * 60)
|
||||
|
||||
# Mock the task function calls
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
|
||||
# Act: Execute the priority task with new code but legacy queue data
|
||||
rag_pipeline_run_task(file_id, tenant.id)
|
||||
|
||||
# Assert: Verify core processing occurred
|
||||
mock_file_service["get_content"].assert_called_once_with(file_id)
|
||||
mock_file_service["delete_file"].assert_called_once_with(file_id)
|
||||
assert mock_pipeline_generator.call_count == 1
|
||||
|
||||
# Verify waiting tasks were processed, pull 1 task a time by default
|
||||
assert mock_delay.call_count == 1
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == legacy_file_ids[0]
|
||||
assert call_kwargs.get("tenant_id") == tenant.id
|
||||
|
||||
# Verify that new code can process legacy queue entries
|
||||
# The new TenantIsolatedTaskQueue should be able to read from the legacy format
|
||||
queue = TenantIsolatedTaskQueue(tenant.id, "pipeline")
|
||||
|
||||
# Verify queue still has remaining tasks (only 1 was pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10)
|
||||
assert len(remaining_tasks) == 2 # 3 original - 1 pulled = 2 remaining
|
||||
|
||||
# Cleanup: Remove legacy test data
|
||||
redis_client.delete(legacy_queue_key)
|
||||
redis_client.delete(legacy_task_key)
|
||||
|
||||
def test_rag_pipeline_run_task_with_waiting_tasks(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test regular RAG pipeline run task with waiting tasks in queue using real Redis.
|
||||
|
||||
This test verifies:
|
||||
- Core task execution
|
||||
- Real Redis-based tenant queue processing of waiting tasks
|
||||
- Task function calls for waiting tasks
|
||||
- Queue management with multiple tasks using actual Redis operations
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1)
|
||||
file_content = self._create_file_content_for_entities(entities)
|
||||
|
||||
# Mock file service
|
||||
file_id = str(uuid.uuid4())
|
||||
mock_file_service["get_content"].return_value = file_content
|
||||
|
||||
# Use real Redis for TenantIsolatedTaskQueue
|
||||
queue = TenantIsolatedTaskQueue(tenant.id, "pipeline")
|
||||
|
||||
# Add waiting tasks to the real Redis queue
|
||||
waiting_file_ids = [str(uuid.uuid4()) for _ in range(3)]
|
||||
queue.push_tasks(waiting_file_ids)
|
||||
|
||||
# Mock the task function calls
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
|
||||
# Act: Execute the regular task
|
||||
rag_pipeline_run_task(file_id, tenant.id)
|
||||
|
||||
# Assert: Verify core processing occurred
|
||||
mock_file_service["get_content"].assert_called_once_with(file_id)
|
||||
mock_file_service["delete_file"].assert_called_once_with(file_id)
|
||||
assert mock_pipeline_generator.call_count == 1
|
||||
|
||||
# Verify waiting tasks were processed, pull 1 task a time by default
|
||||
assert mock_delay.call_count == 1
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_ids[0]
|
||||
assert call_kwargs.get("tenant_id") == tenant.id
|
||||
|
||||
# Verify queue still has remaining tasks (only 1 was pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10)
|
||||
assert len(remaining_tasks) == 2 # 3 original - 1 pulled = 2 remaining
|
||||
|
||||
def test_priority_rag_pipeline_run_task_error_handling(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test error handling in priority RAG pipeline run task using real Redis.
|
||||
|
||||
This test verifies:
|
||||
- Exception handling during core processing
|
||||
- Tenant queue cleanup even on errors using real Redis
|
||||
- Proper error logging
|
||||
- Function completes without raising exceptions
|
||||
- Queue management continues despite core processing errors
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1)
|
||||
file_content = self._create_file_content_for_entities(entities)
|
||||
|
||||
# Mock file service
|
||||
file_id = str(uuid.uuid4())
|
||||
mock_file_service["get_content"].return_value = file_content
|
||||
|
||||
# Mock PipelineGenerator to raise an exception
|
||||
mock_pipeline_generator.side_effect = Exception("Pipeline generation failed")
|
||||
|
||||
# Use real Redis for TenantIsolatedTaskQueue
|
||||
queue = TenantIsolatedTaskQueue(tenant.id, "pipeline")
|
||||
|
||||
# Add waiting task to the real Redis queue
|
||||
waiting_file_id = str(uuid.uuid4())
|
||||
queue.push_tasks([waiting_file_id])
|
||||
|
||||
# Mock the task function calls
|
||||
with patch(
|
||||
"tasks.rag_pipeline.priority_rag_pipeline_run_task.priority_rag_pipeline_run_task.delay"
|
||||
) as mock_delay:
|
||||
# Act: Execute the priority task (should not raise exception)
|
||||
priority_rag_pipeline_run_task(file_id, tenant.id)
|
||||
|
||||
# Assert: Verify error was handled gracefully
|
||||
# The function should not raise exceptions
|
||||
mock_file_service["get_content"].assert_called_once_with(file_id)
|
||||
mock_file_service["delete_file"].assert_called_once_with(file_id)
|
||||
assert mock_pipeline_generator.call_count == 1
|
||||
|
||||
# Verify waiting task was still processed despite core processing error
|
||||
mock_delay.assert_called_once()
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
|
||||
assert call_kwargs.get("tenant_id") == tenant.id
|
||||
|
||||
# Verify queue is empty after processing (task was pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10)
|
||||
assert len(remaining_tasks) == 0
|
||||
|
||||
def test_rag_pipeline_run_task_error_handling(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test error handling in regular RAG pipeline run task using real Redis.
|
||||
|
||||
This test verifies:
|
||||
- Exception handling during core processing
|
||||
- Tenant queue cleanup even on errors using real Redis
|
||||
- Proper error logging
|
||||
- Function completes without raising exceptions
|
||||
- Queue management continues despite core processing errors
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1)
|
||||
file_content = self._create_file_content_for_entities(entities)
|
||||
|
||||
# Mock file service
|
||||
file_id = str(uuid.uuid4())
|
||||
mock_file_service["get_content"].return_value = file_content
|
||||
|
||||
# Mock PipelineGenerator to raise an exception
|
||||
mock_pipeline_generator.side_effect = Exception("Pipeline generation failed")
|
||||
|
||||
# Use real Redis for TenantIsolatedTaskQueue
|
||||
queue = TenantIsolatedTaskQueue(tenant.id, "pipeline")
|
||||
|
||||
# Add waiting task to the real Redis queue
|
||||
waiting_file_id = str(uuid.uuid4())
|
||||
queue.push_tasks([waiting_file_id])
|
||||
|
||||
# Mock the task function calls
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
|
||||
# Act: Execute the regular task (should not raise exception)
|
||||
rag_pipeline_run_task(file_id, tenant.id)
|
||||
|
||||
# Assert: Verify error was handled gracefully
|
||||
# The function should not raise exceptions
|
||||
mock_file_service["get_content"].assert_called_once_with(file_id)
|
||||
mock_file_service["delete_file"].assert_called_once_with(file_id)
|
||||
assert mock_pipeline_generator.call_count == 1
|
||||
|
||||
# Verify waiting task was still processed despite core processing error
|
||||
mock_delay.assert_called_once()
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
|
||||
assert call_kwargs.get("tenant_id") == tenant.id
|
||||
|
||||
# Verify queue is empty after processing (task was pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10)
|
||||
assert len(remaining_tasks) == 0
|
||||
|
||||
def test_priority_rag_pipeline_run_task_tenant_isolation(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test tenant isolation in priority RAG pipeline run task using real Redis.
|
||||
|
||||
This test verifies:
|
||||
- Different tenants have isolated queues
|
||||
- Tasks from one tenant don't affect another tenant's queue
|
||||
- Queue operations are properly scoped to tenant
|
||||
"""
|
||||
# Arrange: Create test data for two different tenants
|
||||
account1, tenant1, pipeline1, workflow1 = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
account2, tenant2, pipeline2, workflow2 = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
|
||||
entities1 = self._create_rag_pipeline_invoke_entities(account1, tenant1, pipeline1, workflow1, count=1)
|
||||
entities2 = self._create_rag_pipeline_invoke_entities(account2, tenant2, pipeline2, workflow2, count=1)
|
||||
|
||||
file_content1 = self._create_file_content_for_entities(entities1)
|
||||
file_content2 = self._create_file_content_for_entities(entities2)
|
||||
|
||||
# Mock file service
|
||||
file_id1 = str(uuid.uuid4())
|
||||
file_id2 = str(uuid.uuid4())
|
||||
mock_file_service["get_content"].side_effect = [file_content1, file_content2]
|
||||
|
||||
# Use real Redis for TenantIsolatedTaskQueue
|
||||
queue1 = TenantIsolatedTaskQueue(tenant1.id, "pipeline")
|
||||
queue2 = TenantIsolatedTaskQueue(tenant2.id, "pipeline")
|
||||
|
||||
# Add waiting tasks to both queues
|
||||
waiting_file_id1 = str(uuid.uuid4())
|
||||
waiting_file_id2 = str(uuid.uuid4())
|
||||
|
||||
queue1.push_tasks([waiting_file_id1])
|
||||
queue2.push_tasks([waiting_file_id2])
|
||||
|
||||
# Mock the task function calls
|
||||
with patch(
|
||||
"tasks.rag_pipeline.priority_rag_pipeline_run_task.priority_rag_pipeline_run_task.delay"
|
||||
) as mock_delay:
|
||||
# Act: Execute the priority task for tenant1 only
|
||||
priority_rag_pipeline_run_task(file_id1, tenant1.id)
|
||||
|
||||
# Assert: Verify core processing occurred for tenant1
|
||||
assert mock_file_service["get_content"].call_count == 1
|
||||
assert mock_file_service["delete_file"].call_count == 1
|
||||
assert mock_pipeline_generator.call_count == 1
|
||||
|
||||
# Verify only tenant1's waiting task was processed
|
||||
mock_delay.assert_called_once()
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id1
|
||||
assert call_kwargs.get("tenant_id") == tenant1.id
|
||||
|
||||
# Verify tenant1's queue is empty
|
||||
remaining_tasks1 = queue1.pull_tasks(count=10)
|
||||
assert len(remaining_tasks1) == 0
|
||||
|
||||
# Verify tenant2's queue still has its task (isolation)
|
||||
remaining_tasks2 = queue2.pull_tasks(count=10)
|
||||
assert len(remaining_tasks2) == 1
|
||||
|
||||
# Verify queue keys are different
|
||||
assert queue1._queue != queue2._queue
|
||||
assert queue1._task_key != queue2._task_key
|
||||
|
||||
def test_rag_pipeline_run_task_tenant_isolation(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test tenant isolation in regular RAG pipeline run task using real Redis.
|
||||
|
||||
This test verifies:
|
||||
- Different tenants have isolated queues
|
||||
- Tasks from one tenant don't affect another tenant's queue
|
||||
- Queue operations are properly scoped to tenant
|
||||
"""
|
||||
# Arrange: Create test data for two different tenants
|
||||
account1, tenant1, pipeline1, workflow1 = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
account2, tenant2, pipeline2, workflow2 = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
|
||||
entities1 = self._create_rag_pipeline_invoke_entities(account1, tenant1, pipeline1, workflow1, count=1)
|
||||
entities2 = self._create_rag_pipeline_invoke_entities(account2, tenant2, pipeline2, workflow2, count=1)
|
||||
|
||||
file_content1 = self._create_file_content_for_entities(entities1)
|
||||
file_content2 = self._create_file_content_for_entities(entities2)
|
||||
|
||||
# Mock file service
|
||||
file_id1 = str(uuid.uuid4())
|
||||
file_id2 = str(uuid.uuid4())
|
||||
mock_file_service["get_content"].side_effect = [file_content1, file_content2]
|
||||
|
||||
# Use real Redis for TenantIsolatedTaskQueue
|
||||
queue1 = TenantIsolatedTaskQueue(tenant1.id, "pipeline")
|
||||
queue2 = TenantIsolatedTaskQueue(tenant2.id, "pipeline")
|
||||
|
||||
# Add waiting tasks to both queues
|
||||
waiting_file_id1 = str(uuid.uuid4())
|
||||
waiting_file_id2 = str(uuid.uuid4())
|
||||
|
||||
queue1.push_tasks([waiting_file_id1])
|
||||
queue2.push_tasks([waiting_file_id2])
|
||||
|
||||
# Mock the task function calls
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
|
||||
# Act: Execute the regular task for tenant1 only
|
||||
rag_pipeline_run_task(file_id1, tenant1.id)
|
||||
|
||||
# Assert: Verify core processing occurred for tenant1
|
||||
assert mock_file_service["get_content"].call_count == 1
|
||||
assert mock_file_service["delete_file"].call_count == 1
|
||||
assert mock_pipeline_generator.call_count == 1
|
||||
|
||||
# Verify only tenant1's waiting task was processed
|
||||
mock_delay.assert_called_once()
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id1
|
||||
assert call_kwargs.get("tenant_id") == tenant1.id
|
||||
|
||||
# Verify tenant1's queue is empty
|
||||
remaining_tasks1 = queue1.pull_tasks(count=10)
|
||||
assert len(remaining_tasks1) == 0
|
||||
|
||||
# Verify tenant2's queue still has its task (isolation)
|
||||
remaining_tasks2 = queue2.pull_tasks(count=10)
|
||||
assert len(remaining_tasks2) == 1
|
||||
|
||||
# Verify queue keys are different
|
||||
assert queue1._queue != queue2._queue
|
||||
assert queue1._task_key != queue2._task_key
|
||||
|
||||
def test_run_single_rag_pipeline_task_success(
|
||||
self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers
|
||||
):
|
||||
"""
|
||||
Test successful run_single_rag_pipeline_task execution.
|
||||
|
||||
This test verifies:
|
||||
- Single RAG pipeline task execution within Flask app context
|
||||
- Entity validation and database queries
|
||||
- PipelineGenerator._generate method call with correct parameters
|
||||
- Proper Flask context handling
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1)
|
||||
entity_data = entities[0].model_dump()
|
||||
|
||||
# Act: Execute the single task
|
||||
with flask_app_with_containers.app_context():
|
||||
run_single_rag_pipeline_task(entity_data, flask_app_with_containers)
|
||||
|
||||
# Assert: Verify expected outcomes
|
||||
# Verify PipelineGenerator._generate was called
|
||||
assert mock_pipeline_generator.call_count == 1
|
||||
|
||||
# Verify call parameters
|
||||
call = mock_pipeline_generator.call_args
|
||||
call_kwargs = call[1] # Get keyword arguments
|
||||
assert call_kwargs["pipeline"].id == pipeline.id
|
||||
assert call_kwargs["workflow_id"] == workflow.id
|
||||
assert call_kwargs["user"].id == account.id
|
||||
assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED
|
||||
assert call_kwargs["streaming"] == False
|
||||
assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity)
|
||||
|
||||
def test_run_single_rag_pipeline_task_entity_validation_error(
|
||||
self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers
|
||||
):
|
||||
"""
|
||||
Test run_single_rag_pipeline_task with invalid entity data.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for invalid entity data
|
||||
- Exception logging
|
||||
- Function raises ValueError for missing entities
|
||||
"""
|
||||
# Arrange: Create entity data with valid UUIDs but non-existent entities
|
||||
fake = Faker()
|
||||
invalid_entity_data = {
|
||||
"pipeline_id": str(uuid.uuid4()),
|
||||
"application_generate_entity": {
|
||||
"app_config": {
|
||||
"app_id": str(uuid.uuid4()),
|
||||
"app_name": "Test App",
|
||||
"mode": "workflow",
|
||||
"workflow_id": str(uuid.uuid4()),
|
||||
},
|
||||
"inputs": {"query": "Test query"},
|
||||
"query": "Test query",
|
||||
"response_mode": "blocking",
|
||||
"user": str(uuid.uuid4()),
|
||||
"files": [],
|
||||
"conversation_id": str(uuid.uuid4()),
|
||||
},
|
||||
"user_id": str(uuid.uuid4()),
|
||||
"tenant_id": str(uuid.uuid4()),
|
||||
"workflow_id": str(uuid.uuid4()),
|
||||
"streaming": False,
|
||||
"workflow_execution_id": str(uuid.uuid4()),
|
||||
"workflow_thread_pool_id": str(uuid.uuid4()),
|
||||
}
|
||||
|
||||
# Act & Assert: Execute the single task with non-existent entities (should raise ValueError)
|
||||
with flask_app_with_containers.app_context():
|
||||
with pytest.raises(ValueError, match="Account .* not found"):
|
||||
run_single_rag_pipeline_task(invalid_entity_data, flask_app_with_containers)
|
||||
|
||||
# Assert: Pipeline generator should not be called
|
||||
mock_pipeline_generator.assert_not_called()
|
||||
|
||||
def test_run_single_rag_pipeline_task_database_entity_not_found(
|
||||
self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers
|
||||
):
|
||||
"""
|
||||
Test run_single_rag_pipeline_task with non-existent database entities.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for missing database entities
|
||||
- Exception logging
|
||||
- Function raises ValueError for missing entities
|
||||
"""
|
||||
# Arrange: Create test data with non-existent IDs
|
||||
fake = Faker()
|
||||
entity_data = {
|
||||
"pipeline_id": str(uuid.uuid4()),
|
||||
"application_generate_entity": {
|
||||
"app_config": {
|
||||
"app_id": str(uuid.uuid4()),
|
||||
"app_name": "Test App",
|
||||
"mode": "workflow",
|
||||
"workflow_id": str(uuid.uuid4()),
|
||||
},
|
||||
"inputs": {"query": "Test query"},
|
||||
"query": "Test query",
|
||||
"response_mode": "blocking",
|
||||
"user": str(uuid.uuid4()),
|
||||
"files": [],
|
||||
"conversation_id": str(uuid.uuid4()),
|
||||
},
|
||||
"user_id": str(uuid.uuid4()),
|
||||
"tenant_id": str(uuid.uuid4()),
|
||||
"workflow_id": str(uuid.uuid4()),
|
||||
"streaming": False,
|
||||
"workflow_execution_id": str(uuid.uuid4()),
|
||||
"workflow_thread_pool_id": str(uuid.uuid4()),
|
||||
}
|
||||
|
||||
# Act & Assert: Execute the single task with non-existent entities (should raise ValueError)
|
||||
with flask_app_with_containers.app_context():
|
||||
with pytest.raises(ValueError, match="Account .* not found"):
|
||||
run_single_rag_pipeline_task(entity_data, flask_app_with_containers)
|
||||
|
||||
# Assert: Pipeline generator should not be called
|
||||
mock_pipeline_generator.assert_not_called()
|
||||
|
||||
def test_priority_rag_pipeline_run_task_file_not_found(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test priority RAG pipeline run task with non-existent file.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for missing files
|
||||
- Exception logging
|
||||
- Function raises Exception for file errors
|
||||
- Queue management continues despite file errors
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
|
||||
# Mock file service to raise exception
|
||||
file_id = str(uuid.uuid4())
|
||||
mock_file_service["get_content"].side_effect = Exception("File not found")
|
||||
|
||||
# Use real Redis for TenantIsolatedTaskQueue
|
||||
queue = TenantIsolatedTaskQueue(tenant.id, "pipeline")
|
||||
|
||||
# Add waiting task to the real Redis queue
|
||||
waiting_file_id = str(uuid.uuid4())
|
||||
queue.push_tasks([waiting_file_id])
|
||||
|
||||
# Mock the task function calls
|
||||
with patch(
|
||||
"tasks.rag_pipeline.priority_rag_pipeline_run_task.priority_rag_pipeline_run_task.delay"
|
||||
) as mock_delay:
|
||||
# Act & Assert: Execute the priority task (should raise Exception)
|
||||
with pytest.raises(Exception, match="File not found"):
|
||||
priority_rag_pipeline_run_task(file_id, tenant.id)
|
||||
|
||||
# Assert: Verify error was handled gracefully
|
||||
mock_file_service["get_content"].assert_called_once_with(file_id)
|
||||
mock_pipeline_generator.assert_not_called()
|
||||
|
||||
# Verify waiting task was still processed despite file error
|
||||
mock_delay.assert_called_once()
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
|
||||
assert call_kwargs.get("tenant_id") == tenant.id
|
||||
|
||||
# Verify queue is empty after processing (task was pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10)
|
||||
assert len(remaining_tasks) == 0
|
||||
|
||||
def test_rag_pipeline_run_task_file_not_found(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test regular RAG pipeline run task with non-existent file.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for missing files
|
||||
- Exception logging
|
||||
- Function raises Exception for file errors
|
||||
- Queue management continues despite file errors
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
|
||||
# Mock file service to raise exception
|
||||
file_id = str(uuid.uuid4())
|
||||
mock_file_service["get_content"].side_effect = Exception("File not found")
|
||||
|
||||
# Use real Redis for TenantIsolatedTaskQueue
|
||||
queue = TenantIsolatedTaskQueue(tenant.id, "pipeline")
|
||||
|
||||
# Add waiting task to the real Redis queue
|
||||
waiting_file_id = str(uuid.uuid4())
|
||||
queue.push_tasks([waiting_file_id])
|
||||
|
||||
# Mock the task function calls
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
|
||||
# Act & Assert: Execute the regular task (should raise Exception)
|
||||
with pytest.raises(Exception, match="File not found"):
|
||||
rag_pipeline_run_task(file_id, tenant.id)
|
||||
|
||||
# Assert: Verify error was handled gracefully
|
||||
mock_file_service["get_content"].assert_called_once_with(file_id)
|
||||
mock_pipeline_generator.assert_not_called()
|
||||
|
||||
# Verify waiting task was still processed despite file error
|
||||
mock_delay.assert_called_once()
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
|
||||
assert call_kwargs.get("tenant_id") == tenant.id
|
||||
|
||||
# Verify queue is empty after processing (task was pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10)
|
||||
assert len(remaining_tasks) == 0
|
||||
|
|
@ -0,0 +1,301 @@
|
|||
"""
|
||||
Unit tests for TenantIsolatedTaskQueue.
|
||||
|
||||
These tests verify the Redis-based task queue functionality for tenant-specific
|
||||
task management with proper serialization and deserialization.
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.rag.pipeline.queue import TaskWrapper, TenantIsolatedTaskQueue
|
||||
|
||||
|
||||
class TestTaskWrapper:
|
||||
"""Test cases for TaskWrapper serialization/deserialization."""
|
||||
|
||||
def test_serialize_simple_data(self):
|
||||
"""Test serialization of simple data types."""
|
||||
data = {"key": "value", "number": 42, "list": [1, 2, 3]}
|
||||
wrapper = TaskWrapper(data=data)
|
||||
|
||||
serialized = wrapper.serialize()
|
||||
assert isinstance(serialized, str)
|
||||
|
||||
# Verify it's valid JSON
|
||||
parsed = json.loads(serialized)
|
||||
assert parsed["data"] == data
|
||||
|
||||
def test_serialize_complex_data(self):
|
||||
"""Test serialization of complex nested data."""
|
||||
data = {
|
||||
"nested": {"deep": {"value": "test", "numbers": [1, 2, 3, 4, 5]}},
|
||||
"unicode": "测试中文",
|
||||
"special_chars": "!@#$%^&*()",
|
||||
}
|
||||
wrapper = TaskWrapper(data=data)
|
||||
|
||||
serialized = wrapper.serialize()
|
||||
parsed = json.loads(serialized)
|
||||
assert parsed["data"] == data
|
||||
|
||||
def test_deserialize_valid_data(self):
|
||||
"""Test deserialization of valid JSON data."""
|
||||
original_data = {"key": "value", "number": 42}
|
||||
# Serialize using TaskWrapper to get the correct format
|
||||
wrapper = TaskWrapper(data=original_data)
|
||||
serialized = wrapper.serialize()
|
||||
|
||||
wrapper = TaskWrapper.deserialize(serialized)
|
||||
assert wrapper.data == original_data
|
||||
|
||||
def test_deserialize_invalid_json(self):
|
||||
"""Test deserialization handles invalid JSON gracefully."""
|
||||
invalid_json = "{invalid json}"
|
||||
|
||||
# Pydantic will raise ValidationError for invalid JSON
|
||||
with pytest.raises(ValidationError):
|
||||
TaskWrapper.deserialize(invalid_json)
|
||||
|
||||
def test_serialize_ensure_ascii_false(self):
|
||||
"""Test that serialization preserves Unicode characters."""
|
||||
data = {"chinese": "中文测试", "emoji": "🚀"}
|
||||
wrapper = TaskWrapper(data=data)
|
||||
|
||||
serialized = wrapper.serialize()
|
||||
assert "中文测试" in serialized
|
||||
assert "🚀" in serialized
|
||||
|
||||
|
||||
class TestTenantIsolatedTaskQueue:
|
||||
"""Test cases for TenantIsolatedTaskQueue functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis_client(self):
|
||||
"""Mock Redis client for testing."""
|
||||
mock_redis = MagicMock()
|
||||
return mock_redis
|
||||
|
||||
@pytest.fixture
|
||||
def sample_queue(self, mock_redis_client):
|
||||
"""Create a sample TenantIsolatedTaskQueue instance."""
|
||||
return TenantIsolatedTaskQueue("tenant-123", "test-key")
|
||||
|
||||
def test_initialization(self, sample_queue):
|
||||
"""Test queue initialization with correct key generation."""
|
||||
assert sample_queue._tenant_id == "tenant-123"
|
||||
assert sample_queue._unique_key == "test-key"
|
||||
assert sample_queue._queue == "tenant_self_test-key_task_queue:tenant-123"
|
||||
assert sample_queue._task_key == "tenant_test-key_task:tenant-123"
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_get_task_key_exists(self, mock_redis, sample_queue):
|
||||
"""Test getting task key when it exists."""
|
||||
mock_redis.get.return_value = "1"
|
||||
|
||||
result = sample_queue.get_task_key()
|
||||
|
||||
assert result == "1"
|
||||
mock_redis.get.assert_called_once_with("tenant_test-key_task:tenant-123")
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_get_task_key_not_exists(self, mock_redis, sample_queue):
|
||||
"""Test getting task key when it doesn't exist."""
|
||||
mock_redis.get.return_value = None
|
||||
|
||||
result = sample_queue.get_task_key()
|
||||
|
||||
assert result is None
|
||||
mock_redis.get.assert_called_once_with("tenant_test-key_task:tenant-123")
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_set_task_waiting_time_default_ttl(self, mock_redis, sample_queue):
|
||||
"""Test setting task waiting flag with default TTL."""
|
||||
sample_queue.set_task_waiting_time()
|
||||
|
||||
mock_redis.setex.assert_called_once_with(
|
||||
"tenant_test-key_task:tenant-123",
|
||||
3600, # DEFAULT_TASK_TTL
|
||||
1,
|
||||
)
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_set_task_waiting_time_custom_ttl(self, mock_redis, sample_queue):
|
||||
"""Test setting task waiting flag with custom TTL."""
|
||||
custom_ttl = 1800
|
||||
sample_queue.set_task_waiting_time(custom_ttl)
|
||||
|
||||
mock_redis.setex.assert_called_once_with("tenant_test-key_task:tenant-123", custom_ttl, 1)
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_delete_task_key(self, mock_redis, sample_queue):
|
||||
"""Test deleting task key."""
|
||||
sample_queue.delete_task_key()
|
||||
|
||||
mock_redis.delete.assert_called_once_with("tenant_test-key_task:tenant-123")
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_push_tasks_string_list(self, mock_redis, sample_queue):
|
||||
"""Test pushing string tasks directly."""
|
||||
tasks = ["task1", "task2", "task3"]
|
||||
|
||||
sample_queue.push_tasks(tasks)
|
||||
|
||||
mock_redis.lpush.assert_called_once_with(
|
||||
"tenant_self_test-key_task_queue:tenant-123", "task1", "task2", "task3"
|
||||
)
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_push_tasks_mixed_types(self, mock_redis, sample_queue):
|
||||
"""Test pushing mixed string and object tasks."""
|
||||
tasks = ["string_task", {"object_task": "data", "id": 123}, "another_string"]
|
||||
|
||||
sample_queue.push_tasks(tasks)
|
||||
|
||||
# Verify lpush was called
|
||||
mock_redis.lpush.assert_called_once()
|
||||
call_args = mock_redis.lpush.call_args
|
||||
|
||||
# Check queue name
|
||||
assert call_args[0][0] == "tenant_self_test-key_task_queue:tenant-123"
|
||||
|
||||
# Check serialized tasks
|
||||
serialized_tasks = call_args[0][1:]
|
||||
assert len(serialized_tasks) == 3
|
||||
assert serialized_tasks[0] == "string_task"
|
||||
assert serialized_tasks[2] == "another_string"
|
||||
|
||||
# Check object task is serialized as TaskWrapper JSON (without prefix)
|
||||
# It should be a valid JSON string that can be deserialized by TaskWrapper
|
||||
wrapper = TaskWrapper.deserialize(serialized_tasks[1])
|
||||
assert wrapper.data == {"object_task": "data", "id": 123}
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_push_tasks_empty_list(self, mock_redis, sample_queue):
|
||||
"""Test pushing empty task list."""
|
||||
sample_queue.push_tasks([])
|
||||
|
||||
mock_redis.lpush.assert_called_once_with("tenant_self_test-key_task_queue:tenant-123")
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_pull_tasks_default_count(self, mock_redis, sample_queue):
|
||||
"""Test pulling tasks with default count (1)."""
|
||||
mock_redis.rpop.side_effect = ["task1", None]
|
||||
|
||||
result = sample_queue.pull_tasks()
|
||||
|
||||
assert result == ["task1"]
|
||||
assert mock_redis.rpop.call_count == 1
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_pull_tasks_custom_count(self, mock_redis, sample_queue):
|
||||
"""Test pulling tasks with custom count."""
|
||||
# First test: pull 3 tasks
|
||||
mock_redis.rpop.side_effect = ["task1", "task2", "task3", None]
|
||||
|
||||
result = sample_queue.pull_tasks(3)
|
||||
|
||||
assert result == ["task1", "task2", "task3"]
|
||||
assert mock_redis.rpop.call_count == 3
|
||||
|
||||
# Reset mock for second test
|
||||
mock_redis.reset_mock()
|
||||
mock_redis.rpop.side_effect = ["task1", "task2", None]
|
||||
|
||||
result = sample_queue.pull_tasks(3)
|
||||
|
||||
assert result == ["task1", "task2"]
|
||||
assert mock_redis.rpop.call_count == 3
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_pull_tasks_zero_count(self, mock_redis, sample_queue):
|
||||
"""Test pulling tasks with zero count returns empty list."""
|
||||
result = sample_queue.pull_tasks(0)
|
||||
|
||||
assert result == []
|
||||
mock_redis.rpop.assert_not_called()
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_pull_tasks_negative_count(self, mock_redis, sample_queue):
|
||||
"""Test pulling tasks with negative count returns empty list."""
|
||||
result = sample_queue.pull_tasks(-1)
|
||||
|
||||
assert result == []
|
||||
mock_redis.rpop.assert_not_called()
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_pull_tasks_with_wrapped_objects(self, mock_redis, sample_queue):
|
||||
"""Test pulling tasks that include wrapped objects."""
|
||||
# Create a wrapped task
|
||||
task_data = {"task_id": 123, "data": "test"}
|
||||
wrapper = TaskWrapper(data=task_data)
|
||||
wrapped_task = wrapper.serialize()
|
||||
|
||||
mock_redis.rpop.side_effect = [
|
||||
"string_task",
|
||||
wrapped_task.encode("utf-8"), # Simulate bytes from Redis
|
||||
None,
|
||||
]
|
||||
|
||||
result = sample_queue.pull_tasks(2)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0] == "string_task"
|
||||
assert result[1] == {"task_id": 123, "data": "test"}
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_pull_tasks_with_invalid_wrapped_data(self, mock_redis, sample_queue):
|
||||
"""Test pulling tasks with invalid JSON falls back to string."""
|
||||
# Invalid JSON string that cannot be deserialized
|
||||
invalid_json = "invalid json data"
|
||||
mock_redis.rpop.side_effect = [invalid_json, None]
|
||||
|
||||
result = sample_queue.pull_tasks(1)
|
||||
|
||||
assert result == [invalid_json]
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_pull_tasks_bytes_decoding(self, mock_redis, sample_queue):
|
||||
"""Test pulling tasks handles bytes from Redis correctly."""
|
||||
mock_redis.rpop.side_effect = [
|
||||
b"task1", # bytes
|
||||
"task2", # string
|
||||
None,
|
||||
]
|
||||
|
||||
result = sample_queue.pull_tasks(2)
|
||||
|
||||
assert result == ["task1", "task2"]
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_complex_object_serialization_roundtrip(self, mock_redis, sample_queue):
|
||||
"""Test complex object serialization and deserialization roundtrip."""
|
||||
complex_task = {
|
||||
"id": uuid4().hex,
|
||||
"data": {"nested": {"deep": [1, 2, 3], "unicode": "测试中文", "special": "!@#$%^&*()"}},
|
||||
"metadata": {"created_at": "2024-01-01T00:00:00Z", "tags": ["tag1", "tag2", "tag3"]},
|
||||
}
|
||||
|
||||
# Push the complex task
|
||||
sample_queue.push_tasks([complex_task])
|
||||
|
||||
# Verify it was serialized as TaskWrapper JSON
|
||||
call_args = mock_redis.lpush.call_args
|
||||
wrapped_task = call_args[0][1]
|
||||
# Verify it's a valid TaskWrapper JSON (starts with {"data":)
|
||||
assert wrapped_task.startswith('{"data":')
|
||||
|
||||
# Verify it can be deserialized
|
||||
wrapper = TaskWrapper.deserialize(wrapped_task)
|
||||
assert wrapper.data == complex_task
|
||||
|
||||
# Simulate pulling it back
|
||||
mock_redis.rpop.return_value = wrapped_task
|
||||
result = sample_queue.pull_tasks(1)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0] == complex_task
|
||||
|
|
@ -111,3 +111,26 @@ class TestVariablePoolGetAndNestedAttribute:
|
|||
assert segment_false is not None
|
||||
assert isinstance(segment_false, BooleanSegment)
|
||||
assert segment_false.value is False
|
||||
|
||||
|
||||
class TestVariablePoolGetNotModifyVariableDictionary:
|
||||
_NODE_ID = "start"
|
||||
_VAR_NAME = "name"
|
||||
|
||||
def test_convert_to_template_should_not_introduce_extra_keys(self):
|
||||
pool = VariablePool.empty()
|
||||
pool.add([self._NODE_ID, self._VAR_NAME], 0)
|
||||
pool.convert_template("The start.name is {{#start.name#}}")
|
||||
assert "The start" not in pool.variable_dictionary
|
||||
|
||||
def test_get_should_not_modify_variable_dictionary(self):
|
||||
pool = VariablePool.empty()
|
||||
pool.get([self._NODE_ID, self._VAR_NAME])
|
||||
assert len(pool.variable_dictionary) == 1 # only contains `sys` node id
|
||||
assert "start" not in pool.variable_dictionary
|
||||
|
||||
pool = VariablePool.empty()
|
||||
pool.add([self._NODE_ID, self._VAR_NAME], "Joe")
|
||||
pool.get([self._NODE_ID, "count"])
|
||||
start_subdict = pool.variable_dictionary[self._NODE_ID]
|
||||
assert "count" not in start_subdict
|
||||
|
|
|
|||
|
|
@ -0,0 +1,514 @@
|
|||
"""
|
||||
Comprehensive unit tests for Redis broadcast channel implementation.
|
||||
|
||||
This test suite covers all aspects of the Redis broadcast channel including:
|
||||
- Basic functionality and contract compliance
|
||||
- Error handling and edge cases
|
||||
- Thread safety and concurrency
|
||||
- Resource management and cleanup
|
||||
- Performance and reliability scenarios
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from libs.broadcast_channel.exc import BroadcastChannelError, SubscriptionClosedError
|
||||
from libs.broadcast_channel.redis.channel import (
|
||||
BroadcastChannel as RedisBroadcastChannel,
|
||||
)
|
||||
from libs.broadcast_channel.redis.channel import (
|
||||
Topic,
|
||||
_RedisSubscription,
|
||||
)
|
||||
|
||||
|
||||
class TestBroadcastChannel:
|
||||
"""Test cases for the main BroadcastChannel class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis_client(self) -> MagicMock:
|
||||
"""Create a mock Redis client for testing."""
|
||||
client = MagicMock()
|
||||
client.pubsub.return_value = MagicMock()
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def broadcast_channel(self, mock_redis_client: MagicMock) -> RedisBroadcastChannel:
|
||||
"""Create a BroadcastChannel instance with mock Redis client."""
|
||||
return RedisBroadcastChannel(mock_redis_client)
|
||||
|
||||
def test_topic_creation(self, broadcast_channel: RedisBroadcastChannel, mock_redis_client: MagicMock):
|
||||
"""Test that topic() method returns a Topic instance with correct parameters."""
|
||||
topic_name = "test-topic"
|
||||
topic = broadcast_channel.topic(topic_name)
|
||||
|
||||
assert isinstance(topic, Topic)
|
||||
assert topic._client == mock_redis_client
|
||||
assert topic._topic == topic_name
|
||||
|
||||
def test_topic_isolation(self, broadcast_channel: RedisBroadcastChannel):
|
||||
"""Test that different topic names create isolated Topic instances."""
|
||||
topic1 = broadcast_channel.topic("topic1")
|
||||
topic2 = broadcast_channel.topic("topic2")
|
||||
|
||||
assert topic1 is not topic2
|
||||
assert topic1._topic == "topic1"
|
||||
assert topic2._topic == "topic2"
|
||||
|
||||
|
||||
class TestTopic:
|
||||
"""Test cases for the Topic class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis_client(self) -> MagicMock:
|
||||
"""Create a mock Redis client for testing."""
|
||||
client = MagicMock()
|
||||
client.pubsub.return_value = MagicMock()
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def topic(self, mock_redis_client: MagicMock) -> Topic:
|
||||
"""Create a Topic instance for testing."""
|
||||
return Topic(mock_redis_client, "test-topic")
|
||||
|
||||
def test_as_producer_returns_self(self, topic: Topic):
|
||||
"""Test that as_producer() returns self as Producer interface."""
|
||||
producer = topic.as_producer()
|
||||
assert producer is topic
|
||||
# Producer is a Protocol, check duck typing instead
|
||||
assert hasattr(producer, "publish")
|
||||
|
||||
def test_as_subscriber_returns_self(self, topic: Topic):
|
||||
"""Test that as_subscriber() returns self as Subscriber interface."""
|
||||
subscriber = topic.as_subscriber()
|
||||
assert subscriber is topic
|
||||
# Subscriber is a Protocol, check duck typing instead
|
||||
assert hasattr(subscriber, "subscribe")
|
||||
|
||||
def test_publish_calls_redis_publish(self, topic: Topic, mock_redis_client: MagicMock):
|
||||
"""Test that publish() calls Redis PUBLISH with correct parameters."""
|
||||
payload = b"test message"
|
||||
topic.publish(payload)
|
||||
|
||||
mock_redis_client.publish.assert_called_once_with("test-topic", payload)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class SubscriptionTestCase:
|
||||
"""Test case data for subscription tests."""
|
||||
|
||||
name: str
|
||||
buffer_size: int
|
||||
payload: bytes
|
||||
expected_messages: list[bytes]
|
||||
should_drop: bool = False
|
||||
description: str = ""
|
||||
|
||||
|
||||
class TestRedisSubscription:
|
||||
"""Test cases for the _RedisSubscription class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pubsub(self) -> MagicMock:
|
||||
"""Create a mock PubSub instance for testing."""
|
||||
pubsub = MagicMock()
|
||||
pubsub.subscribe = MagicMock()
|
||||
pubsub.unsubscribe = MagicMock()
|
||||
pubsub.close = MagicMock()
|
||||
pubsub.get_message = MagicMock()
|
||||
return pubsub
|
||||
|
||||
@pytest.fixture
|
||||
def subscription(self, mock_pubsub: MagicMock) -> Generator[_RedisSubscription, None, None]:
|
||||
"""Create a _RedisSubscription instance for testing."""
|
||||
subscription = _RedisSubscription(
|
||||
pubsub=mock_pubsub,
|
||||
topic="test-topic",
|
||||
)
|
||||
yield subscription
|
||||
subscription.close()
|
||||
|
||||
@pytest.fixture
|
||||
def started_subscription(self, subscription: _RedisSubscription) -> _RedisSubscription:
|
||||
"""Create a subscription that has been started."""
|
||||
subscription._start_if_needed()
|
||||
return subscription
|
||||
|
||||
# ==================== Lifecycle Tests ====================
|
||||
|
||||
def test_subscription_initialization(self, mock_pubsub: MagicMock):
|
||||
"""Test that subscription is properly initialized."""
|
||||
subscription = _RedisSubscription(
|
||||
pubsub=mock_pubsub,
|
||||
topic="test-topic",
|
||||
)
|
||||
|
||||
assert subscription._pubsub is mock_pubsub
|
||||
assert subscription._topic == "test-topic"
|
||||
assert not subscription._closed.is_set()
|
||||
assert subscription._dropped_count == 0
|
||||
assert subscription._listener_thread is None
|
||||
assert not subscription._started
|
||||
|
||||
def test_start_if_needed_first_call(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
|
||||
"""Test that _start_if_needed() properly starts subscription on first call."""
|
||||
subscription._start_if_needed()
|
||||
|
||||
mock_pubsub.subscribe.assert_called_once_with("test-topic")
|
||||
assert subscription._started is True
|
||||
assert subscription._listener_thread is not None
|
||||
|
||||
def test_start_if_needed_subsequent_calls(self, started_subscription: _RedisSubscription):
|
||||
"""Test that _start_if_needed() doesn't start subscription on subsequent calls."""
|
||||
original_thread = started_subscription._listener_thread
|
||||
started_subscription._start_if_needed()
|
||||
|
||||
# Should not create new thread or generator
|
||||
assert started_subscription._listener_thread is original_thread
|
||||
|
||||
def test_start_if_needed_when_closed(self, subscription: _RedisSubscription):
|
||||
"""Test that _start_if_needed() raises error when subscription is closed."""
|
||||
subscription.close()
|
||||
|
||||
with pytest.raises(SubscriptionClosedError, match="The Redis subscription is closed"):
|
||||
subscription._start_if_needed()
|
||||
|
||||
def test_start_if_needed_when_cleaned_up(self, subscription: _RedisSubscription):
|
||||
"""Test that _start_if_needed() raises error when pubsub is None."""
|
||||
subscription._pubsub = None
|
||||
|
||||
with pytest.raises(SubscriptionClosedError, match="The Redis subscription has been cleaned up"):
|
||||
subscription._start_if_needed()
|
||||
|
||||
def test_context_manager_usage(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
|
||||
"""Test that subscription works as context manager."""
|
||||
with subscription as sub:
|
||||
assert sub is subscription
|
||||
assert subscription._started is True
|
||||
mock_pubsub.subscribe.assert_called_once_with("test-topic")
|
||||
|
||||
def test_close_idempotent(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
|
||||
"""Test that close() is idempotent and can be called multiple times."""
|
||||
subscription._start_if_needed()
|
||||
|
||||
# Close multiple times
|
||||
subscription.close()
|
||||
subscription.close()
|
||||
subscription.close()
|
||||
|
||||
# Should only cleanup once
|
||||
mock_pubsub.unsubscribe.assert_called_once_with("test-topic")
|
||||
mock_pubsub.close.assert_called_once()
|
||||
assert subscription._pubsub is None
|
||||
assert subscription._closed.is_set()
|
||||
|
||||
def test_close_cleanup(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
|
||||
"""Test that close() properly cleans up all resources."""
|
||||
subscription._start_if_needed()
|
||||
thread = subscription._listener_thread
|
||||
|
||||
subscription.close()
|
||||
|
||||
# Verify cleanup
|
||||
mock_pubsub.unsubscribe.assert_called_once_with("test-topic")
|
||||
mock_pubsub.close.assert_called_once()
|
||||
assert subscription._pubsub is None
|
||||
assert subscription._listener_thread is None
|
||||
|
||||
# Wait for thread to finish (with timeout)
|
||||
if thread and thread.is_alive():
|
||||
thread.join(timeout=1.0)
|
||||
assert not thread.is_alive()
|
||||
|
||||
# ==================== Message Processing Tests ====================
|
||||
|
||||
def test_message_iterator_with_messages(self, started_subscription: _RedisSubscription):
|
||||
"""Test message iterator behavior with messages in queue."""
|
||||
test_messages = [b"msg1", b"msg2", b"msg3"]
|
||||
|
||||
# Add messages to queue
|
||||
for msg in test_messages:
|
||||
started_subscription._queue.put_nowait(msg)
|
||||
|
||||
# Iterate through messages
|
||||
iterator = iter(started_subscription)
|
||||
received_messages = []
|
||||
|
||||
for msg in iterator:
|
||||
received_messages.append(msg)
|
||||
if len(received_messages) >= len(test_messages):
|
||||
break
|
||||
|
||||
assert received_messages == test_messages
|
||||
|
||||
def test_message_iterator_when_closed(self, subscription: _RedisSubscription):
|
||||
"""Test that iterator raises error when subscription is closed."""
|
||||
subscription.close()
|
||||
|
||||
with pytest.raises(BroadcastChannelError, match="The Redis subscription is closed"):
|
||||
iter(subscription)
|
||||
|
||||
# ==================== Message Enqueue Tests ====================
|
||||
|
||||
def test_enqueue_message_success(self, started_subscription: _RedisSubscription):
|
||||
"""Test successful message enqueue."""
|
||||
payload = b"test message"
|
||||
|
||||
started_subscription._enqueue_message(payload)
|
||||
|
||||
assert started_subscription._queue.qsize() == 1
|
||||
assert started_subscription._queue.get_nowait() == payload
|
||||
|
||||
def test_enqueue_message_when_closed(self, subscription: _RedisSubscription):
|
||||
"""Test message enqueue when subscription is closed."""
|
||||
subscription.close()
|
||||
payload = b"test message"
|
||||
|
||||
# Should not raise exception, but should not enqueue
|
||||
subscription._enqueue_message(payload)
|
||||
|
||||
assert subscription._queue.empty()
|
||||
|
||||
def test_enqueue_message_with_full_queue(self, started_subscription: _RedisSubscription):
|
||||
"""Test message enqueue with full queue (dropping behavior)."""
|
||||
# Fill the queue
|
||||
for i in range(started_subscription._queue.maxsize):
|
||||
started_subscription._queue.put_nowait(f"old_msg_{i}".encode())
|
||||
|
||||
# Try to enqueue new message (should drop oldest)
|
||||
new_message = b"new_message"
|
||||
started_subscription._enqueue_message(new_message)
|
||||
|
||||
# Should have dropped one message and added new one
|
||||
assert started_subscription._dropped_count == 1
|
||||
|
||||
# New message should be in queue
|
||||
messages = []
|
||||
while not started_subscription._queue.empty():
|
||||
messages.append(started_subscription._queue.get_nowait())
|
||||
|
||||
assert new_message in messages
|
||||
|
||||
# ==================== Listener Thread Tests ====================
|
||||
|
||||
@patch("time.sleep", side_effect=lambda x: None) # Speed up test
|
||||
def test_listener_thread_normal_operation(
|
||||
self, mock_sleep, subscription: _RedisSubscription, mock_pubsub: MagicMock
|
||||
):
|
||||
"""Test listener thread normal operation."""
|
||||
# Mock message from Redis
|
||||
mock_message = {"type": "message", "channel": "test-topic", "data": b"test payload"}
|
||||
mock_pubsub.get_message.return_value = mock_message
|
||||
|
||||
# Start listener
|
||||
subscription._start_if_needed()
|
||||
|
||||
# Wait a bit for processing
|
||||
time.sleep(0.1)
|
||||
|
||||
# Verify message was processed
|
||||
assert not subscription._queue.empty()
|
||||
assert subscription._queue.get_nowait() == b"test payload"
|
||||
|
||||
def test_listener_thread_ignores_subscribe_messages(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
|
||||
"""Test that listener thread ignores subscribe/unsubscribe messages."""
|
||||
mock_message = {"type": "subscribe", "channel": "test-topic", "data": 1}
|
||||
mock_pubsub.get_message.return_value = mock_message
|
||||
|
||||
subscription._start_if_needed()
|
||||
time.sleep(0.1)
|
||||
|
||||
# Should not enqueue subscribe messages
|
||||
assert subscription._queue.empty()
|
||||
|
||||
def test_listener_thread_ignores_wrong_channel(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
|
||||
"""Test that listener thread ignores messages from wrong channels."""
|
||||
mock_message = {"type": "message", "channel": "wrong-topic", "data": b"test payload"}
|
||||
mock_pubsub.get_message.return_value = mock_message
|
||||
|
||||
subscription._start_if_needed()
|
||||
time.sleep(0.1)
|
||||
|
||||
# Should not enqueue messages from wrong channels
|
||||
assert subscription._queue.empty()
|
||||
|
||||
def test_listener_thread_handles_redis_exceptions(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
|
||||
"""Test that listener thread handles Redis exceptions gracefully."""
|
||||
mock_pubsub.get_message.side_effect = Exception("Redis error")
|
||||
|
||||
subscription._start_if_needed()
|
||||
|
||||
# Wait for thread to handle exception
|
||||
time.sleep(0.2)
|
||||
|
||||
# Thread should still be alive but not processing
|
||||
assert subscription._listener_thread is not None
|
||||
assert not subscription._listener_thread.is_alive()
|
||||
|
||||
def test_listener_thread_stops_when_closed(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
|
||||
"""Test that listener thread stops when subscription is closed."""
|
||||
subscription._start_if_needed()
|
||||
thread = subscription._listener_thread
|
||||
|
||||
# Close subscription
|
||||
subscription.close()
|
||||
|
||||
# Wait for thread to finish
|
||||
if thread is not None and thread.is_alive():
|
||||
thread.join(timeout=1.0)
|
||||
|
||||
assert thread is None or not thread.is_alive()
|
||||
|
||||
# ==================== Table-driven Tests ====================
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
SubscriptionTestCase(
|
||||
name="basic_message",
|
||||
buffer_size=5,
|
||||
payload=b"hello world",
|
||||
expected_messages=[b"hello world"],
|
||||
description="Basic message publishing and receiving",
|
||||
),
|
||||
SubscriptionTestCase(
|
||||
name="empty_message",
|
||||
buffer_size=5,
|
||||
payload=b"",
|
||||
expected_messages=[b""],
|
||||
description="Empty message handling",
|
||||
),
|
||||
SubscriptionTestCase(
|
||||
name="large_message",
|
||||
buffer_size=5,
|
||||
payload=b"x" * 10000,
|
||||
expected_messages=[b"x" * 10000],
|
||||
description="Large message handling",
|
||||
),
|
||||
SubscriptionTestCase(
|
||||
name="unicode_message",
|
||||
buffer_size=5,
|
||||
payload="你好世界".encode(),
|
||||
expected_messages=["你好世界".encode()],
|
||||
description="Unicode message handling",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_subscription_scenarios(self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock):
|
||||
"""Test various subscription scenarios using table-driven approach."""
|
||||
subscription = _RedisSubscription(
|
||||
pubsub=mock_pubsub,
|
||||
topic="test-topic",
|
||||
)
|
||||
|
||||
# Simulate receiving message
|
||||
mock_message = {"type": "message", "channel": "test-topic", "data": test_case.payload}
|
||||
mock_pubsub.get_message.return_value = mock_message
|
||||
|
||||
try:
|
||||
with subscription:
|
||||
# Wait for message processing
|
||||
time.sleep(0.1)
|
||||
|
||||
# Collect received messages
|
||||
received = []
|
||||
for msg in subscription:
|
||||
received.append(msg)
|
||||
if len(received) >= len(test_case.expected_messages):
|
||||
break
|
||||
|
||||
assert received == test_case.expected_messages, f"Failed: {test_case.description}"
|
||||
finally:
|
||||
subscription.close()
|
||||
|
||||
def test_concurrent_close_and_enqueue(self, started_subscription: _RedisSubscription):
|
||||
"""Test concurrent close and enqueue operations."""
|
||||
errors = []
|
||||
|
||||
def close_subscription():
|
||||
try:
|
||||
time.sleep(0.05) # Small delay
|
||||
started_subscription.close()
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
def enqueue_messages():
|
||||
try:
|
||||
for i in range(50):
|
||||
started_subscription._enqueue_message(f"msg_{i}".encode())
|
||||
time.sleep(0.001)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
# Start threads
|
||||
close_thread = threading.Thread(target=close_subscription)
|
||||
enqueue_thread = threading.Thread(target=enqueue_messages)
|
||||
|
||||
close_thread.start()
|
||||
enqueue_thread.start()
|
||||
|
||||
# Wait for completion
|
||||
close_thread.join(timeout=2.0)
|
||||
enqueue_thread.join(timeout=2.0)
|
||||
|
||||
# Should not have any errors (operations should be safe)
|
||||
assert len(errors) == 0
|
||||
|
||||
# ==================== Error Handling Tests ====================
|
||||
|
||||
def test_iterator_after_close(self, subscription: _RedisSubscription):
|
||||
"""Test iterator behavior after close."""
|
||||
subscription.close()
|
||||
|
||||
with pytest.raises(SubscriptionClosedError, match="The Redis subscription is closed"):
|
||||
iter(subscription)
|
||||
|
||||
def test_start_after_close(self, subscription: _RedisSubscription):
|
||||
"""Test start attempts after close."""
|
||||
subscription.close()
|
||||
|
||||
with pytest.raises(SubscriptionClosedError, match="The Redis subscription is closed"):
|
||||
subscription._start_if_needed()
|
||||
|
||||
def test_pubsub_none_operations(self, subscription: _RedisSubscription):
|
||||
"""Test operations when pubsub is None."""
|
||||
subscription._pubsub = None
|
||||
|
||||
with pytest.raises(SubscriptionClosedError, match="The Redis subscription has been cleaned up"):
|
||||
subscription._start_if_needed()
|
||||
|
||||
# Close should still work
|
||||
subscription.close() # Should not raise
|
||||
|
||||
def test_channel_name_variations(self, mock_pubsub: MagicMock):
|
||||
"""Test various channel name formats."""
|
||||
channel_names = [
|
||||
"simple",
|
||||
"with-dashes",
|
||||
"with_underscores",
|
||||
"with.numbers",
|
||||
"WITH.UPPERCASE",
|
||||
"mixed-CASE_name",
|
||||
"very.long.channel.name.with.multiple.parts",
|
||||
]
|
||||
|
||||
for channel_name in channel_names:
|
||||
subscription = _RedisSubscription(
|
||||
pubsub=mock_pubsub,
|
||||
topic=channel_name,
|
||||
)
|
||||
|
||||
subscription._start_if_needed()
|
||||
mock_pubsub.subscribe.assert_called_with(channel_name)
|
||||
subscription.close()
|
||||
|
||||
def test_received_on_closed_subscription(self, subscription: _RedisSubscription):
|
||||
subscription.close()
|
||||
|
||||
with pytest.raises(SubscriptionClosedError):
|
||||
subscription.receive()
|
||||
|
|
@ -0,0 +1,317 @@
|
|||
from unittest.mock import Mock, patch
|
||||
|
||||
from core.entities.document_task import DocumentTask
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from services.document_indexing_task_proxy import DocumentIndexingTaskProxy
|
||||
|
||||
|
||||
class DocumentIndexingTaskProxyTestDataFactory:
|
||||
"""Factory class for creating test data and mock objects for DocumentIndexingTaskProxy tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_mock_features(billing_enabled: bool = False, plan: CloudPlan = CloudPlan.SANDBOX) -> Mock:
|
||||
"""Create mock features with billing configuration."""
|
||||
features = Mock()
|
||||
features.billing = Mock()
|
||||
features.billing.enabled = billing_enabled
|
||||
features.billing.subscription = Mock()
|
||||
features.billing.subscription.plan = plan
|
||||
return features
|
||||
|
||||
@staticmethod
|
||||
def create_mock_tenant_queue(has_task_key: bool = False) -> Mock:
|
||||
"""Create mock TenantIsolatedTaskQueue."""
|
||||
queue = Mock(spec=TenantIsolatedTaskQueue)
|
||||
queue.get_task_key.return_value = "task_key" if has_task_key else None
|
||||
queue.push_tasks = Mock()
|
||||
queue.set_task_waiting_time = Mock()
|
||||
return queue
|
||||
|
||||
@staticmethod
|
||||
def create_document_task_proxy(
|
||||
tenant_id: str = "tenant-123", dataset_id: str = "dataset-456", document_ids: list[str] | None = None
|
||||
) -> DocumentIndexingTaskProxy:
|
||||
"""Create DocumentIndexingTaskProxy instance for testing."""
|
||||
if document_ids is None:
|
||||
document_ids = ["doc-1", "doc-2", "doc-3"]
|
||||
return DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||
|
||||
|
||||
class TestDocumentIndexingTaskProxy:
|
||||
"""Test cases for DocumentIndexingTaskProxy class."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test DocumentIndexingTaskProxy initialization."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
dataset_id = "dataset-456"
|
||||
document_ids = ["doc-1", "doc-2", "doc-3"]
|
||||
|
||||
# Act
|
||||
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
assert proxy._tenant_id == tenant_id
|
||||
assert proxy._dataset_id == dataset_id
|
||||
assert proxy._document_ids == document_ids
|
||||
assert isinstance(proxy._tenant_isolated_task_queue, TenantIsolatedTaskQueue)
|
||||
assert proxy._tenant_isolated_task_queue._tenant_id == tenant_id
|
||||
assert proxy._tenant_isolated_task_queue._unique_key == "document_indexing"
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
def test_features_property(self, mock_feature_service):
|
||||
"""Test cached_property features."""
|
||||
# Arrange
|
||||
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features()
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
|
||||
# Act
|
||||
features1 = proxy.features
|
||||
features2 = proxy.features # Second call should use cached property
|
||||
|
||||
# Assert
|
||||
assert features1 == mock_features
|
||||
assert features2 == mock_features
|
||||
assert features1 is features2 # Should be the same instance due to caching
|
||||
mock_feature_service.get_features.assert_called_once_with("tenant-123")
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_direct_queue(self, mock_task):
|
||||
"""Test _send_to_direct_queue method."""
|
||||
# Arrange
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
mock_task.delay = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_direct_queue(mock_task)
|
||||
|
||||
# Assert
|
||||
mock_task.delay.assert_called_once_with(
|
||||
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
|
||||
)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task):
|
||||
"""Test _send_to_tenant_queue when task key exists."""
|
||||
# Arrange
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._tenant_isolated_task_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue(
|
||||
has_task_key=True
|
||||
)
|
||||
mock_task.delay = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_tenant_queue(mock_task)
|
||||
|
||||
# Assert
|
||||
proxy._tenant_isolated_task_queue.push_tasks.assert_called_once()
|
||||
pushed_tasks = proxy._tenant_isolated_task_queue.push_tasks.call_args[0][0]
|
||||
assert len(pushed_tasks) == 1
|
||||
assert isinstance(DocumentTask(**pushed_tasks[0]), DocumentTask)
|
||||
assert pushed_tasks[0]["tenant_id"] == "tenant-123"
|
||||
assert pushed_tasks[0]["dataset_id"] == "dataset-456"
|
||||
assert pushed_tasks[0]["document_ids"] == ["doc-1", "doc-2", "doc-3"]
|
||||
mock_task.delay.assert_not_called()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_tenant_queue_without_task_key(self, mock_task):
|
||||
"""Test _send_to_tenant_queue when no task key exists."""
|
||||
# Arrange
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._tenant_isolated_task_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue(
|
||||
has_task_key=False
|
||||
)
|
||||
mock_task.delay = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_tenant_queue(mock_task)
|
||||
|
||||
# Assert
|
||||
proxy._tenant_isolated_task_queue.set_task_waiting_time.assert_called_once()
|
||||
mock_task.delay.assert_called_once_with(
|
||||
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
|
||||
)
|
||||
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_default_tenant_queue(self, mock_task):
|
||||
"""Test _send_to_default_tenant_queue method."""
|
||||
# Arrange
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._send_to_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_default_tenant_queue()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
|
||||
def test_send_to_priority_tenant_queue(self, mock_task):
|
||||
"""Test _send_to_priority_tenant_queue method."""
|
||||
# Arrange
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._send_to_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_priority_tenant_queue()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
|
||||
def test_send_to_priority_direct_queue(self, mock_task):
|
||||
"""Test _send_to_priority_direct_queue method."""
|
||||
# Arrange
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._send_to_direct_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_priority_direct_queue()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_direct_queue.assert_called_once_with(mock_task)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service):
|
||||
"""Test _dispatch method when billing is enabled with sandbox plan."""
|
||||
# Arrange
|
||||
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(
|
||||
billing_enabled=True, plan=CloudPlan.SANDBOX
|
||||
)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._send_to_default_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_default_tenant_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
def test_dispatch_with_billing_enabled_non_sandbox_plan(self, mock_feature_service):
|
||||
"""Test _dispatch method when billing is enabled with non-sandbox plan."""
|
||||
# Arrange
|
||||
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(
|
||||
billing_enabled=True, plan=CloudPlan.TEAM
|
||||
)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._send_to_priority_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# If billing enabled with non sandbox plan, should send to priority tenant queue
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
def test_dispatch_with_billing_disabled(self, mock_feature_service):
|
||||
"""Test _dispatch method when billing is disabled."""
|
||||
# Arrange
|
||||
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=False)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._send_to_priority_direct_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# If billing disabled, for example: self-hosted or enterprise, should send to priority direct queue
|
||||
proxy._send_to_priority_direct_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
def test_delay_method(self, mock_feature_service):
|
||||
"""Test delay method integration."""
|
||||
# Arrange
|
||||
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(
|
||||
billing_enabled=True, plan=CloudPlan.SANDBOX
|
||||
)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._send_to_default_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy.delay()
|
||||
|
||||
# Assert
|
||||
# If billing enabled with sandbox plan, should send to default tenant queue
|
||||
proxy._send_to_default_tenant_queue.assert_called_once()
|
||||
|
||||
def test_document_task_dataclass(self):
|
||||
"""Test DocumentTask dataclass."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
dataset_id = "dataset-456"
|
||||
document_ids = ["doc-1", "doc-2"]
|
||||
|
||||
# Act
|
||||
task = DocumentTask(tenant_id=tenant_id, dataset_id=dataset_id, document_ids=document_ids)
|
||||
|
||||
# Assert
|
||||
assert task.tenant_id == tenant_id
|
||||
assert task.dataset_id == dataset_id
|
||||
assert task.document_ids == document_ids
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
def test_dispatch_edge_case_empty_plan(self, mock_feature_service):
|
||||
"""Test _dispatch method with empty plan string."""
|
||||
# Arrange
|
||||
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan="")
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._send_to_priority_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
def test_dispatch_edge_case_none_plan(self, mock_feature_service):
|
||||
"""Test _dispatch method with None plan."""
|
||||
# Arrange
|
||||
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan=None)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._send_to_priority_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||
|
||||
def test_initialization_with_empty_document_ids(self):
|
||||
"""Test initialization with empty document_ids list."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
dataset_id = "dataset-456"
|
||||
document_ids = []
|
||||
|
||||
# Act
|
||||
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
assert proxy._tenant_id == tenant_id
|
||||
assert proxy._dataset_id == dataset_id
|
||||
assert proxy._document_ids == document_ids
|
||||
|
||||
def test_initialization_with_single_document_id(self):
|
||||
"""Test initialization with single document_id."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
dataset_id = "dataset-456"
|
||||
document_ids = ["doc-1"]
|
||||
|
||||
# Act
|
||||
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
assert proxy._tenant_id == tenant_id
|
||||
assert proxy._dataset_id == dataset_id
|
||||
assert proxy._document_ids == document_ids
|
||||
|
|
@ -0,0 +1,483 @@
|
|||
import json
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from services.rag_pipeline.rag_pipeline_task_proxy import RagPipelineTaskProxy
|
||||
|
||||
|
||||
class RagPipelineTaskProxyTestDataFactory:
|
||||
"""Factory class for creating test data and mock objects for RagPipelineTaskProxy tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_mock_features(billing_enabled: bool = False, plan: CloudPlan = CloudPlan.SANDBOX) -> Mock:
|
||||
"""Create mock features with billing configuration."""
|
||||
features = Mock()
|
||||
features.billing = Mock()
|
||||
features.billing.enabled = billing_enabled
|
||||
features.billing.subscription = Mock()
|
||||
features.billing.subscription.plan = plan
|
||||
return features
|
||||
|
||||
@staticmethod
|
||||
def create_mock_tenant_queue(has_task_key: bool = False) -> Mock:
|
||||
"""Create mock TenantIsolatedTaskQueue."""
|
||||
queue = Mock(spec=TenantIsolatedTaskQueue)
|
||||
queue.get_task_key.return_value = "task_key" if has_task_key else None
|
||||
queue.push_tasks = Mock()
|
||||
queue.set_task_waiting_time = Mock()
|
||||
return queue
|
||||
|
||||
@staticmethod
|
||||
def create_rag_pipeline_invoke_entity(
|
||||
pipeline_id: str = "pipeline-123",
|
||||
user_id: str = "user-456",
|
||||
tenant_id: str = "tenant-789",
|
||||
workflow_id: str = "workflow-101",
|
||||
streaming: bool = True,
|
||||
workflow_execution_id: str | None = None,
|
||||
workflow_thread_pool_id: str | None = None,
|
||||
) -> RagPipelineInvokeEntity:
|
||||
"""Create RagPipelineInvokeEntity instance for testing."""
|
||||
return RagPipelineInvokeEntity(
|
||||
pipeline_id=pipeline_id,
|
||||
application_generate_entity={"key": "value"},
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
workflow_id=workflow_id,
|
||||
streaming=streaming,
|
||||
workflow_execution_id=workflow_execution_id,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_rag_pipeline_task_proxy(
|
||||
dataset_tenant_id: str = "tenant-123",
|
||||
user_id: str = "user-456",
|
||||
rag_pipeline_invoke_entities: list[RagPipelineInvokeEntity] | None = None,
|
||||
) -> RagPipelineTaskProxy:
|
||||
"""Create RagPipelineTaskProxy instance for testing."""
|
||||
if rag_pipeline_invoke_entities is None:
|
||||
rag_pipeline_invoke_entities = [RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity()]
|
||||
return RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities)
|
||||
|
||||
@staticmethod
|
||||
def create_mock_upload_file(file_id: str = "file-123") -> Mock:
|
||||
"""Create mock upload file."""
|
||||
upload_file = Mock()
|
||||
upload_file.id = file_id
|
||||
return upload_file
|
||||
|
||||
|
||||
class TestRagPipelineTaskProxy:
|
||||
"""Test cases for RagPipelineTaskProxy class."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test RagPipelineTaskProxy initialization."""
|
||||
# Arrange
|
||||
dataset_tenant_id = "tenant-123"
|
||||
user_id = "user-456"
|
||||
rag_pipeline_invoke_entities = [RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity()]
|
||||
|
||||
# Act
|
||||
proxy = RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities)
|
||||
|
||||
# Assert
|
||||
assert proxy._dataset_tenant_id == dataset_tenant_id
|
||||
assert proxy._user_id == user_id
|
||||
assert proxy._rag_pipeline_invoke_entities == rag_pipeline_invoke_entities
|
||||
assert isinstance(proxy._tenant_isolated_task_queue, TenantIsolatedTaskQueue)
|
||||
assert proxy._tenant_isolated_task_queue._tenant_id == dataset_tenant_id
|
||||
assert proxy._tenant_isolated_task_queue._unique_key == "pipeline"
|
||||
|
||||
def test_initialization_with_empty_entities(self):
|
||||
"""Test initialization with empty rag_pipeline_invoke_entities."""
|
||||
# Arrange
|
||||
dataset_tenant_id = "tenant-123"
|
||||
user_id = "user-456"
|
||||
rag_pipeline_invoke_entities = []
|
||||
|
||||
# Act
|
||||
proxy = RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities)
|
||||
|
||||
# Assert
|
||||
assert proxy._dataset_tenant_id == dataset_tenant_id
|
||||
assert proxy._user_id == user_id
|
||||
assert proxy._rag_pipeline_invoke_entities == []
|
||||
|
||||
def test_initialization_with_multiple_entities(self):
|
||||
"""Test initialization with multiple rag_pipeline_invoke_entities."""
|
||||
# Arrange
|
||||
dataset_tenant_id = "tenant-123"
|
||||
user_id = "user-456"
|
||||
rag_pipeline_invoke_entities = [
|
||||
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-1"),
|
||||
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-2"),
|
||||
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-3"),
|
||||
]
|
||||
|
||||
# Act
|
||||
proxy = RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities)
|
||||
|
||||
# Assert
|
||||
assert len(proxy._rag_pipeline_invoke_entities) == 3
|
||||
assert proxy._rag_pipeline_invoke_entities[0].pipeline_id == "pipeline-1"
|
||||
assert proxy._rag_pipeline_invoke_entities[1].pipeline_id == "pipeline-2"
|
||||
assert proxy._rag_pipeline_invoke_entities[2].pipeline_id == "pipeline-3"
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
|
||||
def test_features_property(self, mock_feature_service):
|
||||
"""Test cached_property features."""
|
||||
# Arrange
|
||||
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features()
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
|
||||
# Act
|
||||
features1 = proxy.features
|
||||
features2 = proxy.features # Second call should use cached property
|
||||
|
||||
# Assert
|
||||
assert features1 == mock_features
|
||||
assert features2 == mock_features
|
||||
assert features1 is features2 # Should be the same instance due to caching
|
||||
mock_feature_service.get_features.assert_called_once_with("tenant-123")
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_upload_invoke_entities(self, mock_db, mock_file_service_class):
|
||||
"""Test _upload_invoke_entities method."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
mock_file_service = Mock()
|
||||
mock_file_service_class.return_value = mock_file_service
|
||||
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
|
||||
mock_file_service.upload_text.return_value = mock_upload_file
|
||||
|
||||
# Act
|
||||
result = proxy._upload_invoke_entities()
|
||||
|
||||
# Assert
|
||||
assert result == "file-123"
|
||||
mock_file_service_class.assert_called_once_with(mock_db.engine)
|
||||
|
||||
# Verify upload_text was called with correct parameters
|
||||
mock_file_service.upload_text.assert_called_once()
|
||||
call_args = mock_file_service.upload_text.call_args
|
||||
json_text, name, user_id, tenant_id = call_args[0]
|
||||
|
||||
assert name == "rag_pipeline_invoke_entities.json"
|
||||
assert user_id == "user-456"
|
||||
assert tenant_id == "tenant-123"
|
||||
|
||||
# Verify JSON content
|
||||
parsed_json = json.loads(json_text)
|
||||
assert len(parsed_json) == 1
|
||||
assert parsed_json[0]["pipeline_id"] == "pipeline-123"
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_upload_invoke_entities_with_multiple_entities(self, mock_db, mock_file_service_class):
|
||||
"""Test _upload_invoke_entities method with multiple entities."""
|
||||
# Arrange
|
||||
entities = [
|
||||
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-1"),
|
||||
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-2"),
|
||||
]
|
||||
proxy = RagPipelineTaskProxy("tenant-123", "user-456", entities)
|
||||
mock_file_service = Mock()
|
||||
mock_file_service_class.return_value = mock_file_service
|
||||
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-456")
|
||||
mock_file_service.upload_text.return_value = mock_upload_file
|
||||
|
||||
# Act
|
||||
result = proxy._upload_invoke_entities()
|
||||
|
||||
# Assert
|
||||
assert result == "file-456"
|
||||
|
||||
# Verify JSON content contains both entities
|
||||
call_args = mock_file_service.upload_text.call_args
|
||||
json_text = call_args[0][0]
|
||||
parsed_json = json.loads(json_text)
|
||||
assert len(parsed_json) == 2
|
||||
assert parsed_json[0]["pipeline_id"] == "pipeline-1"
|
||||
assert parsed_json[1]["pipeline_id"] == "pipeline-2"
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task")
|
||||
def test_send_to_direct_queue(self, mock_task):
|
||||
"""Test _send_to_direct_queue method."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._tenant_isolated_task_queue = RagPipelineTaskProxyTestDataFactory.create_mock_tenant_queue()
|
||||
upload_file_id = "file-123"
|
||||
mock_task.delay = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_direct_queue(upload_file_id, mock_task)
|
||||
|
||||
# If sent to direct queue, tenant_isolated_task_queue should not be called
|
||||
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
|
||||
|
||||
# Celery should be called directly
|
||||
mock_task.delay.assert_called_once_with(
|
||||
rag_pipeline_invoke_entities_file_id=upload_file_id, tenant_id="tenant-123"
|
||||
)
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task")
|
||||
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task):
|
||||
"""Test _send_to_tenant_queue when task key exists."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._tenant_isolated_task_queue = RagPipelineTaskProxyTestDataFactory.create_mock_tenant_queue(
|
||||
has_task_key=True
|
||||
)
|
||||
upload_file_id = "file-123"
|
||||
mock_task.delay = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_tenant_queue(upload_file_id, mock_task)
|
||||
|
||||
# If task key exists, should push tasks to the queue
|
||||
proxy._tenant_isolated_task_queue.push_tasks.assert_called_once_with([upload_file_id])
|
||||
# Celery should not be called directly
|
||||
mock_task.delay.assert_not_called()
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task")
|
||||
def test_send_to_tenant_queue_without_task_key(self, mock_task):
|
||||
"""Test _send_to_tenant_queue when no task key exists."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._tenant_isolated_task_queue = RagPipelineTaskProxyTestDataFactory.create_mock_tenant_queue(
|
||||
has_task_key=False
|
||||
)
|
||||
upload_file_id = "file-123"
|
||||
mock_task.delay = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_tenant_queue(upload_file_id, mock_task)
|
||||
|
||||
# If no task key, should set task waiting time key first
|
||||
proxy._tenant_isolated_task_queue.set_task_waiting_time.assert_called_once()
|
||||
mock_task.delay.assert_called_once_with(
|
||||
rag_pipeline_invoke_entities_file_id=upload_file_id, tenant_id="tenant-123"
|
||||
)
|
||||
|
||||
# The first task should be sent to celery directly, so push tasks should not be called
|
||||
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task")
|
||||
def test_send_to_default_tenant_queue(self, mock_task):
|
||||
"""Test _send_to_default_tenant_queue method."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._send_to_tenant_queue = Mock()
|
||||
upload_file_id = "file-123"
|
||||
|
||||
# Act
|
||||
proxy._send_to_default_tenant_queue(upload_file_id)
|
||||
|
||||
# Assert
|
||||
proxy._send_to_tenant_queue.assert_called_once_with(upload_file_id, mock_task)
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.priority_rag_pipeline_run_task")
|
||||
def test_send_to_priority_tenant_queue(self, mock_task):
|
||||
"""Test _send_to_priority_tenant_queue method."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._send_to_tenant_queue = Mock()
|
||||
upload_file_id = "file-123"
|
||||
|
||||
# Act
|
||||
proxy._send_to_priority_tenant_queue(upload_file_id)
|
||||
|
||||
# Assert
|
||||
proxy._send_to_tenant_queue.assert_called_once_with(upload_file_id, mock_task)
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.priority_rag_pipeline_run_task")
|
||||
def test_send_to_priority_direct_queue(self, mock_task):
|
||||
"""Test _send_to_priority_direct_queue method."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._send_to_direct_queue = Mock()
|
||||
upload_file_id = "file-123"
|
||||
|
||||
# Act
|
||||
proxy._send_to_priority_direct_queue(upload_file_id)
|
||||
|
||||
# Assert
|
||||
proxy._send_to_direct_queue.assert_called_once_with(upload_file_id, mock_task)
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_db, mock_file_service_class, mock_feature_service):
|
||||
"""Test _dispatch method when billing is enabled with sandbox plan."""
|
||||
# Arrange
|
||||
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(
|
||||
billing_enabled=True, plan=CloudPlan.SANDBOX
|
||||
)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._send_to_default_tenant_queue = Mock()
|
||||
|
||||
mock_file_service = Mock()
|
||||
mock_file_service_class.return_value = mock_file_service
|
||||
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
|
||||
mock_file_service.upload_text.return_value = mock_upload_file
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# If billing is enabled with sandbox plan, should send to default tenant queue
|
||||
proxy._send_to_default_tenant_queue.assert_called_once_with("file-123")
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_dispatch_with_billing_enabled_non_sandbox_plan(
|
||||
self, mock_db, mock_file_service_class, mock_feature_service
|
||||
):
|
||||
"""Test _dispatch method when billing is enabled with non-sandbox plan."""
|
||||
# Arrange
|
||||
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(
|
||||
billing_enabled=True, plan=CloudPlan.TEAM
|
||||
)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._send_to_priority_tenant_queue = Mock()
|
||||
|
||||
mock_file_service = Mock()
|
||||
mock_file_service_class.return_value = mock_file_service
|
||||
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
|
||||
mock_file_service.upload_text.return_value = mock_upload_file
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# If billing is enabled with non-sandbox plan, should send to priority tenant queue
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once_with("file-123")
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_dispatch_with_billing_disabled(self, mock_db, mock_file_service_class, mock_feature_service):
|
||||
"""Test _dispatch method when billing is disabled."""
|
||||
# Arrange
|
||||
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(billing_enabled=False)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._send_to_priority_direct_queue = Mock()
|
||||
|
||||
mock_file_service = Mock()
|
||||
mock_file_service_class.return_value = mock_file_service
|
||||
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
|
||||
mock_file_service.upload_text.return_value = mock_upload_file
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# If billing is disabled, for example: self-hosted or enterprise, should send to priority direct queue
|
||||
proxy._send_to_priority_direct_queue.assert_called_once_with("file-123")
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_dispatch_with_empty_upload_file_id(self, mock_db, mock_file_service_class):
|
||||
"""Test _dispatch method when upload_file_id is empty."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
|
||||
mock_file_service = Mock()
|
||||
mock_file_service_class.return_value = mock_file_service
|
||||
mock_upload_file = Mock()
|
||||
mock_upload_file.id = "" # Empty file ID
|
||||
mock_file_service.upload_text.return_value = mock_upload_file
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="upload_file_id is empty"):
|
||||
proxy._dispatch()
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_dispatch_edge_case_empty_plan(self, mock_db, mock_file_service_class, mock_feature_service):
|
||||
"""Test _dispatch method with empty plan string."""
|
||||
# Arrange
|
||||
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan="")
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._send_to_priority_tenant_queue = Mock()
|
||||
|
||||
mock_file_service = Mock()
|
||||
mock_file_service_class.return_value = mock_file_service
|
||||
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
|
||||
mock_file_service.upload_text.return_value = mock_upload_file
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once_with("file-123")
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_dispatch_edge_case_none_plan(self, mock_db, mock_file_service_class, mock_feature_service):
|
||||
"""Test _dispatch method with None plan."""
|
||||
# Arrange
|
||||
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan=None)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._send_to_priority_tenant_queue = Mock()
|
||||
|
||||
mock_file_service = Mock()
|
||||
mock_file_service_class.return_value = mock_file_service
|
||||
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
|
||||
mock_file_service.upload_text.return_value = mock_upload_file
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once_with("file-123")
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_delay_method(self, mock_db, mock_file_service_class, mock_feature_service):
|
||||
"""Test delay method integration."""
|
||||
# Arrange
|
||||
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(
|
||||
billing_enabled=True, plan=CloudPlan.SANDBOX
|
||||
)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._dispatch = Mock()
|
||||
|
||||
mock_file_service = Mock()
|
||||
mock_file_service_class.return_value = mock_file_service
|
||||
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
|
||||
mock_file_service.upload_text.return_value = mock_upload_file
|
||||
|
||||
# Act
|
||||
proxy.delay()
|
||||
|
||||
# Assert
|
||||
proxy._dispatch.assert_called_once()
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.logger")
|
||||
def test_delay_method_with_empty_entities(self, mock_logger):
|
||||
"""Test delay method with empty rag_pipeline_invoke_entities."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxy("tenant-123", "user-456", [])
|
||||
|
||||
# Act
|
||||
proxy.delay()
|
||||
|
||||
# Assert
|
||||
mock_logger.warning.assert_called_once_with(
|
||||
"Received empty rag pipeline invoke entities, no tasks delivered: %s %s", "tenant-123", "user-456"
|
||||
)
|
||||
1553
api/uv.lock
1553
api/uv.lock
File diff suppressed because it is too large
Load Diff
|
|
@ -85,12 +85,12 @@ if [[ -z "${QUEUES}" ]]; then
|
|||
# Configure queues based on edition
|
||||
if [[ "${EDITION}" == "CLOUD" ]]; then
|
||||
# Cloud edition: separate queues for dataset and trigger tasks
|
||||
QUEUES="dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
|
||||
QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
|
||||
else
|
||||
# Community edition (SELF_HOSTED): dataset and workflow have separate queues
|
||||
QUEUES="dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
|
||||
QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
|
||||
fi
|
||||
|
||||
|
||||
echo "No queues specified, using edition-based defaults: ${QUEUES}"
|
||||
else
|
||||
echo "Using specified queues: ${QUEUES}"
|
||||
|
|
|
|||
|
|
@ -1382,3 +1382,6 @@ ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK=true
|
|||
WORKFLOW_SCHEDULE_POLLER_INTERVAL=1
|
||||
WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE=100
|
||||
WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK=0
|
||||
|
||||
# Tenant isolated task queue configuration
|
||||
TENANT_ISOLATED_TASK_CONCURRENCY=1
|
||||
|
|
|
|||
|
|
@ -620,6 +620,7 @@ x-shared-env: &shared-api-worker-env
|
|||
WORKFLOW_SCHEDULE_POLLER_INTERVAL: ${WORKFLOW_SCHEDULE_POLLER_INTERVAL:-1}
|
||||
WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE: ${WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE:-100}
|
||||
WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK: ${WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK:-0}
|
||||
TENANT_ISOLATED_TASK_CONCURRENCY: ${TENANT_ISOLATED_TASK_CONCURRENCY:-1}
|
||||
|
||||
services:
|
||||
# API service
|
||||
|
|
|
|||
|
|
@ -117,7 +117,7 @@ Tutte le offerte di Dify sono dotate di API corrispondenti, permettendovi di int
|
|||
Avviate rapidamente Dify nel vostro ambiente con questa [guida di avvio rapido](#avvio-rapido). Utilizzate la nostra [documentazione](https://docs.dify.ai) per ulteriori informazioni e istruzioni dettagliate.
|
||||
|
||||
- **Dify per Aziende / Organizzazioni<br/>**
|
||||
Offriamo funzionalità aggiuntive specifiche per le aziende. [Potete comunicarci le vostre domande tramite questo chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) o [inviateci un'email](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) per discutere le vostre esigenze aziendali. <br/>
|
||||
Offriamo funzionalità aggiuntive specifiche per le aziende. Potete [scriverci via email](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) per discutere le vostre esigenze aziendali. <br/>
|
||||
|
||||
> Per startup e piccole imprese che utilizzano AWS, date un'occhiata a [Dify Premium su AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) e distribuitelo con un solo clic nel vostro AWS VPC. Si tratta di un'offerta AMI conveniente con l'opzione di creare app con logo e branding personalizzati.
|
||||
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ Todas os recursos do Dify vêm com APIs correspondentes, permitindo que você in
|
|||
Use nossa [documentação](https://docs.dify.ai) para referências adicionais e instruções mais detalhadas.
|
||||
|
||||
- **Dify para empresas/organizações</br>**
|
||||
Oferecemos recursos adicionais voltados para empresas. [Envie suas perguntas através deste chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) ou [envie-nos um e-mail](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) para discutir necessidades empresariais. </br>
|
||||
Oferecemos recursos adicionais voltados para empresas. Você pode [falar conosco por e-mail](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) para discutir necessidades empresariais. <br/>
|
||||
|
||||
> Para startups e pequenas empresas que utilizam AWS, confira o [Dify Premium no AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) e implemente no seu próprio AWS VPC com um clique. É uma oferta AMI acessível com a opção de criar aplicativos com logotipo e marca personalizados.
|
||||
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ Tất cả các dịch vụ của Dify đều đi kèm với các API tương
|
|||
Sử dụng [tài liệu](https://docs.dify.ai) của chúng tôi để tham khảo thêm và nhận hướng dẫn chi tiết hơn.
|
||||
|
||||
- **Dify cho doanh nghiệp / tổ chức</br>**
|
||||
Chúng tôi cung cấp các tính năng bổ sung tập trung vào doanh nghiệp. [Ghi lại câu hỏi của bạn cho chúng tôi thông qua chatbot này](https://udify.app/chat/22L1zSxg6yW1cWQg) hoặc [gửi email cho chúng tôi](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) để thảo luận về nhu cầu doanh nghiệp. </br>
|
||||
Chúng tôi cung cấp các tính năng bổ sung tập trung vào doanh nghiệp. [Gửi email cho chúng tôi](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) để thảo luận về nhu cầu doanh nghiệp. <br/>
|
||||
|
||||
> Đối với các công ty khởi nghiệp và doanh nghiệp nhỏ sử dụng AWS, hãy xem [Dify Premium trên AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) và triển khai nó vào AWS VPC của riêng bạn chỉ với một cú nhấp chuột. Đây là một AMI giá cả phải chăng với tùy chọn tạo ứng dụng với logo và thương hiệu tùy chỉnh.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,20 +1,23 @@
|
|||
'use client'
|
||||
import { TIME_PERIOD_MAPPING } from '@/app/components/app/log/filter'
|
||||
import React, { useState } from 'react'
|
||||
import dayjs from 'dayjs'
|
||||
import quarterOfYear from 'dayjs/plugin/quarterOfYear'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import type { PeriodParams } from '@/app/components/app/overview/app-chart'
|
||||
import { AvgResponseTime, AvgSessionInteractions, AvgUserInteractions, ConversationsChart, CostChart, EndUsersChart, MessagesChart, TokenPerSecond, UserSatisfactionRate, WorkflowCostChart, WorkflowDailyTerminalsChart, WorkflowMessagesChart } from '@/app/components/app/overview/app-chart'
|
||||
import { useStore as useAppStore } from '@/app/components/app/store'
|
||||
import type { Item } from '@/app/components/base/select'
|
||||
import { SimpleSelect } from '@/app/components/base/select'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
import dayjs from 'dayjs'
|
||||
import quarterOfYear from 'dayjs/plugin/quarterOfYear'
|
||||
import React, { useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import TimeRangePicker from './time-range-picker'
|
||||
|
||||
dayjs.extend(quarterOfYear)
|
||||
|
||||
const today = dayjs()
|
||||
|
||||
const TIME_PERIOD_MAPPING = [
|
||||
{ value: 0, name: 'today' },
|
||||
{ value: 7, name: 'last7days' },
|
||||
{ value: 30, name: 'last30days' },
|
||||
]
|
||||
|
||||
const queryDateFormat = 'YYYY-MM-DD HH:mm'
|
||||
|
||||
export type IChartViewProps = {
|
||||
|
|
@ -25,23 +28,9 @@ export type IChartViewProps = {
|
|||
export default function ChartView({ appId, headerRight }: IChartViewProps) {
|
||||
const { t } = useTranslation()
|
||||
const appDetail = useAppStore(state => state.appDetail)
|
||||
const isChatApp = appDetail?.mode !== AppModeEnum.COMPLETION && appDetail?.mode !== AppModeEnum.WORKFLOW
|
||||
const isWorkflow = appDetail?.mode === AppModeEnum.WORKFLOW
|
||||
const [period, setPeriod] = useState<PeriodParams>({ name: t('appLog.filter.period.last7days'), query: { start: today.subtract(7, 'day').startOf('day').format(queryDateFormat), end: today.endOf('day').format(queryDateFormat) } })
|
||||
|
||||
const onSelect = (item: Item) => {
|
||||
if (item.value === -1) {
|
||||
setPeriod({ name: item.name, query: undefined })
|
||||
}
|
||||
else if (item.value === 0) {
|
||||
const startOfToday = today.startOf('day').format(queryDateFormat)
|
||||
const endOfToday = today.endOf('day').format(queryDateFormat)
|
||||
setPeriod({ name: item.name, query: { start: startOfToday, end: endOfToday } })
|
||||
}
|
||||
else {
|
||||
setPeriod({ name: item.name, query: { start: today.subtract(item.value as number, 'day').startOf('day').format(queryDateFormat), end: today.endOf('day').format(queryDateFormat) } })
|
||||
}
|
||||
}
|
||||
const isChatApp = appDetail?.mode !== 'completion' && appDetail?.mode !== 'workflow'
|
||||
const isWorkflow = appDetail?.mode === 'workflow'
|
||||
const [period, setPeriod] = useState<PeriodParams>({ name: t('appLog.filter.period.today'), query: { start: today.startOf('day').format(queryDateFormat), end: today.endOf('day').format(queryDateFormat) } })
|
||||
|
||||
if (!appDetail)
|
||||
return null
|
||||
|
|
@ -51,20 +40,11 @@ export default function ChartView({ appId, headerRight }: IChartViewProps) {
|
|||
<div className='mb-4'>
|
||||
<div className='system-xl-semibold mb-2 text-text-primary'>{t('common.appMenus.overview')}</div>
|
||||
<div className='flex items-center justify-between'>
|
||||
<div className='flex flex-row items-center'>
|
||||
<SimpleSelect
|
||||
items={Object.entries(TIME_PERIOD_MAPPING).map(([k, v]) => ({ value: k, name: t(`appLog.filter.period.${v.name}`) }))}
|
||||
className='mt-0 !w-40'
|
||||
notClearable={true}
|
||||
onSelect={(item) => {
|
||||
const id = item.value
|
||||
const value = TIME_PERIOD_MAPPING[id]?.value ?? '-1'
|
||||
const name = item.name || t('appLog.filter.period.allTime')
|
||||
onSelect({ value, name })
|
||||
}}
|
||||
defaultValue={'2'}
|
||||
/>
|
||||
</div>
|
||||
<TimeRangePicker
|
||||
ranges={TIME_PERIOD_MAPPING}
|
||||
onSelect={setPeriod}
|
||||
queryDateFormat={queryDateFormat}
|
||||
/>
|
||||
{headerRight}
|
||||
</div>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -0,0 +1,80 @@
|
|||
'use client'
|
||||
import { RiCalendarLine } from '@remixicon/react'
|
||||
import type { Dayjs } from 'dayjs'
|
||||
import type { FC } from 'react'
|
||||
import React, { useCallback } from 'react'
|
||||
import cn from '@/utils/classnames'
|
||||
import { formatToLocalTime } from '@/utils/format'
|
||||
import { useI18N } from '@/context/i18n'
|
||||
import Picker from '@/app/components/base/date-and-time-picker/date-picker'
|
||||
import type { TriggerProps } from '@/app/components/base/date-and-time-picker/types'
|
||||
import { noop } from 'lodash-es'
|
||||
import dayjs from 'dayjs'
|
||||
|
||||
type Props = {
|
||||
start: Dayjs
|
||||
end: Dayjs
|
||||
onStartChange: (date?: Dayjs) => void
|
||||
onEndChange: (date?: Dayjs) => void
|
||||
}
|
||||
|
||||
const today = dayjs()
|
||||
const DatePicker: FC<Props> = ({
|
||||
start,
|
||||
end,
|
||||
onStartChange,
|
||||
onEndChange,
|
||||
}) => {
|
||||
const { locale } = useI18N()
|
||||
|
||||
const renderDate = useCallback(({ value, handleClickTrigger, isOpen }: TriggerProps) => {
|
||||
return (
|
||||
<div className={cn('system-sm-regular flex h-7 cursor-pointer items-center rounded-lg px-1 text-components-input-text-filled hover:bg-state-base-hover', isOpen && 'bg-state-base-hover')} onClick={handleClickTrigger}>
|
||||
{value ? formatToLocalTime(value, locale, 'MMM D') : ''}
|
||||
</div>
|
||||
)
|
||||
}, [locale])
|
||||
|
||||
const availableStartDate = end.subtract(30, 'day')
|
||||
const startDateDisabled = useCallback((date: Dayjs) => {
|
||||
if (date.isAfter(today, 'date'))
|
||||
return true
|
||||
return !((date.isAfter(availableStartDate, 'date') || date.isSame(availableStartDate, 'date')) && (date.isBefore(end, 'date') || date.isSame(end, 'date')))
|
||||
}, [availableStartDate, end])
|
||||
|
||||
const availableEndDate = start.add(30, 'day')
|
||||
const endDateDisabled = useCallback((date: Dayjs) => {
|
||||
if (date.isAfter(today, 'date'))
|
||||
return true
|
||||
return !((date.isAfter(start, 'date') || date.isSame(start, 'date')) && (date.isBefore(availableEndDate, 'date') || date.isSame(availableEndDate, 'date')))
|
||||
}, [availableEndDate, start])
|
||||
|
||||
return (
|
||||
<div className='flex h-8 items-center space-x-0.5 rounded-lg bg-components-input-bg-normal px-2'>
|
||||
<div className='p-px'>
|
||||
<RiCalendarLine className='size-3.5 text-text-tertiary' />
|
||||
</div>
|
||||
<Picker
|
||||
value={start}
|
||||
onChange={onStartChange}
|
||||
renderTrigger={renderDate}
|
||||
needTimePicker={false}
|
||||
onClear={noop}
|
||||
noConfirm
|
||||
getIsDateDisabled={startDateDisabled}
|
||||
/>
|
||||
<span className='system-sm-regular text-text-tertiary'>-</span>
|
||||
<Picker
|
||||
value={end}
|
||||
onChange={onEndChange}
|
||||
renderTrigger={renderDate}
|
||||
needTimePicker={false}
|
||||
onClear={noop}
|
||||
noConfirm
|
||||
getIsDateDisabled={endDateDisabled}
|
||||
/>
|
||||
</div>
|
||||
|
||||
)
|
||||
}
|
||||
export default React.memo(DatePicker)
|
||||
|
|
@ -0,0 +1,86 @@
|
|||
'use client'
|
||||
import type { PeriodParams, PeriodParamsWithTimeRange } from '@/app/components/app/overview/app-chart'
|
||||
import type { FC } from 'react'
|
||||
import React, { useCallback, useState } from 'react'
|
||||
import type { Dayjs } from 'dayjs'
|
||||
import { HourglassShape } from '@/app/components/base/icons/src/vender/other'
|
||||
import RangeSelector from './range-selector'
|
||||
import DatePicker from './date-picker'
|
||||
import dayjs from 'dayjs'
|
||||
import { useI18N } from '@/context/i18n'
|
||||
import { formatToLocalTime } from '@/utils/format'
|
||||
|
||||
const today = dayjs()
|
||||
|
||||
type Props = {
|
||||
ranges: { value: number; name: string }[]
|
||||
onSelect: (payload: PeriodParams) => void
|
||||
queryDateFormat: string
|
||||
}
|
||||
|
||||
const TimeRangePicker: FC<Props> = ({
|
||||
ranges,
|
||||
onSelect,
|
||||
queryDateFormat,
|
||||
}) => {
|
||||
const { locale } = useI18N()
|
||||
|
||||
const [isCustomRange, setIsCustomRange] = useState(false)
|
||||
const [start, setStart] = useState<Dayjs>(today)
|
||||
const [end, setEnd] = useState<Dayjs>(today)
|
||||
|
||||
const handleRangeChange = useCallback((payload: PeriodParamsWithTimeRange) => {
|
||||
setIsCustomRange(false)
|
||||
setStart(payload.query!.start)
|
||||
setEnd(payload.query!.end)
|
||||
onSelect({
|
||||
name: payload.name,
|
||||
query: {
|
||||
start: payload.query!.start.format(queryDateFormat),
|
||||
end: payload.query!.end.format(queryDateFormat),
|
||||
},
|
||||
})
|
||||
}, [onSelect, queryDateFormat])
|
||||
|
||||
const handleDateChange = useCallback((type: 'start' | 'end') => {
|
||||
return (date?: Dayjs) => {
|
||||
if (!date) return
|
||||
if (type === 'start' && date.isSame(start)) return
|
||||
if (type === 'end' && date.isSame(end)) return
|
||||
if (type === 'start')
|
||||
setStart(date)
|
||||
else
|
||||
setEnd(date)
|
||||
|
||||
const currStart = type === 'start' ? date : start
|
||||
const currEnd = type === 'end' ? date : end
|
||||
onSelect({
|
||||
name: `${formatToLocalTime(currStart, locale, 'MMM D')} - ${formatToLocalTime(currEnd, locale, 'MMM D')}`,
|
||||
query: {
|
||||
start: currStart.format(queryDateFormat),
|
||||
end: currEnd.format(queryDateFormat),
|
||||
},
|
||||
})
|
||||
|
||||
setIsCustomRange(true)
|
||||
}
|
||||
}, [start, end, onSelect, locale, queryDateFormat])
|
||||
|
||||
return (
|
||||
<div className='flex items-center'>
|
||||
<RangeSelector
|
||||
isCustomRange={isCustomRange}
|
||||
ranges={ranges}
|
||||
onSelect={handleRangeChange}
|
||||
/>
|
||||
<HourglassShape className='h-3.5 w-2 text-components-input-bg-normal' />
|
||||
<DatePicker
|
||||
start={start}
|
||||
end={end}
|
||||
onStartChange={handleDateChange('start')}
|
||||
onEndChange={handleDateChange('end')}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
export default React.memo(TimeRangePicker)
|
||||
|
|
@ -0,0 +1,81 @@
|
|||
'use client'
|
||||
import type { PeriodParamsWithTimeRange, TimeRange } from '@/app/components/app/overview/app-chart'
|
||||
import type { FC } from 'react'
|
||||
import React, { useCallback } from 'react'
|
||||
import { SimpleSelect } from '@/app/components/base/select'
|
||||
import type { Item } from '@/app/components/base/select'
|
||||
import dayjs from 'dayjs'
|
||||
import { RiArrowDownSLine, RiCheckLine } from '@remixicon/react'
|
||||
import cn from '@/utils/classnames'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
const today = dayjs()
|
||||
|
||||
type Props = {
|
||||
isCustomRange: boolean
|
||||
ranges: { value: number; name: string }[]
|
||||
onSelect: (payload: PeriodParamsWithTimeRange) => void
|
||||
}
|
||||
|
||||
const RangeSelector: FC<Props> = ({
|
||||
isCustomRange,
|
||||
ranges,
|
||||
onSelect,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const handleSelectRange = useCallback((item: Item) => {
|
||||
const { name, value } = item
|
||||
let period: TimeRange | null = null
|
||||
if (value === 0) {
|
||||
const startOfToday = today.startOf('day')
|
||||
const endOfToday = today.endOf('day')
|
||||
period = { start: startOfToday, end: endOfToday }
|
||||
}
|
||||
else {
|
||||
period = { start: today.subtract(item.value as number, 'day').startOf('day'), end: today.endOf('day') }
|
||||
}
|
||||
onSelect({ query: period!, name })
|
||||
}, [onSelect])
|
||||
|
||||
const renderTrigger = useCallback((item: Item | null, isOpen: boolean) => {
|
||||
return (
|
||||
<div className={cn('flex h-8 cursor-pointer items-center space-x-1.5 rounded-lg bg-components-input-bg-normal pl-3 pr-2', isOpen && 'bg-state-base-hover-alt')}>
|
||||
<div className='system-sm-regular text-components-input-text-filled'>{isCustomRange ? t('appLog.filter.period.custom') : item?.name}</div>
|
||||
<RiArrowDownSLine className={cn('size-4 text-text-quaternary', isOpen && 'text-text-secondary')} />
|
||||
</div>
|
||||
)
|
||||
}, [isCustomRange])
|
||||
|
||||
const renderOption = useCallback(({ item, selected }: { item: Item; selected: boolean }) => {
|
||||
return (
|
||||
<>
|
||||
{selected && (
|
||||
<span
|
||||
className={cn(
|
||||
'absolute left-2 top-[9px] flex items-center text-text-accent',
|
||||
)}
|
||||
>
|
||||
<RiCheckLine className="h-4 w-4" aria-hidden="true" />
|
||||
</span>
|
||||
)}
|
||||
<span className={cn('system-md-regular block truncate')}>{item.name}</span>
|
||||
</>
|
||||
)
|
||||
}, [])
|
||||
return (
|
||||
<SimpleSelect
|
||||
items={ranges.map(v => ({ ...v, name: t(`appLog.filter.period.${v.name}`) }))}
|
||||
className='mt-0 !w-40'
|
||||
notClearable={true}
|
||||
onSelect={handleSelectRange}
|
||||
defaultValue={0}
|
||||
wrapperClassName='h-8'
|
||||
optionWrapClassName='w-[200px] translate-x-[-24px]'
|
||||
renderTrigger={renderTrigger}
|
||||
optionClassName='flex items-center py-0 pl-7 pr-2 h-8'
|
||||
renderOption={renderOption}
|
||||
/>
|
||||
)
|
||||
}
|
||||
export default React.memo(RangeSelector)
|
||||
|
|
@ -4,6 +4,7 @@ import React from 'react'
|
|||
import ReactECharts from 'echarts-for-react'
|
||||
import type { EChartsOption } from 'echarts'
|
||||
import useSWR from 'swr'
|
||||
import type { Dayjs } from 'dayjs'
|
||||
import dayjs from 'dayjs'
|
||||
import { get } from 'lodash-es'
|
||||
import Decimal from 'decimal.js'
|
||||
|
|
@ -78,6 +79,16 @@ export type PeriodParams = {
|
|||
}
|
||||
}
|
||||
|
||||
export type TimeRange = {
|
||||
start: Dayjs
|
||||
end: Dayjs
|
||||
}
|
||||
|
||||
export type PeriodParamsWithTimeRange = {
|
||||
name: string
|
||||
query?: TimeRange
|
||||
}
|
||||
|
||||
export type IBizChartProps = {
|
||||
period: PeriodParams
|
||||
id: string
|
||||
|
|
@ -215,9 +226,7 @@ const Chart: React.FC<IChartProps> = ({
|
|||
formatter(params) {
|
||||
return `<div style='color:#6B7280;font-size:12px'>${params.name}</div>
|
||||
<div style='font-size:14px;color:#1F2A37'>${valueFormatter((params.data as any)[yField])}
|
||||
${!CHART_TYPE_CONFIG[chartType].showTokens
|
||||
? ''
|
||||
: `<span style='font-size:12px'>
|
||||
${!CHART_TYPE_CONFIG[chartType].showTokens ? '' : `<span style='font-size:12px'>
|
||||
<span style='margin-left:4px;color:#6B7280'>(</span>
|
||||
<span style='color:#FF8A4C'>~$${get(params.data, 'total_price', 0)}</span>
|
||||
<span style='color:#6B7280'>)</span>
|
||||
|
|
|
|||
|
|
@ -8,9 +8,10 @@ const Calendar: FC<CalendarProps> = ({
|
|||
selectedDate,
|
||||
onDateClick,
|
||||
wrapperClassName,
|
||||
getIsDateDisabled,
|
||||
}) => {
|
||||
return <div className={wrapperClassName}>
|
||||
<DaysOfWeek/>
|
||||
<DaysOfWeek />
|
||||
<div className='grid grid-cols-7 gap-0.5 p-2'>
|
||||
{
|
||||
days.map(day => <CalendarItem
|
||||
|
|
@ -18,6 +19,7 @@ const Calendar: FC<CalendarProps> = ({
|
|||
day={day}
|
||||
selectedDate={selectedDate}
|
||||
onClick={onDateClick}
|
||||
isDisabled={getIsDateDisabled ? getIsDateDisabled(day.date) : false}
|
||||
/>)
|
||||
}
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ const Item: FC<CalendarItemProps> = ({
|
|||
day,
|
||||
selectedDate,
|
||||
onClick,
|
||||
isDisabled,
|
||||
}) => {
|
||||
const { date, isCurrentMonth } = day
|
||||
const isSelected = selectedDate?.isSame(date, 'date')
|
||||
|
|
@ -14,11 +15,12 @@ const Item: FC<CalendarItemProps> = ({
|
|||
|
||||
return (
|
||||
<button type="button"
|
||||
onClick={() => onClick(date)}
|
||||
onClick={() => !isDisabled && onClick(date)}
|
||||
className={cn(
|
||||
'system-sm-medium relative flex items-center justify-center rounded-lg px-1 py-2',
|
||||
isCurrentMonth ? 'text-text-secondary' : 'text-text-quaternary hover:text-text-secondary',
|
||||
isSelected ? 'system-sm-medium bg-components-button-primary-bg text-components-button-primary-text' : 'hover:bg-state-base-hover',
|
||||
isDisabled && 'cursor-not-allowed text-text-quaternary hover:bg-transparent',
|
||||
)}
|
||||
>
|
||||
{date.date()}
|
||||
|
|
|
|||
|
|
@ -36,7 +36,8 @@ const DatePicker = ({
|
|||
renderTrigger,
|
||||
triggerWrapClassName,
|
||||
popupZIndexClassname = 'z-[11]',
|
||||
notClearable = false,
|
||||
noConfirm,
|
||||
getIsDateDisabled,
|
||||
}: DatePickerProps) => {
|
||||
const { t } = useTranslation()
|
||||
const [isOpen, setIsOpen] = useState(false)
|
||||
|
|
@ -121,11 +122,20 @@ const DatePicker = ({
|
|||
setCurrentDate(currentDate.clone().subtract(1, 'month'))
|
||||
}, [currentDate])
|
||||
|
||||
const handleConfirmDate = useCallback((passedInSelectedDate?: Dayjs) => {
|
||||
// passedInSelectedDate may be a click event when noConfirm is false
|
||||
const nextDate = (dayjs.isDayjs(passedInSelectedDate) ? passedInSelectedDate : selectedDate)
|
||||
onChange(nextDate ? nextDate.tz(timezone) : undefined)
|
||||
setIsOpen(false)
|
||||
}, [selectedDate, onChange, timezone])
|
||||
|
||||
const handleDateSelect = useCallback((day: Dayjs) => {
|
||||
const newDate = cloneTime(day, selectedDate || getDateWithTimezone({ timezone }))
|
||||
setCurrentDate(newDate)
|
||||
setSelectedDate(newDate)
|
||||
}, [selectedDate, timezone])
|
||||
if (noConfirm)
|
||||
handleConfirmDate(newDate)
|
||||
}, [selectedDate, timezone, noConfirm, handleConfirmDate])
|
||||
|
||||
const handleSelectCurrentDate = () => {
|
||||
const newDate = getDateWithTimezone({ timezone })
|
||||
|
|
@ -135,12 +145,6 @@ const DatePicker = ({
|
|||
setIsOpen(false)
|
||||
}
|
||||
|
||||
const handleConfirmDate = () => {
|
||||
// debugger
|
||||
onChange(selectedDate ? selectedDate.tz(timezone) : undefined)
|
||||
setIsOpen(false)
|
||||
}
|
||||
|
||||
const handleClickTimePicker = () => {
|
||||
if (view === ViewType.date) {
|
||||
setView(ViewType.time)
|
||||
|
|
@ -208,7 +212,7 @@ const DatePicker = ({
|
|||
<PortalToFollowElem
|
||||
open={isOpen}
|
||||
onOpenChange={setIsOpen}
|
||||
placement='bottom-start'
|
||||
placement='bottom-end'
|
||||
>
|
||||
<PortalToFollowElemTrigger className={triggerWrapClassName}>
|
||||
{renderTrigger ? (renderTrigger({
|
||||
|
|
@ -232,17 +236,15 @@ const DatePicker = ({
|
|||
<RiCalendarLine className={cn(
|
||||
'h-4 w-4 shrink-0 text-text-quaternary',
|
||||
isOpen ? 'text-text-secondary' : 'group-hover:text-text-secondary',
|
||||
(displayValue || (isOpen && selectedDate)) && !notClearable && 'group-hover:hidden',
|
||||
(displayValue || (isOpen && selectedDate)) && 'group-hover:hidden',
|
||||
)} />
|
||||
{!notClearable && (
|
||||
<RiCloseCircleFill
|
||||
className={cn(
|
||||
'hidden h-4 w-4 shrink-0 text-text-quaternary',
|
||||
(displayValue || (isOpen && selectedDate)) && 'hover:text-text-secondary group-hover:inline-block',
|
||||
)}
|
||||
onClick={handleClear}
|
||||
/>
|
||||
)}
|
||||
<RiCloseCircleFill
|
||||
className={cn(
|
||||
'hidden h-4 w-4 shrink-0 text-text-quaternary',
|
||||
(displayValue || (isOpen && selectedDate)) && 'hover:text-text-secondary group-hover:inline-block',
|
||||
)}
|
||||
onClick={handleClear}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</PortalToFollowElemTrigger>
|
||||
|
|
@ -273,6 +275,7 @@ const DatePicker = ({
|
|||
days={days}
|
||||
selectedDate={selectedDate}
|
||||
onDateClick={handleDateSelect}
|
||||
getIsDateDisabled={getIsDateDisabled}
|
||||
/>
|
||||
) : view === ViewType.yearMonth ? (
|
||||
<YearAndMonthPickerOptions
|
||||
|
|
@ -293,7 +296,7 @@ const DatePicker = ({
|
|||
|
||||
{/* Footer */}
|
||||
{
|
||||
[ViewType.date, ViewType.time].includes(view) ? (
|
||||
[ViewType.date, ViewType.time].includes(view) && !noConfirm && (
|
||||
<DatePickerFooter
|
||||
needTimePicker={needTimePicker}
|
||||
displayTime={displayTime}
|
||||
|
|
@ -302,7 +305,10 @@ const DatePicker = ({
|
|||
handleSelectCurrentDate={handleSelectCurrentDate}
|
||||
handleConfirmDate={handleConfirmDate}
|
||||
/>
|
||||
) : (
|
||||
)
|
||||
}
|
||||
{
|
||||
![ViewType.date, ViewType.time].includes(view) && (
|
||||
<YearAndMonthPickerFooter
|
||||
handleYearMonthCancel={handleYearMonthCancel}
|
||||
handleYearMonthConfirm={handleYearMonthConfirm}
|
||||
|
|
|
|||
|
|
@ -30,7 +30,8 @@ export type DatePickerProps = {
|
|||
renderTrigger?: (props: TriggerProps) => React.ReactNode
|
||||
minuteFilter?: (minutes: string[]) => string[]
|
||||
popupZIndexClassname?: string
|
||||
notClearable?: boolean
|
||||
noConfirm?: boolean
|
||||
getIsDateDisabled?: (date: Dayjs) => boolean
|
||||
}
|
||||
|
||||
export type DatePickerHeaderProps = {
|
||||
|
|
@ -64,12 +65,6 @@ export type TimePickerProps = {
|
|||
title?: string
|
||||
minuteFilter?: (minutes: string[]) => string[]
|
||||
popupClassName?: string
|
||||
notClearable?: boolean
|
||||
triggerFullWidth?: boolean
|
||||
/** Show timezone label inline with the time picker */
|
||||
showTimezone?: boolean
|
||||
/** Placement of the popup relative to the trigger */
|
||||
placement?: 'bottom-start' | 'bottom-end' | 'bottom'
|
||||
}
|
||||
|
||||
export type TimePickerFooterProps = {
|
||||
|
|
@ -87,12 +82,14 @@ export type CalendarProps = {
|
|||
selectedDate: Dayjs | undefined
|
||||
onDateClick: (date: Dayjs) => void
|
||||
wrapperClassName?: string
|
||||
getIsDateDisabled?: (date: Dayjs) => boolean
|
||||
}
|
||||
|
||||
export type CalendarItemProps = {
|
||||
day: Day
|
||||
selectedDate: Dayjs | undefined
|
||||
onClick: (date: Dayjs) => void
|
||||
isDisabled: boolean
|
||||
}
|
||||
|
||||
export type TimeOptionsProps = {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,3 @@
|
|||
<svg width="8" height="14" viewBox="0 0 8 14" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M8 14C8 11.7909 6.20914 10 4 10C1.79086 10 0 11.7909 0 14V0C8.05332e-08 2.20914 1.79086 4 4 4C6.20914 4 8 2.20914 8 0V14Z" fill="#C8CEDA" fill-opacity="1"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 267 B |
|
|
@ -0,0 +1,27 @@
|
|||
{
|
||||
"icon": {
|
||||
"type": "element",
|
||||
"isRootNode": true,
|
||||
"name": "svg",
|
||||
"attributes": {
|
||||
"width": "8",
|
||||
"height": "14",
|
||||
"viewBox": "0 0 8 14",
|
||||
"fill": "none",
|
||||
"xmlns": "http://www.w3.org/2000/svg"
|
||||
},
|
||||
"children": [
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M8 14C8 11.7909 6.20914 10 4 10C1.79086 10 0 11.7909 0 14V0C8.05332e-08 2.20914 1.79086 4 4 4C6.20914 4 8 2.20914 8 0V14Z",
|
||||
"fill": "currentColor",
|
||||
"fill-opacity": "1"
|
||||
},
|
||||
"children": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"name": "HourglassShape"
|
||||
}
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
// GENERATE BY script
|
||||
// DON NOT EDIT IT MANUALLY
|
||||
|
||||
import * as React from 'react'
|
||||
import data from './HourglassShape.json'
|
||||
import IconBase from '@/app/components/base/icons/IconBase'
|
||||
import type { IconData } from '@/app/components/base/icons/IconBase'
|
||||
|
||||
const Icon = (
|
||||
{
|
||||
ref,
|
||||
...props
|
||||
}: React.SVGProps<SVGSVGElement> & {
|
||||
ref?: React.RefObject<React.RefObject<HTMLOrSVGElement>>;
|
||||
},
|
||||
) => <IconBase {...props} ref={ref} data={data as IconData} />
|
||||
|
||||
Icon.displayName = 'HourglassShape'
|
||||
|
||||
export default Icon
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
export { default as AnthropicText } from './AnthropicText'
|
||||
export { default as Generator } from './Generator'
|
||||
export { default as Group } from './Group'
|
||||
export { default as HourglassShape } from './HourglassShape'
|
||||
export { default as Mcp } from './Mcp'
|
||||
export { default as NoToolPlaceholder } from './NoToolPlaceholder'
|
||||
export { default as Openai } from './Openai'
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ export type Item = {
|
|||
export type ISelectProps = {
|
||||
className?: string
|
||||
wrapperClassName?: string
|
||||
renderTrigger?: (value: Item | null) => React.JSX.Element | null
|
||||
renderTrigger?: (value: Item | null, isOpen: boolean) => React.JSX.Element | null
|
||||
items?: Item[]
|
||||
defaultValue?: number | string
|
||||
disabled?: boolean
|
||||
|
|
@ -222,7 +222,7 @@ const SimpleSelect: FC<ISelectProps> = ({
|
|||
>
|
||||
{({ open }) => (
|
||||
<div className={classNames('group/simple-select relative h-9', wrapperClassName)}>
|
||||
{renderTrigger && <ListboxButton className='w-full'>{renderTrigger(selectedItem)}</ListboxButton>}
|
||||
{renderTrigger && <ListboxButton className='w-full'>{renderTrigger(selectedItem, open)}</ListboxButton>}
|
||||
{!renderTrigger && (
|
||||
<ListboxButton onClick={() => {
|
||||
onOpenChange?.(open)
|
||||
|
|
|
|||
|
|
@ -74,7 +74,8 @@ Chat applications support session persistence, allowing previous chat history to
|
|||
If set to `false`, can achieve async title generation by calling the conversation rename API and setting `auto_generate` to `true`.
|
||||
</Property>
|
||||
<Property name='workflow_id' type='string' key='workflow_id'>
|
||||
(Optional) Workflow ID to specify a specific version, if not provided, uses the default published version.
|
||||
(Optional) Workflow ID to specify a specific version, if not provided, uses the default published version.<br/>
|
||||
How to obtain: In the version history interface, click the copy icon on the right side of each version entry to copy the complete workflow ID.
|
||||
</Property>
|
||||
<Property name='trace_id' type='string' key='trace_id'>
|
||||
(Optional) Trace ID. Used for integration with existing business trace components to achieve end-to-end distributed tracing. If not provided, the system will automatically generate a trace_id. Supports the following three ways to pass, in order of priority:<br/>
|
||||
|
|
|
|||
|
|
@ -74,7 +74,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||
`false`に設定すると、会話のリネームAPIを呼び出し、`auto_generate`を`true`に設定することで非同期タイトル生成を実現できます。
|
||||
</Property>
|
||||
<Property name='workflow_id' type='string' key='workflow_id'>
|
||||
(オプション)ワークフローID、特定のバージョンを指定するために使用、提供されない場合はデフォルトの公開バージョンを使用。
|
||||
(オプション)ワークフローID、特定のバージョンを指定するために使用、提供されない場合はデフォルトの公開バージョンを使用。<br/>
|
||||
取得方法:バージョン履歴インターフェースで、各バージョンエントリの右側にあるコピーアイコンをクリックすると、完全なワークフローIDをコピーできます。
|
||||
</Property>
|
||||
<Property name='trace_id' type='string' key='trace_id'>
|
||||
(オプション)トレースID。既存の業務システムのトレースコンポーネントと連携し、エンドツーエンドの分散トレーシングを実現するために使用します。指定がない場合、システムが自動的に trace_id を生成します。以下の3つの方法で渡すことができ、優先順位は次のとおりです:<br/>
|
||||
|
|
|
|||
|
|
@ -72,7 +72,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
|
|||
(选填)自动生成标题,默认 `true`。 若设置为 `false`,则可通过调用会话重命名接口并设置 `auto_generate` 为 `true` 实现异步生成标题。
|
||||
</Property>
|
||||
<Property name='workflow_id' type='string' key='workflow_id'>
|
||||
(选填)工作流ID,用于指定特定版本,如果不提供则使用默认的已发布版本。
|
||||
(选填)工作流ID,用于指定特定版本,如果不提供则使用默认的已发布版本。<br/>
|
||||
获取方式:在版本历史界面,点击每个版本条目右侧的复制图标即可复制完整的工作流 ID。
|
||||
</Property>
|
||||
<Property name='trace_id' type='string' key='trace_id'>
|
||||
(选填)链路追踪ID。适用于与业务系统已有的trace组件打通,实现端到端分布式追踪等场景。如果未指定,系统会自动生成<code>trace_id</code>。支持以下三种方式传递,具体优先级依次为:<br/>
|
||||
|
|
|
|||
|
|
@ -344,7 +344,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||
### パス
|
||||
- `workflow_id` (string) 必須 特定バージョンのワークフローを指定するためのワークフローID
|
||||
|
||||
取得方法:バージョン履歴で特定バージョンのワークフローIDを照会できます。
|
||||
取得方法:バージョン履歴インターフェースで、各バージョンエントリの右側にあるコピーアイコンをクリックすると、完全なワークフローIDをコピーできます。
|
||||
|
||||
### リクエストボディ
|
||||
- `inputs` (object) 必須
|
||||
|
|
|
|||
|
|
@ -334,7 +334,7 @@ Workflow 应用无会话支持,适合用于翻译/文章写作/总结 AI 等
|
|||
### Path
|
||||
- `workflow_id` (string) Required 工作流ID,用于指定特定版本的工作流
|
||||
|
||||
获取方式:可以在版本历史中查询特定版本的工作流ID。
|
||||
获取方式:在版本历史界面,点击每个版本条目右侧的复制图标即可复制完整的工作流 ID。
|
||||
|
||||
### Request Body
|
||||
- `inputs` (object) Required
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ const ModelList: FC<ModelListProps> = ({
|
|||
{
|
||||
models.map(model => (
|
||||
<ModelListItem
|
||||
key={`${model.model}-${model.fetch_from}`}
|
||||
key={`${model.model}-${model.model_type}-${model.fetch_from}`}
|
||||
{...{
|
||||
model,
|
||||
provider,
|
||||
|
|
|
|||
|
|
@ -856,6 +856,18 @@
|
|||
color: var(--color-prettylights-syntax-comment);
|
||||
}
|
||||
|
||||
.markdown-body .katex {
|
||||
/* Allow long inline formulas to wrap instead of overflowing */
|
||||
white-space: normal !important;
|
||||
overflow-wrap: break-word; /* better cross-browser support */
|
||||
word-break: break-word; /* non-standard fallback for older WebKit/Blink */
|
||||
}
|
||||
|
||||
.markdown-body .katex-display {
|
||||
/* Fallback for very long display equations */
|
||||
overflow-x: auto;
|
||||
}
|
||||
|
||||
.markdown-body .pl-c1,
|
||||
.markdown-body .pl-s .pl-v {
|
||||
color: var(--color-prettylights-syntax-constant);
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import MaintenanceNotice from '@/app/components/header/maintenance-notice'
|
|||
import { noop } from 'lodash-es'
|
||||
import { setZendeskConversationFields } from '@/app/components/base/zendesk/utils'
|
||||
import { ZENDESK_FIELD_IDS } from '@/config'
|
||||
import { useGlobalPublicStore } from './global-public-context'
|
||||
|
||||
export type AppContextValue = {
|
||||
userProfile: UserProfileResponse
|
||||
|
|
@ -77,6 +78,7 @@ export type AppContextProviderProps = {
|
|||
}
|
||||
|
||||
export const AppContextProvider: FC<AppContextProviderProps> = ({ children }) => {
|
||||
const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
|
||||
const { data: userProfileResponse, mutate: mutateUserProfile, error: userProfileError } = useSWR({ url: '/account/profile', params: {} }, fetchUserProfile)
|
||||
const { data: currentWorkspaceResponse, mutate: mutateCurrentWorkspace, isLoading: isLoadingCurrentWorkspace } = useSWR({ url: '/workspaces/current', params: {} }, fetchCurrentWorkspace)
|
||||
|
||||
|
|
@ -92,10 +94,12 @@ export const AppContextProvider: FC<AppContextProviderProps> = ({ children }) =>
|
|||
try {
|
||||
const result = await userProfileResponse.json()
|
||||
setUserProfile(result)
|
||||
const current_version = userProfileResponse.headers.get('x-version')
|
||||
const current_env = process.env.NODE_ENV === 'development' ? 'DEVELOPMENT' : userProfileResponse.headers.get('x-env')
|
||||
const versionData = await fetchLangGeniusVersion({ url: '/version', params: { current_version } })
|
||||
setLangGeniusVersionInfo({ ...versionData, current_version, latest_version: versionData.version, current_env })
|
||||
if (!systemFeatures.branding.enabled) {
|
||||
const current_version = userProfileResponse.headers.get('x-version')
|
||||
const current_env = process.env.NODE_ENV === 'development' ? 'DEVELOPMENT' : userProfileResponse.headers.get('x-env')
|
||||
const versionData = await fetchLangGeniusVersion({ url: '/version', params: { current_version } })
|
||||
setLangGeniusVersionInfo({ ...versionData, current_version, latest_version: versionData.version, current_env })
|
||||
}
|
||||
}
|
||||
catch (error) {
|
||||
console.error('Failed to update user profile:', error)
|
||||
|
|
|
|||
|
|
@ -66,6 +66,8 @@ const translation = {
|
|||
quarterToDate: 'Quartal bis heute',
|
||||
yearToDate: 'Jahr bis heute',
|
||||
allTime: 'Gesamte Zeit',
|
||||
last30days: 'Letzte 30 Tage',
|
||||
custom: 'Benutzerdefiniert',
|
||||
},
|
||||
annotation: {
|
||||
all: 'Alle',
|
||||
|
|
|
|||
|
|
@ -60,6 +60,7 @@ const translation = {
|
|||
period: {
|
||||
today: 'Today',
|
||||
last7days: 'Last 7 Days',
|
||||
last30days: 'Last 30 Days',
|
||||
last4weeks: 'Last 4 weeks',
|
||||
last3months: 'Last 3 months',
|
||||
last12months: 'Last 12 months',
|
||||
|
|
@ -67,6 +68,7 @@ const translation = {
|
|||
quarterToDate: 'Quarter to date',
|
||||
yearToDate: 'Year to date',
|
||||
allTime: 'All time',
|
||||
custom: 'Custom',
|
||||
},
|
||||
annotation: {
|
||||
all: 'All',
|
||||
|
|
|
|||
|
|
@ -65,6 +65,8 @@ const translation = {
|
|||
quarterToDate: 'Trimestre hasta la fecha',
|
||||
yearToDate: 'Año hasta la fecha',
|
||||
allTime: 'Todo el tiempo',
|
||||
custom: 'Personalizado',
|
||||
last30days: 'Últimos 30 días',
|
||||
},
|
||||
annotation: {
|
||||
all: 'Todos',
|
||||
|
|
|
|||
|
|
@ -65,6 +65,8 @@ const translation = {
|
|||
quarterToDate: 'از ابتدای فصل تاکنون',
|
||||
yearToDate: 'از ابتدای سال تاکنون',
|
||||
allTime: 'همه زمانها',
|
||||
last30days: '۳۰ روز گذشته',
|
||||
custom: 'سفارشی',
|
||||
},
|
||||
annotation: {
|
||||
all: 'همه',
|
||||
|
|
|
|||
|
|
@ -65,6 +65,8 @@ const translation = {
|
|||
quarterToDate: 'Trimestre à ce jour',
|
||||
yearToDate: 'Année à ce jour',
|
||||
allTime: 'Tout le temps',
|
||||
custom: 'Personnalisé',
|
||||
last30days: 'Derniers 30 jours',
|
||||
},
|
||||
annotation: {
|
||||
all: 'Tous',
|
||||
|
|
|
|||
|
|
@ -67,6 +67,8 @@ const translation = {
|
|||
quarterToDate: 'तिमाही तक तिथि',
|
||||
yearToDate: 'वर्ष तक तिथि',
|
||||
allTime: 'सभी समय',
|
||||
last30days: 'पिछले 30 दिन',
|
||||
custom: 'कस्टम',
|
||||
},
|
||||
annotation: {
|
||||
all: 'सभी',
|
||||
|
|
|
|||
|
|
@ -60,6 +60,8 @@ const translation = {
|
|||
yearToDate: 'Tahun hingga saat ini',
|
||||
allTime: 'Sepanjang masa',
|
||||
last12months: '12 bulan terakhir',
|
||||
custom: 'Kustom',
|
||||
last30days: '30 Hari Terakhir',
|
||||
},
|
||||
annotation: {
|
||||
all: 'Semua',
|
||||
|
|
|
|||
|
|
@ -69,6 +69,8 @@ const translation = {
|
|||
quarterToDate: 'Trimestre corrente',
|
||||
yearToDate: 'Anno corrente',
|
||||
allTime: 'Tutto il tempo',
|
||||
custom: 'Personalizzato',
|
||||
last30days: 'Ultimi 30 giorni',
|
||||
},
|
||||
annotation: {
|
||||
all: 'Tutti',
|
||||
|
|
|
|||
|
|
@ -60,6 +60,7 @@ const translation = {
|
|||
period: {
|
||||
today: '今日',
|
||||
last7days: '過去 7 日間',
|
||||
last30days: '過去 30 日間',
|
||||
last4weeks: '過去 4 週間',
|
||||
last3months: '過去 3 ヶ月',
|
||||
last12months: '過去 12 ヶ月',
|
||||
|
|
@ -67,6 +68,7 @@ const translation = {
|
|||
quarterToDate: '四半期初から今日まで',
|
||||
yearToDate: '年初から今日まで',
|
||||
allTime: 'すべての期間',
|
||||
custom: 'カスタム',
|
||||
},
|
||||
annotation: {
|
||||
all: 'すべて',
|
||||
|
|
|
|||
|
|
@ -66,6 +66,8 @@ const translation = {
|
|||
quarterToDate: '분기 초부터 오늘까지',
|
||||
yearToDate: '연 초부터 오늘까지',
|
||||
allTime: '모든 기간',
|
||||
last30days: '최근 30일',
|
||||
custom: '사용자 정의',
|
||||
},
|
||||
annotation: {
|
||||
all: '모두',
|
||||
|
|
|
|||
|
|
@ -69,6 +69,8 @@ const translation = {
|
|||
quarterToDate: 'Od początku kwartału',
|
||||
yearToDate: 'Od początku roku',
|
||||
allTime: 'Cały czas',
|
||||
custom: 'Niestandardowy',
|
||||
last30days: 'Ostatnie 30 dni',
|
||||
},
|
||||
annotation: {
|
||||
all: 'Wszystkie',
|
||||
|
|
|
|||
|
|
@ -65,6 +65,8 @@ const translation = {
|
|||
quarterToDate: 'Trimestre até hoje',
|
||||
yearToDate: 'Ano até hoje',
|
||||
allTime: 'Todo o tempo',
|
||||
custom: 'Personalizado',
|
||||
last30days: 'Últimos 30 Dias',
|
||||
},
|
||||
annotation: {
|
||||
all: 'Tudo',
|
||||
|
|
|
|||
|
|
@ -65,6 +65,8 @@ const translation = {
|
|||
quarterToDate: 'Trimestrul curent',
|
||||
yearToDate: 'Anul curent',
|
||||
allTime: 'Tot timpul',
|
||||
custom: 'Personalizat',
|
||||
last30days: 'Ultimele 30 de zile',
|
||||
},
|
||||
annotation: {
|
||||
all: 'Toate',
|
||||
|
|
|
|||
|
|
@ -65,6 +65,8 @@ const translation = {
|
|||
quarterToDate: 'С начала квартала',
|
||||
yearToDate: 'С начала года',
|
||||
allTime: 'Все время',
|
||||
last30days: 'Последние 30 дней',
|
||||
custom: 'Кастомный',
|
||||
},
|
||||
annotation: {
|
||||
all: 'Все',
|
||||
|
|
|
|||
|
|
@ -65,6 +65,8 @@ const translation = {
|
|||
quarterToDate: 'Četrtletje do danes',
|
||||
yearToDate: 'Leto do danes',
|
||||
allTime: 'Vse obdobje',
|
||||
last30days: 'Zadnjih 30 dni',
|
||||
custom: 'Po meri',
|
||||
},
|
||||
annotation: {
|
||||
all: 'Vse',
|
||||
|
|
|
|||
|
|
@ -65,6 +65,8 @@ const translation = {
|
|||
quarterToDate: 'ไตรมาสจนถึงปัจจุบัน',
|
||||
yearToDate: 'ปีจนถึงปัจจุบัน',
|
||||
allTime: 'ตลอดเวลา',
|
||||
last30days: '30 วันที่ผ่านมา',
|
||||
custom: 'กำหนดเอง',
|
||||
},
|
||||
annotation: {
|
||||
all: 'ทั้งหมด',
|
||||
|
|
|
|||
|
|
@ -65,6 +65,8 @@ const translation = {
|
|||
quarterToDate: 'Çeyrek Başlangıcından İtibaren',
|
||||
yearToDate: 'Yıl Başlangıcından İtibaren',
|
||||
allTime: 'Tüm Zamanlar',
|
||||
custom: 'Özel',
|
||||
last30days: 'Son 30 Gün',
|
||||
},
|
||||
annotation: {
|
||||
all: 'Hepsi',
|
||||
|
|
|
|||
|
|
@ -65,6 +65,8 @@ const translation = {
|
|||
quarterToDate: 'Квартал до сьогодні',
|
||||
yearToDate: 'Рік до сьогодні',
|
||||
allTime: 'За весь час',
|
||||
last30days: 'Останні 30 днів',
|
||||
custom: 'Користувацький',
|
||||
},
|
||||
annotation: {
|
||||
all: 'Всі',
|
||||
|
|
|
|||
|
|
@ -65,6 +65,8 @@ const translation = {
|
|||
quarterToDate: 'Quý hiện tại',
|
||||
yearToDate: 'Năm hiện tại',
|
||||
allTime: 'Tất cả thời gian',
|
||||
custom: 'Tùy chỉnh',
|
||||
last30days: '30 Ngày Qua',
|
||||
},
|
||||
annotation: {
|
||||
all: 'Tất cả',
|
||||
|
|
|
|||
|
|
@ -60,6 +60,7 @@ const translation = {
|
|||
period: {
|
||||
today: '今天',
|
||||
last7days: '过去 7 天',
|
||||
last30days: '过去 30 天',
|
||||
last4weeks: '过去 4 周',
|
||||
last3months: '过去 3 月',
|
||||
last12months: '过去 12 月',
|
||||
|
|
@ -67,6 +68,7 @@ const translation = {
|
|||
quarterToDate: '本季度至今',
|
||||
yearToDate: '本年至今',
|
||||
allTime: '所有时间',
|
||||
custom: '自定义',
|
||||
},
|
||||
annotation: {
|
||||
all: '全部',
|
||||
|
|
|
|||
|
|
@ -65,6 +65,8 @@ const translation = {
|
|||
quarterToDate: '本季度至今',
|
||||
yearToDate: '本年至今',
|
||||
allTime: '所有時間',
|
||||
last30days: '過去30天',
|
||||
custom: '自訂',
|
||||
},
|
||||
annotation: {
|
||||
all: '全部',
|
||||
|
|
|
|||
|
|
@ -1,3 +1,50 @@
|
|||
import type { Locale } from '@/i18n-config'
|
||||
import type { Dayjs } from 'dayjs'
|
||||
import 'dayjs/locale/de'
|
||||
import 'dayjs/locale/es'
|
||||
import 'dayjs/locale/fa'
|
||||
import 'dayjs/locale/fr'
|
||||
import 'dayjs/locale/hi'
|
||||
import 'dayjs/locale/id'
|
||||
import 'dayjs/locale/it'
|
||||
import 'dayjs/locale/ja'
|
||||
import 'dayjs/locale/ko'
|
||||
import 'dayjs/locale/pl'
|
||||
import 'dayjs/locale/pt-br'
|
||||
import 'dayjs/locale/ro'
|
||||
import 'dayjs/locale/ru'
|
||||
import 'dayjs/locale/sl'
|
||||
import 'dayjs/locale/th'
|
||||
import 'dayjs/locale/tr'
|
||||
import 'dayjs/locale/uk'
|
||||
import 'dayjs/locale/vi'
|
||||
import 'dayjs/locale/zh-cn'
|
||||
import 'dayjs/locale/zh-tw'
|
||||
|
||||
const localeMap: Record<Locale, string> = {
|
||||
'en-US': 'en',
|
||||
'zh-Hans': 'zh-cn',
|
||||
'zh-Hant': 'zh-tw',
|
||||
'pt-BR': 'pt-br',
|
||||
'es-ES': 'es',
|
||||
'fr-FR': 'fr',
|
||||
'de-DE': 'de',
|
||||
'ja-JP': 'ja',
|
||||
'ko-KR': 'ko',
|
||||
'ru-RU': 'ru',
|
||||
'it-IT': 'it',
|
||||
'th-TH': 'th',
|
||||
'id-ID': 'id',
|
||||
'uk-UA': 'uk',
|
||||
'vi-VN': 'vi',
|
||||
'ro-RO': 'ro',
|
||||
'pl-PL': 'pl',
|
||||
'hi-IN': 'hi',
|
||||
'tr-TR': 'tr',
|
||||
'fa-IR': 'fa',
|
||||
'sl-SI': 'sl',
|
||||
}
|
||||
|
||||
/**
|
||||
* Formats a number with comma separators.
|
||||
* @example formatNumber(1234567) will return '1,234,567'
|
||||
|
|
@ -90,3 +137,7 @@ export const formatNumberAbbreviated = (num: number) => {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const formatToLocalTime = (time: Dayjs, local: string, format: string) => {
|
||||
return time.locale(localeMap[local] ?? 'en').format(format)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -45,5 +45,118 @@ describe('get-icon', () => {
|
|||
const result = getIconFromMarketPlace(pluginId)
|
||||
expect(result).toBe(`${MARKETPLACE_API_PREFIX}/plugins/${pluginId}/icon`)
|
||||
})
|
||||
|
||||
/**
|
||||
* Security tests: Path traversal attempts
|
||||
* These tests document current behavior and potential security concerns
|
||||
* Note: Current implementation does not sanitize path traversal sequences
|
||||
*/
|
||||
test('handles path traversal attempts', () => {
|
||||
const pluginId = '../../../etc/passwd'
|
||||
const result = getIconFromMarketPlace(pluginId)
|
||||
// Current implementation includes path traversal sequences in URL
|
||||
// This is a potential security concern that should be addressed
|
||||
expect(result).toContain('../')
|
||||
expect(result).toContain(pluginId)
|
||||
})
|
||||
|
||||
test('handles multiple path traversal attempts', () => {
|
||||
const pluginId = '../../../../etc/passwd'
|
||||
const result = getIconFromMarketPlace(pluginId)
|
||||
// Current implementation includes path traversal sequences in URL
|
||||
expect(result).toContain('../')
|
||||
expect(result).toContain(pluginId)
|
||||
})
|
||||
|
||||
test('passes through URL-encoded path traversal sequences', () => {
|
||||
const pluginId = '..%2F..%2Fetc%2Fpasswd'
|
||||
const result = getIconFromMarketPlace(pluginId)
|
||||
expect(result).toContain(pluginId)
|
||||
})
|
||||
|
||||
/**
|
||||
* Security tests: Null and undefined handling
|
||||
* These tests document current behavior with invalid input types
|
||||
* Note: Current implementation converts null/undefined to strings instead of throwing
|
||||
*/
|
||||
test('handles null plugin ID', () => {
|
||||
// Current implementation converts null to string "null"
|
||||
const result = getIconFromMarketPlace(null as any)
|
||||
expect(result).toContain('null')
|
||||
// This is a potential issue - should validate input type
|
||||
})
|
||||
|
||||
test('handles undefined plugin ID', () => {
|
||||
// Current implementation converts undefined to string "undefined"
|
||||
const result = getIconFromMarketPlace(undefined as any)
|
||||
expect(result).toContain('undefined')
|
||||
// This is a potential issue - should validate input type
|
||||
})
|
||||
|
||||
/**
|
||||
* Security tests: URL-sensitive characters
|
||||
* These tests verify that URL-sensitive characters are handled appropriately
|
||||
*/
|
||||
test('does not encode URL-sensitive characters', () => {
|
||||
const pluginId = 'plugin/with?special=chars#hash'
|
||||
const result = getIconFromMarketPlace(pluginId)
|
||||
// Note: Current implementation doesn't encode, but test documents the behavior
|
||||
expect(result).toContain(pluginId)
|
||||
expect(result).toContain('?')
|
||||
expect(result).toContain('#')
|
||||
expect(result).toContain('=')
|
||||
})
|
||||
|
||||
test('handles URL characters like & and %', () => {
|
||||
const pluginId = 'plugin&with%encoding'
|
||||
const result = getIconFromMarketPlace(pluginId)
|
||||
expect(result).toContain(pluginId)
|
||||
})
|
||||
|
||||
/**
|
||||
* Edge case tests: Extreme inputs
|
||||
* These tests verify behavior with unusual but valid inputs
|
||||
*/
|
||||
test('handles very long plugin ID', () => {
|
||||
const pluginId = 'a'.repeat(10000)
|
||||
const result = getIconFromMarketPlace(pluginId)
|
||||
expect(result).toContain(pluginId)
|
||||
expect(result.length).toBeGreaterThan(10000)
|
||||
})
|
||||
|
||||
test('handles Unicode characters', () => {
|
||||
const pluginId = '插件-🚀-测试-日本語'
|
||||
const result = getIconFromMarketPlace(pluginId)
|
||||
expect(result).toContain(pluginId)
|
||||
})
|
||||
|
||||
test('handles control characters', () => {
|
||||
const pluginId = 'plugin\nwith\ttabs\r\nand\0null'
|
||||
const result = getIconFromMarketPlace(pluginId)
|
||||
expect(result).toContain(pluginId)
|
||||
})
|
||||
|
||||
/**
|
||||
* Security tests: XSS attempts
|
||||
* These tests verify that XSS attempts are handled appropriately
|
||||
*/
|
||||
test('handles XSS attempts with script tags', () => {
|
||||
const pluginId = '<script>alert("xss")</script>'
|
||||
const result = getIconFromMarketPlace(pluginId)
|
||||
expect(result).toContain(pluginId)
|
||||
// Note: Current implementation doesn't sanitize, but test documents the behavior
|
||||
})
|
||||
|
||||
test('handles XSS attempts with event handlers', () => {
|
||||
const pluginId = 'plugin"onerror="alert(1)"'
|
||||
const result = getIconFromMarketPlace(pluginId)
|
||||
expect(result).toContain(pluginId)
|
||||
})
|
||||
|
||||
test('handles XSS attempts with encoded script tags', () => {
|
||||
const pluginId = '%3Cscript%3Ealert%28%22xss%22%29%3C%2Fscript%3E'
|
||||
const result = getIconFromMarketPlace(pluginId)
|
||||
expect(result).toContain(pluginId)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
|||
|
|
@ -87,7 +87,8 @@ describe('time', () => {
|
|||
test('works with timestamps', () => {
|
||||
const date = 1705276800000 // 2024-01-15 00:00:00 UTC
|
||||
const result = formatTime({ date, dateFormat: 'YYYY-MM-DD' })
|
||||
expect(result).toContain('2024-01-1') // Account for timezone differences
|
||||
// Account for timezone differences: UTC-5 to UTC+8 can result in 2024-01-14 or 2024-01-15
|
||||
expect(result).toMatch(/^2024-01-(14|15)$/)
|
||||
})
|
||||
|
||||
test('handles ISO 8601 format', () => {
|
||||
|
|
|
|||
Loading…
Reference in New Issue