From 1c1f124891c8e07f4c18fd1f9239226547bf32ae Mon Sep 17 00:00:00 2001
From: QuantumGhost
Date: Wed, 26 Nov 2025 19:59:34 +0800
Subject: [PATCH 001/431] Enhanced GraphEngine Pause Handling (#28196)
This commit:
1. Convert `pause_reason` to `pause_reasons` in `GraphExecution` and relevant classes. Change the field from a scalar value to a list that can contain multiple `PauseReason` objects, ensuring all pause events are properly captured.
2. Introduce a new `WorkflowPauseReason` model to record reasons associated with a specific `WorkflowPause`.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: -LAN-
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
---
api/.importlinter | 1 +
.../app/layers/pause_state_persist_layer.py | 1 +
api/core/workflow/entities/__init__.py | 6 --
api/core/workflow/entities/pause_reason.py | 47 ++++--------
.../graph_engine/domain/graph_execution.py | 12 ++--
.../event_management/event_manager.py | 8 ++-
.../workflow/graph_engine/graph_engine.py | 8 +--
api/core/workflow/graph_events/graph.py | 3 +-
.../nodes/human_input/human_input_node.py | 3 +-
.../workflow/runtime/graph_runtime_state.py | 8 ++-
...b7a422_add_workflow_pause_reasons_table.py | 41 +++++++++++
api/models/workflow.py | 66 +++++++++++++++++
.../api_workflow_run_repository.py | 4 +-
.../entities/workflow_pause.py | 15 ++++
.../sqlalchemy_api_workflow_run_repository.py | 71 +++++++++++++------
api/services/workflow_service.py | 3 +-
.../layers/test_pause_state_persist_layer.py | 13 ++--
.../test_workflow_pause_integration.py | 25 +++++--
.../layers/test_pause_state_persist_layer.py | 16 +++--
.../entities/test_private_workflow_pause.py | 52 +++-----------
.../workflow/graph/test_graph_validation.py | 3 +-
.../graph_engine/test_command_system.py | 5 +-
..._sqlalchemy_api_workflow_run_repository.py | 21 +++---
.../test_workflow_run_service_pause.py | 28 +-------
24 files changed, 275 insertions(+), 185 deletions(-)
create mode 100644 api/migrations/versions/2025_11_18_1859-7bb281b7a422_add_workflow_pause_reasons_table.py
rename api/{core/workflow => repositories}/entities/workflow_pause.py (77%)
diff --git a/api/.importlinter b/api/.importlinter
index 98fe5f50bb..24ece72b30 100644
--- a/api/.importlinter
+++ b/api/.importlinter
@@ -16,6 +16,7 @@ layers =
graph
nodes
node_events
+ runtime
entities
containers =
core.workflow
diff --git a/api/core/app/layers/pause_state_persist_layer.py b/api/core/app/layers/pause_state_persist_layer.py
index 412eb98dd4..61a3e1baca 100644
--- a/api/core/app/layers/pause_state_persist_layer.py
+++ b/api/core/app/layers/pause_state_persist_layer.py
@@ -118,6 +118,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
workflow_run_id=workflow_run_id,
state_owner_user_id=self._state_owner_user_id,
state=state.dumps(),
+ pause_reasons=event.reasons,
)
def on_graph_end(self, error: Exception | None) -> None:
diff --git a/api/core/workflow/entities/__init__.py b/api/core/workflow/entities/__init__.py
index f4ce9052e0..be70e467a0 100644
--- a/api/core/workflow/entities/__init__.py
+++ b/api/core/workflow/entities/__init__.py
@@ -1,17 +1,11 @@
-from ..runtime.graph_runtime_state import GraphRuntimeState
-from ..runtime.variable_pool import VariablePool
from .agent import AgentNodeStrategyInit
from .graph_init_params import GraphInitParams
from .workflow_execution import WorkflowExecution
from .workflow_node_execution import WorkflowNodeExecution
-from .workflow_pause import WorkflowPauseEntity
__all__ = [
"AgentNodeStrategyInit",
"GraphInitParams",
- "GraphRuntimeState",
- "VariablePool",
"WorkflowExecution",
"WorkflowNodeExecution",
- "WorkflowPauseEntity",
]
diff --git a/api/core/workflow/entities/pause_reason.py b/api/core/workflow/entities/pause_reason.py
index 16ad3d639d..c6655b7eab 100644
--- a/api/core/workflow/entities/pause_reason.py
+++ b/api/core/workflow/entities/pause_reason.py
@@ -1,49 +1,26 @@
from enum import StrEnum, auto
-from typing import Annotated, Any, ClassVar, TypeAlias
+from typing import Annotated, Literal, TypeAlias
-from pydantic import BaseModel, Discriminator, Tag
+from pydantic import BaseModel, Field
-class _PauseReasonType(StrEnum):
+class PauseReasonType(StrEnum):
HUMAN_INPUT_REQUIRED = auto()
SCHEDULED_PAUSE = auto()
-class _PauseReasonBase(BaseModel):
- TYPE: ClassVar[_PauseReasonType]
+class HumanInputRequired(BaseModel):
+ TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED
+
+ form_id: str
+ # The identifier of the human input node causing the pause.
+ node_id: str
-class HumanInputRequired(_PauseReasonBase):
- TYPE = _PauseReasonType.HUMAN_INPUT_REQUIRED
-
-
-class SchedulingPause(_PauseReasonBase):
- TYPE = _PauseReasonType.SCHEDULED_PAUSE
+class SchedulingPause(BaseModel):
+ TYPE: Literal[PauseReasonType.SCHEDULED_PAUSE] = PauseReasonType.SCHEDULED_PAUSE
message: str
-def _get_pause_reason_discriminator(v: Any) -> _PauseReasonType | None:
- if isinstance(v, _PauseReasonBase):
- return v.TYPE
- elif isinstance(v, dict):
- reason_type_str = v.get("TYPE")
- if reason_type_str is None:
- return None
- try:
- reason_type = _PauseReasonType(reason_type_str)
- except ValueError:
- return None
- return reason_type
- else:
- # return None if the discriminator value isn't found
- return None
-
-
-PauseReason: TypeAlias = Annotated[
- (
- Annotated[HumanInputRequired, Tag(_PauseReasonType.HUMAN_INPUT_REQUIRED)]
- | Annotated[SchedulingPause, Tag(_PauseReasonType.SCHEDULED_PAUSE)]
- ),
- Discriminator(_get_pause_reason_discriminator),
-]
+PauseReason: TypeAlias = Annotated[HumanInputRequired | SchedulingPause, Field(discriminator="TYPE")]
diff --git a/api/core/workflow/graph_engine/domain/graph_execution.py b/api/core/workflow/graph_engine/domain/graph_execution.py
index 3d587d6691..9ca607458f 100644
--- a/api/core/workflow/graph_engine/domain/graph_execution.py
+++ b/api/core/workflow/graph_engine/domain/graph_execution.py
@@ -42,7 +42,7 @@ class GraphExecutionState(BaseModel):
completed: bool = Field(default=False)
aborted: bool = Field(default=False)
paused: bool = Field(default=False)
- pause_reason: PauseReason | None = Field(default=None)
+ pause_reasons: list[PauseReason] = Field(default_factory=list)
error: GraphExecutionErrorState | None = Field(default=None)
exceptions_count: int = Field(default=0)
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])
@@ -107,7 +107,7 @@ class GraphExecution:
completed: bool = False
aborted: bool = False
paused: bool = False
- pause_reason: PauseReason | None = None
+ pause_reasons: list[PauseReason] = field(default_factory=list)
error: Exception | None = None
node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
exceptions_count: int = 0
@@ -137,10 +137,8 @@ class GraphExecution:
raise RuntimeError("Cannot pause execution that has completed")
if self.aborted:
raise RuntimeError("Cannot pause execution that has been aborted")
- if self.paused:
- return
self.paused = True
- self.pause_reason = reason
+ self.pause_reasons.append(reason)
def fail(self, error: Exception) -> None:
"""Mark the graph execution as failed."""
@@ -195,7 +193,7 @@ class GraphExecution:
completed=self.completed,
aborted=self.aborted,
paused=self.paused,
- pause_reason=self.pause_reason,
+ pause_reasons=self.pause_reasons,
error=_serialize_error(self.error),
exceptions_count=self.exceptions_count,
node_executions=node_states,
@@ -221,7 +219,7 @@ class GraphExecution:
self.completed = state.completed
self.aborted = state.aborted
self.paused = state.paused
- self.pause_reason = state.pause_reason
+ self.pause_reasons = state.pause_reasons
self.error = _deserialize_error(state.error)
self.exceptions_count = state.exceptions_count
self.node_executions = {
diff --git a/api/core/workflow/graph_engine/event_management/event_manager.py b/api/core/workflow/graph_engine/event_management/event_manager.py
index 689cf53cf0..71043b9a43 100644
--- a/api/core/workflow/graph_engine/event_management/event_manager.py
+++ b/api/core/workflow/graph_engine/event_management/event_manager.py
@@ -110,7 +110,13 @@ class EventManager:
"""
with self._lock.write_lock():
self._events.append(event)
- self._notify_layers(event)
+
+ # NOTE: `_notify_layers` is intentionally called outside the critical section
+ # to minimize lock contention and avoid blocking other readers or writers.
+ #
+ # The public `notify_layers` method also does not use a write lock,
+ # so protecting `_notify_layers` with a lock here is unnecessary.
+ self._notify_layers(event)
def _get_new_events(self, start_index: int) -> list[GraphEngineEvent]:
"""
diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py
index 98e1a20044..a4b2df2a8c 100644
--- a/api/core/workflow/graph_engine/graph_engine.py
+++ b/api/core/workflow/graph_engine/graph_engine.py
@@ -232,7 +232,7 @@ class GraphEngine:
self._graph_execution.start()
else:
self._graph_execution.paused = False
- self._graph_execution.pause_reason = None
+ self._graph_execution.pause_reasons = []
start_event = GraphRunStartedEvent()
self._event_manager.notify_layers(start_event)
@@ -246,11 +246,11 @@ class GraphEngine:
# Handle completion
if self._graph_execution.is_paused:
- pause_reason = self._graph_execution.pause_reason
- assert pause_reason is not None, "pause_reason should not be None when execution is paused."
+ pause_reasons = self._graph_execution.pause_reasons
+ assert pause_reasons, "pause_reasons should not be empty when execution is paused."
# Ensure we have a valid PauseReason for the event
paused_event = GraphRunPausedEvent(
- reason=pause_reason,
+ reasons=pause_reasons,
outputs=self._graph_runtime_state.outputs,
)
self._event_manager.notify_layers(paused_event)
diff --git a/api/core/workflow/graph_events/graph.py b/api/core/workflow/graph_events/graph.py
index 9faafc3173..5d10a76c15 100644
--- a/api/core/workflow/graph_events/graph.py
+++ b/api/core/workflow/graph_events/graph.py
@@ -45,8 +45,7 @@ class GraphRunAbortedEvent(BaseGraphEvent):
class GraphRunPausedEvent(BaseGraphEvent):
"""Event emitted when a graph run is paused by user command."""
- # reason: str | None = Field(default=None, description="reason for pause")
- reason: PauseReason = Field(..., description="reason for pause")
+ reasons: list[PauseReason] = Field(description="reason for pause", default_factory=list)
outputs: dict[str, object] = Field(
default_factory=dict,
description="Outputs available to the client while the run is paused.",
diff --git a/api/core/workflow/nodes/human_input/human_input_node.py b/api/core/workflow/nodes/human_input/human_input_node.py
index 2d6d9760af..c0d64a060a 100644
--- a/api/core/workflow/nodes/human_input/human_input_node.py
+++ b/api/core/workflow/nodes/human_input/human_input_node.py
@@ -65,7 +65,8 @@ class HumanInputNode(Node):
return self._pause_generator()
def _pause_generator(self):
- yield PauseRequestedEvent(reason=HumanInputRequired())
+ # TODO(QuantumGhost): yield a real form id.
+ yield PauseRequestedEvent(reason=HumanInputRequired(form_id="test_form_id", node_id=self.id))
def _is_completion_ready(self) -> bool:
"""Determine whether all required inputs are satisfied."""
diff --git a/api/core/workflow/runtime/graph_runtime_state.py b/api/core/workflow/runtime/graph_runtime_state.py
index 0fbc8ab23e..1561b789df 100644
--- a/api/core/workflow/runtime/graph_runtime_state.py
+++ b/api/core/workflow/runtime/graph_runtime_state.py
@@ -10,6 +10,7 @@ from typing import Any, Protocol
from pydantic.json import pydantic_encoder
from core.model_runtime.entities.llm_entities import LLMUsage
+from core.workflow.entities.pause_reason import PauseReason
from core.workflow.runtime.variable_pool import VariablePool
@@ -46,7 +47,11 @@ class ReadyQueueProtocol(Protocol):
class GraphExecutionProtocol(Protocol):
- """Structural interface for graph execution aggregate."""
+ """Structural interface for graph execution aggregate.
+
+ Defines the minimal set of attributes and methods required from a GraphExecution entity
+ for runtime orchestration and state management.
+ """
workflow_id: str
started: bool
@@ -54,6 +59,7 @@ class GraphExecutionProtocol(Protocol):
aborted: bool
error: Exception | None
exceptions_count: int
+ pause_reasons: list[PauseReason]
def start(self) -> None:
"""Transition execution into the running state."""
diff --git a/api/migrations/versions/2025_11_18_1859-7bb281b7a422_add_workflow_pause_reasons_table.py b/api/migrations/versions/2025_11_18_1859-7bb281b7a422_add_workflow_pause_reasons_table.py
new file mode 100644
index 0000000000..8478820999
--- /dev/null
+++ b/api/migrations/versions/2025_11_18_1859-7bb281b7a422_add_workflow_pause_reasons_table.py
@@ -0,0 +1,41 @@
+"""Add workflow_pauses_reasons table
+
+Revision ID: 7bb281b7a422
+Revises: 09cfdda155d1
+Create Date: 2025-11-18 18:59:26.999572
+
+"""
+
+from alembic import op
+import models as models
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = "7bb281b7a422"
+down_revision = "09cfdda155d1"
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ op.create_table(
+ "workflow_pause_reasons",
+ sa.Column("id", models.types.StringUUID(), nullable=False),
+ sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
+ sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
+
+ sa.Column("pause_id", models.types.StringUUID(), nullable=False),
+ sa.Column("type_", sa.String(20), nullable=False),
+ sa.Column("form_id", sa.String(length=36), nullable=False),
+ sa.Column("node_id", sa.String(length=255), nullable=False),
+ sa.Column("message", sa.String(length=255), nullable=False),
+
+ sa.PrimaryKeyConstraint("id", name=op.f("workflow_pause_reasons_pkey")),
+ )
+ with op.batch_alter_table("workflow_pause_reasons", schema=None) as batch_op:
+ batch_op.create_index(batch_op.f("workflow_pause_reasons_pause_id_idx"), ["pause_id"], unique=False)
+
+
+def downgrade():
+ op.drop_table("workflow_pause_reasons")
diff --git a/api/models/workflow.py b/api/models/workflow.py
index f206a6a870..4efa829692 100644
--- a/api/models/workflow.py
+++ b/api/models/workflow.py
@@ -29,6 +29,7 @@ from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
SYSTEM_VARIABLE_NODE_ID,
)
+from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause
from core.workflow.enums import NodeType
from extensions.ext_storage import Storage
from factories.variable_factory import TypeMismatchError, build_segment_with_type
@@ -1728,3 +1729,68 @@ class WorkflowPause(DefaultFieldsMixin, Base):
primaryjoin="WorkflowPause.workflow_run_id == WorkflowRun.id",
back_populates="pause",
)
+
+
+class WorkflowPauseReason(DefaultFieldsMixin, Base):
+ __tablename__ = "workflow_pause_reasons"
+
+ # `pause_id` represents the identifier of the pause,
+ # correspond to the `id` field of `WorkflowPause`.
+ pause_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True)
+
+ type_: Mapped[PauseReasonType] = mapped_column(EnumText(PauseReasonType), nullable=False)
+
+ # form_id is not empty if and if only type_ == PauseReasonType.HUMAN_INPUT_REQUIRED
+ #
+ form_id: Mapped[str] = mapped_column(
+ String(36),
+ nullable=False,
+ default="",
+ )
+
+ # message records the text description of this pause reason. For example,
+ # "The workflow has been paused due to scheduling."
+ #
+ # Empty message means that this pause reason is not speified.
+ message: Mapped[str] = mapped_column(
+ String(255),
+ nullable=False,
+ default="",
+ )
+
+ # `node_id` is the identifier of node causing the pasue, correspond to
+ # `Node.id`. Empty `node_id` means that this pause reason is not caused by any specific node
+ # (E.G. time slicing pauses.)
+ node_id: Mapped[str] = mapped_column(
+ String(255),
+ nullable=False,
+ default="",
+ )
+
+ # Relationship to WorkflowPause
+ pause: Mapped[WorkflowPause] = orm.relationship(
+ foreign_keys=[pause_id],
+ # require explicit preloading.
+ lazy="raise",
+ uselist=False,
+ primaryjoin="WorkflowPauseReason.pause_id == WorkflowPause.id",
+ )
+
+ @classmethod
+ def from_entity(cls, pause_reason: PauseReason) -> "WorkflowPauseReason":
+ if isinstance(pause_reason, HumanInputRequired):
+ return cls(
+ type_=PauseReasonType.HUMAN_INPUT_REQUIRED, form_id=pause_reason.form_id, node_id=pause_reason.node_id
+ )
+ elif isinstance(pause_reason, SchedulingPause):
+ return cls(type_=PauseReasonType.SCHEDULED_PAUSE, message=pause_reason.message, node_id="")
+ else:
+ raise AssertionError(f"Unknown pause reason type: {pause_reason}")
+
+ def to_entity(self) -> PauseReason:
+ if self.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED:
+ return HumanInputRequired(form_id=self.form_id, node_id=self.node_id)
+ elif self.type_ == PauseReasonType.SCHEDULED_PAUSE:
+ return SchedulingPause(message=self.message)
+ else:
+ raise AssertionError(f"Unknown pause reason type: {self.type_}")
diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py
index 21fd57cd22..fd547c78ba 100644
--- a/api/repositories/api_workflow_run_repository.py
+++ b/api/repositories/api_workflow_run_repository.py
@@ -38,11 +38,12 @@ from collections.abc import Sequence
from datetime import datetime
from typing import Protocol
-from core.workflow.entities.workflow_pause import WorkflowPauseEntity
+from core.workflow.entities.pause_reason import PauseReason
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.enums import WorkflowRunTriggeredFrom
from models.workflow import WorkflowRun
+from repositories.entities.workflow_pause import WorkflowPauseEntity
from repositories.types import (
AverageInteractionStats,
DailyRunsStats,
@@ -257,6 +258,7 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
workflow_run_id: str,
state_owner_user_id: str,
state: str,
+ pause_reasons: Sequence[PauseReason],
) -> WorkflowPauseEntity:
"""
Create a new workflow pause state.
diff --git a/api/core/workflow/entities/workflow_pause.py b/api/repositories/entities/workflow_pause.py
similarity index 77%
rename from api/core/workflow/entities/workflow_pause.py
rename to api/repositories/entities/workflow_pause.py
index 2f31c1ff53..b970f39816 100644
--- a/api/core/workflow/entities/workflow_pause.py
+++ b/api/repositories/entities/workflow_pause.py
@@ -7,8 +7,11 @@ and don't contain implementation details like tenant_id, app_id, etc.
"""
from abc import ABC, abstractmethod
+from collections.abc import Sequence
from datetime import datetime
+from core.workflow.entities.pause_reason import PauseReason
+
class WorkflowPauseEntity(ABC):
"""
@@ -59,3 +62,15 @@ class WorkflowPauseEntity(ABC):
the pause is not resumed yet.
"""
pass
+
+ @abstractmethod
+ def get_pause_reasons(self) -> Sequence[PauseReason]:
+ """
+ Retrieve detailed reasons for this pause.
+
+ Returns a sequence of `PauseReason` objects describing the specific nodes and
+ reasons for which the workflow execution was paused.
+ This information is related to, but distinct from, the `PauseReason` type
+ defined in `api/core/workflow/entities/pause_reason.py`.
+ """
+ ...
diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py
index eb2a32d764..b172c6a3ac 100644
--- a/api/repositories/sqlalchemy_api_workflow_run_repository.py
+++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py
@@ -31,7 +31,7 @@ from sqlalchemy import and_, delete, func, null, or_, select
from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session, selectinload, sessionmaker
-from core.workflow.entities.workflow_pause import WorkflowPauseEntity
+from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, SchedulingPause
from core.workflow.enums import WorkflowExecutionStatus
from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now
@@ -41,8 +41,9 @@ from libs.time_parser import get_time_threshold
from libs.uuid_utils import uuidv7
from models.enums import WorkflowRunTriggeredFrom
from models.workflow import WorkflowPause as WorkflowPauseModel
-from models.workflow import WorkflowRun
+from models.workflow import WorkflowPauseReason, WorkflowRun
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
+from repositories.entities.workflow_pause import WorkflowPauseEntity
from repositories.types import (
AverageInteractionStats,
DailyRunsStats,
@@ -318,6 +319,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
workflow_run_id: str,
state_owner_user_id: str,
state: str,
+ pause_reasons: Sequence[PauseReason],
) -> WorkflowPauseEntity:
"""
Create a new workflow pause state.
@@ -371,6 +373,25 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
pause_model.workflow_run_id = workflow_run.id
pause_model.state_object_key = state_obj_key
pause_model.created_at = naive_utc_now()
+ pause_reason_models = []
+ for reason in pause_reasons:
+ if isinstance(reason, HumanInputRequired):
+ # TODO(QuantumGhost): record node_id for `WorkflowPauseReason`
+ pause_reason_model = WorkflowPauseReason(
+ pause_id=pause_model.id,
+ type_=reason.TYPE,
+ form_id=reason.form_id,
+ )
+ elif isinstance(reason, SchedulingPause):
+ pause_reason_model = WorkflowPauseReason(
+ pause_id=pause_model.id,
+ type_=reason.TYPE,
+ message=reason.message,
+ )
+ else:
+ raise AssertionError(f"unkown reason type: {type(reason)}")
+
+ pause_reason_models.append(pause_reason_model)
# Update workflow run status
workflow_run.status = WorkflowExecutionStatus.PAUSED
@@ -378,10 +399,16 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
# Save everything in a transaction
session.add(pause_model)
session.add(workflow_run)
+ session.add_all(pause_reason_models)
logger.info("Created workflow pause %s for workflow run %s", pause_model.id, workflow_run_id)
- return _PrivateWorkflowPauseEntity.from_models(pause_model)
+ return _PrivateWorkflowPauseEntity(pause_model=pause_model, reason_models=pause_reason_models)
+
+ def _get_reasons_by_pause_id(self, session: Session, pause_id: str):
+ reason_stmt = select(WorkflowPauseReason).where(WorkflowPauseReason.pause_id == pause_id)
+ pause_reason_models = session.scalars(reason_stmt).all()
+ return pause_reason_models
def get_workflow_pause(
self,
@@ -413,8 +440,16 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
pause_model = workflow_run.pause
if pause_model is None:
return None
+ pause_reason_models = self._get_reasons_by_pause_id(session, pause_model.id)
- return _PrivateWorkflowPauseEntity.from_models(pause_model)
+ human_input_form: list[Any] = []
+ # TODO(QuantumGhost): query human_input_forms model and rebuild PauseReason
+
+ return _PrivateWorkflowPauseEntity(
+ pause_model=pause_model,
+ reason_models=pause_reason_models,
+ human_input_form=human_input_form,
+ )
def resume_workflow_pause(
self,
@@ -466,6 +501,8 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
if pause_model.resumed_at is not None:
raise _WorkflowRunError(f"Cannot resume an already resumed pause, pause_id={pause_model.id}")
+ pause_reasons = self._get_reasons_by_pause_id(session, pause_model.id)
+
# Mark as resumed
pause_model.resumed_at = naive_utc_now()
workflow_run.pause_id = None # type: ignore
@@ -476,7 +513,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
logger.info("Resumed workflow pause %s for workflow run %s", pause_model.id, workflow_run_id)
- return _PrivateWorkflowPauseEntity.from_models(pause_model)
+ return _PrivateWorkflowPauseEntity(pause_model=pause_model, reason_models=pause_reasons)
def delete_workflow_pause(
self,
@@ -815,26 +852,13 @@ class _PrivateWorkflowPauseEntity(WorkflowPauseEntity):
self,
*,
pause_model: WorkflowPauseModel,
+ reason_models: Sequence[WorkflowPauseReason],
+ human_input_form: Sequence = (),
) -> None:
self._pause_model = pause_model
+ self._reason_models = reason_models
self._cached_state: bytes | None = None
-
- @classmethod
- def from_models(cls, workflow_pause_model) -> "_PrivateWorkflowPauseEntity":
- """
- Create a _PrivateWorkflowPauseEntity from database models.
-
- Args:
- workflow_pause_model: The WorkflowPause database model
- upload_file_model: The UploadFile database model
-
- Returns:
- _PrivateWorkflowPauseEntity: The constructed entity
-
- Raises:
- ValueError: If required model attributes are missing
- """
- return cls(pause_model=workflow_pause_model)
+ self._human_input_form = human_input_form
@property
def id(self) -> str:
@@ -867,3 +891,6 @@ class _PrivateWorkflowPauseEntity(WorkflowPauseEntity):
@property
def resumed_at(self) -> datetime | None:
return self._pause_model.resumed_at
+
+ def get_pause_reasons(self) -> Sequence[PauseReason]:
+ return [reason.to_entity() for reason in self._reason_models]
diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py
index b6764f1fa7..b45a167b73 100644
--- a/api/services/workflow_service.py
+++ b/api/services/workflow_service.py
@@ -15,7 +15,7 @@ from core.file import File
from core.repositories import DifyCoreRepositoryFactory
from core.variables import Variable
from core.variables.variables import VariableUnion
-from core.workflow.entities import VariablePool, WorkflowNodeExecution
+from core.workflow.entities import WorkflowNodeExecution
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent
@@ -24,6 +24,7 @@ from core.workflow.nodes import NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.nodes.start.entities import StartNodeData
+from core.workflow.runtime import VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_entry import WorkflowEntry
from enums.cloud_plan import CloudPlan
diff --git a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py
index bec3517d66..72469ad646 100644
--- a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py
+++ b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py
@@ -319,7 +319,7 @@ class TestPauseStatePersistenceLayerTestContainers:
# Create pause event
event = GraphRunPausedEvent(
- reason=SchedulingPause(message="test pause"),
+ reasons=[SchedulingPause(message="test pause")],
outputs={"intermediate": "result"},
)
@@ -381,7 +381,7 @@ class TestPauseStatePersistenceLayerTestContainers:
command_channel = _TestCommandChannelImpl()
layer.initialize(graph_runtime_state, command_channel)
- event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
+ event = GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")])
# Act - Save pause state
layer.on_event(event)
@@ -390,6 +390,7 @@ class TestPauseStatePersistenceLayerTestContainers:
pause_entity = self.workflow_run_service._workflow_run_repo.get_workflow_pause(self.test_workflow_run_id)
assert pause_entity is not None
assert pause_entity.workflow_execution_id == self.test_workflow_run_id
+ assert pause_entity.get_pause_reasons() == event.reasons
state_bytes = pause_entity.get_state()
resumption_context = WorkflowResumptionContext.loads(state_bytes.decode())
@@ -414,7 +415,7 @@ class TestPauseStatePersistenceLayerTestContainers:
command_channel = _TestCommandChannelImpl()
layer.initialize(graph_runtime_state, command_channel)
- event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
+ event = GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")])
# Act
layer.on_event(event)
@@ -448,7 +449,7 @@ class TestPauseStatePersistenceLayerTestContainers:
command_channel = _TestCommandChannelImpl()
layer.initialize(graph_runtime_state, command_channel)
- event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
+ event = GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")])
# Act
layer.on_event(event)
@@ -514,7 +515,7 @@ class TestPauseStatePersistenceLayerTestContainers:
command_channel = _TestCommandChannelImpl()
layer.initialize(graph_runtime_state, command_channel)
- event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
+ event = GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")])
# Act
layer.on_event(event)
@@ -570,7 +571,7 @@ class TestPauseStatePersistenceLayerTestContainers:
layer = self._create_pause_state_persistence_layer()
# Don't initialize - graph_runtime_state should not be set
- event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
+ event = GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")])
# Act & Assert - Should raise AttributeError
with pytest.raises(AttributeError):
diff --git a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py
index 79da5d4d0e..889e3d1d83 100644
--- a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py
+++ b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py
@@ -334,12 +334,14 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
+ pause_reasons=[],
)
# Assert - Pause state created
assert pause_entity is not None
assert pause_entity.id is not None
assert pause_entity.workflow_execution_id == workflow_run.id
+ assert list(pause_entity.get_pause_reasons()) == []
# Convert both to strings for comparison
retrieved_state = pause_entity.get_state()
if isinstance(retrieved_state, bytes):
@@ -366,6 +368,7 @@ class TestWorkflowPauseIntegration:
if isinstance(retrieved_state, bytes):
retrieved_state = retrieved_state.decode()
assert retrieved_state == test_state
+ assert list(retrieved_entity.get_pause_reasons()) == []
# Act - Resume workflow
resumed_entity = repository.resume_workflow_pause(
@@ -402,6 +405,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
+ pause_reasons=[],
)
assert pause_entity is not None
@@ -432,6 +436,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
+ pause_reasons=[],
)
@pytest.mark.parametrize("test_case", resume_workflow_success_cases(), ids=lambda tc: tc.name)
@@ -449,6 +454,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
+ pause_reasons=[],
)
self.session.refresh(workflow_run)
@@ -480,6 +486,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
+ pause_reasons=[],
)
self.session.refresh(workflow_run)
@@ -503,6 +510,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
+ pause_reasons=[],
)
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
pause_model.resumed_at = naive_utc_now()
@@ -530,6 +538,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=nonexistent_id,
state_owner_user_id=self.test_user_id,
state=test_state,
+ pause_reasons=[],
)
def test_resume_nonexistent_workflow_run(self):
@@ -543,6 +552,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
+ pause_reasons=[],
)
nonexistent_id = str(uuid.uuid4())
@@ -570,6 +580,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
+ pause_reasons=[],
)
# Manually adjust timestamps for testing
@@ -648,6 +659,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
+ pause_reasons=[],
)
pause_entities.append(pause_entity)
@@ -750,6 +762,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run1.id,
state_owner_user_id=self.test_user_id,
state=test_state,
+ pause_reasons=[],
)
# Try to access pause from tenant 2 using tenant 1's repository
@@ -762,6 +775,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run2.id,
state_owner_user_id=account2.id,
state=test_state,
+ pause_reasons=[],
)
# Assert - Both pauses should exist and be separate
@@ -782,6 +796,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
+ pause_reasons=[],
)
# Verify pause is properly scoped
@@ -802,6 +817,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
+ pause_reasons=[],
)
# Assert - Verify file was uploaded to storage
@@ -828,9 +844,7 @@ class TestWorkflowPauseIntegration:
repository = self._get_workflow_run_repository()
pause_entity = repository.create_workflow_pause(
- workflow_run_id=workflow_run.id,
- state_owner_user_id=self.test_user_id,
- state=test_state,
+ workflow_run_id=workflow_run.id, state_owner_user_id=self.test_user_id, state=test_state, pause_reasons=[]
)
# Get file info before deletion
@@ -868,6 +882,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=large_state_json,
+ pause_reasons=[],
)
# Assert
@@ -902,9 +917,7 @@ class TestWorkflowPauseIntegration:
# Pause
pause_entity = repository.create_workflow_pause(
- workflow_run_id=workflow_run.id,
- state_owner_user_id=self.test_user_id,
- state=state,
+ workflow_run_id=workflow_run.id, state_owner_user_id=self.test_user_id, state=state, pause_reasons=[]
)
assert pause_entity is not None
diff --git a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py
index 807f5e0fa5..534420f21e 100644
--- a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py
+++ b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py
@@ -31,7 +31,7 @@ class TestDataFactory:
@staticmethod
def create_graph_run_paused_event(outputs: dict[str, object] | None = None) -> GraphRunPausedEvent:
- return GraphRunPausedEvent(reason=SchedulingPause(message="test pause"), outputs=outputs or {})
+ return GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")], outputs=outputs or {})
@staticmethod
def create_graph_run_started_event() -> GraphRunStartedEvent:
@@ -255,15 +255,17 @@ class TestPauseStatePersistenceLayer:
layer.on_event(event)
mock_factory.assert_called_once_with(session_factory)
- mock_repo.create_workflow_pause.assert_called_once_with(
- workflow_run_id="run-123",
- state_owner_user_id="owner-123",
- state=mock_repo.create_workflow_pause.call_args.kwargs["state"],
- )
- serialized_state = mock_repo.create_workflow_pause.call_args.kwargs["state"]
+ assert mock_repo.create_workflow_pause.call_count == 1
+ call_kwargs = mock_repo.create_workflow_pause.call_args.kwargs
+ assert call_kwargs["workflow_run_id"] == "run-123"
+ assert call_kwargs["state_owner_user_id"] == "owner-123"
+ serialized_state = call_kwargs["state"]
resumption_context = WorkflowResumptionContext.loads(serialized_state)
assert resumption_context.serialized_graph_runtime_state == expected_state
assert resumption_context.get_generate_entity().model_dump() == generate_entity.model_dump()
+ pause_reasons = call_kwargs["pause_reasons"]
+
+ assert isinstance(pause_reasons, list)
def test_on_event_ignores_non_paused_events(self, monkeypatch: pytest.MonkeyPatch):
session_factory = Mock(name="session_factory")
diff --git a/api/tests/unit_tests/core/workflow/entities/test_private_workflow_pause.py b/api/tests/unit_tests/core/workflow/entities/test_private_workflow_pause.py
index ccb2dff85a..be165bf1c1 100644
--- a/api/tests/unit_tests/core/workflow/entities/test_private_workflow_pause.py
+++ b/api/tests/unit_tests/core/workflow/entities/test_private_workflow_pause.py
@@ -19,38 +19,18 @@ class TestPrivateWorkflowPauseEntity:
mock_pause_model.resumed_at = None
# Create entity
- entity = _PrivateWorkflowPauseEntity(
- pause_model=mock_pause_model,
- )
+ entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
# Verify initialization
assert entity._pause_model is mock_pause_model
assert entity._cached_state is None
- def test_from_models_classmethod(self):
- """Test from_models class method."""
- # Create mock models
- mock_pause_model = MagicMock(spec=WorkflowPauseModel)
- mock_pause_model.id = "pause-123"
- mock_pause_model.workflow_run_id = "execution-456"
-
- # Create entity using from_models
- entity = _PrivateWorkflowPauseEntity.from_models(
- workflow_pause_model=mock_pause_model,
- )
-
- # Verify entity creation
- assert isinstance(entity, _PrivateWorkflowPauseEntity)
- assert entity._pause_model is mock_pause_model
-
def test_id_property(self):
"""Test id property returns pause model ID."""
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.id = "pause-123"
- entity = _PrivateWorkflowPauseEntity(
- pause_model=mock_pause_model,
- )
+ entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
assert entity.id == "pause-123"
@@ -59,9 +39,7 @@ class TestPrivateWorkflowPauseEntity:
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.workflow_run_id = "execution-456"
- entity = _PrivateWorkflowPauseEntity(
- pause_model=mock_pause_model,
- )
+ entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
assert entity.workflow_execution_id == "execution-456"
@@ -72,9 +50,7 @@ class TestPrivateWorkflowPauseEntity:
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.resumed_at = resumed_at
- entity = _PrivateWorkflowPauseEntity(
- pause_model=mock_pause_model,
- )
+ entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
assert entity.resumed_at == resumed_at
@@ -83,9 +59,7 @@ class TestPrivateWorkflowPauseEntity:
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.resumed_at = None
- entity = _PrivateWorkflowPauseEntity(
- pause_model=mock_pause_model,
- )
+ entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
assert entity.resumed_at is None
@@ -98,9 +72,7 @@ class TestPrivateWorkflowPauseEntity:
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.state_object_key = "test-state-key"
- entity = _PrivateWorkflowPauseEntity(
- pause_model=mock_pause_model,
- )
+ entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
# First call should load from storage
result = entity.get_state()
@@ -118,9 +90,7 @@ class TestPrivateWorkflowPauseEntity:
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.state_object_key = "test-state-key"
- entity = _PrivateWorkflowPauseEntity(
- pause_model=mock_pause_model,
- )
+ entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
# First call
result1 = entity.get_state()
@@ -139,9 +109,7 @@ class TestPrivateWorkflowPauseEntity:
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
- entity = _PrivateWorkflowPauseEntity(
- pause_model=mock_pause_model,
- )
+ entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
# Pre-cache data
entity._cached_state = state_data
@@ -162,9 +130,7 @@ class TestPrivateWorkflowPauseEntity:
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
- entity = _PrivateWorkflowPauseEntity(
- pause_model=mock_pause_model,
- )
+ entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
result = entity.get_state()
diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py
index c55c40c5b4..0f62a11684 100644
--- a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py
+++ b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py
@@ -8,12 +8,13 @@ from typing import Any
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
-from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
+from core.workflow.entities import GraphInitParams
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
from core.workflow.graph import Graph
from core.workflow.graph.validation import GraphValidationError
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
+from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py
index 868edf9832..5d958803bc 100644
--- a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py
+++ b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py
@@ -178,8 +178,7 @@ def test_pause_command():
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
pause_events = [e for e in events if isinstance(e, GraphRunPausedEvent)]
assert len(pause_events) == 1
- assert pause_events[0].reason == SchedulingPause(message="User requested pause")
+ assert pause_events[0].reasons == [SchedulingPause(message="User requested pause")]
graph_execution = engine.graph_runtime_state.graph_execution
- assert graph_execution.paused
- assert graph_execution.pause_reason == SchedulingPause(message="User requested pause")
+ assert graph_execution.pause_reasons == [SchedulingPause(message="User requested pause")]
diff --git a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py
index 73b35b8e63..0c34676252 100644
--- a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py
+++ b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py
@@ -6,10 +6,10 @@ from unittest.mock import Mock, patch
import pytest
from sqlalchemy.orm import Session, sessionmaker
-from core.workflow.entities.workflow_pause import WorkflowPauseEntity
from core.workflow.enums import WorkflowExecutionStatus
from models.workflow import WorkflowPause as WorkflowPauseModel
from models.workflow import WorkflowRun
+from repositories.entities.workflow_pause import WorkflowPauseEntity
from repositories.sqlalchemy_api_workflow_run_repository import (
DifyAPISQLAlchemyWorkflowRunRepository,
_PrivateWorkflowPauseEntity,
@@ -129,12 +129,14 @@ class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
workflow_run_id=workflow_run_id,
state_owner_user_id=state_owner_user_id,
state=state,
+ pause_reasons=[],
)
# Assert
assert isinstance(result, _PrivateWorkflowPauseEntity)
assert result.id == "pause-123"
assert result.workflow_execution_id == workflow_run_id
+ assert result.get_pause_reasons() == []
# Verify database interactions
mock_session.get.assert_called_once_with(WorkflowRun, workflow_run_id)
@@ -156,6 +158,7 @@ class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
workflow_run_id="workflow-run-123",
state_owner_user_id="user-123",
state='{"test": "state"}',
+ pause_reasons=[],
)
mock_session.get.assert_called_once_with(WorkflowRun, "workflow-run-123")
@@ -174,6 +177,7 @@ class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
workflow_run_id="workflow-run-123",
state_owner_user_id="user-123",
state='{"test": "state"}',
+ pause_reasons=[],
)
@@ -316,19 +320,10 @@ class TestDeleteWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository):
"""Test _PrivateWorkflowPauseEntity class."""
- def test_from_models(self, sample_workflow_pause: Mock):
- """Test creating _PrivateWorkflowPauseEntity from models."""
- # Act
- entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
-
- # Assert
- assert isinstance(entity, _PrivateWorkflowPauseEntity)
- assert entity._pause_model == sample_workflow_pause
-
def test_properties(self, sample_workflow_pause: Mock):
"""Test entity properties."""
# Arrange
- entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
+ entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
# Act & Assert
assert entity.id == sample_workflow_pause.id
@@ -338,7 +333,7 @@ class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository)
def test_get_state(self, sample_workflow_pause: Mock):
"""Test getting state from storage."""
# Arrange
- entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
+ entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
expected_state = b'{"test": "state"}'
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
@@ -354,7 +349,7 @@ class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository)
def test_get_state_caching(self, sample_workflow_pause: Mock):
"""Test state caching in get_state method."""
# Arrange
- entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
+ entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
expected_state = b'{"test": "state"}'
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
diff --git a/api/tests/unit_tests/services/test_workflow_run_service_pause.py b/api/tests/unit_tests/services/test_workflow_run_service_pause.py
index a062d9444e..f45a72927e 100644
--- a/api/tests/unit_tests/services/test_workflow_run_service_pause.py
+++ b/api/tests/unit_tests/services/test_workflow_run_service_pause.py
@@ -17,6 +17,7 @@ from sqlalchemy import Engine
from sqlalchemy.orm import Session, sessionmaker
from core.workflow.enums import WorkflowExecutionStatus
+from models.workflow import WorkflowPause
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.sqlalchemy_api_workflow_run_repository import _PrivateWorkflowPauseEntity
from services.workflow_run_service import (
@@ -63,7 +64,7 @@ class TestDataFactory:
**kwargs,
) -> MagicMock:
"""Create a mock WorkflowPauseModel object."""
- mock_pause = MagicMock()
+ mock_pause = MagicMock(spec=WorkflowPause)
mock_pause.id = id
mock_pause.tenant_id = tenant_id
mock_pause.app_id = app_id
@@ -77,38 +78,15 @@ class TestDataFactory:
return mock_pause
- @staticmethod
- def create_upload_file_mock(
- id: str = "file-456",
- key: str = "upload_files/test/state.json",
- name: str = "state.json",
- tenant_id: str = "tenant-456",
- **kwargs,
- ) -> MagicMock:
- """Create a mock UploadFile object."""
- mock_file = MagicMock()
- mock_file.id = id
- mock_file.key = key
- mock_file.name = name
- mock_file.tenant_id = tenant_id
-
- for key, value in kwargs.items():
- setattr(mock_file, key, value)
-
- return mock_file
-
@staticmethod
def create_pause_entity_mock(
pause_model: MagicMock | None = None,
- upload_file: MagicMock | None = None,
) -> _PrivateWorkflowPauseEntity:
"""Create a mock _PrivateWorkflowPauseEntity object."""
if pause_model is None:
pause_model = TestDataFactory.create_workflow_pause_mock()
- if upload_file is None:
- upload_file = TestDataFactory.create_upload_file_mock()
- return _PrivateWorkflowPauseEntity.from_models(pause_model, upload_file)
+ return _PrivateWorkflowPauseEntity(pause_model=pause_model, reason_models=[], human_input_form=[])
class TestWorkflowRunService:
From af587f38695800586965a2af21eea991464e6cc6 Mon Sep 17 00:00:00 2001
From: GuanMu
Date: Wed, 26 Nov 2025 22:37:05 +0800
Subject: [PATCH 002/431] chore: update packageManager version to pnpm@10.23.0
(#28708)
---
web/package.json | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/web/package.json b/web/package.json
index c7d8980f48..89a3a349a8 100644
--- a/web/package.json
+++ b/web/package.json
@@ -2,7 +2,7 @@
"name": "dify-web",
"version": "1.10.1",
"private": true,
- "packageManager": "pnpm@10.22.0+sha512.bf049efe995b28f527fd2b41ae0474ce29186f7edcb3bf545087bd61fbbebb2bf75362d1307fda09c2d288e1e499787ac12d4fcb617a974718a6051f2eee741c",
+ "packageManager": "pnpm@10.23.0+sha512.21c4e5698002ade97e4efe8b8b4a89a8de3c85a37919f957e7a0f30f38fbc5bbdd05980ffe29179b2fb6e6e691242e098d945d1601772cad0fef5fb6411e2a4b",
"engines": {
"node": ">=v22.11.0"
},
From 6b8c6498769a3716fc34b83384bdf9d3cb2d944c Mon Sep 17 00:00:00 2001
From: Yuichiro Utsumi <81412151+utsumi-fj@users.noreply.github.com>
Date: Wed, 26 Nov 2025 23:39:29 +0900
Subject: [PATCH 003/431] fix: prevent auto-scrolling from stopping in chat
(#28690)
Signed-off-by: Yuichiro Utsumi
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
---
web/app/components/base/chat/chat/index.tsx | 40 +++++++++++++++------
1 file changed, 30 insertions(+), 10 deletions(-)
diff --git a/web/app/components/base/chat/chat/index.tsx b/web/app/components/base/chat/chat/index.tsx
index a362f4dc99..51b5df4f32 100644
--- a/web/app/components/base/chat/chat/index.tsx
+++ b/web/app/components/base/chat/chat/index.tsx
@@ -128,10 +128,17 @@ const Chat: FC = ({
const chatFooterRef = useRef(null)
const chatFooterInnerRef = useRef(null)
const userScrolledRef = useRef(false)
+ const isAutoScrollingRef = useRef(false)
const handleScrollToBottom = useCallback(() => {
- if (chatList.length > 1 && chatContainerRef.current && !userScrolledRef.current)
+ if (chatList.length > 1 && chatContainerRef.current && !userScrolledRef.current) {
+ isAutoScrollingRef.current = true
chatContainerRef.current.scrollTop = chatContainerRef.current.scrollHeight
+
+ requestAnimationFrame(() => {
+ isAutoScrollingRef.current = false
+ })
+ }
}, [chatList.length])
const handleWindowResize = useCallback(() => {
@@ -198,18 +205,31 @@ const Chat: FC = ({
}, [handleScrollToBottom])
useEffect(() => {
- const chatContainer = chatContainerRef.current
- if (chatContainer) {
- const setUserScrolled = () => {
- // eslint-disable-next-line sonarjs/no-gratuitous-expressions
- if (chatContainer) // its in event callback, chatContainer may be null
- userScrolledRef.current = chatContainer.scrollHeight - chatContainer.scrollTop > chatContainer.clientHeight
- }
- chatContainer.addEventListener('scroll', setUserScrolled)
- return () => chatContainer.removeEventListener('scroll', setUserScrolled)
+ const setUserScrolled = () => {
+ const container = chatContainerRef.current
+ if (!container) return
+
+ if (isAutoScrollingRef.current) return
+
+ const distanceToBottom = container.scrollHeight - container.clientHeight - container.scrollTop
+ const SCROLL_UP_THRESHOLD = 100
+
+ userScrolledRef.current = distanceToBottom > SCROLL_UP_THRESHOLD
}
+
+ const container = chatContainerRef.current
+ if (!container) return
+
+ container.addEventListener('scroll', setUserScrolled)
+ return () => container.removeEventListener('scroll', setUserScrolled)
}, [])
+ // Reset user scroll state when a new chat starts (length <= 1)
+ useEffect(() => {
+ if (chatList.length <= 1)
+ userScrolledRef.current = false
+ }, [chatList.length])
+
useEffect(() => {
if (!sidebarCollapseState)
setTimeout(() => handleWindowResize(), 200)
From 6635ea62c2bfa7ff740056e22308ce8b0e0d4bf8 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?=
Date: Wed, 26 Nov 2025 22:41:52 +0800
Subject: [PATCH 004/431] fix: change existing node to a webhook node raise 404
(#28686)
---
.../workflow/hooks/use-nodes-interactions.ts | 12 +++++++++++-
web/service/apps.ts | 6 +++++-
2 files changed, 16 insertions(+), 2 deletions(-)
diff --git a/web/app/components/workflow/hooks/use-nodes-interactions.ts b/web/app/components/workflow/hooks/use-nodes-interactions.ts
index 3cbdf08e43..d56b85893e 100644
--- a/web/app/components/workflow/hooks/use-nodes-interactions.ts
+++ b/web/app/components/workflow/hooks/use-nodes-interactions.ts
@@ -59,6 +59,7 @@ import {
useWorkflowHistory,
} from './use-workflow-history'
import { useNodesMetaData } from './use-nodes-meta-data'
+import { useAutoGenerateWebhookUrl } from './use-auto-generate-webhook-url'
import type { RAGPipelineVariables } from '@/models/pipeline'
import useInspectVarsCrud from './use-inspect-vars-crud'
import { getNodeUsedVars } from '../nodes/_base/components/variable/utils'
@@ -94,6 +95,7 @@ export const useNodesInteractions = () => {
const { nodesMap: nodesMetaDataMap } = useNodesMetaData()
const { saveStateToHistory, undo, redo } = useWorkflowHistory()
+ const autoGenerateWebhookUrl = useAutoGenerateWebhookUrl()
const handleNodeDragStart = useCallback(
(_, node) => {
@@ -1401,7 +1403,14 @@ export const useNodesInteractions = () => {
return filtered
})
setEdges(newEdges)
- handleSyncWorkflowDraft()
+ if (nodeType === BlockEnum.TriggerWebhook) {
+ handleSyncWorkflowDraft(true, true, {
+ onSuccess: () => autoGenerateWebhookUrl(newCurrentNode.id),
+ })
+ }
+ else {
+ handleSyncWorkflowDraft()
+ }
saveStateToHistory(WorkflowHistoryEvent.NodeChange, {
nodeId: currentNodeId,
@@ -1413,6 +1422,7 @@ export const useNodesInteractions = () => {
handleSyncWorkflowDraft,
saveStateToHistory,
nodesMetaDataMap,
+ autoGenerateWebhookUrl,
],
)
diff --git a/web/service/apps.ts b/web/service/apps.ts
index b1124767ad..7a4cfb93ff 100644
--- a/web/service/apps.ts
+++ b/web/service/apps.ts
@@ -164,7 +164,11 @@ export const updateTracingStatus: Fetcher = ({ appId, nodeId }) => {
- return get(`apps/${appId}/workflows/triggers/webhook`, { params: { node_id: nodeId } })
+ return get(
+ `apps/${appId}/workflows/triggers/webhook`,
+ { params: { node_id: nodeId } },
+ { silent: true },
+ )
}
export const fetchTracingConfig: Fetcher = ({ appId, provider }) => {
From e76129b5a4c4d82fd3efc27cb4eab14c8e9df0f4 Mon Sep 17 00:00:00 2001
From: aka James4u
Date: Wed, 26 Nov 2025 06:42:58 -0800
Subject: [PATCH 005/431] test: add comprehensive unit tests for
HitTestingService Fix: #28667 (#28668)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
---
api/tests/unit_tests/services/hit_service.py | 802 +++++++++++++++++++
1 file changed, 802 insertions(+)
create mode 100644 api/tests/unit_tests/services/hit_service.py
diff --git a/api/tests/unit_tests/services/hit_service.py b/api/tests/unit_tests/services/hit_service.py
new file mode 100644
index 0000000000..17f3a7e94e
--- /dev/null
+++ b/api/tests/unit_tests/services/hit_service.py
@@ -0,0 +1,802 @@
+"""
+Unit tests for HitTestingService.
+
+This module contains comprehensive unit tests for the HitTestingService class,
+which handles retrieval testing operations for datasets, including internal
+dataset retrieval and external knowledge base retrieval.
+"""
+
+from unittest.mock import MagicMock, Mock, patch
+
+import pytest
+
+from core.rag.models.document import Document
+from core.rag.retrieval.retrieval_methods import RetrievalMethod
+from models import Account
+from models.dataset import Dataset
+from services.hit_testing_service import HitTestingService
+
+
+class HitTestingTestDataFactory:
+ """
+ Factory class for creating test data and mock objects for hit testing service tests.
+
+ This factory provides static methods to create mock objects for datasets, users,
+ documents, and retrieval records used in HitTestingService unit tests.
+ """
+
+ @staticmethod
+ def create_dataset_mock(
+ dataset_id: str = "dataset-123",
+ tenant_id: str = "tenant-123",
+ provider: str = "vendor",
+ retrieval_model: dict | None = None,
+ **kwargs,
+ ) -> Mock:
+ """
+ Create a mock dataset with specified attributes.
+
+ Args:
+ dataset_id: Unique identifier for the dataset
+ tenant_id: Tenant identifier
+ provider: Dataset provider (vendor, external, etc.)
+ retrieval_model: Optional retrieval model configuration
+ **kwargs: Additional attributes to set on the mock
+
+ Returns:
+ Mock object configured as a Dataset instance
+ """
+ dataset = Mock(spec=Dataset)
+ dataset.id = dataset_id
+ dataset.tenant_id = tenant_id
+ dataset.provider = provider
+ dataset.retrieval_model = retrieval_model
+ for key, value in kwargs.items():
+ setattr(dataset, key, value)
+ return dataset
+
+ @staticmethod
+ def create_user_mock(
+ user_id: str = "user-789",
+ tenant_id: str = "tenant-123",
+ **kwargs,
+ ) -> Mock:
+ """
+ Create a mock user (Account) with specified attributes.
+
+ Args:
+ user_id: Unique identifier for the user
+ tenant_id: Tenant identifier
+ **kwargs: Additional attributes to set on the mock
+
+ Returns:
+ Mock object configured as an Account instance
+ """
+ user = Mock(spec=Account)
+ user.id = user_id
+ user.current_tenant_id = tenant_id
+ user.name = "Test User"
+ for key, value in kwargs.items():
+ setattr(user, key, value)
+ return user
+
+ @staticmethod
+ def create_document_mock(
+ content: str = "Test document content",
+ metadata: dict | None = None,
+ **kwargs,
+ ) -> Mock:
+ """
+ Create a mock Document from core.rag.models.document.
+
+ Args:
+ content: Document content/text
+ metadata: Optional metadata dictionary
+ **kwargs: Additional attributes to set on the mock
+
+ Returns:
+ Mock object configured as a Document instance
+ """
+ document = Mock(spec=Document)
+ document.page_content = content
+ document.metadata = metadata or {}
+ for key, value in kwargs.items():
+ setattr(document, key, value)
+ return document
+
+ @staticmethod
+ def create_retrieval_record_mock(
+ content: str = "Test content",
+ score: float = 0.95,
+ **kwargs,
+ ) -> Mock:
+ """
+ Create a mock retrieval record.
+
+ Args:
+ content: Record content
+ score: Retrieval score
+ **kwargs: Additional fields for the record
+
+ Returns:
+ Mock object with model_dump method returning record data
+ """
+ record = Mock()
+ record.model_dump.return_value = {
+ "content": content,
+ "score": score,
+ **kwargs,
+ }
+ return record
+
+
+class TestHitTestingServiceRetrieve:
+ """
+ Tests for HitTestingService.retrieve method (hit_testing).
+
+ This test class covers the main retrieval testing functionality, including
+ various retrieval model configurations, metadata filtering, and query logging.
+ """
+
+ @pytest.fixture
+ def mock_db_session(self):
+ """
+ Mock database session.
+
+ Provides a mocked database session for testing database operations
+ like adding and committing DatasetQuery records.
+ """
+ with patch("services.hit_testing_service.db.session") as mock_db:
+ yield mock_db
+
+ def test_retrieve_success_with_default_retrieval_model(self, mock_db_session):
+ """
+ Test successful retrieval with default retrieval model.
+
+ Verifies that the retrieve method works correctly when no custom
+ retrieval model is provided, using the default retrieval configuration.
+ """
+ # Arrange
+ dataset = HitTestingTestDataFactory.create_dataset_mock(retrieval_model=None)
+ account = HitTestingTestDataFactory.create_user_mock()
+ query = "test query"
+ retrieval_model = None
+ external_retrieval_model = {}
+
+ documents = [
+ HitTestingTestDataFactory.create_document_mock(content="Doc 1"),
+ HitTestingTestDataFactory.create_document_mock(content="Doc 2"),
+ ]
+
+ mock_records = [
+ HitTestingTestDataFactory.create_retrieval_record_mock(content="Doc 1"),
+ HitTestingTestDataFactory.create_retrieval_record_mock(content="Doc 2"),
+ ]
+
+ with (
+ patch("services.hit_testing_service.RetrievalService.retrieve") as mock_retrieve,
+ patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format,
+ patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
+ ):
+ mock_perf_counter.side_effect = [0.0, 0.1] # start, end
+ mock_retrieve.return_value = documents
+ mock_format.return_value = mock_records
+
+ # Act
+ result = HitTestingService.retrieve(dataset, query, account, retrieval_model, external_retrieval_model)
+
+ # Assert
+ assert result["query"]["content"] == query
+ assert len(result["records"]) == 2
+ mock_retrieve.assert_called_once()
+ mock_db_session.add.assert_called_once()
+ mock_db_session.commit.assert_called_once()
+
+ def test_retrieve_success_with_custom_retrieval_model(self, mock_db_session):
+ """
+ Test successful retrieval with custom retrieval model.
+
+ Verifies that custom retrieval model parameters (search method, reranking,
+ score threshold, etc.) are properly passed to RetrievalService.
+ """
+ # Arrange
+ dataset = HitTestingTestDataFactory.create_dataset_mock()
+ account = HitTestingTestDataFactory.create_user_mock()
+ query = "test query"
+ retrieval_model = {
+ "search_method": RetrievalMethod.KEYWORD_SEARCH,
+ "reranking_enable": True,
+ "reranking_model": {"reranking_provider_name": "cohere", "reranking_model_name": "rerank-1"},
+ "top_k": 5,
+ "score_threshold_enabled": True,
+ "score_threshold": 0.7,
+ "weights": {"vector_setting": 0.5, "keyword_setting": 0.5},
+ }
+ external_retrieval_model = {}
+
+ documents = [HitTestingTestDataFactory.create_document_mock()]
+ mock_records = [HitTestingTestDataFactory.create_retrieval_record_mock()]
+
+ with (
+ patch("services.hit_testing_service.RetrievalService.retrieve") as mock_retrieve,
+ patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format,
+ patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
+ ):
+ mock_perf_counter.side_effect = [0.0, 0.1]
+ mock_retrieve.return_value = documents
+ mock_format.return_value = mock_records
+
+ # Act
+ result = HitTestingService.retrieve(dataset, query, account, retrieval_model, external_retrieval_model)
+
+ # Assert
+ assert result["query"]["content"] == query
+ mock_retrieve.assert_called_once()
+ call_kwargs = mock_retrieve.call_args[1]
+ assert call_kwargs["retrieval_method"] == RetrievalMethod.KEYWORD_SEARCH
+ assert call_kwargs["top_k"] == 5
+ assert call_kwargs["score_threshold"] == 0.7
+ assert call_kwargs["reranking_model"] == retrieval_model["reranking_model"]
+
+ def test_retrieve_with_metadata_filtering(self, mock_db_session):
+ """
+ Test retrieval with metadata filtering conditions.
+
+ Verifies that metadata filtering conditions are properly processed
+ and document ID filters are applied to the retrieval query.
+ """
+ # Arrange
+ dataset = HitTestingTestDataFactory.create_dataset_mock()
+ account = HitTestingTestDataFactory.create_user_mock()
+ query = "test query"
+ retrieval_model = {
+ "metadata_filtering_conditions": {
+ "conditions": [
+ {"field": "category", "operator": "is", "value": "test"},
+ ],
+ },
+ }
+ external_retrieval_model = {}
+
+ mock_dataset_retrieval = MagicMock()
+ mock_dataset_retrieval.get_metadata_filter_condition.return_value = (
+ {dataset.id: ["doc-1", "doc-2"]},
+ None,
+ )
+
+ documents = [HitTestingTestDataFactory.create_document_mock()]
+ mock_records = [HitTestingTestDataFactory.create_retrieval_record_mock()]
+
+ with (
+ patch("services.hit_testing_service.RetrievalService.retrieve") as mock_retrieve,
+ patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format,
+ patch("services.hit_testing_service.DatasetRetrieval") as mock_dataset_retrieval_class,
+ patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
+ ):
+ mock_perf_counter.side_effect = [0.0, 0.1]
+ mock_dataset_retrieval_class.return_value = mock_dataset_retrieval
+ mock_retrieve.return_value = documents
+ mock_format.return_value = mock_records
+
+ # Act
+ result = HitTestingService.retrieve(dataset, query, account, retrieval_model, external_retrieval_model)
+
+ # Assert
+ assert result["query"]["content"] == query
+ mock_dataset_retrieval.get_metadata_filter_condition.assert_called_once()
+ call_kwargs = mock_retrieve.call_args[1]
+ assert call_kwargs["document_ids_filter"] == ["doc-1", "doc-2"]
+
+ def test_retrieve_with_metadata_filtering_no_documents(self, mock_db_session):
+ """
+ Test retrieval with metadata filtering that returns no documents.
+
+ Verifies that when metadata filtering results in no matching documents,
+ an empty result is returned without calling RetrievalService.
+ """
+ # Arrange
+ dataset = HitTestingTestDataFactory.create_dataset_mock()
+ account = HitTestingTestDataFactory.create_user_mock()
+ query = "test query"
+ retrieval_model = {
+ "metadata_filtering_conditions": {
+ "conditions": [
+ {"field": "category", "operator": "is", "value": "test"},
+ ],
+ },
+ }
+ external_retrieval_model = {}
+
+ mock_dataset_retrieval = MagicMock()
+ mock_dataset_retrieval.get_metadata_filter_condition.return_value = ({}, True)
+
+ with (
+ patch("services.hit_testing_service.DatasetRetrieval") as mock_dataset_retrieval_class,
+ patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format,
+ ):
+ mock_dataset_retrieval_class.return_value = mock_dataset_retrieval
+ mock_format.return_value = []
+
+ # Act
+ result = HitTestingService.retrieve(dataset, query, account, retrieval_model, external_retrieval_model)
+
+ # Assert
+ assert result["query"]["content"] == query
+ assert result["records"] == []
+
+ def test_retrieve_with_dataset_retrieval_model(self, mock_db_session):
+ """
+ Test retrieval using dataset's retrieval model when not provided.
+
+ Verifies that when no retrieval model is provided, the dataset's
+ retrieval model is used as a fallback.
+ """
+ # Arrange
+ dataset_retrieval_model = {
+ "search_method": RetrievalMethod.HYBRID_SEARCH,
+ "top_k": 3,
+ }
+ dataset = HitTestingTestDataFactory.create_dataset_mock(retrieval_model=dataset_retrieval_model)
+ account = HitTestingTestDataFactory.create_user_mock()
+ query = "test query"
+ retrieval_model = None
+ external_retrieval_model = {}
+
+ documents = [HitTestingTestDataFactory.create_document_mock()]
+ mock_records = [HitTestingTestDataFactory.create_retrieval_record_mock()]
+
+ with (
+ patch("services.hit_testing_service.RetrievalService.retrieve") as mock_retrieve,
+ patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format,
+ patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
+ ):
+ mock_perf_counter.side_effect = [0.0, 0.1]
+ mock_retrieve.return_value = documents
+ mock_format.return_value = mock_records
+
+ # Act
+ result = HitTestingService.retrieve(dataset, query, account, retrieval_model, external_retrieval_model)
+
+ # Assert
+ assert result["query"]["content"] == query
+ call_kwargs = mock_retrieve.call_args[1]
+ assert call_kwargs["retrieval_method"] == RetrievalMethod.HYBRID_SEARCH
+ assert call_kwargs["top_k"] == 3
+
+
+class TestHitTestingServiceExternalRetrieve:
+ """
+ Tests for HitTestingService.external_retrieve method.
+
+ This test class covers external knowledge base retrieval functionality,
+ including query escaping, response formatting, and provider validation.
+ """
+
+ @pytest.fixture
+ def mock_db_session(self):
+ """
+ Mock database session.
+
+ Provides a mocked database session for testing database operations
+ like adding and committing DatasetQuery records.
+ """
+ with patch("services.hit_testing_service.db.session") as mock_db:
+ yield mock_db
+
+ def test_external_retrieve_success(self, mock_db_session):
+ """
+ Test successful external retrieval.
+
+ Verifies that external knowledge base retrieval works correctly,
+ including query escaping, document formatting, and query logging.
+ """
+ # Arrange
+ dataset = HitTestingTestDataFactory.create_dataset_mock(provider="external")
+ account = HitTestingTestDataFactory.create_user_mock()
+ query = 'test query with "quotes"'
+ external_retrieval_model = {"top_k": 5, "score_threshold": 0.8}
+ metadata_filtering_conditions = {}
+
+ external_documents = [
+ {"content": "External doc 1", "title": "Title 1", "score": 0.95, "metadata": {"key": "value"}},
+ {"content": "External doc 2", "title": "Title 2", "score": 0.85, "metadata": {}},
+ ]
+
+ with (
+ patch("services.hit_testing_service.RetrievalService.external_retrieve") as mock_external_retrieve,
+ patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
+ ):
+ mock_perf_counter.side_effect = [0.0, 0.1]
+ mock_external_retrieve.return_value = external_documents
+
+ # Act
+ result = HitTestingService.external_retrieve(
+ dataset, query, account, external_retrieval_model, metadata_filtering_conditions
+ )
+
+ # Assert
+ assert result["query"]["content"] == query
+ assert len(result["records"]) == 2
+ assert result["records"][0]["content"] == "External doc 1"
+ assert result["records"][0]["title"] == "Title 1"
+ assert result["records"][0]["score"] == 0.95
+ mock_external_retrieve.assert_called_once()
+ # Verify query was escaped
+ assert mock_external_retrieve.call_args[1]["query"] == 'test query with \\"quotes\\"'
+ mock_db_session.add.assert_called_once()
+ mock_db_session.commit.assert_called_once()
+
+ def test_external_retrieve_non_external_provider(self, mock_db_session):
+ """
+ Test external retrieval with non-external provider (should return empty).
+
+ Verifies that when the dataset provider is not "external", the method
+ returns an empty result without performing retrieval or database operations.
+ """
+ # Arrange
+ dataset = HitTestingTestDataFactory.create_dataset_mock(provider="vendor")
+ account = HitTestingTestDataFactory.create_user_mock()
+ query = "test query"
+ external_retrieval_model = {}
+ metadata_filtering_conditions = {}
+
+ # Act
+ result = HitTestingService.external_retrieve(
+ dataset, query, account, external_retrieval_model, metadata_filtering_conditions
+ )
+
+ # Assert
+ assert result["query"]["content"] == query
+ assert result["records"] == []
+ mock_db_session.add.assert_not_called()
+
+ def test_external_retrieve_with_metadata_filtering(self, mock_db_session):
+ """
+ Test external retrieval with metadata filtering conditions.
+
+ Verifies that metadata filtering conditions are properly passed
+ to the external retrieval service.
+ """
+ # Arrange
+ dataset = HitTestingTestDataFactory.create_dataset_mock(provider="external")
+ account = HitTestingTestDataFactory.create_user_mock()
+ query = "test query"
+ external_retrieval_model = {"top_k": 3}
+ metadata_filtering_conditions = {"category": "test"}
+
+ external_documents = [{"content": "Doc 1", "title": "Title", "score": 0.9, "metadata": {}}]
+
+ with (
+ patch("services.hit_testing_service.RetrievalService.external_retrieve") as mock_external_retrieve,
+ patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
+ ):
+ mock_perf_counter.side_effect = [0.0, 0.1]
+ mock_external_retrieve.return_value = external_documents
+
+ # Act
+ result = HitTestingService.external_retrieve(
+ dataset, query, account, external_retrieval_model, metadata_filtering_conditions
+ )
+
+ # Assert
+ assert result["query"]["content"] == query
+ assert len(result["records"]) == 1
+ call_kwargs = mock_external_retrieve.call_args[1]
+ assert call_kwargs["metadata_filtering_conditions"] == metadata_filtering_conditions
+
+ def test_external_retrieve_empty_documents(self, mock_db_session):
+ """
+ Test external retrieval with empty document list.
+
+ Verifies that when external retrieval returns no documents,
+ an empty result is properly formatted and returned.
+ """
+ # Arrange
+ dataset = HitTestingTestDataFactory.create_dataset_mock(provider="external")
+ account = HitTestingTestDataFactory.create_user_mock()
+ query = "test query"
+ external_retrieval_model = {}
+ metadata_filtering_conditions = {}
+
+ with (
+ patch("services.hit_testing_service.RetrievalService.external_retrieve") as mock_external_retrieve,
+ patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
+ ):
+ mock_perf_counter.side_effect = [0.0, 0.1]
+ mock_external_retrieve.return_value = []
+
+ # Act
+ result = HitTestingService.external_retrieve(
+ dataset, query, account, external_retrieval_model, metadata_filtering_conditions
+ )
+
+ # Assert
+ assert result["query"]["content"] == query
+ assert result["records"] == []
+
+
+class TestHitTestingServiceCompactRetrieveResponse:
+ """
+ Tests for HitTestingService.compact_retrieve_response method.
+
+ This test class covers response formatting for internal dataset retrieval,
+ ensuring documents are properly formatted into retrieval records.
+ """
+
+ def test_compact_retrieve_response_success(self):
+ """
+ Test successful response formatting.
+
+ Verifies that documents are properly formatted into retrieval records
+ with correct structure and data.
+ """
+ # Arrange
+ query = "test query"
+ documents = [
+ HitTestingTestDataFactory.create_document_mock(content="Doc 1"),
+ HitTestingTestDataFactory.create_document_mock(content="Doc 2"),
+ ]
+
+ mock_records = [
+ HitTestingTestDataFactory.create_retrieval_record_mock(content="Doc 1", score=0.95),
+ HitTestingTestDataFactory.create_retrieval_record_mock(content="Doc 2", score=0.85),
+ ]
+
+ with patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format:
+ mock_format.return_value = mock_records
+
+ # Act
+ result = HitTestingService.compact_retrieve_response(query, documents)
+
+ # Assert
+ assert result["query"]["content"] == query
+ assert len(result["records"]) == 2
+ assert result["records"][0]["content"] == "Doc 1"
+ assert result["records"][0]["score"] == 0.95
+ mock_format.assert_called_once_with(documents)
+
+ def test_compact_retrieve_response_empty_documents(self):
+ """
+ Test response formatting with empty document list.
+
+ Verifies that an empty document list results in an empty records array
+ while maintaining the correct response structure.
+ """
+ # Arrange
+ query = "test query"
+ documents = []
+
+ with patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format:
+ mock_format.return_value = []
+
+ # Act
+ result = HitTestingService.compact_retrieve_response(query, documents)
+
+ # Assert
+ assert result["query"]["content"] == query
+ assert result["records"] == []
+
+
+class TestHitTestingServiceCompactExternalRetrieveResponse:
+ """
+ Tests for HitTestingService.compact_external_retrieve_response method.
+
+ This test class covers response formatting for external knowledge base
+ retrieval, ensuring proper field extraction and provider validation.
+ """
+
+ def test_compact_external_retrieve_response_external_provider(self):
+ """
+ Test external response formatting for external provider.
+
+ Verifies that external documents are properly formatted with all
+ required fields (content, title, score, metadata).
+ """
+ # Arrange
+ dataset = HitTestingTestDataFactory.create_dataset_mock(provider="external")
+ query = "test query"
+ documents = [
+ {"content": "Doc 1", "title": "Title 1", "score": 0.95, "metadata": {"key": "value"}},
+ {"content": "Doc 2", "title": "Title 2", "score": 0.85, "metadata": {}},
+ ]
+
+ # Act
+ result = HitTestingService.compact_external_retrieve_response(dataset, query, documents)
+
+ # Assert
+ assert result["query"]["content"] == query
+ assert len(result["records"]) == 2
+ assert result["records"][0]["content"] == "Doc 1"
+ assert result["records"][0]["title"] == "Title 1"
+ assert result["records"][0]["score"] == 0.95
+ assert result["records"][0]["metadata"] == {"key": "value"}
+
+ def test_compact_external_retrieve_response_non_external_provider(self):
+ """
+ Test external response formatting for non-external provider.
+
+ Verifies that non-external providers return an empty records array
+ regardless of input documents.
+ """
+ # Arrange
+ dataset = HitTestingTestDataFactory.create_dataset_mock(provider="vendor")
+ query = "test query"
+ documents = [{"content": "Doc 1"}]
+
+ # Act
+ result = HitTestingService.compact_external_retrieve_response(dataset, query, documents)
+
+ # Assert
+ assert result["query"]["content"] == query
+ assert result["records"] == []
+
+ def test_compact_external_retrieve_response_missing_fields(self):
+ """
+ Test external response formatting with missing optional fields.
+
+ Verifies that missing optional fields (title, score, metadata) are
+ handled gracefully by setting them to None.
+ """
+ # Arrange
+ dataset = HitTestingTestDataFactory.create_dataset_mock(provider="external")
+ query = "test query"
+ documents = [
+ {"content": "Doc 1"}, # Missing title, score, metadata
+ {"content": "Doc 2", "title": "Title 2"}, # Missing score, metadata
+ ]
+
+ # Act
+ result = HitTestingService.compact_external_retrieve_response(dataset, query, documents)
+
+ # Assert
+ assert result["query"]["content"] == query
+ assert len(result["records"]) == 2
+ assert result["records"][0]["content"] == "Doc 1"
+ assert result["records"][0]["title"] is None
+ assert result["records"][0]["score"] is None
+ assert result["records"][0]["metadata"] is None
+
+
+class TestHitTestingServiceHitTestingArgsCheck:
+ """
+ Tests for HitTestingService.hit_testing_args_check method.
+
+ This test class covers query argument validation, ensuring queries
+ meet the required criteria (non-empty, max 250 characters).
+ """
+
+ def test_hit_testing_args_check_success(self):
+ """
+ Test successful argument validation.
+
+ Verifies that valid queries pass validation without raising errors.
+ """
+ # Arrange
+ args = {"query": "valid query"}
+
+ # Act & Assert (should not raise)
+ HitTestingService.hit_testing_args_check(args)
+
+ def test_hit_testing_args_check_empty_query(self):
+ """
+ Test validation fails with empty query.
+
+ Verifies that empty queries raise a ValueError with appropriate message.
+ """
+ # Arrange
+ args = {"query": ""}
+
+ # Act & Assert
+ with pytest.raises(ValueError, match="Query is required and cannot exceed 250 characters"):
+ HitTestingService.hit_testing_args_check(args)
+
+ def test_hit_testing_args_check_none_query(self):
+ """
+ Test validation fails with None query.
+
+ Verifies that None queries raise a ValueError with appropriate message.
+ """
+ # Arrange
+ args = {"query": None}
+
+ # Act & Assert
+ with pytest.raises(ValueError, match="Query is required and cannot exceed 250 characters"):
+ HitTestingService.hit_testing_args_check(args)
+
+ def test_hit_testing_args_check_too_long_query(self):
+ """
+ Test validation fails with query exceeding 250 characters.
+
+ Verifies that queries longer than 250 characters raise a ValueError.
+ """
+ # Arrange
+ args = {"query": "a" * 251}
+
+ # Act & Assert
+ with pytest.raises(ValueError, match="Query is required and cannot exceed 250 characters"):
+ HitTestingService.hit_testing_args_check(args)
+
+ def test_hit_testing_args_check_exactly_250_characters(self):
+ """
+ Test validation succeeds with exactly 250 characters.
+
+ Verifies that queries with exactly 250 characters (the maximum)
+ pass validation successfully.
+ """
+ # Arrange
+ args = {"query": "a" * 250}
+
+ # Act & Assert (should not raise)
+ HitTestingService.hit_testing_args_check(args)
+
+
+class TestHitTestingServiceEscapeQueryForSearch:
+ """
+ Tests for HitTestingService.escape_query_for_search method.
+
+ This test class covers query escaping functionality for external search,
+ ensuring special characters are properly escaped.
+ """
+
+ def test_escape_query_for_search_with_quotes(self):
+ """
+ Test escaping quotes in query.
+
+ Verifies that double quotes in queries are properly escaped with
+ backslashes for external search compatibility.
+ """
+ # Arrange
+ query = 'test query with "quotes"'
+
+ # Act
+ result = HitTestingService.escape_query_for_search(query)
+
+ # Assert
+ assert result == 'test query with \\"quotes\\"'
+
+ def test_escape_query_for_search_without_quotes(self):
+ """
+ Test query without quotes (no change).
+
+ Verifies that queries without quotes remain unchanged after escaping.
+ """
+ # Arrange
+ query = "test query without quotes"
+
+ # Act
+ result = HitTestingService.escape_query_for_search(query)
+
+ # Assert
+ assert result == query
+
+ def test_escape_query_for_search_multiple_quotes(self):
+ """
+ Test escaping multiple quotes in query.
+
+ Verifies that all occurrences of double quotes in a query are
+ properly escaped, not just the first one.
+ """
+ # Arrange
+ query = 'test "query" with "multiple" quotes'
+
+ # Act
+ result = HitTestingService.escape_query_for_search(query)
+
+ # Assert
+ assert result == 'test \\"query\\" with \\"multiple\\" quotes'
+
+ def test_escape_query_for_search_empty_string(self):
+ """
+ Test escaping empty string.
+
+ Verifies that empty strings are handled correctly and remain empty
+ after the escaping operation.
+ """
+ # Arrange
+ query = ""
+
+ # Act
+ result = HitTestingService.escape_query_for_search(query)
+
+ # Assert
+ assert result == ""
From e8ca80a61ad2ca2b8d19f94156d4ef9a4deb4a4c Mon Sep 17 00:00:00 2001
From: Satoshi Dev <162055292+0xsatoshi99@users.noreply.github.com>
Date: Wed, 26 Nov 2025 06:43:30 -0800
Subject: [PATCH 006/431] add unit tests for list operator node (#28597)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
---
.../workflow/nodes/list_operator/__init__.py | 1 +
.../workflow/nodes/list_operator/node_spec.py | 544 ++++++++++++++++++
2 files changed, 545 insertions(+)
create mode 100644 api/tests/unit_tests/core/workflow/nodes/list_operator/__init__.py
create mode 100644 api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py
diff --git a/api/tests/unit_tests/core/workflow/nodes/list_operator/__init__.py b/api/tests/unit_tests/core/workflow/nodes/list_operator/__init__.py
new file mode 100644
index 0000000000..8b13789179
--- /dev/null
+++ b/api/tests/unit_tests/core/workflow/nodes/list_operator/__init__.py
@@ -0,0 +1 @@
+
diff --git a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py
new file mode 100644
index 0000000000..366bec5001
--- /dev/null
+++ b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py
@@ -0,0 +1,544 @@
+from unittest.mock import MagicMock
+
+import pytest
+from core.workflow.graph_engine.entities.graph import Graph
+from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
+from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
+
+from core.variables import ArrayNumberSegment, ArrayStringSegment
+from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
+from core.workflow.nodes.list_operator.node import ListOperatorNode
+from models.workflow import WorkflowType
+
+
+class TestListOperatorNode:
+ """Comprehensive tests for ListOperatorNode."""
+
+ @pytest.fixture
+ def mock_graph_runtime_state(self):
+ """Create mock GraphRuntimeState."""
+ mock_state = MagicMock(spec=GraphRuntimeState)
+ mock_variable_pool = MagicMock()
+ mock_state.variable_pool = mock_variable_pool
+ return mock_state
+
+ @pytest.fixture
+ def mock_graph(self):
+ """Create mock Graph."""
+ return MagicMock(spec=Graph)
+
+ @pytest.fixture
+ def graph_init_params(self):
+ """Create GraphInitParams fixture."""
+ return GraphInitParams(
+ tenant_id="test",
+ app_id="test",
+ workflow_type=WorkflowType.WORKFLOW,
+ workflow_id="test",
+ graph_config={},
+ user_id="test",
+ user_from="test",
+ invoke_from="test",
+ call_depth=0,
+ )
+
+ @pytest.fixture
+ def list_operator_node_factory(self, graph_init_params, mock_graph, mock_graph_runtime_state):
+ """Factory fixture for creating ListOperatorNode instances."""
+
+ def _create_node(config, mock_variable):
+ mock_graph_runtime_state.variable_pool.get.return_value = mock_variable
+ return ListOperatorNode(
+ id="test",
+ config=config,
+ graph_init_params=graph_init_params,
+ graph=mock_graph,
+ graph_runtime_state=mock_graph_runtime_state,
+ )
+
+ return _create_node
+
+ def test_node_initialization(self, mock_graph, mock_graph_runtime_state, graph_init_params):
+ """Test node initializes correctly."""
+ config = {
+ "title": "List Operator",
+ "variable": ["sys", "list"],
+ "filter_by": {"enabled": False},
+ "order_by": {"enabled": False},
+ "limit": {"enabled": False},
+ }
+
+ node = ListOperatorNode(
+ id="test",
+ config=config,
+ graph_init_params=graph_init_params,
+ graph=mock_graph,
+ graph_runtime_state=mock_graph_runtime_state,
+ )
+
+ assert node.node_type == NodeType.LIST_OPERATOR
+ assert node._node_data.title == "List Operator"
+
+ def test_version(self):
+ """Test version returns correct value."""
+ assert ListOperatorNode.version() == "1"
+
+ def test_run_with_string_array(self, list_operator_node_factory):
+ """Test with string array."""
+ config = {
+ "title": "Test",
+ "variable": ["sys", "items"],
+ "filter_by": {"enabled": False},
+ "order_by": {"enabled": False},
+ "limit": {"enabled": False},
+ }
+
+ mock_var = ArrayStringSegment(value=["apple", "banana", "cherry"])
+ node = list_operator_node_factory(config, mock_var)
+
+ result = node._run()
+
+ assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+ assert result.outputs["result"].value == ["apple", "banana", "cherry"]
+
+ def test_run_with_empty_array(self, mock_graph, mock_graph_runtime_state, graph_init_params):
+ """Test with empty array."""
+ config = {
+ "title": "Test",
+ "variable": ["sys", "items"],
+ "filter_by": {"enabled": False},
+ "order_by": {"enabled": False},
+ "limit": {"enabled": False},
+ }
+
+ mock_var = ArrayStringSegment(value=[])
+ mock_graph_runtime_state.variable_pool.get.return_value = mock_var
+
+ node = ListOperatorNode(
+ id="test",
+ config=config,
+ graph_init_params=graph_init_params,
+ graph=mock_graph,
+ graph_runtime_state=mock_graph_runtime_state,
+ )
+
+ result = node._run()
+
+ assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+ assert result.outputs["result"].value == []
+ assert result.outputs["first_record"] is None
+ assert result.outputs["last_record"] is None
+
+ def test_run_with_filter_contains(self, mock_graph, mock_graph_runtime_state, graph_init_params):
+ """Test filter with contains condition."""
+ config = {
+ "title": "Test",
+ "variable": ["sys", "items"],
+ "filter_by": {
+ "enabled": True,
+ "condition": "contains",
+ "value": "app",
+ },
+ "order_by": {"enabled": False},
+ "limit": {"enabled": False},
+ }
+
+ mock_var = ArrayStringSegment(value=["apple", "banana", "pineapple", "cherry"])
+ mock_graph_runtime_state.variable_pool.get.return_value = mock_var
+
+ node = ListOperatorNode(
+ id="test",
+ config=config,
+ graph_init_params=graph_init_params,
+ graph=mock_graph,
+ graph_runtime_state=mock_graph_runtime_state,
+ )
+
+ result = node._run()
+
+ assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+ assert result.outputs["result"].value == ["apple", "pineapple"]
+
+ def test_run_with_filter_not_contains(self, mock_graph, mock_graph_runtime_state, graph_init_params):
+ """Test filter with not contains condition."""
+ config = {
+ "title": "Test",
+ "variable": ["sys", "items"],
+ "filter_by": {
+ "enabled": True,
+ "condition": "not contains",
+ "value": "app",
+ },
+ "order_by": {"enabled": False},
+ "limit": {"enabled": False},
+ }
+
+ mock_var = ArrayStringSegment(value=["apple", "banana", "pineapple", "cherry"])
+ mock_graph_runtime_state.variable_pool.get.return_value = mock_var
+
+ node = ListOperatorNode(
+ id="test",
+ config=config,
+ graph_init_params=graph_init_params,
+ graph=mock_graph,
+ graph_runtime_state=mock_graph_runtime_state,
+ )
+
+ result = node._run()
+
+ assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+ assert result.outputs["result"].value == ["banana", "cherry"]
+
+ def test_run_with_number_filter_greater_than(self, mock_graph, mock_graph_runtime_state, graph_init_params):
+ """Test filter with greater than condition on numbers."""
+ config = {
+ "title": "Test",
+ "variable": ["sys", "numbers"],
+ "filter_by": {
+ "enabled": True,
+ "condition": ">",
+ "value": "5",
+ },
+ "order_by": {"enabled": False},
+ "limit": {"enabled": False},
+ }
+
+ mock_var = ArrayNumberSegment(value=[1, 3, 5, 7, 9, 11])
+ mock_graph_runtime_state.variable_pool.get.return_value = mock_var
+
+ node = ListOperatorNode(
+ id="test",
+ config=config,
+ graph_init_params=graph_init_params,
+ graph=mock_graph,
+ graph_runtime_state=mock_graph_runtime_state,
+ )
+
+ result = node._run()
+
+ assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+ assert result.outputs["result"].value == [7, 9, 11]
+
+ def test_run_with_order_ascending(self, mock_graph, mock_graph_runtime_state, graph_init_params):
+ """Test ordering in ascending order."""
+ config = {
+ "title": "Test",
+ "variable": ["sys", "items"],
+ "filter_by": {"enabled": False},
+ "order_by": {
+ "enabled": True,
+ "value": "asc",
+ },
+ "limit": {"enabled": False},
+ }
+
+ mock_var = ArrayStringSegment(value=["cherry", "apple", "banana"])
+ mock_graph_runtime_state.variable_pool.get.return_value = mock_var
+
+ node = ListOperatorNode(
+ id="test",
+ config=config,
+ graph_init_params=graph_init_params,
+ graph=mock_graph,
+ graph_runtime_state=mock_graph_runtime_state,
+ )
+
+ result = node._run()
+
+ assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+ assert result.outputs["result"].value == ["apple", "banana", "cherry"]
+
+ def test_run_with_order_descending(self, mock_graph, mock_graph_runtime_state, graph_init_params):
+ """Test ordering in descending order."""
+ config = {
+ "title": "Test",
+ "variable": ["sys", "items"],
+ "filter_by": {"enabled": False},
+ "order_by": {
+ "enabled": True,
+ "value": "desc",
+ },
+ "limit": {"enabled": False},
+ }
+
+ mock_var = ArrayStringSegment(value=["cherry", "apple", "banana"])
+ mock_graph_runtime_state.variable_pool.get.return_value = mock_var
+
+ node = ListOperatorNode(
+ id="test",
+ config=config,
+ graph_init_params=graph_init_params,
+ graph=mock_graph,
+ graph_runtime_state=mock_graph_runtime_state,
+ )
+
+ result = node._run()
+
+ assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+ assert result.outputs["result"].value == ["cherry", "banana", "apple"]
+
+ def test_run_with_limit(self, mock_graph, mock_graph_runtime_state, graph_init_params):
+ """Test with limit enabled."""
+ config = {
+ "title": "Test",
+ "variable": ["sys", "items"],
+ "filter_by": {"enabled": False},
+ "order_by": {"enabled": False},
+ "limit": {
+ "enabled": True,
+ "size": 2,
+ },
+ }
+
+ mock_var = ArrayStringSegment(value=["apple", "banana", "cherry", "date"])
+ mock_graph_runtime_state.variable_pool.get.return_value = mock_var
+
+ node = ListOperatorNode(
+ id="test",
+ config=config,
+ graph_init_params=graph_init_params,
+ graph=mock_graph,
+ graph_runtime_state=mock_graph_runtime_state,
+ )
+
+ result = node._run()
+
+ assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+ assert result.outputs["result"].value == ["apple", "banana"]
+
+ def test_run_with_filter_order_and_limit(self, mock_graph, mock_graph_runtime_state, graph_init_params):
+ """Test with filter, order, and limit combined."""
+ config = {
+ "title": "Test",
+ "variable": ["sys", "numbers"],
+ "filter_by": {
+ "enabled": True,
+ "condition": ">",
+ "value": "3",
+ },
+ "order_by": {
+ "enabled": True,
+ "value": "desc",
+ },
+ "limit": {
+ "enabled": True,
+ "size": 3,
+ },
+ }
+
+ mock_var = ArrayNumberSegment(value=[1, 2, 3, 4, 5, 6, 7, 8, 9])
+ mock_graph_runtime_state.variable_pool.get.return_value = mock_var
+
+ node = ListOperatorNode(
+ id="test",
+ config=config,
+ graph_init_params=graph_init_params,
+ graph=mock_graph,
+ graph_runtime_state=mock_graph_runtime_state,
+ )
+
+ result = node._run()
+
+ assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+ assert result.outputs["result"].value == [9, 8, 7]
+
+ def test_run_with_variable_not_found(self, mock_graph, mock_graph_runtime_state, graph_init_params):
+ """Test when variable is not found."""
+ config = {
+ "title": "Test",
+ "variable": ["sys", "missing"],
+ "filter_by": {"enabled": False},
+ "order_by": {"enabled": False},
+ "limit": {"enabled": False},
+ }
+
+ mock_graph_runtime_state.variable_pool.get.return_value = None
+
+ node = ListOperatorNode(
+ id="test",
+ config=config,
+ graph_init_params=graph_init_params,
+ graph=mock_graph,
+ graph_runtime_state=mock_graph_runtime_state,
+ )
+
+ result = node._run()
+
+ assert result.status == WorkflowNodeExecutionStatus.FAILED
+ assert "Variable not found" in result.error
+
+ def test_run_with_first_and_last_record(self, mock_graph, mock_graph_runtime_state, graph_init_params):
+ """Test first_record and last_record outputs."""
+ config = {
+ "title": "Test",
+ "variable": ["sys", "items"],
+ "filter_by": {"enabled": False},
+ "order_by": {"enabled": False},
+ "limit": {"enabled": False},
+ }
+
+ mock_var = ArrayStringSegment(value=["first", "middle", "last"])
+ mock_graph_runtime_state.variable_pool.get.return_value = mock_var
+
+ node = ListOperatorNode(
+ id="test",
+ config=config,
+ graph_init_params=graph_init_params,
+ graph=mock_graph,
+ graph_runtime_state=mock_graph_runtime_state,
+ )
+
+ result = node._run()
+
+ assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+ assert result.outputs["first_record"] == "first"
+ assert result.outputs["last_record"] == "last"
+
+ def test_run_with_filter_startswith(self, mock_graph, mock_graph_runtime_state, graph_init_params):
+ """Test filter with startswith condition."""
+ config = {
+ "title": "Test",
+ "variable": ["sys", "items"],
+ "filter_by": {
+ "enabled": True,
+ "condition": "start with",
+ "value": "app",
+ },
+ "order_by": {"enabled": False},
+ "limit": {"enabled": False},
+ }
+
+ mock_var = ArrayStringSegment(value=["apple", "application", "banana", "apricot"])
+ mock_graph_runtime_state.variable_pool.get.return_value = mock_var
+
+ node = ListOperatorNode(
+ id="test",
+ config=config,
+ graph_init_params=graph_init_params,
+ graph=mock_graph,
+ graph_runtime_state=mock_graph_runtime_state,
+ )
+
+ result = node._run()
+
+ assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+ assert result.outputs["result"].value == ["apple", "application"]
+
+ def test_run_with_filter_endswith(self, mock_graph, mock_graph_runtime_state, graph_init_params):
+ """Test filter with endswith condition."""
+ config = {
+ "title": "Test",
+ "variable": ["sys", "items"],
+ "filter_by": {
+ "enabled": True,
+ "condition": "end with",
+ "value": "le",
+ },
+ "order_by": {"enabled": False},
+ "limit": {"enabled": False},
+ }
+
+ mock_var = ArrayStringSegment(value=["apple", "banana", "pineapple", "table"])
+ mock_graph_runtime_state.variable_pool.get.return_value = mock_var
+
+ node = ListOperatorNode(
+ id="test",
+ config=config,
+ graph_init_params=graph_init_params,
+ graph=mock_graph,
+ graph_runtime_state=mock_graph_runtime_state,
+ )
+
+ result = node._run()
+
+ assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+ assert result.outputs["result"].value == ["apple", "pineapple", "table"]
+
+ def test_run_with_number_filter_equals(self, mock_graph, mock_graph_runtime_state, graph_init_params):
+ """Test number filter with equals condition."""
+ config = {
+ "title": "Test",
+ "variable": ["sys", "numbers"],
+ "filter_by": {
+ "enabled": True,
+ "condition": "=",
+ "value": "5",
+ },
+ "order_by": {"enabled": False},
+ "limit": {"enabled": False},
+ }
+
+ mock_var = ArrayNumberSegment(value=[1, 3, 5, 5, 7, 9])
+ mock_graph_runtime_state.variable_pool.get.return_value = mock_var
+
+ node = ListOperatorNode(
+ id="test",
+ config=config,
+ graph_init_params=graph_init_params,
+ graph=mock_graph,
+ graph_runtime_state=mock_graph_runtime_state,
+ )
+
+ result = node._run()
+
+ assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+ assert result.outputs["result"].value == [5, 5]
+
+ def test_run_with_number_filter_not_equals(self, mock_graph, mock_graph_runtime_state, graph_init_params):
+ """Test number filter with not equals condition."""
+ config = {
+ "title": "Test",
+ "variable": ["sys", "numbers"],
+ "filter_by": {
+ "enabled": True,
+ "condition": "≠",
+ "value": "5",
+ },
+ "order_by": {"enabled": False},
+ "limit": {"enabled": False},
+ }
+
+ mock_var = ArrayNumberSegment(value=[1, 3, 5, 7, 9])
+ mock_graph_runtime_state.variable_pool.get.return_value = mock_var
+
+ node = ListOperatorNode(
+ id="test",
+ config=config,
+ graph_init_params=graph_init_params,
+ graph=mock_graph,
+ graph_runtime_state=mock_graph_runtime_state,
+ )
+
+ result = node._run()
+
+ assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+ assert result.outputs["result"].value == [1, 3, 7, 9]
+
+ def test_run_with_number_order_ascending(self, mock_graph, mock_graph_runtime_state, graph_init_params):
+ """Test number ordering in ascending order."""
+ config = {
+ "title": "Test",
+ "variable": ["sys", "numbers"],
+ "filter_by": {"enabled": False},
+ "order_by": {
+ "enabled": True,
+ "value": "asc",
+ },
+ "limit": {"enabled": False},
+ }
+
+ mock_var = ArrayNumberSegment(value=[9, 3, 7, 1, 5])
+ mock_graph_runtime_state.variable_pool.get.return_value = mock_var
+
+ node = ListOperatorNode(
+ id="test",
+ config=config,
+ graph_init_params=graph_init_params,
+ graph=mock_graph,
+ graph_runtime_state=mock_graph_runtime_state,
+ )
+
+ result = node._run()
+
+ assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+ assert result.outputs["result"].value == [1, 3, 5, 7, 9]
From 2731b04ff9d25b7d6049ea578d64746be578ab49 Mon Sep 17 00:00:00 2001
From: Asuka Minato
Date: Wed, 26 Nov 2025 23:44:14 +0900
Subject: [PATCH 007/431] Pydantic models (#28697)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
---
.../console/app/workflow_trigger.py | 42 +-
api/controllers/console/workspace/account.py | 443 ++++++++++------
api/controllers/console/workspace/members.py | 123 +++--
.../console/workspace/model_providers.py | 215 +++++---
api/controllers/console/workspace/models.py | 458 ++++++++--------
api/controllers/console/workspace/plugin.py | 490 ++++++++++--------
.../console/workspace/workspace.py | 94 ++--
api/services/account_service.py | 2 +-
8 files changed, 1065 insertions(+), 802 deletions(-)
diff --git a/api/controllers/console/app/workflow_trigger.py b/api/controllers/console/app/workflow_trigger.py
index 597ff1f6c5..b3e5c9619f 100644
--- a/api/controllers/console/app/workflow_trigger.py
+++ b/api/controllers/console/app/workflow_trigger.py
@@ -1,6 +1,8 @@
import logging
-from flask_restx import Resource, marshal_with, reqparse
+from flask import request
+from flask_restx import Resource, marshal_with
+from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
@@ -18,16 +20,30 @@ from ..app.wraps import get_app_model
from ..wraps import account_initialization_required, edit_permission_required, setup_required
logger = logging.getLogger(__name__)
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
-parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, help="Node ID is required")
+class Parser(BaseModel):
+ node_id: str
+
+
+class ParserEnable(BaseModel):
+ trigger_id: str
+ enable_trigger: bool
+
+
+console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
+
+console_ns.schema_model(
+ ParserEnable.__name__, ParserEnable.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
+)
@console_ns.route("/apps//workflows/triggers/webhook")
class WebhookTriggerApi(Resource):
"""Webhook Trigger API"""
- @console_ns.expect(parser)
+ @console_ns.expect(console_ns.models[Parser.__name__], validate=True)
@setup_required
@login_required
@account_initialization_required
@@ -35,9 +51,9 @@ class WebhookTriggerApi(Resource):
@marshal_with(webhook_trigger_fields)
def get(self, app_model: App):
"""Get webhook trigger for a node"""
- args = parser.parse_args()
+ args = Parser.model_validate(request.args.to_dict(flat=True)) # type: ignore
- node_id = str(args["node_id"])
+ node_id = args.node_id
with Session(db.engine) as session:
# Get webhook trigger for this app and node
@@ -96,16 +112,9 @@ class AppTriggersApi(Resource):
return {"data": triggers}
-parser_enable = (
- reqparse.RequestParser()
- .add_argument("trigger_id", type=str, required=True, nullable=False, location="json")
- .add_argument("enable_trigger", type=bool, required=True, nullable=False, location="json")
-)
-
-
@console_ns.route("/apps//trigger-enable")
class AppTriggerEnableApi(Resource):
- @console_ns.expect(parser_enable)
+ @console_ns.expect(console_ns.models[ParserEnable.__name__], validate=True)
@setup_required
@login_required
@account_initialization_required
@@ -114,12 +123,11 @@ class AppTriggerEnableApi(Resource):
@marshal_with(trigger_fields)
def post(self, app_model: App):
"""Update app trigger (enable/disable)"""
- args = parser_enable.parse_args()
+ args = ParserEnable.model_validate(console_ns.payload)
assert current_user.current_tenant_id is not None
- trigger_id = args["trigger_id"]
-
+ trigger_id = args.trigger_id
with Session(db.engine) as session:
# Find the trigger using select
trigger = session.execute(
@@ -134,7 +142,7 @@ class AppTriggerEnableApi(Resource):
raise NotFound("Trigger not found")
# Update status based on enable_trigger boolean
- trigger.status = AppTriggerStatus.ENABLED if args["enable_trigger"] else AppTriggerStatus.DISABLED
+ trigger.status = AppTriggerStatus.ENABLED if args.enable_trigger else AppTriggerStatus.DISABLED
session.commit()
session.refresh(trigger)
diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py
index 838cd3ee95..b4d1b42657 100644
--- a/api/controllers/console/workspace/account.py
+++ b/api/controllers/console/workspace/account.py
@@ -1,8 +1,10 @@
from datetime import datetime
+from typing import Literal
import pytz
from flask import request
-from flask_restx import Resource, fields, marshal_with, reqparse
+from flask_restx import Resource, fields, marshal_with
+from pydantic import BaseModel, Field, field_validator, model_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -42,20 +44,198 @@ from services.account_service import AccountService
from services.billing_service import BillingService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
-def _init_parser():
- parser = reqparse.RequestParser()
- if dify_config.EDITION == "CLOUD":
- parser.add_argument("invitation_code", type=str, location="json")
- parser.add_argument("interface_language", type=supported_language, required=True, location="json").add_argument(
- "timezone", type=timezone, required=True, location="json"
- )
- return parser
+
+class AccountInitPayload(BaseModel):
+ interface_language: str
+ timezone: str
+ invitation_code: str | None = None
+
+ @field_validator("interface_language")
+ @classmethod
+ def validate_language(cls, value: str) -> str:
+ return supported_language(value)
+
+ @field_validator("timezone")
+ @classmethod
+ def validate_timezone(cls, value: str) -> str:
+ return timezone(value)
+
+
+class AccountNamePayload(BaseModel):
+ name: str = Field(min_length=3, max_length=30)
+
+
+class AccountAvatarPayload(BaseModel):
+ avatar: str
+
+
+class AccountInterfaceLanguagePayload(BaseModel):
+ interface_language: str
+
+ @field_validator("interface_language")
+ @classmethod
+ def validate_language(cls, value: str) -> str:
+ return supported_language(value)
+
+
+class AccountInterfaceThemePayload(BaseModel):
+ interface_theme: Literal["light", "dark"]
+
+
+class AccountTimezonePayload(BaseModel):
+ timezone: str
+
+ @field_validator("timezone")
+ @classmethod
+ def validate_timezone(cls, value: str) -> str:
+ return timezone(value)
+
+
+class AccountPasswordPayload(BaseModel):
+ password: str | None = None
+ new_password: str
+ repeat_new_password: str
+
+ @model_validator(mode="after")
+ def check_passwords_match(self) -> "AccountPasswordPayload":
+ if self.new_password != self.repeat_new_password:
+ raise RepeatPasswordNotMatchError()
+ return self
+
+
+class AccountDeletePayload(BaseModel):
+ token: str
+ code: str
+
+
+class AccountDeletionFeedbackPayload(BaseModel):
+ email: str
+ feedback: str
+
+ @field_validator("email")
+ @classmethod
+ def validate_email(cls, value: str) -> str:
+ return email(value)
+
+
+class EducationActivatePayload(BaseModel):
+ token: str
+ institution: str
+ role: str
+
+
+class EducationAutocompleteQuery(BaseModel):
+ keywords: str
+ page: int = 0
+ limit: int = 20
+
+
+class ChangeEmailSendPayload(BaseModel):
+ email: str
+ language: str | None = None
+ phase: str | None = None
+ token: str | None = None
+
+ @field_validator("email")
+ @classmethod
+ def validate_email(cls, value: str) -> str:
+ return email(value)
+
+
+class ChangeEmailValidityPayload(BaseModel):
+ email: str
+ code: str
+ token: str
+
+ @field_validator("email")
+ @classmethod
+ def validate_email(cls, value: str) -> str:
+ return email(value)
+
+
+class ChangeEmailResetPayload(BaseModel):
+ new_email: str
+ token: str
+
+ @field_validator("new_email")
+ @classmethod
+ def validate_email(cls, value: str) -> str:
+ return email(value)
+
+
+class CheckEmailUniquePayload(BaseModel):
+ email: str
+
+ @field_validator("email")
+ @classmethod
+ def validate_email(cls, value: str) -> str:
+ return email(value)
+
+
+console_ns.schema_model(
+ AccountInitPayload.__name__, AccountInitPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
+)
+console_ns.schema_model(
+ AccountNamePayload.__name__, AccountNamePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
+)
+console_ns.schema_model(
+ AccountAvatarPayload.__name__, AccountAvatarPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
+)
+console_ns.schema_model(
+ AccountInterfaceLanguagePayload.__name__,
+ AccountInterfaceLanguagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+console_ns.schema_model(
+ AccountInterfaceThemePayload.__name__,
+ AccountInterfaceThemePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+console_ns.schema_model(
+ AccountTimezonePayload.__name__,
+ AccountTimezonePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+console_ns.schema_model(
+ AccountPasswordPayload.__name__,
+ AccountPasswordPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+console_ns.schema_model(
+ AccountDeletePayload.__name__,
+ AccountDeletePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+console_ns.schema_model(
+ AccountDeletionFeedbackPayload.__name__,
+ AccountDeletionFeedbackPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+console_ns.schema_model(
+ EducationActivatePayload.__name__,
+ EducationActivatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+console_ns.schema_model(
+ EducationAutocompleteQuery.__name__,
+ EducationAutocompleteQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+console_ns.schema_model(
+ ChangeEmailSendPayload.__name__,
+ ChangeEmailSendPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+console_ns.schema_model(
+ ChangeEmailValidityPayload.__name__,
+ ChangeEmailValidityPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+console_ns.schema_model(
+ ChangeEmailResetPayload.__name__,
+ ChangeEmailResetPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+console_ns.schema_model(
+ CheckEmailUniquePayload.__name__,
+ CheckEmailUniquePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
@console_ns.route("/account/init")
class AccountInitApi(Resource):
- @console_ns.expect(_init_parser())
+ @console_ns.expect(console_ns.models[AccountInitPayload.__name__])
@setup_required
@login_required
def post(self):
@@ -64,17 +244,18 @@ class AccountInitApi(Resource):
if account.status == "active":
raise AccountAlreadyInitedError()
- args = _init_parser().parse_args()
+ payload = console_ns.payload or {}
+ args = AccountInitPayload.model_validate(payload)
if dify_config.EDITION == "CLOUD":
- if not args["invitation_code"]:
+ if not args.invitation_code:
raise ValueError("invitation_code is required")
# check invitation code
invitation_code = (
db.session.query(InvitationCode)
.where(
- InvitationCode.code == args["invitation_code"],
+ InvitationCode.code == args.invitation_code,
InvitationCode.status == "unused",
)
.first()
@@ -88,8 +269,8 @@ class AccountInitApi(Resource):
invitation_code.used_by_tenant_id = account.current_tenant_id
invitation_code.used_by_account_id = account.id
- account.interface_language = args["interface_language"]
- account.timezone = args["timezone"]
+ account.interface_language = args.interface_language
+ account.timezone = args.timezone
account.interface_theme = "light"
account.status = "active"
account.initialized_at = naive_utc_now()
@@ -110,137 +291,104 @@ class AccountProfileApi(Resource):
return current_user
-parser_name = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
-
-
@console_ns.route("/account/name")
class AccountNameApi(Resource):
- @console_ns.expect(parser_name)
+ @console_ns.expect(console_ns.models[AccountNamePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
- args = parser_name.parse_args()
-
- # Validate account name length
- if len(args["name"]) < 3 or len(args["name"]) > 30:
- raise ValueError("Account name must be between 3 and 30 characters.")
-
- updated_account = AccountService.update_account(current_user, name=args["name"])
+ payload = console_ns.payload or {}
+ args = AccountNamePayload.model_validate(payload)
+ updated_account = AccountService.update_account(current_user, name=args.name)
return updated_account
-parser_avatar = reqparse.RequestParser().add_argument("avatar", type=str, required=True, location="json")
-
-
@console_ns.route("/account/avatar")
class AccountAvatarApi(Resource):
- @console_ns.expect(parser_avatar)
+ @console_ns.expect(console_ns.models[AccountAvatarPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
- args = parser_avatar.parse_args()
+ payload = console_ns.payload or {}
+ args = AccountAvatarPayload.model_validate(payload)
- updated_account = AccountService.update_account(current_user, avatar=args["avatar"])
+ updated_account = AccountService.update_account(current_user, avatar=args.avatar)
return updated_account
-parser_interface = reqparse.RequestParser().add_argument(
- "interface_language", type=supported_language, required=True, location="json"
-)
-
-
@console_ns.route("/account/interface-language")
class AccountInterfaceLanguageApi(Resource):
- @console_ns.expect(parser_interface)
+ @console_ns.expect(console_ns.models[AccountInterfaceLanguagePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
- args = parser_interface.parse_args()
+ payload = console_ns.payload or {}
+ args = AccountInterfaceLanguagePayload.model_validate(payload)
- updated_account = AccountService.update_account(current_user, interface_language=args["interface_language"])
+ updated_account = AccountService.update_account(current_user, interface_language=args.interface_language)
return updated_account
-parser_theme = reqparse.RequestParser().add_argument(
- "interface_theme", type=str, choices=["light", "dark"], required=True, location="json"
-)
-
-
@console_ns.route("/account/interface-theme")
class AccountInterfaceThemeApi(Resource):
- @console_ns.expect(parser_theme)
+ @console_ns.expect(console_ns.models[AccountInterfaceThemePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
- args = parser_theme.parse_args()
+ payload = console_ns.payload or {}
+ args = AccountInterfaceThemePayload.model_validate(payload)
- updated_account = AccountService.update_account(current_user, interface_theme=args["interface_theme"])
+ updated_account = AccountService.update_account(current_user, interface_theme=args.interface_theme)
return updated_account
-parser_timezone = reqparse.RequestParser().add_argument("timezone", type=str, required=True, location="json")
-
-
@console_ns.route("/account/timezone")
class AccountTimezoneApi(Resource):
- @console_ns.expect(parser_timezone)
+ @console_ns.expect(console_ns.models[AccountTimezonePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
- args = parser_timezone.parse_args()
+ payload = console_ns.payload or {}
+ args = AccountTimezonePayload.model_validate(payload)
- # Validate timezone string, e.g. America/New_York, Asia/Shanghai
- if args["timezone"] not in pytz.all_timezones:
- raise ValueError("Invalid timezone string.")
-
- updated_account = AccountService.update_account(current_user, timezone=args["timezone"])
+ updated_account = AccountService.update_account(current_user, timezone=args.timezone)
return updated_account
-parser_pw = (
- reqparse.RequestParser()
- .add_argument("password", type=str, required=False, location="json")
- .add_argument("new_password", type=str, required=True, location="json")
- .add_argument("repeat_new_password", type=str, required=True, location="json")
-)
-
-
@console_ns.route("/account/password")
class AccountPasswordApi(Resource):
- @console_ns.expect(parser_pw)
+ @console_ns.expect(console_ns.models[AccountPasswordPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
- args = parser_pw.parse_args()
-
- if args["new_password"] != args["repeat_new_password"]:
- raise RepeatPasswordNotMatchError()
+ payload = console_ns.payload or {}
+ args = AccountPasswordPayload.model_validate(payload)
try:
- AccountService.update_account_password(current_user, args["password"], args["new_password"])
+ AccountService.update_account_password(current_user, args.password, args.new_password)
except ServiceCurrentPasswordIncorrectError:
raise CurrentPasswordIncorrectError()
@@ -316,25 +464,19 @@ class AccountDeleteVerifyApi(Resource):
return {"result": "success", "data": token}
-parser_delete = (
- reqparse.RequestParser()
- .add_argument("token", type=str, required=True, location="json")
- .add_argument("code", type=str, required=True, location="json")
-)
-
-
@console_ns.route("/account/delete")
class AccountDeleteApi(Resource):
- @console_ns.expect(parser_delete)
+ @console_ns.expect(console_ns.models[AccountDeletePayload.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
account, _ = current_account_with_tenant()
- args = parser_delete.parse_args()
+ payload = console_ns.payload or {}
+ args = AccountDeletePayload.model_validate(payload)
- if not AccountService.verify_account_deletion_code(args["token"], args["code"]):
+ if not AccountService.verify_account_deletion_code(args.token, args.code):
raise InvalidAccountDeletionCodeError()
AccountService.delete_account(account)
@@ -342,21 +484,15 @@ class AccountDeleteApi(Resource):
return {"result": "success"}
-parser_feedback = (
- reqparse.RequestParser()
- .add_argument("email", type=str, required=True, location="json")
- .add_argument("feedback", type=str, required=True, location="json")
-)
-
-
@console_ns.route("/account/delete/feedback")
class AccountDeleteUpdateFeedbackApi(Resource):
- @console_ns.expect(parser_feedback)
+ @console_ns.expect(console_ns.models[AccountDeletionFeedbackPayload.__name__])
@setup_required
def post(self):
- args = parser_feedback.parse_args()
+ payload = console_ns.payload or {}
+ args = AccountDeletionFeedbackPayload.model_validate(payload)
- BillingService.update_account_deletion_feedback(args["email"], args["feedback"])
+ BillingService.update_account_deletion_feedback(args.email, args.feedback)
return {"result": "success"}
@@ -379,14 +515,6 @@ class EducationVerifyApi(Resource):
return BillingService.EducationIdentity.verify(account.id, account.email)
-parser_edu = (
- reqparse.RequestParser()
- .add_argument("token", type=str, required=True, location="json")
- .add_argument("institution", type=str, required=True, location="json")
- .add_argument("role", type=str, required=True, location="json")
-)
-
-
@console_ns.route("/account/education")
class EducationApi(Resource):
status_fields = {
@@ -396,7 +524,7 @@ class EducationApi(Resource):
"allow_refresh": fields.Boolean,
}
- @console_ns.expect(parser_edu)
+ @console_ns.expect(console_ns.models[EducationActivatePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -405,9 +533,10 @@ class EducationApi(Resource):
def post(self):
account, _ = current_account_with_tenant()
- args = parser_edu.parse_args()
+ payload = console_ns.payload or {}
+ args = EducationActivatePayload.model_validate(payload)
- return BillingService.EducationIdentity.activate(account, args["token"], args["institution"], args["role"])
+ return BillingService.EducationIdentity.activate(account, args.token, args.institution, args.role)
@setup_required
@login_required
@@ -425,14 +554,6 @@ class EducationApi(Resource):
return res
-parser_autocomplete = (
- reqparse.RequestParser()
- .add_argument("keywords", type=str, required=True, location="args")
- .add_argument("page", type=int, required=False, location="args", default=0)
- .add_argument("limit", type=int, required=False, location="args", default=20)
-)
-
-
@console_ns.route("/account/education/autocomplete")
class EducationAutoCompleteApi(Resource):
data_fields = {
@@ -441,7 +562,7 @@ class EducationAutoCompleteApi(Resource):
"has_next": fields.Boolean,
}
- @console_ns.expect(parser_autocomplete)
+ @console_ns.expect(console_ns.models[EducationAutocompleteQuery.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -449,46 +570,39 @@ class EducationAutoCompleteApi(Resource):
@cloud_edition_billing_enabled
@marshal_with(data_fields)
def get(self):
- args = parser_autocomplete.parse_args()
+ payload = request.args.to_dict(flat=True) # type: ignore
+ args = EducationAutocompleteQuery.model_validate(payload)
- return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"])
-
-
-parser_change_email = (
- reqparse.RequestParser()
- .add_argument("email", type=email, required=True, location="json")
- .add_argument("language", type=str, required=False, location="json")
- .add_argument("phase", type=str, required=False, location="json")
- .add_argument("token", type=str, required=False, location="json")
-)
+ return BillingService.EducationIdentity.autocomplete(args.keywords, args.page, args.limit)
@console_ns.route("/account/change-email")
class ChangeEmailSendEmailApi(Resource):
- @console_ns.expect(parser_change_email)
+ @console_ns.expect(console_ns.models[ChangeEmailSendPayload.__name__])
@enable_change_email
@setup_required
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
- args = parser_change_email.parse_args()
+ payload = console_ns.payload or {}
+ args = ChangeEmailSendPayload.model_validate(payload)
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
- if args["language"] is not None and args["language"] == "zh-Hans":
+ if args.language is not None and args.language == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
account = None
- user_email = args["email"]
- if args["phase"] is not None and args["phase"] == "new_email":
- if args["token"] is None:
+ user_email = args.email
+ if args.phase is not None and args.phase == "new_email":
+ if args.token is None:
raise InvalidTokenError()
- reset_data = AccountService.get_change_email_data(args["token"])
+ reset_data = AccountService.get_change_email_data(args.token)
if reset_data is None:
raise InvalidTokenError()
user_email = reset_data.get("email", "")
@@ -497,118 +611,103 @@ class ChangeEmailSendEmailApi(Resource):
raise InvalidEmailError()
else:
with Session(db.engine) as session:
- account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
+ account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
if account is None:
raise AccountNotFound()
token = AccountService.send_change_email_email(
- account=account, email=args["email"], old_email=user_email, language=language, phase=args["phase"]
+ account=account, email=args.email, old_email=user_email, language=language, phase=args.phase
)
return {"result": "success", "data": token}
-parser_validity = (
- reqparse.RequestParser()
- .add_argument("email", type=email, required=True, location="json")
- .add_argument("code", type=str, required=True, location="json")
- .add_argument("token", type=str, required=True, nullable=False, location="json")
-)
-
-
@console_ns.route("/account/change-email/validity")
class ChangeEmailCheckApi(Resource):
- @console_ns.expect(parser_validity)
+ @console_ns.expect(console_ns.models[ChangeEmailValidityPayload.__name__])
@enable_change_email
@setup_required
@login_required
@account_initialization_required
def post(self):
- args = parser_validity.parse_args()
+ payload = console_ns.payload or {}
+ args = ChangeEmailValidityPayload.model_validate(payload)
- user_email = args["email"]
+ user_email = args.email
- is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args["email"])
+ is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args.email)
if is_change_email_error_rate_limit:
raise EmailChangeLimitError()
- token_data = AccountService.get_change_email_data(args["token"])
+ token_data = AccountService.get_change_email_data(args.token)
if token_data is None:
raise InvalidTokenError()
if user_email != token_data.get("email"):
raise InvalidEmailError()
- if args["code"] != token_data.get("code"):
- AccountService.add_change_email_error_rate_limit(args["email"])
+ if args.code != token_data.get("code"):
+ AccountService.add_change_email_error_rate_limit(args.email)
raise EmailCodeError()
# Verified, revoke the first token
- AccountService.revoke_change_email_token(args["token"])
+ AccountService.revoke_change_email_token(args.token)
# Refresh token data by generating a new token
_, new_token = AccountService.generate_change_email_token(
- user_email, code=args["code"], old_email=token_data.get("old_email"), additional_data={}
+ user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={}
)
- AccountService.reset_change_email_error_rate_limit(args["email"])
+ AccountService.reset_change_email_error_rate_limit(args.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
-parser_reset = (
- reqparse.RequestParser()
- .add_argument("new_email", type=email, required=True, location="json")
- .add_argument("token", type=str, required=True, nullable=False, location="json")
-)
-
-
@console_ns.route("/account/change-email/reset")
class ChangeEmailResetApi(Resource):
- @console_ns.expect(parser_reset)
+ @console_ns.expect(console_ns.models[ChangeEmailResetPayload.__name__])
@enable_change_email
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
- args = parser_reset.parse_args()
+ payload = console_ns.payload or {}
+ args = ChangeEmailResetPayload.model_validate(payload)
- if AccountService.is_account_in_freeze(args["new_email"]):
+ if AccountService.is_account_in_freeze(args.new_email):
raise AccountInFreezeError()
- if not AccountService.check_email_unique(args["new_email"]):
+ if not AccountService.check_email_unique(args.new_email):
raise EmailAlreadyInUseError()
- reset_data = AccountService.get_change_email_data(args["token"])
+ reset_data = AccountService.get_change_email_data(args.token)
if not reset_data:
raise InvalidTokenError()
- AccountService.revoke_change_email_token(args["token"])
+ AccountService.revoke_change_email_token(args.token)
old_email = reset_data.get("old_email", "")
current_user, _ = current_account_with_tenant()
if current_user.email != old_email:
raise AccountNotFound()
- updated_account = AccountService.update_account_email(current_user, email=args["new_email"])
+ updated_account = AccountService.update_account_email(current_user, email=args.new_email)
AccountService.send_change_email_completed_notify_email(
- email=args["new_email"],
+ email=args.new_email,
)
return updated_account
-parser_check = reqparse.RequestParser().add_argument("email", type=email, required=True, location="json")
-
-
@console_ns.route("/account/change-email/check-email-unique")
class CheckEmailUnique(Resource):
- @console_ns.expect(parser_check)
+ @console_ns.expect(console_ns.models[CheckEmailUniquePayload.__name__])
@setup_required
def post(self):
- args = parser_check.parse_args()
- if AccountService.is_account_in_freeze(args["email"]):
+ payload = console_ns.payload or {}
+ args = CheckEmailUniquePayload.model_validate(payload)
+ if AccountService.is_account_in_freeze(args.email):
raise AccountInFreezeError()
- if not AccountService.check_email_unique(args["email"]):
+ if not AccountService.check_email_unique(args.email):
raise EmailAlreadyInUseError()
return {"result": "success"}
diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py
index f17f8e4bcf..f72d247398 100644
--- a/api/controllers/console/workspace/members.py
+++ b/api/controllers/console/workspace/members.py
@@ -1,7 +1,8 @@
from urllib import parse
from flask import abort, request
-from flask_restx import Resource, marshal_with, reqparse
+from flask_restx import Resource, marshal_with
+from pydantic import BaseModel, Field
import services
from configs import dify_config
@@ -31,6 +32,53 @@ from services.account_service import AccountService, RegisterService, TenantServ
from services.errors.account import AccountAlreadyInTenantError
from services.feature_service import FeatureService
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+
+
+class MemberInvitePayload(BaseModel):
+ emails: list[str] = Field(default_factory=list)
+ role: TenantAccountRole
+ language: str | None = None
+
+
+class MemberRoleUpdatePayload(BaseModel):
+ role: str
+
+
+class OwnerTransferEmailPayload(BaseModel):
+ language: str | None = None
+
+
+class OwnerTransferCheckPayload(BaseModel):
+ code: str
+ token: str
+
+
+class OwnerTransferPayload(BaseModel):
+ token: str
+
+
+console_ns.schema_model(
+ MemberInvitePayload.__name__,
+ MemberInvitePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+console_ns.schema_model(
+ MemberRoleUpdatePayload.__name__,
+ MemberRoleUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+console_ns.schema_model(
+ OwnerTransferEmailPayload.__name__,
+ OwnerTransferEmailPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+console_ns.schema_model(
+ OwnerTransferCheckPayload.__name__,
+ OwnerTransferCheckPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+console_ns.schema_model(
+ OwnerTransferPayload.__name__,
+ OwnerTransferPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+
@console_ns.route("/workspaces/current/members")
class MemberListApi(Resource):
@@ -48,29 +96,22 @@ class MemberListApi(Resource):
return {"result": "success", "accounts": members}, 200
-parser_invite = (
- reqparse.RequestParser()
- .add_argument("emails", type=list, required=True, location="json")
- .add_argument("role", type=str, required=True, default="admin", location="json")
- .add_argument("language", type=str, required=False, location="json")
-)
-
-
@console_ns.route("/workspaces/current/members/invite-email")
class MemberInviteEmailApi(Resource):
"""Invite a new member by email."""
- @console_ns.expect(parser_invite)
+ @console_ns.expect(console_ns.models[MemberInvitePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("members")
def post(self):
- args = parser_invite.parse_args()
+ payload = console_ns.payload or {}
+ args = MemberInvitePayload.model_validate(payload)
- invitee_emails = args["emails"]
- invitee_role = args["role"]
- interface_language = args["language"]
+ invitee_emails = args.emails
+ invitee_role = args.role
+ interface_language = args.language
if not TenantAccountRole.is_non_owner_role(invitee_role):
return {"code": "invalid-role", "message": "Invalid role"}, 400
current_user, _ = current_account_with_tenant()
@@ -146,20 +187,18 @@ class MemberCancelInviteApi(Resource):
}, 200
-parser_update = reqparse.RequestParser().add_argument("role", type=str, required=True, location="json")
-
-
@console_ns.route("/workspaces/current/members//update-role")
class MemberUpdateRoleApi(Resource):
"""Update member role."""
- @console_ns.expect(parser_update)
+ @console_ns.expect(console_ns.models[MemberRoleUpdatePayload.__name__])
@setup_required
@login_required
@account_initialization_required
def put(self, member_id):
- args = parser_update.parse_args()
- new_role = args["role"]
+ payload = console_ns.payload or {}
+ args = MemberRoleUpdatePayload.model_validate(payload)
+ new_role = args.role
if not TenantAccountRole.is_valid_role(new_role):
return {"code": "invalid-role", "message": "Invalid role"}, 400
@@ -197,20 +236,18 @@ class DatasetOperatorMemberListApi(Resource):
return {"result": "success", "accounts": members}, 200
-parser_send = reqparse.RequestParser().add_argument("language", type=str, required=False, location="json")
-
-
@console_ns.route("/workspaces/current/members/send-owner-transfer-confirm-email")
class SendOwnerTransferEmailApi(Resource):
"""Send owner transfer email."""
- @console_ns.expect(parser_send)
+ @console_ns.expect(console_ns.models[OwnerTransferEmailPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@is_allow_transfer_owner
def post(self):
- args = parser_send.parse_args()
+ payload = console_ns.payload or {}
+ args = OwnerTransferEmailPayload.model_validate(payload)
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
@@ -221,7 +258,7 @@ class SendOwnerTransferEmailApi(Resource):
if not TenantService.is_owner(current_user, current_user.current_tenant):
raise NotOwnerError()
- if args["language"] is not None and args["language"] == "zh-Hans":
+ if args.language is not None and args.language == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
@@ -238,22 +275,16 @@ class SendOwnerTransferEmailApi(Resource):
return {"result": "success", "data": token}
-parser_owner = (
- reqparse.RequestParser()
- .add_argument("code", type=str, required=True, location="json")
- .add_argument("token", type=str, required=True, nullable=False, location="json")
-)
-
-
@console_ns.route("/workspaces/current/members/owner-transfer-check")
class OwnerTransferCheckApi(Resource):
- @console_ns.expect(parser_owner)
+ @console_ns.expect(console_ns.models[OwnerTransferCheckPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@is_allow_transfer_owner
def post(self):
- args = parser_owner.parse_args()
+ payload = console_ns.payload or {}
+ args = OwnerTransferCheckPayload.model_validate(payload)
# check if the current user is the owner of the workspace
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
@@ -267,41 +298,37 @@ class OwnerTransferCheckApi(Resource):
if is_owner_transfer_error_rate_limit:
raise OwnerTransferLimitError()
- token_data = AccountService.get_owner_transfer_data(args["token"])
+ token_data = AccountService.get_owner_transfer_data(args.token)
if token_data is None:
raise InvalidTokenError()
if user_email != token_data.get("email"):
raise InvalidEmailError()
- if args["code"] != token_data.get("code"):
+ if args.code != token_data.get("code"):
AccountService.add_owner_transfer_error_rate_limit(user_email)
raise EmailCodeError()
# Verified, revoke the first token
- AccountService.revoke_owner_transfer_token(args["token"])
+ AccountService.revoke_owner_transfer_token(args.token)
# Refresh token data by generating a new token
- _, new_token = AccountService.generate_owner_transfer_token(user_email, code=args["code"], additional_data={})
+ _, new_token = AccountService.generate_owner_transfer_token(user_email, code=args.code, additional_data={})
AccountService.reset_owner_transfer_error_rate_limit(user_email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
-parser_owner_transfer = reqparse.RequestParser().add_argument(
- "token", type=str, required=True, nullable=False, location="json"
-)
-
-
@console_ns.route("/workspaces/current/members//owner-transfer")
class OwnerTransfer(Resource):
- @console_ns.expect(parser_owner_transfer)
+ @console_ns.expect(console_ns.models[OwnerTransferPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@is_allow_transfer_owner
def post(self, member_id):
- args = parser_owner_transfer.parse_args()
+ payload = console_ns.payload or {}
+ args = OwnerTransferPayload.model_validate(payload)
# check if the current user is the owner of the workspace
current_user, _ = current_account_with_tenant()
@@ -313,14 +340,14 @@ class OwnerTransfer(Resource):
if current_user.id == str(member_id):
raise CannotTransferOwnerToSelfError()
- transfer_token_data = AccountService.get_owner_transfer_data(args["token"])
+ transfer_token_data = AccountService.get_owner_transfer_data(args.token)
if not transfer_token_data:
raise InvalidTokenError()
if transfer_token_data.get("email") != current_user.email:
raise InvalidEmailError()
- AccountService.revoke_owner_transfer_token(args["token"])
+ AccountService.revoke_owner_transfer_token(args.token)
member = db.session.get(Account, str(member_id))
if not member:
diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py
index 8ca69121bf..d40748d5e3 100644
--- a/api/controllers/console/workspace/model_providers.py
+++ b/api/controllers/console/workspace/model_providers.py
@@ -1,31 +1,123 @@
import io
+from typing import Any, Literal
-from flask import send_file
-from flask_restx import Resource, reqparse
+from flask import request, send_file
+from flask_restx import Resource
+from pydantic import BaseModel, Field, field_validator
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
-from libs.helper import StrLen, uuid_value
+from libs.helper import uuid_value
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
from services.model_provider_service import ModelProviderService
-parser_model = reqparse.RequestParser().add_argument(
- "model_type",
- type=str,
- required=False,
- nullable=True,
- choices=[mt.value for mt in ModelType],
- location="args",
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+
+
+class ParserModelList(BaseModel):
+ model_type: ModelType | None = None
+
+
+class ParserCredentialId(BaseModel):
+ credential_id: str | None = None
+
+ @field_validator("credential_id")
+ @classmethod
+ def validate_optional_credential_id(cls, value: str | None) -> str | None:
+ if value is None:
+ return value
+ return uuid_value(value)
+
+
+class ParserCredentialCreate(BaseModel):
+ credentials: dict[str, Any]
+ name: str | None = Field(default=None, max_length=30)
+
+
+class ParserCredentialUpdate(BaseModel):
+ credential_id: str
+ credentials: dict[str, Any]
+ name: str | None = Field(default=None, max_length=30)
+
+ @field_validator("credential_id")
+ @classmethod
+ def validate_update_credential_id(cls, value: str) -> str:
+ return uuid_value(value)
+
+
+class ParserCredentialDelete(BaseModel):
+ credential_id: str
+
+ @field_validator("credential_id")
+ @classmethod
+ def validate_delete_credential_id(cls, value: str) -> str:
+ return uuid_value(value)
+
+
+class ParserCredentialSwitch(BaseModel):
+ credential_id: str
+
+ @field_validator("credential_id")
+ @classmethod
+ def validate_switch_credential_id(cls, value: str) -> str:
+ return uuid_value(value)
+
+
+class ParserCredentialValidate(BaseModel):
+ credentials: dict[str, Any]
+
+
+class ParserPreferredProviderType(BaseModel):
+ preferred_provider_type: Literal["system", "custom"]
+
+
+console_ns.schema_model(
+ ParserModelList.__name__, ParserModelList.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
+)
+
+console_ns.schema_model(
+ ParserCredentialId.__name__,
+ ParserCredentialId.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+
+console_ns.schema_model(
+ ParserCredentialCreate.__name__,
+ ParserCredentialCreate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+
+console_ns.schema_model(
+ ParserCredentialUpdate.__name__,
+ ParserCredentialUpdate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+
+console_ns.schema_model(
+ ParserCredentialDelete.__name__,
+ ParserCredentialDelete.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+
+console_ns.schema_model(
+ ParserCredentialSwitch.__name__,
+ ParserCredentialSwitch.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+
+console_ns.schema_model(
+ ParserCredentialValidate.__name__,
+ ParserCredentialValidate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+
+console_ns.schema_model(
+ ParserPreferredProviderType.__name__,
+ ParserPreferredProviderType.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/workspaces/current/model-providers")
class ModelProviderListApi(Resource):
- @console_ns.expect(parser_model)
+ @console_ns.expect(console_ns.models[ParserModelList.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -33,38 +125,18 @@ class ModelProviderListApi(Resource):
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
- args = parser_model.parse_args()
+ payload = request.args.to_dict(flat=True) # type: ignore
+ args = ParserModelList.model_validate(payload)
model_provider_service = ModelProviderService()
- provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get("model_type"))
+ provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.model_type)
return jsonable_encoder({"data": provider_list})
-parser_cred = reqparse.RequestParser().add_argument(
- "credential_id", type=uuid_value, required=False, nullable=True, location="args"
-)
-parser_post_cred = (
- reqparse.RequestParser()
- .add_argument("credentials", type=dict, required=True, nullable=False, location="json")
- .add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
-)
-
-parser_put_cred = (
- reqparse.RequestParser()
- .add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
- .add_argument("credentials", type=dict, required=True, nullable=False, location="json")
- .add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
-)
-
-parser_delete_cred = reqparse.RequestParser().add_argument(
- "credential_id", type=uuid_value, required=True, nullable=False, location="json"
-)
-
-
@console_ns.route("/workspaces/current/model-providers//credentials")
class ModelProviderCredentialApi(Resource):
- @console_ns.expect(parser_cred)
+ @console_ns.expect(console_ns.models[ParserCredentialId.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -72,23 +144,25 @@ class ModelProviderCredentialApi(Resource):
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
# if credential_id is not provided, return current used credential
- args = parser_cred.parse_args()
+ payload = request.args.to_dict(flat=True) # type: ignore
+ args = ParserCredentialId.model_validate(payload)
model_provider_service = ModelProviderService()
credentials = model_provider_service.get_provider_credential(
- tenant_id=tenant_id, provider=provider, credential_id=args.get("credential_id")
+ tenant_id=tenant_id, provider=provider, credential_id=args.credential_id
)
return {"credentials": credentials}
- @console_ns.expect(parser_post_cred)
+ @console_ns.expect(console_ns.models[ParserCredentialCreate.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
- args = parser_post_cred.parse_args()
+ payload = console_ns.payload or {}
+ args = ParserCredentialCreate.model_validate(payload)
model_provider_service = ModelProviderService()
@@ -96,15 +170,15 @@ class ModelProviderCredentialApi(Resource):
model_provider_service.create_provider_credential(
tenant_id=current_tenant_id,
provider=provider,
- credentials=args["credentials"],
- credential_name=args["name"],
+ credentials=args.credentials,
+ credential_name=args.name,
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {"result": "success"}, 201
- @console_ns.expect(parser_put_cred)
+ @console_ns.expect(console_ns.models[ParserCredentialUpdate.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@@ -112,7 +186,8 @@ class ModelProviderCredentialApi(Resource):
def put(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
- args = parser_put_cred.parse_args()
+ payload = console_ns.payload or {}
+ args = ParserCredentialUpdate.model_validate(payload)
model_provider_service = ModelProviderService()
@@ -120,71 +195,64 @@ class ModelProviderCredentialApi(Resource):
model_provider_service.update_provider_credential(
tenant_id=current_tenant_id,
provider=provider,
- credentials=args["credentials"],
- credential_id=args["credential_id"],
- credential_name=args["name"],
+ credentials=args.credentials,
+ credential_id=args.credential_id,
+ credential_name=args.name,
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {"result": "success"}
- @console_ns.expect(parser_delete_cred)
+ @console_ns.expect(console_ns.models[ParserCredentialDelete.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def delete(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
- args = parser_delete_cred.parse_args()
+ payload = console_ns.payload or {}
+ args = ParserCredentialDelete.model_validate(payload)
model_provider_service = ModelProviderService()
model_provider_service.remove_provider_credential(
- tenant_id=current_tenant_id, provider=provider, credential_id=args["credential_id"]
+ tenant_id=current_tenant_id, provider=provider, credential_id=args.credential_id
)
return {"result": "success"}, 204
-parser_switch = reqparse.RequestParser().add_argument(
- "credential_id", type=str, required=True, nullable=False, location="json"
-)
-
-
@console_ns.route("/workspaces/current/model-providers//credentials/switch")
class ModelProviderCredentialSwitchApi(Resource):
- @console_ns.expect(parser_switch)
+ @console_ns.expect(console_ns.models[ParserCredentialSwitch.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
- args = parser_switch.parse_args()
+ payload = console_ns.payload or {}
+ args = ParserCredentialSwitch.model_validate(payload)
service = ModelProviderService()
service.switch_active_provider_credential(
tenant_id=current_tenant_id,
provider=provider,
- credential_id=args["credential_id"],
+ credential_id=args.credential_id,
)
return {"result": "success"}
-parser_validate = reqparse.RequestParser().add_argument(
- "credentials", type=dict, required=True, nullable=False, location="json"
-)
-
-
@console_ns.route("/workspaces/current/model-providers//credentials/validate")
class ModelProviderValidateApi(Resource):
- @console_ns.expect(parser_validate)
+ @console_ns.expect(console_ns.models[ParserCredentialValidate.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
- args = parser_validate.parse_args()
+ payload = console_ns.payload or {}
+ args = ParserCredentialValidate.model_validate(payload)
tenant_id = current_tenant_id
@@ -195,7 +263,7 @@ class ModelProviderValidateApi(Resource):
try:
model_provider_service.validate_provider_credentials(
- tenant_id=tenant_id, provider=provider, credentials=args["credentials"]
+ tenant_id=tenant_id, provider=provider, credentials=args.credentials
)
except CredentialsValidateFailedError as ex:
result = False
@@ -228,19 +296,9 @@ class ModelProviderIconApi(Resource):
return send_file(io.BytesIO(icon), mimetype=mimetype)
-parser_preferred = reqparse.RequestParser().add_argument(
- "preferred_provider_type",
- type=str,
- required=True,
- nullable=False,
- choices=["system", "custom"],
- location="json",
-)
-
-
@console_ns.route("/workspaces/current/model-providers//preferred-provider-type")
class PreferredProviderTypeUpdateApi(Resource):
- @console_ns.expect(parser_preferred)
+ @console_ns.expect(console_ns.models[ParserPreferredProviderType.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@@ -250,11 +308,12 @@ class PreferredProviderTypeUpdateApi(Resource):
tenant_id = current_tenant_id
- args = parser_preferred.parse_args()
+ payload = console_ns.payload or {}
+ args = ParserPreferredProviderType.model_validate(payload)
model_provider_service = ModelProviderService()
model_provider_service.switch_preferred_provider(
- tenant_id=tenant_id, provider=provider, preferred_provider_type=args["preferred_provider_type"]
+ tenant_id=tenant_id, provider=provider, preferred_provider_type=args.preferred_provider_type
)
return {"result": "success"}
diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py
index 2aca73806a..8e402b4bae 100644
--- a/api/controllers/console/workspace/models.py
+++ b/api/controllers/console/workspace/models.py
@@ -1,52 +1,172 @@
import logging
+from typing import Any
-from flask_restx import Resource, reqparse
+from flask import request
+from flask_restx import Resource
+from pydantic import BaseModel, Field, field_validator
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
-from libs.helper import StrLen, uuid_value
+from libs.helper import uuid_value
from libs.login import current_account_with_tenant, login_required
from services.model_load_balancing_service import ModelLoadBalancingService
from services.model_provider_service import ModelProviderService
logger = logging.getLogger(__name__)
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
-parser_get_default = reqparse.RequestParser().add_argument(
- "model_type",
- type=str,
- required=True,
- nullable=False,
- choices=[mt.value for mt in ModelType],
- location="args",
+class ParserGetDefault(BaseModel):
+ model_type: ModelType
+
+
+class ParserPostDefault(BaseModel):
+ class Inner(BaseModel):
+ model_type: ModelType
+ model: str
+ provider: str | None = None
+
+ model_settings: list[Inner]
+
+
+console_ns.schema_model(
+ ParserGetDefault.__name__, ParserGetDefault.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
-parser_post_default = reqparse.RequestParser().add_argument(
- "model_settings", type=list, required=True, nullable=False, location="json"
+
+console_ns.schema_model(
+ ParserPostDefault.__name__, ParserPostDefault.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
+)
+
+
+class ParserDeleteModels(BaseModel):
+ model: str
+ model_type: ModelType
+
+
+console_ns.schema_model(
+ ParserDeleteModels.__name__, ParserDeleteModels.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
+)
+
+
+class LoadBalancingPayload(BaseModel):
+ configs: list[dict[str, Any]] | None = None
+ enabled: bool | None = None
+
+
+class ParserPostModels(BaseModel):
+ model: str
+ model_type: ModelType
+ load_balancing: LoadBalancingPayload | None = None
+ config_from: str | None = None
+ credential_id: str | None = None
+
+ @field_validator("credential_id")
+ @classmethod
+ def validate_credential_id(cls, value: str | None) -> str | None:
+ if value is None:
+ return value
+ return uuid_value(value)
+
+
+class ParserGetCredentials(BaseModel):
+ model: str
+ model_type: ModelType
+ config_from: str | None = None
+ credential_id: str | None = None
+
+ @field_validator("credential_id")
+ @classmethod
+ def validate_get_credential_id(cls, value: str | None) -> str | None:
+ if value is None:
+ return value
+ return uuid_value(value)
+
+
+class ParserCredentialBase(BaseModel):
+ model: str
+ model_type: ModelType
+
+
+class ParserCreateCredential(ParserCredentialBase):
+ name: str | None = Field(default=None, max_length=30)
+ credentials: dict[str, Any]
+
+
+class ParserUpdateCredential(ParserCredentialBase):
+ credential_id: str
+ credentials: dict[str, Any]
+ name: str | None = Field(default=None, max_length=30)
+
+ @field_validator("credential_id")
+ @classmethod
+ def validate_update_credential_id(cls, value: str) -> str:
+ return uuid_value(value)
+
+
+class ParserDeleteCredential(ParserCredentialBase):
+ credential_id: str
+
+ @field_validator("credential_id")
+ @classmethod
+ def validate_delete_credential_id(cls, value: str) -> str:
+ return uuid_value(value)
+
+
+class ParserParameter(BaseModel):
+ model: str
+
+
+console_ns.schema_model(
+ ParserPostModels.__name__, ParserPostModels.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
+)
+
+console_ns.schema_model(
+ ParserGetCredentials.__name__,
+ ParserGetCredentials.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+
+console_ns.schema_model(
+ ParserCreateCredential.__name__,
+ ParserCreateCredential.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+
+console_ns.schema_model(
+ ParserUpdateCredential.__name__,
+ ParserUpdateCredential.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+
+console_ns.schema_model(
+ ParserDeleteCredential.__name__,
+ ParserDeleteCredential.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+
+console_ns.schema_model(
+ ParserParameter.__name__, ParserParameter.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/workspaces/current/default-model")
class DefaultModelApi(Resource):
- @console_ns.expect(parser_get_default)
+ @console_ns.expect(console_ns.models[ParserGetDefault.__name__], validate=True)
@setup_required
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
- args = parser_get_default.parse_args()
+ args = ParserGetDefault.model_validate(request.args.to_dict(flat=True)) # type: ignore
model_provider_service = ModelProviderService()
default_model_entity = model_provider_service.get_default_model_of_model_type(
- tenant_id=tenant_id, model_type=args["model_type"]
+ tenant_id=tenant_id, model_type=args.model_type
)
return jsonable_encoder({"data": default_model_entity})
- @console_ns.expect(parser_post_default)
+ @console_ns.expect(console_ns.models[ParserPostDefault.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@@ -54,66 +174,31 @@ class DefaultModelApi(Resource):
def post(self):
_, tenant_id = current_account_with_tenant()
- args = parser_post_default.parse_args()
+ args = ParserPostDefault.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
- model_settings = args["model_settings"]
+ model_settings = args.model_settings
for model_setting in model_settings:
- if "model_type" not in model_setting or model_setting["model_type"] not in [mt.value for mt in ModelType]:
- raise ValueError("invalid model type")
-
- if "provider" not in model_setting:
+ if model_setting.provider is None:
continue
- if "model" not in model_setting:
- raise ValueError("invalid model")
-
try:
model_provider_service.update_default_model_of_model_type(
tenant_id=tenant_id,
- model_type=model_setting["model_type"],
- provider=model_setting["provider"],
- model=model_setting["model"],
+ model_type=model_setting.model_type,
+ provider=model_setting.provider,
+ model=model_setting.model,
)
except Exception as ex:
logger.exception(
"Failed to update default model, model type: %s, model: %s",
- model_setting["model_type"],
- model_setting.get("model"),
+ model_setting.model_type,
+ model_setting.model,
)
raise ex
return {"result": "success"}
-parser_post_models = (
- reqparse.RequestParser()
- .add_argument("model", type=str, required=True, nullable=False, location="json")
- .add_argument(
- "model_type",
- type=str,
- required=True,
- nullable=False,
- choices=[mt.value for mt in ModelType],
- location="json",
- )
- .add_argument("load_balancing", type=dict, required=False, nullable=True, location="json")
- .add_argument("config_from", type=str, required=False, nullable=True, location="json")
- .add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="json")
-)
-parser_delete_models = (
- reqparse.RequestParser()
- .add_argument("model", type=str, required=True, nullable=False, location="json")
- .add_argument(
- "model_type",
- type=str,
- required=True,
- nullable=False,
- choices=[mt.value for mt in ModelType],
- location="json",
- )
-)
-
-
@console_ns.route("/workspaces/current/model-providers//models")
class ModelProviderModelApi(Resource):
@setup_required
@@ -127,7 +212,7 @@ class ModelProviderModelApi(Resource):
return jsonable_encoder({"data": models})
- @console_ns.expect(parser_post_models)
+ @console_ns.expect(console_ns.models[ParserPostModels.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@@ -135,45 +220,45 @@ class ModelProviderModelApi(Resource):
def post(self, provider: str):
# To save the model's load balance configs
_, tenant_id = current_account_with_tenant()
- args = parser_post_models.parse_args()
+ args = ParserPostModels.model_validate(console_ns.payload)
- if args.get("config_from", "") == "custom-model":
- if not args.get("credential_id"):
+ if args.config_from == "custom-model":
+ if not args.credential_id:
raise ValueError("credential_id is required when configuring a custom-model")
service = ModelProviderService()
service.switch_active_custom_model_credential(
tenant_id=tenant_id,
provider=provider,
- model_type=args["model_type"],
- model=args["model"],
- credential_id=args["credential_id"],
+ model_type=args.model_type,
+ model=args.model,
+ credential_id=args.credential_id,
)
model_load_balancing_service = ModelLoadBalancingService()
- if "load_balancing" in args and args["load_balancing"] and "configs" in args["load_balancing"]:
+ if args.load_balancing and args.load_balancing.configs:
# save load balancing configs
model_load_balancing_service.update_load_balancing_configs(
tenant_id=tenant_id,
provider=provider,
- model=args["model"],
- model_type=args["model_type"],
- configs=args["load_balancing"]["configs"],
- config_from=args.get("config_from", ""),
+ model=args.model,
+ model_type=args.model_type,
+ configs=args.load_balancing.configs,
+ config_from=args.config_from or "",
)
- if args.get("load_balancing", {}).get("enabled"):
+ if args.load_balancing.enabled:
model_load_balancing_service.enable_model_load_balancing(
- tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
+ tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type
)
else:
model_load_balancing_service.disable_model_load_balancing(
- tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
+ tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type
)
return {"result": "success"}, 200
- @console_ns.expect(parser_delete_models)
+ @console_ns.expect(console_ns.models[ParserDeleteModels.__name__], validate=True)
@setup_required
@login_required
@is_admin_or_owner_required
@@ -181,113 +266,53 @@ class ModelProviderModelApi(Resource):
def delete(self, provider: str):
_, tenant_id = current_account_with_tenant()
- args = parser_delete_models.parse_args()
+ args = ParserDeleteModels.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
model_provider_service.remove_model(
- tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
+ tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type
)
return {"result": "success"}, 204
-parser_get_credentials = (
- reqparse.RequestParser()
- .add_argument("model", type=str, required=True, nullable=False, location="args")
- .add_argument(
- "model_type",
- type=str,
- required=True,
- nullable=False,
- choices=[mt.value for mt in ModelType],
- location="args",
- )
- .add_argument("config_from", type=str, required=False, nullable=True, location="args")
- .add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
-)
-
-
-parser_post_cred = (
- reqparse.RequestParser()
- .add_argument("model", type=str, required=True, nullable=False, location="json")
- .add_argument(
- "model_type",
- type=str,
- required=True,
- nullable=False,
- choices=[mt.value for mt in ModelType],
- location="json",
- )
- .add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
- .add_argument("credentials", type=dict, required=True, nullable=False, location="json")
-)
-parser_put_cred = (
- reqparse.RequestParser()
- .add_argument("model", type=str, required=True, nullable=False, location="json")
- .add_argument(
- "model_type",
- type=str,
- required=True,
- nullable=False,
- choices=[mt.value for mt in ModelType],
- location="json",
- )
- .add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
- .add_argument("credentials", type=dict, required=True, nullable=False, location="json")
- .add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
-)
-parser_delete_cred = (
- reqparse.RequestParser()
- .add_argument("model", type=str, required=True, nullable=False, location="json")
- .add_argument(
- "model_type",
- type=str,
- required=True,
- nullable=False,
- choices=[mt.value for mt in ModelType],
- location="json",
- )
- .add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
-)
-
-
@console_ns.route("/workspaces/current/model-providers//models/credentials")
class ModelProviderModelCredentialApi(Resource):
- @console_ns.expect(parser_get_credentials)
+ @console_ns.expect(console_ns.models[ParserGetCredentials.__name__])
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
_, tenant_id = current_account_with_tenant()
- args = parser_get_credentials.parse_args()
+ args = ParserGetCredentials.model_validate(request.args.to_dict(flat=True)) # type: ignore
model_provider_service = ModelProviderService()
current_credential = model_provider_service.get_model_credential(
tenant_id=tenant_id,
provider=provider,
- model_type=args["model_type"],
- model=args["model"],
- credential_id=args.get("credential_id"),
+ model_type=args.model_type,
+ model=args.model,
+ credential_id=args.credential_id,
)
model_load_balancing_service = ModelLoadBalancingService()
is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs(
tenant_id=tenant_id,
provider=provider,
- model=args["model"],
- model_type=args["model_type"],
- config_from=args.get("config_from", ""),
+ model=args.model,
+ model_type=args.model_type,
+ config_from=args.config_from or "",
)
- if args.get("config_from", "") == "predefined-model":
+ if args.config_from == "predefined-model":
available_credentials = model_provider_service.provider_manager.get_provider_available_credentials(
tenant_id=tenant_id, provider_name=provider
)
else:
- model_type = ModelType.value_of(args["model_type"]).to_origin_model_type()
+ model_type = args.model_type
available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials(
- tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args["model"]
+ tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args.model
)
return jsonable_encoder(
@@ -304,7 +329,7 @@ class ModelProviderModelCredentialApi(Resource):
}
)
- @console_ns.expect(parser_post_cred)
+ @console_ns.expect(console_ns.models[ParserCreateCredential.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@@ -312,7 +337,7 @@ class ModelProviderModelCredentialApi(Resource):
def post(self, provider: str):
_, tenant_id = current_account_with_tenant()
- args = parser_post_cred.parse_args()
+ args = ParserCreateCredential.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
@@ -320,30 +345,30 @@ class ModelProviderModelCredentialApi(Resource):
model_provider_service.create_model_credential(
tenant_id=tenant_id,
provider=provider,
- model=args["model"],
- model_type=args["model_type"],
- credentials=args["credentials"],
- credential_name=args["name"],
+ model=args.model,
+ model_type=args.model_type,
+ credentials=args.credentials,
+ credential_name=args.name,
)
except CredentialsValidateFailedError as ex:
logger.exception(
"Failed to save model credentials, tenant_id: %s, model: %s, model_type: %s",
tenant_id,
- args.get("model"),
- args.get("model_type"),
+ args.model,
+ args.model_type,
)
raise ValueError(str(ex))
return {"result": "success"}, 201
- @console_ns.expect(parser_put_cred)
+ @console_ns.expect(console_ns.models[ParserUpdateCredential.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def put(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
- args = parser_put_cred.parse_args()
+ args = ParserUpdateCredential.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
@@ -351,106 +376,87 @@ class ModelProviderModelCredentialApi(Resource):
model_provider_service.update_model_credential(
tenant_id=current_tenant_id,
provider=provider,
- model_type=args["model_type"],
- model=args["model"],
- credentials=args["credentials"],
- credential_id=args["credential_id"],
- credential_name=args["name"],
+ model_type=args.model_type,
+ model=args.model,
+ credentials=args.credentials,
+ credential_id=args.credential_id,
+ credential_name=args.name,
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {"result": "success"}
- @console_ns.expect(parser_delete_cred)
+ @console_ns.expect(console_ns.models[ParserDeleteCredential.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def delete(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
- args = parser_delete_cred.parse_args()
+ args = ParserDeleteCredential.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
model_provider_service.remove_model_credential(
tenant_id=current_tenant_id,
provider=provider,
- model_type=args["model_type"],
- model=args["model"],
- credential_id=args["credential_id"],
+ model_type=args.model_type,
+ model=args.model,
+ credential_id=args.credential_id,
)
return {"result": "success"}, 204
-parser_switch = (
- reqparse.RequestParser()
- .add_argument("model", type=str, required=True, nullable=False, location="json")
- .add_argument(
- "model_type",
- type=str,
- required=True,
- nullable=False,
- choices=[mt.value for mt in ModelType],
- location="json",
- )
- .add_argument("credential_id", type=str, required=True, nullable=False, location="json")
+class ParserSwitch(BaseModel):
+ model: str
+ model_type: ModelType
+ credential_id: str
+
+
+console_ns.schema_model(
+ ParserSwitch.__name__, ParserSwitch.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/workspaces/current/model-providers//models/credentials/switch")
class ModelProviderModelCredentialSwitchApi(Resource):
- @console_ns.expect(parser_switch)
+ @console_ns.expect(console_ns.models[ParserSwitch.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
-
- args = parser_switch.parse_args()
+ args = ParserSwitch.model_validate(console_ns.payload)
service = ModelProviderService()
service.add_model_credential_to_model_list(
tenant_id=current_tenant_id,
provider=provider,
- model_type=args["model_type"],
- model=args["model"],
- credential_id=args["credential_id"],
+ model_type=args.model_type,
+ model=args.model,
+ credential_id=args.credential_id,
)
return {"result": "success"}
-parser_model_enable_disable = (
- reqparse.RequestParser()
- .add_argument("model", type=str, required=True, nullable=False, location="json")
- .add_argument(
- "model_type",
- type=str,
- required=True,
- nullable=False,
- choices=[mt.value for mt in ModelType],
- location="json",
- )
-)
-
-
@console_ns.route(
"/workspaces/current/model-providers//models/enable", endpoint="model-provider-model-enable"
)
class ModelProviderModelEnableApi(Resource):
- @console_ns.expect(parser_model_enable_disable)
+ @console_ns.expect(console_ns.models[ParserDeleteModels.__name__])
@setup_required
@login_required
@account_initialization_required
def patch(self, provider: str):
_, tenant_id = current_account_with_tenant()
- args = parser_model_enable_disable.parse_args()
+ args = ParserDeleteModels.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
model_provider_service.enable_model(
- tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
+ tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type
)
return {"result": "success"}
@@ -460,48 +466,43 @@ class ModelProviderModelEnableApi(Resource):
"/workspaces/current/model-providers//models/disable", endpoint="model-provider-model-disable"
)
class ModelProviderModelDisableApi(Resource):
- @console_ns.expect(parser_model_enable_disable)
+ @console_ns.expect(console_ns.models[ParserDeleteModels.__name__])
@setup_required
@login_required
@account_initialization_required
def patch(self, provider: str):
_, tenant_id = current_account_with_tenant()
- args = parser_model_enable_disable.parse_args()
+ args = ParserDeleteModels.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
model_provider_service.disable_model(
- tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
+ tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type
)
return {"result": "success"}
-parser_validate = (
- reqparse.RequestParser()
- .add_argument("model", type=str, required=True, nullable=False, location="json")
- .add_argument(
- "model_type",
- type=str,
- required=True,
- nullable=False,
- choices=[mt.value for mt in ModelType],
- location="json",
- )
- .add_argument("credentials", type=dict, required=True, nullable=False, location="json")
+class ParserValidate(BaseModel):
+ model: str
+ model_type: ModelType
+ credentials: dict
+
+
+console_ns.schema_model(
+ ParserValidate.__name__, ParserValidate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/workspaces/current/model-providers//models/credentials/validate")
class ModelProviderModelValidateApi(Resource):
- @console_ns.expect(parser_validate)
+ @console_ns.expect(console_ns.models[ParserValidate.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str):
_, tenant_id = current_account_with_tenant()
-
- args = parser_validate.parse_args()
+ args = ParserValidate.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
@@ -512,9 +513,9 @@ class ModelProviderModelValidateApi(Resource):
model_provider_service.validate_model_credentials(
tenant_id=tenant_id,
provider=provider,
- model=args["model"],
- model_type=args["model_type"],
- credentials=args["credentials"],
+ model=args.model,
+ model_type=args.model_type,
+ credentials=args.credentials,
)
except CredentialsValidateFailedError as ex:
result = False
@@ -528,24 +529,19 @@ class ModelProviderModelValidateApi(Resource):
return response
-parser_parameter = reqparse.RequestParser().add_argument(
- "model", type=str, required=True, nullable=False, location="args"
-)
-
-
@console_ns.route("/workspaces/current/model-providers//models/parameter-rules")
class ModelProviderModelParameterRuleApi(Resource):
- @console_ns.expect(parser_parameter)
+ @console_ns.expect(console_ns.models[ParserParameter.__name__])
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
- args = parser_parameter.parse_args()
+ args = ParserParameter.model_validate(request.args.to_dict(flat=True)) # type: ignore
_, tenant_id = current_account_with_tenant()
model_provider_service = ModelProviderService()
parameter_rules = model_provider_service.get_model_parameter_rules(
- tenant_id=tenant_id, provider=provider, model=args["model"]
+ tenant_id=tenant_id, provider=provider, model=args.model
)
return jsonable_encoder({"data": parameter_rules})
diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py
index e3345033f8..7e08ea55f9 100644
--- a/api/controllers/console/workspace/plugin.py
+++ b/api/controllers/console/workspace/plugin.py
@@ -1,7 +1,9 @@
import io
+from typing import Literal
from flask import request, send_file
-from flask_restx import Resource, reqparse
+from flask_restx import Resource
+from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden
from configs import dify_config
@@ -17,6 +19,8 @@ from services.plugin.plugin_parameter_service import PluginParameterService
from services.plugin.plugin_permission_service import PluginPermissionService
from services.plugin.plugin_service import PluginService
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+
@console_ns.route("/workspaces/current/plugin/debugging-key")
class PluginDebuggingKeyApi(Resource):
@@ -37,88 +41,251 @@ class PluginDebuggingKeyApi(Resource):
raise ValueError(e)
-parser_list = (
- reqparse.RequestParser()
- .add_argument("page", type=int, required=False, location="args", default=1)
- .add_argument("page_size", type=int, required=False, location="args", default=256)
+class ParserList(BaseModel):
+ page: int = Field(default=1)
+ page_size: int = Field(default=256)
+
+
+console_ns.schema_model(
+ ParserList.__name__, ParserList.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/workspaces/current/plugin/list")
class PluginListApi(Resource):
- @console_ns.expect(parser_list)
+ @console_ns.expect(console_ns.models[ParserList.__name__])
@setup_required
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
- args = parser_list.parse_args()
+ args = ParserList.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
- plugins_with_total = PluginService.list_with_total(tenant_id, args["page"], args["page_size"])
+ plugins_with_total = PluginService.list_with_total(tenant_id, args.page, args.page_size)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total})
-parser_latest = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
+class ParserLatest(BaseModel):
+ plugin_ids: list[str]
+
+
+console_ns.schema_model(
+ ParserLatest.__name__, ParserLatest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
+)
+
+
+class ParserIcon(BaseModel):
+ tenant_id: str
+ filename: str
+
+
+class ParserAsset(BaseModel):
+ plugin_unique_identifier: str
+ file_name: str
+
+
+class ParserGithubUpload(BaseModel):
+ repo: str
+ version: str
+ package: str
+
+
+class ParserPluginIdentifiers(BaseModel):
+ plugin_unique_identifiers: list[str]
+
+
+class ParserGithubInstall(BaseModel):
+ plugin_unique_identifier: str
+ repo: str
+ version: str
+ package: str
+
+
+class ParserPluginIdentifierQuery(BaseModel):
+ plugin_unique_identifier: str
+
+
+class ParserTasks(BaseModel):
+ page: int
+ page_size: int
+
+
+class ParserMarketplaceUpgrade(BaseModel):
+ original_plugin_unique_identifier: str
+ new_plugin_unique_identifier: str
+
+
+class ParserGithubUpgrade(BaseModel):
+ original_plugin_unique_identifier: str
+ new_plugin_unique_identifier: str
+ repo: str
+ version: str
+ package: str
+
+
+class ParserUninstall(BaseModel):
+ plugin_installation_id: str
+
+
+class ParserPermissionChange(BaseModel):
+ install_permission: TenantPluginPermission.InstallPermission
+ debug_permission: TenantPluginPermission.DebugPermission
+
+
+class ParserDynamicOptions(BaseModel):
+ plugin_id: str
+ provider: str
+ action: str
+ parameter: str
+ credential_id: str | None = None
+ provider_type: Literal["tool", "trigger"]
+
+
+class PluginPermissionSettingsPayload(BaseModel):
+ install_permission: TenantPluginPermission.InstallPermission = TenantPluginPermission.InstallPermission.EVERYONE
+ debug_permission: TenantPluginPermission.DebugPermission = TenantPluginPermission.DebugPermission.EVERYONE
+
+
+class PluginAutoUpgradeSettingsPayload(BaseModel):
+ strategy_setting: TenantPluginAutoUpgradeStrategy.StrategySetting = (
+ TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY
+ )
+ upgrade_time_of_day: int = 0
+ upgrade_mode: TenantPluginAutoUpgradeStrategy.UpgradeMode = TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE
+ exclude_plugins: list[str] = Field(default_factory=list)
+ include_plugins: list[str] = Field(default_factory=list)
+
+
+class ParserPreferencesChange(BaseModel):
+ permission: PluginPermissionSettingsPayload
+ auto_upgrade: PluginAutoUpgradeSettingsPayload
+
+
+class ParserExcludePlugin(BaseModel):
+ plugin_id: str
+
+
+class ParserReadme(BaseModel):
+ plugin_unique_identifier: str
+ language: str = Field(default="en-US")
+
+
+console_ns.schema_model(
+ ParserIcon.__name__, ParserIcon.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
+)
+
+console_ns.schema_model(
+ ParserAsset.__name__, ParserAsset.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
+)
+
+console_ns.schema_model(
+ ParserGithubUpload.__name__, ParserGithubUpload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
+)
+
+console_ns.schema_model(
+ ParserPluginIdentifiers.__name__,
+ ParserPluginIdentifiers.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+
+console_ns.schema_model(
+ ParserGithubInstall.__name__, ParserGithubInstall.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
+)
+
+console_ns.schema_model(
+ ParserPluginIdentifierQuery.__name__,
+ ParserPluginIdentifierQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+
+console_ns.schema_model(
+ ParserTasks.__name__, ParserTasks.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
+)
+
+console_ns.schema_model(
+ ParserMarketplaceUpgrade.__name__,
+ ParserMarketplaceUpgrade.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+
+console_ns.schema_model(
+ ParserGithubUpgrade.__name__, ParserGithubUpgrade.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
+)
+
+console_ns.schema_model(
+ ParserUninstall.__name__, ParserUninstall.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
+)
+
+console_ns.schema_model(
+ ParserPermissionChange.__name__,
+ ParserPermissionChange.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+
+console_ns.schema_model(
+ ParserDynamicOptions.__name__,
+ ParserDynamicOptions.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+
+console_ns.schema_model(
+ ParserPreferencesChange.__name__,
+ ParserPreferencesChange.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+
+console_ns.schema_model(
+ ParserExcludePlugin.__name__,
+ ParserExcludePlugin.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+
+console_ns.schema_model(
+ ParserReadme.__name__, ParserReadme.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
+)
@console_ns.route("/workspaces/current/plugin/list/latest-versions")
class PluginListLatestVersionsApi(Resource):
- @console_ns.expect(parser_latest)
+ @console_ns.expect(console_ns.models[ParserLatest.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
- args = parser_latest.parse_args()
+ args = ParserLatest.model_validate(console_ns.payload)
try:
- versions = PluginService.list_latest_versions(args["plugin_ids"])
+ versions = PluginService.list_latest_versions(args.plugin_ids)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder({"versions": versions})
-parser_ids = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
-
-
@console_ns.route("/workspaces/current/plugin/list/installations/ids")
class PluginListInstallationsFromIdsApi(Resource):
- @console_ns.expect(parser_ids)
+ @console_ns.expect(console_ns.models[ParserLatest.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
_, tenant_id = current_account_with_tenant()
- args = parser_ids.parse_args()
+ args = ParserLatest.model_validate(console_ns.payload)
try:
- plugins = PluginService.list_installations_from_ids(tenant_id, args["plugin_ids"])
+ plugins = PluginService.list_installations_from_ids(tenant_id, args.plugin_ids)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder({"plugins": plugins})
-parser_icon = (
- reqparse.RequestParser()
- .add_argument("tenant_id", type=str, required=True, location="args")
- .add_argument("filename", type=str, required=True, location="args")
-)
-
-
@console_ns.route("/workspaces/current/plugin/icon")
class PluginIconApi(Resource):
- @console_ns.expect(parser_icon)
+ @console_ns.expect(console_ns.models[ParserIcon.__name__])
@setup_required
def get(self):
- args = parser_icon.parse_args()
+ args = ParserIcon.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
- icon_bytes, mimetype = PluginService.get_asset(args["tenant_id"], args["filename"])
+ icon_bytes, mimetype = PluginService.get_asset(args.tenant_id, args.filename)
except PluginDaemonClientSideError as e:
raise ValueError(e)
@@ -128,20 +295,16 @@ class PluginIconApi(Resource):
@console_ns.route("/workspaces/current/plugin/asset")
class PluginAssetApi(Resource):
+ @console_ns.expect(console_ns.models[ParserAsset.__name__])
@setup_required
@login_required
@account_initialization_required
def get(self):
- req = (
- reqparse.RequestParser()
- .add_argument("plugin_unique_identifier", type=str, required=True, location="args")
- .add_argument("file_name", type=str, required=True, location="args")
- )
- args = req.parse_args()
+ args = ParserAsset.model_validate(request.args.to_dict(flat=True)) # type: ignore
_, tenant_id = current_account_with_tenant()
try:
- binary = PluginService.extract_asset(tenant_id, args["plugin_unique_identifier"], args["file_name"])
+ binary = PluginService.extract_asset(tenant_id, args.plugin_unique_identifier, args.file_name)
return send_file(io.BytesIO(binary), mimetype="application/octet-stream")
except PluginDaemonClientSideError as e:
raise ValueError(e)
@@ -171,17 +334,9 @@ class PluginUploadFromPkgApi(Resource):
return jsonable_encoder(response)
-parser_github = (
- reqparse.RequestParser()
- .add_argument("repo", type=str, required=True, location="json")
- .add_argument("version", type=str, required=True, location="json")
- .add_argument("package", type=str, required=True, location="json")
-)
-
-
@console_ns.route("/workspaces/current/plugin/upload/github")
class PluginUploadFromGithubApi(Resource):
- @console_ns.expect(parser_github)
+ @console_ns.expect(console_ns.models[ParserGithubUpload.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -189,10 +344,10 @@ class PluginUploadFromGithubApi(Resource):
def post(self):
_, tenant_id = current_account_with_tenant()
- args = parser_github.parse_args()
+ args = ParserGithubUpload.model_validate(console_ns.payload)
try:
- response = PluginService.upload_pkg_from_github(tenant_id, args["repo"], args["version"], args["package"])
+ response = PluginService.upload_pkg_from_github(tenant_id, args.repo, args.version, args.package)
except PluginDaemonClientSideError as e:
raise ValueError(e)
@@ -223,47 +378,28 @@ class PluginUploadFromBundleApi(Resource):
return jsonable_encoder(response)
-parser_pkg = reqparse.RequestParser().add_argument(
- "plugin_unique_identifiers", type=list, required=True, location="json"
-)
-
-
@console_ns.route("/workspaces/current/plugin/install/pkg")
class PluginInstallFromPkgApi(Resource):
- @console_ns.expect(parser_pkg)
+ @console_ns.expect(console_ns.models[ParserPluginIdentifiers.__name__])
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
_, tenant_id = current_account_with_tenant()
- args = parser_pkg.parse_args()
-
- # check if all plugin_unique_identifiers are valid string
- for plugin_unique_identifier in args["plugin_unique_identifiers"]:
- if not isinstance(plugin_unique_identifier, str):
- raise ValueError("Invalid plugin unique identifier")
+ args = ParserPluginIdentifiers.model_validate(console_ns.payload)
try:
- response = PluginService.install_from_local_pkg(tenant_id, args["plugin_unique_identifiers"])
+ response = PluginService.install_from_local_pkg(tenant_id, args.plugin_unique_identifiers)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder(response)
-parser_githubapi = (
- reqparse.RequestParser()
- .add_argument("repo", type=str, required=True, location="json")
- .add_argument("version", type=str, required=True, location="json")
- .add_argument("package", type=str, required=True, location="json")
- .add_argument("plugin_unique_identifier", type=str, required=True, location="json")
-)
-
-
@console_ns.route("/workspaces/current/plugin/install/github")
class PluginInstallFromGithubApi(Resource):
- @console_ns.expect(parser_githubapi)
+ @console_ns.expect(console_ns.models[ParserGithubInstall.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -271,15 +407,15 @@ class PluginInstallFromGithubApi(Resource):
def post(self):
_, tenant_id = current_account_with_tenant()
- args = parser_githubapi.parse_args()
+ args = ParserGithubInstall.model_validate(console_ns.payload)
try:
response = PluginService.install_from_github(
tenant_id,
- args["plugin_unique_identifier"],
- args["repo"],
- args["version"],
- args["package"],
+ args.plugin_unique_identifier,
+ args.repo,
+ args.version,
+ args.package,
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
@@ -287,14 +423,9 @@ class PluginInstallFromGithubApi(Resource):
return jsonable_encoder(response)
-parser_marketplace = reqparse.RequestParser().add_argument(
- "plugin_unique_identifiers", type=list, required=True, location="json"
-)
-
-
@console_ns.route("/workspaces/current/plugin/install/marketplace")
class PluginInstallFromMarketplaceApi(Resource):
- @console_ns.expect(parser_marketplace)
+ @console_ns.expect(console_ns.models[ParserPluginIdentifiers.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -302,43 +433,33 @@ class PluginInstallFromMarketplaceApi(Resource):
def post(self):
_, tenant_id = current_account_with_tenant()
- args = parser_marketplace.parse_args()
-
- # check if all plugin_unique_identifiers are valid string
- for plugin_unique_identifier in args["plugin_unique_identifiers"]:
- if not isinstance(plugin_unique_identifier, str):
- raise ValueError("Invalid plugin unique identifier")
+ args = ParserPluginIdentifiers.model_validate(console_ns.payload)
try:
- response = PluginService.install_from_marketplace_pkg(tenant_id, args["plugin_unique_identifiers"])
+ response = PluginService.install_from_marketplace_pkg(tenant_id, args.plugin_unique_identifiers)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder(response)
-parser_pkgapi = reqparse.RequestParser().add_argument(
- "plugin_unique_identifier", type=str, required=True, location="args"
-)
-
-
@console_ns.route("/workspaces/current/plugin/marketplace/pkg")
class PluginFetchMarketplacePkgApi(Resource):
- @console_ns.expect(parser_pkgapi)
+ @console_ns.expect(console_ns.models[ParserPluginIdentifierQuery.__name__])
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def get(self):
_, tenant_id = current_account_with_tenant()
- args = parser_pkgapi.parse_args()
+ args = ParserPluginIdentifierQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
return jsonable_encoder(
{
"manifest": PluginService.fetch_marketplace_pkg(
tenant_id,
- args["plugin_unique_identifier"],
+ args.plugin_unique_identifier,
)
}
)
@@ -346,14 +467,9 @@ class PluginFetchMarketplacePkgApi(Resource):
raise ValueError(e)
-parser_fetch = reqparse.RequestParser().add_argument(
- "plugin_unique_identifier", type=str, required=True, location="args"
-)
-
-
@console_ns.route("/workspaces/current/plugin/fetch-manifest")
class PluginFetchManifestApi(Resource):
- @console_ns.expect(parser_fetch)
+ @console_ns.expect(console_ns.models[ParserPluginIdentifierQuery.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -361,30 +477,19 @@ class PluginFetchManifestApi(Resource):
def get(self):
_, tenant_id = current_account_with_tenant()
- args = parser_fetch.parse_args()
+ args = ParserPluginIdentifierQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
return jsonable_encoder(
- {
- "manifest": PluginService.fetch_plugin_manifest(
- tenant_id, args["plugin_unique_identifier"]
- ).model_dump()
- }
+ {"manifest": PluginService.fetch_plugin_manifest(tenant_id, args.plugin_unique_identifier).model_dump()}
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
-parser_tasks = (
- reqparse.RequestParser()
- .add_argument("page", type=int, required=True, location="args")
- .add_argument("page_size", type=int, required=True, location="args")
-)
-
-
@console_ns.route("/workspaces/current/plugin/tasks")
class PluginFetchInstallTasksApi(Resource):
- @console_ns.expect(parser_tasks)
+ @console_ns.expect(console_ns.models[ParserTasks.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -392,12 +497,10 @@ class PluginFetchInstallTasksApi(Resource):
def get(self):
_, tenant_id = current_account_with_tenant()
- args = parser_tasks.parse_args()
+ args = ParserTasks.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
- return jsonable_encoder(
- {"tasks": PluginService.fetch_install_tasks(tenant_id, args["page"], args["page_size"])}
- )
+ return jsonable_encoder({"tasks": PluginService.fetch_install_tasks(tenant_id, args.page, args.page_size)})
except PluginDaemonClientSideError as e:
raise ValueError(e)
@@ -462,16 +565,9 @@ class PluginDeleteInstallTaskItemApi(Resource):
raise ValueError(e)
-parser_marketplace_api = (
- reqparse.RequestParser()
- .add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
- .add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
-)
-
-
@console_ns.route("/workspaces/current/plugin/upgrade/marketplace")
class PluginUpgradeFromMarketplaceApi(Resource):
- @console_ns.expect(parser_marketplace_api)
+ @console_ns.expect(console_ns.models[ParserMarketplaceUpgrade.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -479,31 +575,21 @@ class PluginUpgradeFromMarketplaceApi(Resource):
def post(self):
_, tenant_id = current_account_with_tenant()
- args = parser_marketplace_api.parse_args()
+ args = ParserMarketplaceUpgrade.model_validate(console_ns.payload)
try:
return jsonable_encoder(
PluginService.upgrade_plugin_with_marketplace(
- tenant_id, args["original_plugin_unique_identifier"], args["new_plugin_unique_identifier"]
+ tenant_id, args.original_plugin_unique_identifier, args.new_plugin_unique_identifier
)
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
-parser_github_post = (
- reqparse.RequestParser()
- .add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
- .add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
- .add_argument("repo", type=str, required=True, location="json")
- .add_argument("version", type=str, required=True, location="json")
- .add_argument("package", type=str, required=True, location="json")
-)
-
-
@console_ns.route("/workspaces/current/plugin/upgrade/github")
class PluginUpgradeFromGithubApi(Resource):
- @console_ns.expect(parser_github_post)
+ @console_ns.expect(console_ns.models[ParserGithubUpgrade.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -511,56 +597,44 @@ class PluginUpgradeFromGithubApi(Resource):
def post(self):
_, tenant_id = current_account_with_tenant()
- args = parser_github_post.parse_args()
+ args = ParserGithubUpgrade.model_validate(console_ns.payload)
try:
return jsonable_encoder(
PluginService.upgrade_plugin_with_github(
tenant_id,
- args["original_plugin_unique_identifier"],
- args["new_plugin_unique_identifier"],
- args["repo"],
- args["version"],
- args["package"],
+ args.original_plugin_unique_identifier,
+ args.new_plugin_unique_identifier,
+ args.repo,
+ args.version,
+ args.package,
)
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
-parser_uninstall = reqparse.RequestParser().add_argument(
- "plugin_installation_id", type=str, required=True, location="json"
-)
-
-
@console_ns.route("/workspaces/current/plugin/uninstall")
class PluginUninstallApi(Resource):
- @console_ns.expect(parser_uninstall)
+ @console_ns.expect(console_ns.models[ParserUninstall.__name__])
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
- args = parser_uninstall.parse_args()
+ args = ParserUninstall.model_validate(console_ns.payload)
_, tenant_id = current_account_with_tenant()
try:
- return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])}
+ return {"success": PluginService.uninstall(tenant_id, args.plugin_installation_id)}
except PluginDaemonClientSideError as e:
raise ValueError(e)
-parser_change_post = (
- reqparse.RequestParser()
- .add_argument("install_permission", type=str, required=True, location="json")
- .add_argument("debug_permission", type=str, required=True, location="json")
-)
-
-
@console_ns.route("/workspaces/current/plugin/permission/change")
class PluginChangePermissionApi(Resource):
- @console_ns.expect(parser_change_post)
+ @console_ns.expect(console_ns.models[ParserPermissionChange.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -570,14 +644,15 @@ class PluginChangePermissionApi(Resource):
if not user.is_admin_or_owner:
raise Forbidden()
- args = parser_change_post.parse_args()
-
- install_permission = TenantPluginPermission.InstallPermission(args["install_permission"])
- debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"])
+ args = ParserPermissionChange.model_validate(console_ns.payload)
tenant_id = current_tenant_id
- return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)}
+ return {
+ "success": PluginPermissionService.change_permission(
+ tenant_id, args.install_permission, args.debug_permission
+ )
+ }
@console_ns.route("/workspaces/current/plugin/permission/fetch")
@@ -605,20 +680,9 @@ class PluginFetchPermissionApi(Resource):
)
-parser_dynamic = (
- reqparse.RequestParser()
- .add_argument("plugin_id", type=str, required=True, location="args")
- .add_argument("provider", type=str, required=True, location="args")
- .add_argument("action", type=str, required=True, location="args")
- .add_argument("parameter", type=str, required=True, location="args")
- .add_argument("credential_id", type=str, required=False, location="args")
- .add_argument("provider_type", type=str, required=True, location="args")
-)
-
-
@console_ns.route("/workspaces/current/plugin/parameters/dynamic-options")
class PluginFetchDynamicSelectOptionsApi(Resource):
- @console_ns.expect(parser_dynamic)
+ @console_ns.expect(console_ns.models[ParserDynamicOptions.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@@ -627,18 +691,18 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
current_user, tenant_id = current_account_with_tenant()
user_id = current_user.id
- args = parser_dynamic.parse_args()
+ args = ParserDynamicOptions.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
options = PluginParameterService.get_dynamic_select_options(
tenant_id=tenant_id,
user_id=user_id,
- plugin_id=args["plugin_id"],
- provider=args["provider"],
- action=args["action"],
- parameter=args["parameter"],
- credential_id=args["credential_id"],
- provider_type=args["provider_type"],
+ plugin_id=args.plugin_id,
+ provider=args.provider,
+ action=args.action,
+ parameter=args.parameter,
+ credential_id=args.credential_id,
+ provider_type=args.provider_type,
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
@@ -646,16 +710,9 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
return jsonable_encoder({"options": options})
-parser_change = (
- reqparse.RequestParser()
- .add_argument("permission", type=dict, required=True, location="json")
- .add_argument("auto_upgrade", type=dict, required=True, location="json")
-)
-
-
@console_ns.route("/workspaces/current/plugin/preferences/change")
class PluginChangePreferencesApi(Resource):
- @console_ns.expect(parser_change)
+ @console_ns.expect(console_ns.models[ParserPreferencesChange.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -664,22 +721,20 @@ class PluginChangePreferencesApi(Resource):
if not user.is_admin_or_owner:
raise Forbidden()
- args = parser_change.parse_args()
+ args = ParserPreferencesChange.model_validate(console_ns.payload)
- permission = args["permission"]
+ permission = args.permission
- install_permission = TenantPluginPermission.InstallPermission(permission.get("install_permission", "everyone"))
- debug_permission = TenantPluginPermission.DebugPermission(permission.get("debug_permission", "everyone"))
+ install_permission = permission.install_permission
+ debug_permission = permission.debug_permission
- auto_upgrade = args["auto_upgrade"]
+ auto_upgrade = args.auto_upgrade
- strategy_setting = TenantPluginAutoUpgradeStrategy.StrategySetting(
- auto_upgrade.get("strategy_setting", "fix_only")
- )
- upgrade_time_of_day = auto_upgrade.get("upgrade_time_of_day", 0)
- upgrade_mode = TenantPluginAutoUpgradeStrategy.UpgradeMode(auto_upgrade.get("upgrade_mode", "exclude"))
- exclude_plugins = auto_upgrade.get("exclude_plugins", [])
- include_plugins = auto_upgrade.get("include_plugins", [])
+ strategy_setting = auto_upgrade.strategy_setting
+ upgrade_time_of_day = auto_upgrade.upgrade_time_of_day
+ upgrade_mode = auto_upgrade.upgrade_mode
+ exclude_plugins = auto_upgrade.exclude_plugins
+ include_plugins = auto_upgrade.include_plugins
# set permission
set_permission_result = PluginPermissionService.change_permission(
@@ -744,12 +799,9 @@ class PluginFetchPreferencesApi(Resource):
return jsonable_encoder({"permission": permission_dict, "auto_upgrade": auto_upgrade_dict})
-parser_exclude = reqparse.RequestParser().add_argument("plugin_id", type=str, required=True, location="json")
-
-
@console_ns.route("/workspaces/current/plugin/preferences/autoupgrade/exclude")
class PluginAutoUpgradeExcludePluginApi(Resource):
- @console_ns.expect(parser_exclude)
+ @console_ns.expect(console_ns.models[ParserExcludePlugin.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -757,28 +809,20 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
# exclude one single plugin
_, tenant_id = current_account_with_tenant()
- args = parser_exclude.parse_args()
+ args = ParserExcludePlugin.model_validate(console_ns.payload)
- return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])})
+ return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args.plugin_id)})
@console_ns.route("/workspaces/current/plugin/readme")
class PluginReadmeApi(Resource):
+ @console_ns.expect(console_ns.models[ParserReadme.__name__])
@setup_required
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
- parser = (
- reqparse.RequestParser()
- .add_argument("plugin_unique_identifier", type=str, required=True, location="args")
- .add_argument("language", type=str, required=False, location="args")
- )
- args = parser.parse_args()
+ args = ParserReadme.model_validate(request.args.to_dict(flat=True)) # type: ignore
return jsonable_encoder(
- {
- "readme": PluginService.fetch_plugin_readme(
- tenant_id, args["plugin_unique_identifier"], args.get("language", "en-US")
- )
- }
+ {"readme": PluginService.fetch_plugin_readme(tenant_id, args.plugin_unique_identifier, args.language)}
)
diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py
index 37c7dc3040..9b76cb7a9c 100644
--- a/api/controllers/console/workspace/workspace.py
+++ b/api/controllers/console/workspace/workspace.py
@@ -1,7 +1,8 @@
import logging
from flask import request
-from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
+from flask_restx import Resource, fields, marshal, marshal_with
+from pydantic import BaseModel, Field
from sqlalchemy import select
from werkzeug.exceptions import Unauthorized
@@ -32,6 +33,45 @@ from services.file_service import FileService
from services.workspace_service import WorkspaceService
logger = logging.getLogger(__name__)
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+
+
+class WorkspaceListQuery(BaseModel):
+ page: int = Field(default=1, ge=1, le=99999)
+ limit: int = Field(default=20, ge=1, le=100)
+
+
+class SwitchWorkspacePayload(BaseModel):
+ tenant_id: str
+
+
+class WorkspaceCustomConfigPayload(BaseModel):
+ remove_webapp_brand: bool | None = None
+ replace_webapp_logo: str | None = None
+
+
+class WorkspaceInfoPayload(BaseModel):
+ name: str
+
+
+console_ns.schema_model(
+ WorkspaceListQuery.__name__, WorkspaceListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
+)
+
+console_ns.schema_model(
+ SwitchWorkspacePayload.__name__,
+ SwitchWorkspacePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+
+console_ns.schema_model(
+ WorkspaceCustomConfigPayload.__name__,
+ WorkspaceCustomConfigPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+
+console_ns.schema_model(
+ WorkspaceInfoPayload.__name__,
+ WorkspaceInfoPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
provider_fields = {
@@ -95,18 +135,15 @@ class TenantListApi(Resource):
@console_ns.route("/all-workspaces")
class WorkspaceListApi(Resource):
+ @console_ns.expect(console_ns.models[WorkspaceListQuery.__name__])
@setup_required
@admin_required
def get(self):
- parser = (
- reqparse.RequestParser()
- .add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
- .add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
- )
- args = parser.parse_args()
+ payload = request.args.to_dict(flat=True) # type: ignore
+ args = WorkspaceListQuery.model_validate(payload)
stmt = select(Tenant).order_by(Tenant.created_at.desc())
- tenants = db.paginate(select=stmt, page=args["page"], per_page=args["limit"], error_out=False)
+ tenants = db.paginate(select=stmt, page=args.page, per_page=args.limit, error_out=False)
has_more = False
if tenants.has_next:
@@ -115,8 +152,8 @@ class WorkspaceListApi(Resource):
return {
"data": marshal(tenants.items, workspace_fields),
"has_more": has_more,
- "limit": args["limit"],
- "page": args["page"],
+ "limit": args.limit,
+ "page": args.page,
"total": tenants.total,
}, 200
@@ -150,26 +187,24 @@ class TenantApi(Resource):
return WorkspaceService.get_tenant_info(tenant), 200
-parser_switch = reqparse.RequestParser().add_argument("tenant_id", type=str, required=True, location="json")
-
-
@console_ns.route("/workspaces/switch")
class SwitchWorkspaceApi(Resource):
- @console_ns.expect(parser_switch)
+ @console_ns.expect(console_ns.models[SwitchWorkspacePayload.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
- args = parser_switch.parse_args()
+ payload = console_ns.payload or {}
+ args = SwitchWorkspacePayload.model_validate(payload)
# check if tenant_id is valid, 403 if not
try:
- TenantService.switch_tenant(current_user, args["tenant_id"])
+ TenantService.switch_tenant(current_user, args.tenant_id)
except Exception:
raise AccountNotLinkTenantError("Account not link tenant")
- new_tenant = db.session.query(Tenant).get(args["tenant_id"]) # Get new tenant
+ new_tenant = db.session.query(Tenant).get(args.tenant_id) # Get new tenant
if new_tenant is None:
raise ValueError("Tenant not found")
@@ -178,24 +213,21 @@ class SwitchWorkspaceApi(Resource):
@console_ns.route("/workspaces/custom-config")
class CustomConfigWorkspaceApi(Resource):
+ @console_ns.expect(console_ns.models[WorkspaceCustomConfigPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("workspace_custom")
def post(self):
_, current_tenant_id = current_account_with_tenant()
- parser = (
- reqparse.RequestParser()
- .add_argument("remove_webapp_brand", type=bool, location="json")
- .add_argument("replace_webapp_logo", type=str, location="json")
- )
- args = parser.parse_args()
+ payload = console_ns.payload or {}
+ args = WorkspaceCustomConfigPayload.model_validate(payload)
tenant = db.get_or_404(Tenant, current_tenant_id)
custom_config_dict = {
- "remove_webapp_brand": args["remove_webapp_brand"],
- "replace_webapp_logo": args["replace_webapp_logo"]
- if args["replace_webapp_logo"] is not None
+ "remove_webapp_brand": args.remove_webapp_brand,
+ "replace_webapp_logo": args.replace_webapp_logo
+ if args.replace_webapp_logo is not None
else tenant.custom_config_dict.get("replace_webapp_logo"),
}
@@ -245,24 +277,22 @@ class WebappLogoWorkspaceApi(Resource):
return {"id": upload_file.id}, 201
-parser_info = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
-
-
@console_ns.route("/workspaces/info")
class WorkspaceInfoApi(Resource):
- @console_ns.expect(parser_info)
+ @console_ns.expect(console_ns.models[WorkspaceInfoPayload.__name__])
@setup_required
@login_required
@account_initialization_required
# Change workspace name
def post(self):
_, current_tenant_id = current_account_with_tenant()
- args = parser_info.parse_args()
+ payload = console_ns.payload or {}
+ args = WorkspaceInfoPayload.model_validate(payload)
if not current_tenant_id:
raise ValueError("No current tenant")
tenant = db.get_or_404(Tenant, current_tenant_id)
- tenant.name = args["name"]
+ tenant.name = args.name
db.session.commit()
return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
diff --git a/api/services/account_service.py b/api/services/account_service.py
index 13c3993fb5..ac6d1bde77 100644
--- a/api/services/account_service.py
+++ b/api/services/account_service.py
@@ -1352,7 +1352,7 @@ class RegisterService:
@classmethod
def invite_new_member(
- cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Account | None = None
+ cls, tenant: Tenant, email: str, language: str | None, role: str = "normal", inviter: Account | None = None
) -> str:
if not inviter:
raise ValueError("Inviter is required")
From 1e23957657bd4cd8a550a372243880f0f129d24e Mon Sep 17 00:00:00 2001
From: XlKsyt
Date: Wed, 26 Nov 2025 22:45:20 +0800
Subject: [PATCH 008/431] fix(ops): add streaming metrics and LLM span for
agent-chat traces (#28320)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
---
.../advanced_chat/generate_task_pipeline.py | 90 +++++++++++++++++--
api/core/app/entities/task_entities.py | 3 +
.../easy_ui_based_generate_task_pipeline.py | 18 ++++
api/core/ops/tencent_trace/span_builder.py | 53 +++++++++++
api/core/ops/tencent_trace/tencent_trace.py | 6 +-
api/models/model.py | 8 ++
6 files changed, 171 insertions(+), 7 deletions(-)
diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py
index 01c377956b..c98bc1ffdd 100644
--- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py
+++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py
@@ -62,7 +62,8 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
-from core.ops.ops_trace_manager import TraceQueueManager
+from core.ops.entities.trace_entity import TraceTaskName
+from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.nodes import NodeType
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
@@ -72,7 +73,7 @@ from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import Account, Conversation, EndUser, Message, MessageFile
from models.enums import CreatorUserRole
-from models.workflow import Workflow
+from models.workflow import Workflow, WorkflowNodeExecutionModel
logger = logging.getLogger(__name__)
@@ -580,7 +581,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
with self._database_session() as session:
# Save message
- self._save_message(session=session, graph_runtime_state=resolved_state)
+ self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager)
yield workflow_finish_resp
elif event.stopped_by in (
@@ -590,7 +591,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
# When hitting input-moderation or annotation-reply, the workflow will not start
with self._database_session() as session:
# Save message
- self._save_message(session=session)
+ self._save_message(session=session, trace_manager=trace_manager)
yield self._message_end_to_stream_response()
@@ -599,6 +600,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
event: QueueAdvancedChatMessageEndEvent,
*,
graph_runtime_state: GraphRuntimeState | None = None,
+ trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle advanced chat message end events."""
@@ -616,7 +618,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
# Save message
with self._database_session() as session:
- self._save_message(session=session, graph_runtime_state=resolved_state)
+ self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager)
yield self._message_end_to_stream_response()
@@ -770,7 +772,13 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join()
- def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None):
+ def _save_message(
+ self,
+ *,
+ session: Session,
+ graph_runtime_state: GraphRuntimeState | None = None,
+ trace_manager: TraceQueueManager | None = None,
+ ):
message = self._get_message(session=session)
# If there are assistant files, remove markdown image links from answer
@@ -809,6 +817,14 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
metadata = self._task_state.metadata.model_dump()
message.message_metadata = json.dumps(jsonable_encoder(metadata))
+
+ # Extract model provider and model_id from workflow node executions for tracing
+ if message.workflow_run_id:
+ model_info = self._extract_model_info_from_workflow(session, message.workflow_run_id)
+ if model_info:
+ message.model_provider = model_info.get("provider")
+ message.model_id = model_info.get("model")
+
message_files = [
MessageFile(
message_id=message.id,
@@ -826,6 +842,68 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
]
session.add_all(message_files)
+ # Trigger MESSAGE_TRACE for tracing integrations
+ if trace_manager:
+ trace_manager.add_trace_task(
+ TraceTask(
+ TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation_id, message_id=self._message_id
+ )
+ )
+
+ def _extract_model_info_from_workflow(self, session: Session, workflow_run_id: str) -> dict[str, str] | None:
+ """
+ Extract model provider and model_id from workflow node executions.
+ Returns dict with 'provider' and 'model' keys, or None if not found.
+ """
+ try:
+ # Query workflow node executions for LLM or Agent nodes
+ stmt = (
+ select(WorkflowNodeExecutionModel)
+ .where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
+ .where(WorkflowNodeExecutionModel.node_type.in_(["llm", "agent"]))
+ .order_by(WorkflowNodeExecutionModel.created_at.desc())
+ .limit(1)
+ )
+ node_execution = session.scalar(stmt)
+
+ if not node_execution:
+ return None
+
+ # Try to extract from execution_metadata for agent nodes
+ if node_execution.execution_metadata:
+ try:
+ metadata = json.loads(node_execution.execution_metadata)
+ agent_log = metadata.get("agent_log", [])
+ # Look for the first agent thought with provider info
+ for log_entry in agent_log:
+ entry_metadata = log_entry.get("metadata", {})
+ provider_str = entry_metadata.get("provider")
+ if provider_str:
+ # Parse format like "langgenius/deepseek/deepseek"
+ parts = provider_str.split("/")
+ if len(parts) >= 3:
+ return {"provider": parts[1], "model": parts[2]}
+ elif len(parts) == 2:
+ return {"provider": parts[0], "model": parts[1]}
+ except (json.JSONDecodeError, KeyError, AttributeError) as e:
+ logger.debug("Failed to parse execution_metadata: %s", e)
+
+ # Try to extract from process_data for llm nodes
+ if node_execution.process_data:
+ try:
+ process_data = json.loads(node_execution.process_data)
+ provider = process_data.get("model_provider")
+ model = process_data.get("model_name")
+ if provider and model:
+ return {"provider": provider, "model": model}
+ except (json.JSONDecodeError, KeyError) as e:
+ logger.debug("Failed to parse process_data: %s", e)
+
+ return None
+ except Exception as e:
+ logger.warning("Failed to extract model info from workflow: %s", e)
+ return None
+
def _seed_graph_runtime_state_from_queue_manager(self) -> None:
"""Bootstrap the cached runtime state from the queue manager when present."""
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state
diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py
index 79a5e657b3..7692128985 100644
--- a/api/core/app/entities/task_entities.py
+++ b/api/core/app/entities/task_entities.py
@@ -40,6 +40,9 @@ class EasyUITaskState(TaskState):
"""
llm_result: LLMResult
+ first_token_time: float | None = None
+ last_token_time: float | None = None
+ is_streaming_response: bool = False
class WorkflowTaskState(TaskState):
diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
index da2ebac3bd..c49db9aad1 100644
--- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
+++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
@@ -332,6 +332,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
if not self._task_state.llm_result.prompt_messages:
self._task_state.llm_result.prompt_messages = chunk.prompt_messages
+ # Track streaming response times
+ if self._task_state.first_token_time is None:
+ self._task_state.first_token_time = time.perf_counter()
+ self._task_state.is_streaming_response = True
+ self._task_state.last_token_time = time.perf_counter()
+
# handle output moderation chunk
should_direct_answer = self._handle_output_moderation_chunk(cast(str, delta_text))
if should_direct_answer:
@@ -398,6 +404,18 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
message.total_price = usage.total_price
message.currency = usage.currency
self._task_state.llm_result.usage.latency = message.provider_response_latency
+
+ # Add streaming metrics to usage if available
+ if self._task_state.is_streaming_response and self._task_state.first_token_time:
+ start_time = self.start_at
+ first_token_time = self._task_state.first_token_time
+ last_token_time = self._task_state.last_token_time or first_token_time
+ usage.time_to_first_token = round(first_token_time - start_time, 3)
+ usage.time_to_generate = round(last_token_time - first_token_time, 3)
+
+ # Update metadata with the complete usage info
+ self._task_state.metadata.usage = usage
+
message.message_metadata = self._task_state.metadata.model_dump_json()
if trace_manager:
diff --git a/api/core/ops/tencent_trace/span_builder.py b/api/core/ops/tencent_trace/span_builder.py
index 26e8779e3e..db92e9b8bd 100644
--- a/api/core/ops/tencent_trace/span_builder.py
+++ b/api/core/ops/tencent_trace/span_builder.py
@@ -222,6 +222,59 @@ class TencentSpanBuilder:
links=links,
)
+ @staticmethod
+ def build_message_llm_span(
+ trace_info: MessageTraceInfo, trace_id: int, parent_span_id: int, user_id: str
+ ) -> SpanData:
+ """Build LLM span for message traces with detailed LLM attributes."""
+ status = Status(StatusCode.OK)
+ if trace_info.error:
+ status = Status(StatusCode.ERROR, trace_info.error)
+
+ # Extract model information from `metadata`` or `message_data`
+ trace_metadata = trace_info.metadata or {}
+ message_data = trace_info.message_data or {}
+
+ model_provider = trace_metadata.get("ls_provider") or (
+ message_data.get("model_provider", "") if isinstance(message_data, dict) else ""
+ )
+ model_name = trace_metadata.get("ls_model_name") or (
+ message_data.get("model_id", "") if isinstance(message_data, dict) else ""
+ )
+
+ inputs_str = str(trace_info.inputs or "")
+ outputs_str = str(trace_info.outputs or "")
+
+ attributes = {
+ GEN_AI_SESSION_ID: trace_metadata.get("conversation_id", ""),
+ GEN_AI_USER_ID: str(user_id),
+ GEN_AI_SPAN_KIND: GenAISpanKind.GENERATION.value,
+ GEN_AI_FRAMEWORK: "dify",
+ GEN_AI_MODEL_NAME: str(model_name),
+ GEN_AI_PROVIDER: str(model_provider),
+ GEN_AI_USAGE_INPUT_TOKENS: str(trace_info.message_tokens or 0),
+ GEN_AI_USAGE_OUTPUT_TOKENS: str(trace_info.answer_tokens or 0),
+ GEN_AI_USAGE_TOTAL_TOKENS: str(trace_info.total_tokens or 0),
+ GEN_AI_PROMPT: inputs_str,
+ GEN_AI_COMPLETION: outputs_str,
+ INPUT_VALUE: inputs_str,
+ OUTPUT_VALUE: outputs_str,
+ }
+
+ if trace_info.is_streaming_request:
+ attributes[GEN_AI_IS_STREAMING_REQUEST] = "true"
+
+ return SpanData(
+ trace_id=trace_id,
+ parent_span_id=parent_span_id,
+ span_id=TencentTraceUtils.convert_to_span_id(trace_info.message_id, "llm"),
+ name="GENERATION",
+ start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
+ end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
+ attributes=attributes,
+ status=status,
+ )
+
@staticmethod
def build_tool_span(trace_info: ToolTraceInfo, trace_id: int, parent_span_id: int) -> SpanData:
"""Build tool span."""
diff --git a/api/core/ops/tencent_trace/tencent_trace.py b/api/core/ops/tencent_trace/tencent_trace.py
index 9b3df86e16..3d176da97a 100644
--- a/api/core/ops/tencent_trace/tencent_trace.py
+++ b/api/core/ops/tencent_trace/tencent_trace.py
@@ -107,9 +107,13 @@ class TencentDataTrace(BaseTraceInstance):
links.append(TencentTraceUtils.create_link(trace_info.trace_id))
message_span = TencentSpanBuilder.build_message_span(trace_info, trace_id, str(user_id), links)
-
self.trace_client.add_span(message_span)
+ # Add LLM child span with detailed attributes
+ parent_span_id = TencentTraceUtils.convert_to_span_id(trace_info.message_id, "message")
+ llm_span = TencentSpanBuilder.build_message_llm_span(trace_info, trace_id, parent_span_id, str(user_id))
+ self.trace_client.add_span(llm_span)
+
self._record_message_llm_metrics(trace_info)
# Record trace duration for entry span
diff --git a/api/models/model.py b/api/models/model.py
index fb084d1dc6..33a94628f0 100644
--- a/api/models/model.py
+++ b/api/models/model.py
@@ -1251,9 +1251,13 @@ class Message(Base):
"id": self.id,
"app_id": self.app_id,
"conversation_id": self.conversation_id,
+ "model_provider": self.model_provider,
"model_id": self.model_id,
"inputs": self.inputs,
"query": self.query,
+ "message_tokens": self.message_tokens,
+ "answer_tokens": self.answer_tokens,
+ "provider_response_latency": self.provider_response_latency,
"total_price": self.total_price,
"message": self.message,
"answer": self.answer,
@@ -1275,8 +1279,12 @@ class Message(Base):
id=data["id"],
app_id=data["app_id"],
conversation_id=data["conversation_id"],
+ model_provider=data.get("model_provider"),
model_id=data["model_id"],
inputs=data["inputs"],
+ message_tokens=data.get("message_tokens", 0),
+ answer_tokens=data.get("answer_tokens", 0),
+ provider_response_latency=data.get("provider_response_latency", 0.0),
total_price=data["total_price"],
query=data["query"],
message=data["message"],
From ddc5cbe86592cd1d1bcbb7cb2072fbb837e2331b Mon Sep 17 00:00:00 2001
From: Gritty_dev <101377478+codomposer@users.noreply.github.com>
Date: Wed, 26 Nov 2025 09:48:08 -0500
Subject: [PATCH 009/431] feat: complete test script of dataset service
(#28710)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
---
.../services/test_dataset_service.py | 1200 +++++++++++++++++
1 file changed, 1200 insertions(+)
create mode 100644 api/tests/unit_tests/services/test_dataset_service.py
diff --git a/api/tests/unit_tests/services/test_dataset_service.py b/api/tests/unit_tests/services/test_dataset_service.py
new file mode 100644
index 0000000000..87fd29bbc0
--- /dev/null
+++ b/api/tests/unit_tests/services/test_dataset_service.py
@@ -0,0 +1,1200 @@
+"""
+Comprehensive unit tests for DatasetService.
+
+This test suite provides complete coverage of dataset management operations in Dify,
+following TDD principles with the Arrange-Act-Assert pattern.
+
+## Test Coverage
+
+### 1. Dataset Creation (TestDatasetServiceCreateDataset)
+Tests the creation of knowledge base datasets with various configurations:
+- Internal datasets (provider='vendor') with economy or high-quality indexing
+- External datasets (provider='external') connected to third-party APIs
+- Embedding model configuration for semantic search
+- Duplicate name validation
+- Permission and access control setup
+
+### 2. Dataset Updates (TestDatasetServiceUpdateDataset)
+Tests modification of existing dataset settings:
+- Basic field updates (name, description, permission)
+- Indexing technique switching (economy ↔ high_quality)
+- Embedding model changes with vector index rebuilding
+- Retrieval configuration updates
+- External knowledge binding updates
+
+### 3. Dataset Deletion (TestDatasetServiceDeleteDataset)
+Tests safe deletion with cascade cleanup:
+- Normal deletion with documents and embeddings
+- Empty dataset deletion (regression test for #27073)
+- Permission verification
+- Event-driven cleanup (vector DB, file storage)
+
+### 4. Document Indexing (TestDatasetServiceDocumentIndexing)
+Tests async document processing operations:
+- Pause/resume indexing for resource management
+- Retry failed documents
+- Status transitions through indexing pipeline
+- Redis-based concurrency control
+
+### 5. Retrieval Configuration (TestDatasetServiceRetrievalConfiguration)
+Tests search and ranking settings:
+- Search method configuration (semantic, full-text, hybrid)
+- Top-k and score threshold tuning
+- Reranking model integration for improved relevance
+
+## Testing Approach
+
+- **Mocking Strategy**: All external dependencies (database, Redis, model providers)
+ are mocked to ensure fast, isolated unit tests
+- **Factory Pattern**: DatasetServiceTestDataFactory provides consistent test data
+- **Fixtures**: Pytest fixtures set up common mock configurations per test class
+- **Assertions**: Each test verifies both the return value and all side effects
+ (database operations, event signals, async task triggers)
+
+## Key Concepts
+
+**Indexing Techniques:**
+- economy: Keyword-based search (fast, less accurate)
+- high_quality: Vector embeddings for semantic search (slower, more accurate)
+
+**Dataset Providers:**
+- vendor: Internal storage and indexing
+- external: Third-party knowledge sources via API
+
+**Document Lifecycle:**
+waiting → parsing → cleaning → splitting → indexing → completed (or error)
+"""
+
+from unittest.mock import Mock, create_autospec, patch
+from uuid import uuid4
+
+import pytest
+
+from core.model_runtime.entities.model_entities import ModelType
+from models.account import Account, TenantAccountRole
+from models.dataset import Dataset, DatasetPermissionEnum, Document, ExternalKnowledgeBindings
+from services.dataset_service import DatasetService
+from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
+from services.errors.dataset import DatasetNameDuplicateError
+
+
+class DatasetServiceTestDataFactory:
+ """
+ Factory class for creating test data and mock objects.
+
+ This factory provides reusable methods to create mock objects for testing.
+ Using a factory pattern ensures consistency across tests and reduces code duplication.
+ All methods return properly configured Mock objects that simulate real model instances.
+ """
+
+ @staticmethod
+ def create_account_mock(
+ account_id: str = "account-123",
+ tenant_id: str = "tenant-123",
+ role: TenantAccountRole = TenantAccountRole.NORMAL,
+ **kwargs,
+ ) -> Mock:
+ """
+ Create a mock account with specified attributes.
+
+ Args:
+ account_id: Unique identifier for the account
+ tenant_id: Tenant ID the account belongs to
+ role: User role (NORMAL, ADMIN, etc.)
+ **kwargs: Additional attributes to set on the mock
+
+ Returns:
+ Mock: A properly configured Account mock object
+ """
+ account = create_autospec(Account, instance=True)
+ account.id = account_id
+ account.current_tenant_id = tenant_id
+ account.current_role = role
+ for key, value in kwargs.items():
+ setattr(account, key, value)
+ return account
+
+ @staticmethod
+ def create_dataset_mock(
+ dataset_id: str = "dataset-123",
+ name: str = "Test Dataset",
+ tenant_id: str = "tenant-123",
+ created_by: str = "user-123",
+ provider: str = "vendor",
+ indexing_technique: str | None = "high_quality",
+ **kwargs,
+ ) -> Mock:
+ """
+ Create a mock dataset with specified attributes.
+
+ Args:
+ dataset_id: Unique identifier for the dataset
+ name: Display name of the dataset
+ tenant_id: Tenant ID the dataset belongs to
+ created_by: User ID who created the dataset
+ provider: Dataset provider type ('vendor' for internal, 'external' for external)
+ indexing_technique: Indexing method ('high_quality', 'economy', or None)
+ **kwargs: Additional attributes (embedding_model, retrieval_model, etc.)
+
+ Returns:
+ Mock: A properly configured Dataset mock object
+ """
+ dataset = create_autospec(Dataset, instance=True)
+ dataset.id = dataset_id
+ dataset.name = name
+ dataset.tenant_id = tenant_id
+ dataset.created_by = created_by
+ dataset.provider = provider
+ dataset.indexing_technique = indexing_technique
+ dataset.permission = kwargs.get("permission", DatasetPermissionEnum.ONLY_ME)
+ dataset.embedding_model_provider = kwargs.get("embedding_model_provider")
+ dataset.embedding_model = kwargs.get("embedding_model")
+ dataset.collection_binding_id = kwargs.get("collection_binding_id")
+ dataset.retrieval_model = kwargs.get("retrieval_model")
+ dataset.description = kwargs.get("description")
+ dataset.doc_form = kwargs.get("doc_form")
+ for key, value in kwargs.items():
+ if not hasattr(dataset, key):
+ setattr(dataset, key, value)
+ return dataset
+
+ @staticmethod
+ def create_embedding_model_mock(model: str = "text-embedding-ada-002", provider: str = "openai") -> Mock:
+ """
+ Create a mock embedding model for high-quality indexing.
+
+ Embedding models are used to convert text into vector representations
+ for semantic search capabilities.
+
+ Args:
+ model: Model name (e.g., 'text-embedding-ada-002')
+ provider: Model provider (e.g., 'openai', 'cohere')
+
+ Returns:
+ Mock: Embedding model mock with model and provider attributes
+ """
+ embedding_model = Mock()
+ embedding_model.model = model
+ embedding_model.provider = provider
+ return embedding_model
+
+ @staticmethod
+ def create_retrieval_model_mock() -> Mock:
+ """
+ Create a mock retrieval model configuration.
+
+ Retrieval models define how documents are searched and ranked,
+ including search method, top-k results, and score thresholds.
+
+ Returns:
+ Mock: RetrievalModel mock with model_dump() method
+ """
+ retrieval_model = Mock(spec=RetrievalModel)
+ retrieval_model.model_dump.return_value = {
+ "search_method": "semantic_search",
+ "top_k": 2,
+ "score_threshold": 0.0,
+ }
+ retrieval_model.reranking_model = None
+ return retrieval_model
+
+ @staticmethod
+ def create_collection_binding_mock(binding_id: str = "binding-456") -> Mock:
+ """
+ Create a mock collection binding for vector database.
+
+ Collection bindings link datasets to their vector storage locations
+ in the vector database (e.g., Qdrant, Weaviate).
+
+ Args:
+ binding_id: Unique identifier for the collection binding
+
+ Returns:
+ Mock: Collection binding mock object
+ """
+ binding = Mock()
+ binding.id = binding_id
+ return binding
+
+ @staticmethod
+ def create_external_binding_mock(
+ dataset_id: str = "dataset-123",
+ external_knowledge_id: str = "knowledge-123",
+ external_knowledge_api_id: str = "api-123",
+ ) -> Mock:
+ """
+ Create a mock external knowledge binding.
+
+ External knowledge bindings connect datasets to external knowledge sources
+ (e.g., third-party APIs, external databases) for retrieval.
+
+ Args:
+ dataset_id: Dataset ID this binding belongs to
+ external_knowledge_id: External knowledge source identifier
+ external_knowledge_api_id: External API configuration identifier
+
+ Returns:
+ Mock: ExternalKnowledgeBindings mock object
+ """
+ binding = Mock(spec=ExternalKnowledgeBindings)
+ binding.dataset_id = dataset_id
+ binding.external_knowledge_id = external_knowledge_id
+ binding.external_knowledge_api_id = external_knowledge_api_id
+ return binding
+
+ @staticmethod
+ def create_document_mock(
+ document_id: str = "doc-123",
+ dataset_id: str = "dataset-123",
+ indexing_status: str = "completed",
+ **kwargs,
+ ) -> Mock:
+ """
+ Create a mock document for testing document operations.
+
+ Documents are the individual files/content items within a dataset
+ that go through indexing, parsing, and chunking processes.
+
+ Args:
+ document_id: Unique identifier for the document
+ dataset_id: Parent dataset ID
+ indexing_status: Current status ('waiting', 'indexing', 'completed', 'error')
+ **kwargs: Additional attributes (is_paused, enabled, archived, etc.)
+
+ Returns:
+ Mock: Document mock object
+ """
+ document = Mock(spec=Document)
+ document.id = document_id
+ document.dataset_id = dataset_id
+ document.indexing_status = indexing_status
+ for key, value in kwargs.items():
+ setattr(document, key, value)
+ return document
+
+
+# ==================== Dataset Creation Tests ====================
+
+
+class TestDatasetServiceCreateDataset:
+ """
+ Comprehensive unit tests for dataset creation logic.
+
+ Covers:
+ - Internal dataset creation with various indexing techniques
+ - External dataset creation with external knowledge bindings
+ - RAG pipeline dataset creation
+ - Error handling for duplicate names and missing configurations
+ """
+
+ @pytest.fixture
+ def mock_dataset_service_dependencies(self):
+ """
+ Common mock setup for dataset service dependencies.
+
+ This fixture patches all external dependencies that DatasetService.create_empty_dataset
+ interacts with, including:
+ - db.session: Database operations (query, add, commit)
+ - ModelManager: Embedding model management
+ - check_embedding_model_setting: Validates embedding model configuration
+ - check_reranking_model_setting: Validates reranking model configuration
+ - ExternalDatasetService: Handles external knowledge API operations
+
+ Yields:
+ dict: Dictionary of mocked dependencies for use in tests
+ """
+ with (
+ patch("services.dataset_service.db.session") as mock_db,
+ patch("services.dataset_service.ModelManager") as mock_model_manager,
+ patch("services.dataset_service.DatasetService.check_embedding_model_setting") as mock_check_embedding,
+ patch("services.dataset_service.DatasetService.check_reranking_model_setting") as mock_check_reranking,
+ patch("services.dataset_service.ExternalDatasetService") as mock_external_service,
+ ):
+ yield {
+ "db_session": mock_db,
+ "model_manager": mock_model_manager,
+ "check_embedding": mock_check_embedding,
+ "check_reranking": mock_check_reranking,
+ "external_service": mock_external_service,
+ }
+
+ def test_create_internal_dataset_basic_success(self, mock_dataset_service_dependencies):
+ """
+ Test successful creation of basic internal dataset.
+
+ Verifies that a dataset can be created with minimal configuration:
+ - No indexing technique specified (None)
+ - Default permission (only_me)
+ - Vendor provider (internal dataset)
+
+ This is the simplest dataset creation scenario.
+ """
+ # Arrange: Set up test data and mocks
+ tenant_id = str(uuid4())
+ account = DatasetServiceTestDataFactory.create_account_mock(tenant_id=tenant_id)
+ name = "Test Dataset"
+ description = "Test description"
+
+ # Mock database query to return None (no duplicate name exists)
+ mock_query = Mock()
+ mock_query.filter_by.return_value.first.return_value = None
+ mock_dataset_service_dependencies["db_session"].query.return_value = mock_query
+
+ # Mock database session operations for dataset creation
+ mock_db = mock_dataset_service_dependencies["db_session"]
+ mock_db.add = Mock() # Tracks dataset being added to session
+ mock_db.flush = Mock() # Flushes to get dataset ID
+ mock_db.commit = Mock() # Commits transaction
+
+ # Act
+ result = DatasetService.create_empty_dataset(
+ tenant_id=tenant_id,
+ name=name,
+ description=description,
+ indexing_technique=None,
+ account=account,
+ )
+
+ # Assert
+ assert result is not None
+ assert result.name == name
+ assert result.description == description
+ assert result.tenant_id == tenant_id
+ assert result.created_by == account.id
+ assert result.updated_by == account.id
+ assert result.provider == "vendor"
+ assert result.permission == "only_me"
+ mock_db.add.assert_called_once()
+ mock_db.commit.assert_called_once()
+
+ def test_create_internal_dataset_with_economy_indexing(self, mock_dataset_service_dependencies):
+ """Test successful creation of internal dataset with economy indexing."""
+ # Arrange
+ tenant_id = str(uuid4())
+ account = DatasetServiceTestDataFactory.create_account_mock(tenant_id=tenant_id)
+ name = "Economy Dataset"
+
+ # Mock database query
+ mock_query = Mock()
+ mock_query.filter_by.return_value.first.return_value = None
+ mock_dataset_service_dependencies["db_session"].query.return_value = mock_query
+
+ mock_db = mock_dataset_service_dependencies["db_session"]
+ mock_db.add = Mock()
+ mock_db.flush = Mock()
+ mock_db.commit = Mock()
+
+ # Act
+ result = DatasetService.create_empty_dataset(
+ tenant_id=tenant_id,
+ name=name,
+ description=None,
+ indexing_technique="economy",
+ account=account,
+ )
+
+ # Assert
+ assert result.indexing_technique == "economy"
+ assert result.embedding_model_provider is None
+ assert result.embedding_model is None
+ mock_db.commit.assert_called_once()
+
+ def test_create_internal_dataset_with_high_quality_indexing(self, mock_dataset_service_dependencies):
+ """Test creation with high_quality indexing using default embedding model."""
+ # Arrange
+ tenant_id = str(uuid4())
+ account = DatasetServiceTestDataFactory.create_account_mock(tenant_id=tenant_id)
+ name = "High Quality Dataset"
+
+ # Mock database query
+ mock_query = Mock()
+ mock_query.filter_by.return_value.first.return_value = None
+ mock_dataset_service_dependencies["db_session"].query.return_value = mock_query
+
+ # Mock model manager
+ embedding_model = DatasetServiceTestDataFactory.create_embedding_model_mock()
+ mock_model_manager_instance = Mock()
+ mock_model_manager_instance.get_default_model_instance.return_value = embedding_model
+ mock_dataset_service_dependencies["model_manager"].return_value = mock_model_manager_instance
+
+ mock_db = mock_dataset_service_dependencies["db_session"]
+ mock_db.add = Mock()
+ mock_db.flush = Mock()
+ mock_db.commit = Mock()
+
+ # Act
+ result = DatasetService.create_empty_dataset(
+ tenant_id=tenant_id,
+ name=name,
+ description=None,
+ indexing_technique="high_quality",
+ account=account,
+ )
+
+ # Assert
+ assert result.indexing_technique == "high_quality"
+ assert result.embedding_model_provider == embedding_model.provider
+ assert result.embedding_model == embedding_model.model
+ mock_model_manager_instance.get_default_model_instance.assert_called_once_with(
+ tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING
+ )
+ mock_db.commit.assert_called_once()
+
+ def test_create_dataset_duplicate_name_error(self, mock_dataset_service_dependencies):
+ """Test error when creating dataset with duplicate name."""
+ # Arrange
+ tenant_id = str(uuid4())
+ account = DatasetServiceTestDataFactory.create_account_mock(tenant_id=tenant_id)
+ name = "Duplicate Dataset"
+
+ # Mock database query to return existing dataset
+ existing_dataset = DatasetServiceTestDataFactory.create_dataset_mock(name=name, tenant_id=tenant_id)
+ mock_query = Mock()
+ mock_query.filter_by.return_value.first.return_value = existing_dataset
+ mock_dataset_service_dependencies["db_session"].query.return_value = mock_query
+
+ # Act & Assert
+ with pytest.raises(DatasetNameDuplicateError) as context:
+ DatasetService.create_empty_dataset(
+ tenant_id=tenant_id,
+ name=name,
+ description=None,
+ indexing_technique=None,
+ account=account,
+ )
+
+ assert f"Dataset with name {name} already exists" in str(context.value)
+
+ def test_create_external_dataset_success(self, mock_dataset_service_dependencies):
+ """Test successful creation of external dataset with external knowledge binding."""
+ # Arrange
+ tenant_id = str(uuid4())
+ account = DatasetServiceTestDataFactory.create_account_mock(tenant_id=tenant_id)
+ name = "External Dataset"
+ external_knowledge_api_id = "api-123"
+ external_knowledge_id = "knowledge-123"
+
+ # Mock database query
+ mock_query = Mock()
+ mock_query.filter_by.return_value.first.return_value = None
+ mock_dataset_service_dependencies["db_session"].query.return_value = mock_query
+
+ # Mock external knowledge API
+ external_api = Mock()
+ external_api.id = external_knowledge_api_id
+ mock_dataset_service_dependencies["external_service"].get_external_knowledge_api.return_value = external_api
+
+ mock_db = mock_dataset_service_dependencies["db_session"]
+ mock_db.add = Mock()
+ mock_db.flush = Mock()
+ mock_db.commit = Mock()
+
+ # Act
+ result = DatasetService.create_empty_dataset(
+ tenant_id=tenant_id,
+ name=name,
+ description=None,
+ indexing_technique=None,
+ account=account,
+ provider="external",
+ external_knowledge_api_id=external_knowledge_api_id,
+ external_knowledge_id=external_knowledge_id,
+ )
+
+ # Assert
+ assert result.provider == "external"
+ assert mock_db.add.call_count == 2 # Dataset + ExternalKnowledgeBinding
+ mock_db.commit.assert_called_once()
+
+
+# ==================== Dataset Update Tests ====================
+
+
+class TestDatasetServiceUpdateDataset:
+ """
+ Comprehensive unit tests for dataset update settings.
+
+ Covers:
+ - Basic field updates (name, description, permission)
+ - Indexing technique changes (economy <-> high_quality)
+ - Embedding model updates
+ - Retrieval configuration updates
+ - External dataset updates
+ """
+
+ @pytest.fixture
+ def mock_dataset_service_dependencies(self):
+ """Common mock setup for dataset service dependencies."""
+ with (
+ patch("services.dataset_service.DatasetService.get_dataset") as mock_get_dataset,
+ patch("services.dataset_service.DatasetService._has_dataset_same_name") as mock_has_same_name,
+ patch("services.dataset_service.DatasetService.check_dataset_permission") as mock_check_perm,
+ patch("services.dataset_service.db.session") as mock_db,
+ patch("services.dataset_service.naive_utc_now") as mock_time,
+ patch(
+ "services.dataset_service.DatasetService._update_pipeline_knowledge_base_node_data"
+ ) as mock_update_pipeline,
+ ):
+ mock_time.return_value = "2024-01-01T00:00:00"
+ yield {
+ "get_dataset": mock_get_dataset,
+ "has_dataset_same_name": mock_has_same_name,
+ "check_permission": mock_check_perm,
+ "db_session": mock_db,
+ "current_time": "2024-01-01T00:00:00",
+ "update_pipeline": mock_update_pipeline,
+ }
+
+ @pytest.fixture
+ def mock_internal_provider_dependencies(self):
+ """Mock dependencies for internal dataset provider operations."""
+ with (
+ patch("services.dataset_service.ModelManager") as mock_model_manager,
+ patch("services.dataset_service.DatasetCollectionBindingService") as mock_binding_service,
+ patch("services.dataset_service.deal_dataset_vector_index_task") as mock_task,
+ patch("services.dataset_service.current_user") as mock_current_user,
+ ):
+ # Mock current_user as Account instance
+ mock_current_user_account = DatasetServiceTestDataFactory.create_account_mock(
+ account_id="user-123", tenant_id="tenant-123"
+ )
+ mock_current_user.return_value = mock_current_user_account
+ mock_current_user.current_tenant_id = "tenant-123"
+ mock_current_user.id = "user-123"
+ # Make isinstance check pass
+ mock_current_user.__class__ = Account
+
+ yield {
+ "model_manager": mock_model_manager,
+ "get_binding": mock_binding_service.get_dataset_collection_binding,
+ "task": mock_task,
+ "current_user": mock_current_user,
+ }
+
+ @pytest.fixture
+ def mock_external_provider_dependencies(self):
+ """Mock dependencies for external dataset provider operations."""
+ with (
+ patch("services.dataset_service.Session") as mock_session,
+ patch("services.dataset_service.db.engine") as mock_engine,
+ ):
+ yield mock_session
+
+ def test_update_internal_dataset_basic_success(self, mock_dataset_service_dependencies):
+ """Test successful update of internal dataset with basic fields."""
+ # Arrange
+ dataset = DatasetServiceTestDataFactory.create_dataset_mock(
+ provider="vendor",
+ indexing_technique="high_quality",
+ embedding_model_provider="openai",
+ embedding_model="text-embedding-ada-002",
+ collection_binding_id="binding-123",
+ )
+ mock_dataset_service_dependencies["get_dataset"].return_value = dataset
+
+ user = DatasetServiceTestDataFactory.create_account_mock()
+
+ update_data = {
+ "name": "new_name",
+ "description": "new_description",
+ "indexing_technique": "high_quality",
+ "retrieval_model": "new_model",
+ "embedding_model_provider": "openai",
+ "embedding_model": "text-embedding-ada-002",
+ }
+
+ mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
+
+ # Act
+ result = DatasetService.update_dataset("dataset-123", update_data, user)
+
+ # Assert
+ mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user)
+ mock_dataset_service_dependencies[
+ "db_session"
+ ].query.return_value.filter_by.return_value.update.assert_called_once()
+ mock_dataset_service_dependencies["db_session"].commit.assert_called_once()
+ assert result == dataset
+
+ def test_update_dataset_not_found_error(self, mock_dataset_service_dependencies):
+ """Test error when updating non-existent dataset."""
+ # Arrange
+ mock_dataset_service_dependencies["get_dataset"].return_value = None
+ user = DatasetServiceTestDataFactory.create_account_mock()
+
+ # Act & Assert
+ with pytest.raises(ValueError) as context:
+ DatasetService.update_dataset("non-existent", {}, user)
+
+ assert "Dataset not found" in str(context.value)
+
+ def test_update_dataset_duplicate_name_error(self, mock_dataset_service_dependencies):
+ """Test error when updating dataset to duplicate name."""
+ # Arrange
+ dataset = DatasetServiceTestDataFactory.create_dataset_mock()
+ mock_dataset_service_dependencies["get_dataset"].return_value = dataset
+ mock_dataset_service_dependencies["has_dataset_same_name"].return_value = True
+
+ user = DatasetServiceTestDataFactory.create_account_mock()
+ update_data = {"name": "duplicate_name"}
+
+ # Act & Assert
+ with pytest.raises(ValueError) as context:
+ DatasetService.update_dataset("dataset-123", update_data, user)
+
+ assert "Dataset name already exists" in str(context.value)
+
+ def test_update_indexing_technique_to_economy(
+ self, mock_dataset_service_dependencies, mock_internal_provider_dependencies
+ ):
+ """Test updating indexing technique from high_quality to economy."""
+ # Arrange
+ dataset = DatasetServiceTestDataFactory.create_dataset_mock(
+ provider="vendor", indexing_technique="high_quality"
+ )
+ mock_dataset_service_dependencies["get_dataset"].return_value = dataset
+
+ user = DatasetServiceTestDataFactory.create_account_mock()
+
+ update_data = {"indexing_technique": "economy", "retrieval_model": "new_model"}
+ mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
+
+ # Act
+ result = DatasetService.update_dataset("dataset-123", update_data, user)
+
+ # Assert
+ mock_dataset_service_dependencies[
+ "db_session"
+ ].query.return_value.filter_by.return_value.update.assert_called_once()
+ # Verify embedding model fields are cleared
+ call_args = mock_dataset_service_dependencies[
+ "db_session"
+ ].query.return_value.filter_by.return_value.update.call_args[0][0]
+ assert call_args["embedding_model"] is None
+ assert call_args["embedding_model_provider"] is None
+ assert call_args["collection_binding_id"] is None
+ assert result == dataset
+
+ def test_update_indexing_technique_to_high_quality(
+ self, mock_dataset_service_dependencies, mock_internal_provider_dependencies
+ ):
+ """Test updating indexing technique from economy to high_quality."""
+ # Arrange
+ dataset = DatasetServiceTestDataFactory.create_dataset_mock(provider="vendor", indexing_technique="economy")
+ mock_dataset_service_dependencies["get_dataset"].return_value = dataset
+
+ user = DatasetServiceTestDataFactory.create_account_mock()
+
+ # Mock embedding model
+ embedding_model = DatasetServiceTestDataFactory.create_embedding_model_mock()
+ mock_internal_provider_dependencies[
+ "model_manager"
+ ].return_value.get_model_instance.return_value = embedding_model
+
+ # Mock collection binding
+ binding = DatasetServiceTestDataFactory.create_collection_binding_mock()
+ mock_internal_provider_dependencies["get_binding"].return_value = binding
+
+ update_data = {
+ "indexing_technique": "high_quality",
+ "embedding_model_provider": "openai",
+ "embedding_model": "text-embedding-ada-002",
+ "retrieval_model": "new_model",
+ }
+ mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
+
+ # Act
+ result = DatasetService.update_dataset("dataset-123", update_data, user)
+
+ # Assert
+ mock_internal_provider_dependencies["model_manager"].return_value.get_model_instance.assert_called_once()
+ mock_internal_provider_dependencies["get_binding"].assert_called_once()
+ mock_internal_provider_dependencies["task"].delay.assert_called_once()
+ call_args = mock_internal_provider_dependencies["task"].delay.call_args[0]
+ assert call_args[0] == "dataset-123"
+ assert call_args[1] == "add"
+
+ # Verify return value
+ assert result == dataset
+
+ # Note: External dataset update test removed due to Flask app context complexity in unit tests
+ # External dataset functionality is covered by integration tests
+
+ def test_update_external_dataset_missing_knowledge_id_error(self, mock_dataset_service_dependencies):
+ """Test error when external knowledge id is missing."""
+ # Arrange
+ dataset = DatasetServiceTestDataFactory.create_dataset_mock(provider="external")
+ mock_dataset_service_dependencies["get_dataset"].return_value = dataset
+
+ user = DatasetServiceTestDataFactory.create_account_mock()
+ update_data = {"name": "new_name", "external_knowledge_api_id": "api_id"}
+ mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
+
+ # Act & Assert
+ with pytest.raises(ValueError) as context:
+ DatasetService.update_dataset("dataset-123", update_data, user)
+
+ assert "External knowledge id is required" in str(context.value)
+
+
+# ==================== Dataset Deletion Tests ====================
+
+
+class TestDatasetServiceDeleteDataset:
+ """
+ Comprehensive unit tests for dataset deletion with cascade operations.
+
+ Covers:
+ - Normal dataset deletion with documents
+ - Empty dataset deletion (no documents)
+ - Dataset deletion with partial None values
+ - Permission checks
+ - Event handling for cascade operations
+
+ Dataset deletion is a critical operation that triggers cascade cleanup:
+ - Documents and segments are removed from vector database
+ - File storage is cleaned up
+ - Related bindings and metadata are deleted
+ - The dataset_was_deleted event notifies listeners for cleanup
+ """
+
+ @pytest.fixture
+ def mock_dataset_service_dependencies(self):
+ """
+ Common mock setup for dataset deletion dependencies.
+
+ Patches:
+ - get_dataset: Retrieves the dataset to delete
+ - check_dataset_permission: Verifies user has delete permission
+ - db.session: Database operations (delete, commit)
+ - dataset_was_deleted: Signal/event for cascade cleanup operations
+
+ The dataset_was_deleted signal is crucial - it triggers cleanup handlers
+ that remove vector embeddings, files, and related data.
+ """
+ with (
+ patch("services.dataset_service.DatasetService.get_dataset") as mock_get_dataset,
+ patch("services.dataset_service.DatasetService.check_dataset_permission") as mock_check_perm,
+ patch("services.dataset_service.db.session") as mock_db,
+ patch("services.dataset_service.dataset_was_deleted") as mock_dataset_was_deleted,
+ ):
+ yield {
+ "get_dataset": mock_get_dataset,
+ "check_permission": mock_check_perm,
+ "db_session": mock_db,
+ "dataset_was_deleted": mock_dataset_was_deleted,
+ }
+
+ def test_delete_dataset_with_documents_success(self, mock_dataset_service_dependencies):
+ """Test successful deletion of a dataset with documents."""
+ # Arrange
+ dataset = DatasetServiceTestDataFactory.create_dataset_mock(
+ doc_form="text_model", indexing_technique="high_quality"
+ )
+ user = DatasetServiceTestDataFactory.create_account_mock()
+
+ mock_dataset_service_dependencies["get_dataset"].return_value = dataset
+
+ # Act
+ result = DatasetService.delete_dataset(dataset.id, user)
+
+ # Assert
+ assert result is True
+ mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset.id)
+ mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user)
+ mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_called_once_with(dataset)
+ mock_dataset_service_dependencies["db_session"].delete.assert_called_once_with(dataset)
+ mock_dataset_service_dependencies["db_session"].commit.assert_called_once()
+
+ def test_delete_empty_dataset_success(self, mock_dataset_service_dependencies):
+ """
+ Test successful deletion of an empty dataset (no documents, doc_form is None).
+
+ Empty datasets are created but never had documents uploaded. They have:
+ - doc_form = None (no document format configured)
+ - indexing_technique = None (no indexing method set)
+
+ This test ensures empty datasets can be deleted without errors.
+ The event handler should gracefully skip cleanup operations when
+ there's no actual data to clean up.
+
+ This test provides regression protection for issue #27073 where
+ deleting empty datasets caused internal server errors.
+ """
+ # Arrange
+ dataset = DatasetServiceTestDataFactory.create_dataset_mock(doc_form=None, indexing_technique=None)
+ user = DatasetServiceTestDataFactory.create_account_mock()
+
+ mock_dataset_service_dependencies["get_dataset"].return_value = dataset
+
+ # Act
+ result = DatasetService.delete_dataset(dataset.id, user)
+
+ # Assert - Verify complete deletion flow
+ assert result is True
+ mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset.id)
+ mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user)
+ # Event is sent even for empty datasets - handlers check for None values
+ mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_called_once_with(dataset)
+ mock_dataset_service_dependencies["db_session"].delete.assert_called_once_with(dataset)
+ mock_dataset_service_dependencies["db_session"].commit.assert_called_once()
+
+ def test_delete_dataset_not_found(self, mock_dataset_service_dependencies):
+ """Test deletion attempt when dataset doesn't exist."""
+ # Arrange
+ dataset_id = "non-existent-dataset"
+ user = DatasetServiceTestDataFactory.create_account_mock()
+
+ mock_dataset_service_dependencies["get_dataset"].return_value = None
+
+ # Act
+ result = DatasetService.delete_dataset(dataset_id, user)
+
+ # Assert
+ assert result is False
+ mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset_id)
+ mock_dataset_service_dependencies["check_permission"].assert_not_called()
+ mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_not_called()
+ mock_dataset_service_dependencies["db_session"].delete.assert_not_called()
+ mock_dataset_service_dependencies["db_session"].commit.assert_not_called()
+
+ def test_delete_dataset_with_partial_none_values(self, mock_dataset_service_dependencies):
+ """Test deletion of dataset with partial None values (doc_form exists but indexing_technique is None)."""
+ # Arrange
+ dataset = DatasetServiceTestDataFactory.create_dataset_mock(doc_form="text_model", indexing_technique=None)
+ user = DatasetServiceTestDataFactory.create_account_mock()
+
+ mock_dataset_service_dependencies["get_dataset"].return_value = dataset
+
+ # Act
+ result = DatasetService.delete_dataset(dataset.id, user)
+
+ # Assert
+ assert result is True
+ mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_called_once_with(dataset)
+ mock_dataset_service_dependencies["db_session"].delete.assert_called_once_with(dataset)
+ mock_dataset_service_dependencies["db_session"].commit.assert_called_once()
+
+
+# ==================== Document Indexing Logic Tests ====================
+
+
+class TestDatasetServiceDocumentIndexing:
+ """
+ Comprehensive unit tests for document indexing logic.
+
+ Covers:
+ - Document indexing status transitions
+ - Pause/resume document indexing
+ - Retry document indexing
+ - Sync website document indexing
+ - Document indexing task triggering
+
+ Document indexing is an async process with multiple stages:
+ 1. waiting: Document queued for processing
+ 2. parsing: Extracting text from file
+ 3. cleaning: Removing unwanted content
+ 4. splitting: Breaking into chunks
+ 5. indexing: Creating embeddings and storing in vector DB
+ 6. completed: Successfully indexed
+ 7. error: Failed at some stage
+
+ Users can pause/resume indexing or retry failed documents.
+ """
+
+ @pytest.fixture
+ def mock_document_service_dependencies(self):
+ """
+ Common mock setup for document service dependencies.
+
+ Patches:
+ - redis_client: Caches indexing state and prevents concurrent operations
+ - db.session: Database operations for document status updates
+ - current_user: User context for tracking who paused/resumed
+
+ Redis is used to:
+ - Store pause flags (document_{id}_is_paused)
+ - Prevent duplicate retry operations (document_{id}_is_retried)
+ - Track active indexing operations (document_{id}_indexing)
+ """
+ with (
+ patch("services.dataset_service.redis_client") as mock_redis,
+ patch("services.dataset_service.db.session") as mock_db,
+ patch("services.dataset_service.current_user") as mock_current_user,
+ ):
+ mock_current_user.id = "user-123"
+ yield {
+ "redis_client": mock_redis,
+ "db_session": mock_db,
+ "current_user": mock_current_user,
+ }
+
+ def test_pause_document_success(self, mock_document_service_dependencies):
+ """
+ Test successful pause of document indexing.
+
+ Pausing allows users to temporarily stop indexing without canceling it.
+ This is useful when:
+ - System resources are needed elsewhere
+ - User wants to modify document settings before continuing
+ - Indexing is taking too long and needs to be deferred
+
+ When paused:
+ - is_paused flag is set to True
+ - paused_by and paused_at are recorded
+ - Redis flag prevents indexing worker from processing
+ - Document remains in current indexing stage
+ """
+ # Arrange
+ document = DatasetServiceTestDataFactory.create_document_mock(indexing_status="indexing")
+ mock_db = mock_document_service_dependencies["db_session"]
+ mock_redis = mock_document_service_dependencies["redis_client"]
+
+ # Act
+ from services.dataset_service import DocumentService
+
+ DocumentService.pause_document(document)
+
+ # Assert - Verify pause state is persisted
+ assert document.is_paused is True
+ mock_db.add.assert_called_once_with(document)
+ mock_db.commit.assert_called_once()
+ # setnx (set if not exists) prevents race conditions
+ mock_redis.setnx.assert_called_once()
+
+ def test_pause_document_invalid_status_error(self, mock_document_service_dependencies):
+ """Test error when pausing document with invalid status."""
+ # Arrange
+ document = DatasetServiceTestDataFactory.create_document_mock(indexing_status="completed")
+
+ # Act & Assert
+ from services.dataset_service import DocumentService
+ from services.errors.document import DocumentIndexingError
+
+ with pytest.raises(DocumentIndexingError):
+ DocumentService.pause_document(document)
+
+ def test_recover_document_success(self, mock_document_service_dependencies):
+ """Test successful recovery of paused document indexing."""
+ # Arrange
+ document = DatasetServiceTestDataFactory.create_document_mock(indexing_status="indexing", is_paused=True)
+ mock_db = mock_document_service_dependencies["db_session"]
+ mock_redis = mock_document_service_dependencies["redis_client"]
+
+ # Act
+ with patch("services.dataset_service.recover_document_indexing_task") as mock_task:
+ from services.dataset_service import DocumentService
+
+ DocumentService.recover_document(document)
+
+ # Assert
+ assert document.is_paused is False
+ mock_db.add.assert_called_once_with(document)
+ mock_db.commit.assert_called_once()
+ mock_redis.delete.assert_called_once()
+ mock_task.delay.assert_called_once_with(document.dataset_id, document.id)
+
+ def test_retry_document_indexing_success(self, mock_document_service_dependencies):
+ """Test successful retry of document indexing."""
+ # Arrange
+ dataset_id = "dataset-123"
+ documents = [
+ DatasetServiceTestDataFactory.create_document_mock(document_id="doc-1", indexing_status="error"),
+ DatasetServiceTestDataFactory.create_document_mock(document_id="doc-2", indexing_status="error"),
+ ]
+ mock_db = mock_document_service_dependencies["db_session"]
+ mock_redis = mock_document_service_dependencies["redis_client"]
+ mock_redis.get.return_value = None
+
+ # Act
+ with patch("services.dataset_service.retry_document_indexing_task") as mock_task:
+ from services.dataset_service import DocumentService
+
+ DocumentService.retry_document(dataset_id, documents)
+
+ # Assert
+ for doc in documents:
+ assert doc.indexing_status == "waiting"
+ assert mock_db.add.call_count == len(documents)
+ # Commit is called once per document
+ assert mock_db.commit.call_count == len(documents)
+ mock_task.delay.assert_called_once()
+
+
+# ==================== Retrieval Configuration Tests ====================
+
+
+class TestDatasetServiceRetrievalConfiguration:
+ """
+ Comprehensive unit tests for retrieval configuration.
+
+ Covers:
+ - Retrieval model configuration
+ - Search method configuration
+ - Top-k and score threshold settings
+ - Reranking model configuration
+
+ Retrieval configuration controls how documents are searched and ranked:
+
+ Search Methods:
+ - semantic_search: Uses vector similarity (cosine distance)
+ - full_text_search: Uses keyword matching (BM25)
+ - hybrid_search: Combines both methods with weighted scores
+
+ Parameters:
+ - top_k: Number of results to return (default: 2-10)
+ - score_threshold: Minimum similarity score (0.0-1.0)
+ - reranking_enable: Whether to use reranking model for better results
+
+ Reranking:
+ After initial retrieval, a reranking model (e.g., Cohere rerank) can
+ reorder results for better relevance. This is more accurate but slower.
+ """
+
+ @pytest.fixture
+ def mock_dataset_service_dependencies(self):
+ """
+ Common mock setup for retrieval configuration tests.
+
+ Patches:
+ - get_dataset: Retrieves dataset with retrieval configuration
+ - db.session: Database operations for configuration updates
+ """
+ with (
+ patch("services.dataset_service.DatasetService.get_dataset") as mock_get_dataset,
+ patch("services.dataset_service.db.session") as mock_db,
+ ):
+ yield {
+ "get_dataset": mock_get_dataset,
+ "db_session": mock_db,
+ }
+
+ def test_get_dataset_retrieval_configuration(self, mock_dataset_service_dependencies):
+ """Test retrieving dataset with retrieval configuration."""
+ # Arrange
+ dataset_id = "dataset-123"
+ retrieval_model_config = {
+ "search_method": "semantic_search",
+ "top_k": 5,
+ "score_threshold": 0.5,
+ "reranking_enable": True,
+ }
+ dataset = DatasetServiceTestDataFactory.create_dataset_mock(
+ dataset_id=dataset_id, retrieval_model=retrieval_model_config
+ )
+
+ mock_dataset_service_dependencies["get_dataset"].return_value = dataset
+
+ # Act
+ result = DatasetService.get_dataset(dataset_id)
+
+ # Assert
+ assert result is not None
+ assert result.retrieval_model == retrieval_model_config
+ assert result.retrieval_model["search_method"] == "semantic_search"
+ assert result.retrieval_model["top_k"] == 5
+ assert result.retrieval_model["score_threshold"] == 0.5
+
+ def test_update_dataset_retrieval_configuration(self, mock_dataset_service_dependencies):
+ """Test updating dataset retrieval configuration."""
+ # Arrange
+ dataset = DatasetServiceTestDataFactory.create_dataset_mock(
+ provider="vendor",
+ indexing_technique="high_quality",
+ retrieval_model={"search_method": "semantic_search", "top_k": 2},
+ )
+
+ with (
+ patch("services.dataset_service.DatasetService._has_dataset_same_name") as mock_has_same_name,
+ patch("services.dataset_service.DatasetService.check_dataset_permission") as mock_check_perm,
+ patch("services.dataset_service.naive_utc_now") as mock_time,
+ patch(
+ "services.dataset_service.DatasetService._update_pipeline_knowledge_base_node_data"
+ ) as mock_update_pipeline,
+ ):
+ mock_dataset_service_dependencies["get_dataset"].return_value = dataset
+ mock_has_same_name.return_value = False
+ mock_time.return_value = "2024-01-01T00:00:00"
+
+ user = DatasetServiceTestDataFactory.create_account_mock()
+
+ new_retrieval_config = {
+ "search_method": "full_text_search",
+ "top_k": 10,
+ "score_threshold": 0.7,
+ }
+
+ update_data = {
+ "indexing_technique": "high_quality",
+ "retrieval_model": new_retrieval_config,
+ }
+
+ # Act
+ result = DatasetService.update_dataset("dataset-123", update_data, user)
+
+ # Assert
+ mock_dataset_service_dependencies[
+ "db_session"
+ ].query.return_value.filter_by.return_value.update.assert_called_once()
+ call_args = mock_dataset_service_dependencies[
+ "db_session"
+ ].query.return_value.filter_by.return_value.update.call_args[0][0]
+ assert call_args["retrieval_model"] == new_retrieval_config
+ assert result == dataset
+
+ def test_create_dataset_with_retrieval_model_and_reranking(self, mock_dataset_service_dependencies):
+ """Test creating dataset with retrieval model and reranking configuration."""
+ # Arrange
+ tenant_id = str(uuid4())
+ account = DatasetServiceTestDataFactory.create_account_mock(tenant_id=tenant_id)
+ name = "Dataset with Reranking"
+
+ # Mock database query
+ mock_query = Mock()
+ mock_query.filter_by.return_value.first.return_value = None
+ mock_dataset_service_dependencies["db_session"].query.return_value = mock_query
+
+ # Mock retrieval model with reranking
+ retrieval_model = Mock(spec=RetrievalModel)
+ retrieval_model.model_dump.return_value = {
+ "search_method": "semantic_search",
+ "top_k": 3,
+ "score_threshold": 0.6,
+ "reranking_enable": True,
+ }
+ reranking_model = Mock()
+ reranking_model.reranking_provider_name = "cohere"
+ reranking_model.reranking_model_name = "rerank-english-v2.0"
+ retrieval_model.reranking_model = reranking_model
+
+ # Mock model manager
+ embedding_model = DatasetServiceTestDataFactory.create_embedding_model_mock()
+ mock_model_manager_instance = Mock()
+ mock_model_manager_instance.get_default_model_instance.return_value = embedding_model
+
+ with (
+ patch("services.dataset_service.ModelManager") as mock_model_manager,
+ patch("services.dataset_service.DatasetService.check_embedding_model_setting") as mock_check_embedding,
+ patch("services.dataset_service.DatasetService.check_reranking_model_setting") as mock_check_reranking,
+ ):
+ mock_model_manager.return_value = mock_model_manager_instance
+
+ mock_db = mock_dataset_service_dependencies["db_session"]
+ mock_db.add = Mock()
+ mock_db.flush = Mock()
+ mock_db.commit = Mock()
+
+ # Act
+ result = DatasetService.create_empty_dataset(
+ tenant_id=tenant_id,
+ name=name,
+ description=None,
+ indexing_technique="high_quality",
+ account=account,
+ retrieval_model=retrieval_model,
+ )
+
+ # Assert
+ assert result.retrieval_model == retrieval_model.model_dump()
+ mock_check_reranking.assert_called_once_with(tenant_id, "cohere", "rerank-english-v2.0")
+ mock_db.commit.assert_called_once()
From b2a7cec644e79c5c5e38f983d9466254414a7b5d Mon Sep 17 00:00:00 2001
From: Satoshi Dev <162055292+0xsatoshi99@users.noreply.github.com>
Date: Wed, 26 Nov 2025 06:50:20 -0800
Subject: [PATCH 010/431] add unit tests for template transform node (#28595)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
---
.../nodes/template_transform/__init__.py | 1 +
.../nodes/template_transform/entities_spec.py | 225 ++++++++++
.../template_transform_node_spec.py | 414 ++++++++++++++++++
3 files changed, 640 insertions(+)
create mode 100644 api/tests/unit_tests/core/workflow/nodes/template_transform/__init__.py
create mode 100644 api/tests/unit_tests/core/workflow/nodes/template_transform/entities_spec.py
create mode 100644 api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py
diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/__init__.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/__init__.py
new file mode 100644
index 0000000000..8b13789179
--- /dev/null
+++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/__init__.py
@@ -0,0 +1 @@
+
diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/entities_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/entities_spec.py
new file mode 100644
index 0000000000..5eb302798f
--- /dev/null
+++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/entities_spec.py
@@ -0,0 +1,225 @@
+import pytest
+from pydantic import ValidationError
+
+from core.workflow.enums import ErrorStrategy
+from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
+
+
+class TestTemplateTransformNodeData:
+ """Test suite for TemplateTransformNodeData entity."""
+
+ def test_valid_template_transform_node_data(self):
+ """Test creating valid TemplateTransformNodeData."""
+ data = {
+ "title": "Template Transform",
+ "desc": "Transform data using Jinja2 template",
+ "variables": [
+ {"variable": "name", "value_selector": ["sys", "user_name"]},
+ {"variable": "age", "value_selector": ["sys", "user_age"]},
+ ],
+ "template": "Hello {{ name }}, you are {{ age }} years old!",
+ }
+
+ node_data = TemplateTransformNodeData.model_validate(data)
+
+ assert node_data.title == "Template Transform"
+ assert node_data.desc == "Transform data using Jinja2 template"
+ assert len(node_data.variables) == 2
+ assert node_data.variables[0].variable == "name"
+ assert node_data.variables[0].value_selector == ["sys", "user_name"]
+ assert node_data.variables[1].variable == "age"
+ assert node_data.variables[1].value_selector == ["sys", "user_age"]
+ assert node_data.template == "Hello {{ name }}, you are {{ age }} years old!"
+
+ def test_template_transform_node_data_with_empty_variables(self):
+ """Test TemplateTransformNodeData with no variables."""
+ data = {
+ "title": "Static Template",
+ "variables": [],
+ "template": "This is a static template with no variables.",
+ }
+
+ node_data = TemplateTransformNodeData.model_validate(data)
+
+ assert node_data.title == "Static Template"
+ assert len(node_data.variables) == 0
+ assert node_data.template == "This is a static template with no variables."
+
+ def test_template_transform_node_data_with_complex_template(self):
+ """Test TemplateTransformNodeData with complex Jinja2 template."""
+ data = {
+ "title": "Complex Template",
+ "variables": [
+ {"variable": "items", "value_selector": ["sys", "item_list"]},
+ {"variable": "total", "value_selector": ["sys", "total_count"]},
+ ],
+ "template": (
+ "{% for item in items %}{{ item }}{% if not loop.last %}, {% endif %}{% endfor %}. Total: {{ total }}"
+ ),
+ }
+
+ node_data = TemplateTransformNodeData.model_validate(data)
+
+ assert node_data.title == "Complex Template"
+ assert len(node_data.variables) == 2
+ assert "{% for item in items %}" in node_data.template
+ assert "{{ total }}" in node_data.template
+
+ def test_template_transform_node_data_with_error_strategy(self):
+ """Test TemplateTransformNodeData with error handling strategy."""
+ data = {
+ "title": "Template with Error Handling",
+ "variables": [{"variable": "value", "value_selector": ["sys", "input"]}],
+ "template": "{{ value }}",
+ "error_strategy": "fail-branch",
+ }
+
+ node_data = TemplateTransformNodeData.model_validate(data)
+
+ assert node_data.error_strategy == ErrorStrategy.FAIL_BRANCH
+
+ def test_template_transform_node_data_with_retry_config(self):
+ """Test TemplateTransformNodeData with retry configuration."""
+ data = {
+ "title": "Template with Retry",
+ "variables": [{"variable": "data", "value_selector": ["sys", "data"]}],
+ "template": "{{ data }}",
+ "retry_config": {"enabled": True, "max_retries": 3, "retry_interval": 1000},
+ }
+
+ node_data = TemplateTransformNodeData.model_validate(data)
+
+ assert node_data.retry_config.enabled is True
+ assert node_data.retry_config.max_retries == 3
+ assert node_data.retry_config.retry_interval == 1000
+
+ def test_template_transform_node_data_missing_required_fields(self):
+ """Test that missing required fields raises ValidationError."""
+ data = {
+ "title": "Incomplete Template",
+ # Missing 'variables' and 'template'
+ }
+
+ with pytest.raises(ValidationError) as exc_info:
+ TemplateTransformNodeData.model_validate(data)
+
+ errors = exc_info.value.errors()
+ assert len(errors) >= 2
+ error_fields = {error["loc"][0] for error in errors}
+ assert "variables" in error_fields
+ assert "template" in error_fields
+
+ def test_template_transform_node_data_invalid_variable_selector(self):
+ """Test that invalid variable selector format raises ValidationError."""
+ data = {
+ "title": "Invalid Variable",
+ "variables": [
+ {"variable": "name", "value_selector": "invalid_format"} # Should be list
+ ],
+ "template": "{{ name }}",
+ }
+
+ with pytest.raises(ValidationError):
+ TemplateTransformNodeData.model_validate(data)
+
+ def test_template_transform_node_data_with_default_value_dict(self):
+ """Test TemplateTransformNodeData with default value dictionary."""
+ data = {
+ "title": "Template with Defaults",
+ "variables": [
+ {"variable": "name", "value_selector": ["sys", "user_name"]},
+ {"variable": "greeting", "value_selector": ["sys", "greeting"]},
+ ],
+ "template": "{{ greeting }} {{ name }}!",
+ "default_value_dict": {"greeting": "Hello", "name": "Guest"},
+ }
+
+ node_data = TemplateTransformNodeData.model_validate(data)
+
+ assert node_data.default_value_dict == {"greeting": "Hello", "name": "Guest"}
+
+ def test_template_transform_node_data_with_nested_selectors(self):
+ """Test TemplateTransformNodeData with nested variable selectors."""
+ data = {
+ "title": "Nested Selectors",
+ "variables": [
+ {"variable": "user_info", "value_selector": ["sys", "user", "profile", "name"]},
+ {"variable": "settings", "value_selector": ["sys", "config", "app", "theme"]},
+ ],
+ "template": "User: {{ user_info }}, Theme: {{ settings }}",
+ }
+
+ node_data = TemplateTransformNodeData.model_validate(data)
+
+ assert len(node_data.variables) == 2
+ assert node_data.variables[0].value_selector == ["sys", "user", "profile", "name"]
+ assert node_data.variables[1].value_selector == ["sys", "config", "app", "theme"]
+
+ def test_template_transform_node_data_with_multiline_template(self):
+ """Test TemplateTransformNodeData with multiline template."""
+ data = {
+ "title": "Multiline Template",
+ "variables": [
+ {"variable": "title", "value_selector": ["sys", "title"]},
+ {"variable": "content", "value_selector": ["sys", "content"]},
+ ],
+ "template": """
+# {{ title }}
+
+{{ content }}
+
+---
+Generated by Template Transform Node
+ """,
+ }
+
+ node_data = TemplateTransformNodeData.model_validate(data)
+
+ assert "# {{ title }}" in node_data.template
+ assert "{{ content }}" in node_data.template
+ assert "Generated by Template Transform Node" in node_data.template
+
+ def test_template_transform_node_data_serialization(self):
+ """Test that TemplateTransformNodeData can be serialized and deserialized."""
+ original_data = {
+ "title": "Serialization Test",
+ "desc": "Test serialization",
+ "variables": [{"variable": "test", "value_selector": ["sys", "test"]}],
+ "template": "{{ test }}",
+ }
+
+ node_data = TemplateTransformNodeData.model_validate(original_data)
+ serialized = node_data.model_dump()
+ deserialized = TemplateTransformNodeData.model_validate(serialized)
+
+ assert deserialized.title == node_data.title
+ assert deserialized.desc == node_data.desc
+ assert len(deserialized.variables) == len(node_data.variables)
+ assert deserialized.template == node_data.template
+
+ def test_template_transform_node_data_with_special_characters(self):
+ """Test TemplateTransformNodeData with special characters in template."""
+ data = {
+ "title": "Special Characters",
+ "variables": [{"variable": "text", "value_selector": ["sys", "input"]}],
+ "template": "Special: {{ text }} | Symbols: @#$%^&*() | Unicode: 你好 🎉",
+ }
+
+ node_data = TemplateTransformNodeData.model_validate(data)
+
+ assert "@#$%^&*()" in node_data.template
+ assert "你好" in node_data.template
+ assert "🎉" in node_data.template
+
+ def test_template_transform_node_data_empty_template(self):
+ """Test TemplateTransformNodeData with empty template string."""
+ data = {
+ "title": "Empty Template",
+ "variables": [],
+ "template": "",
+ }
+
+ node_data = TemplateTransformNodeData.model_validate(data)
+
+ assert node_data.template == ""
+ assert len(node_data.variables) == 0
diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py
new file mode 100644
index 0000000000..1a67d5c3e3
--- /dev/null
+++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py
@@ -0,0 +1,414 @@
+from unittest.mock import MagicMock, patch
+
+import pytest
+from core.workflow.graph_engine.entities.graph import Graph
+from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
+from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
+
+from core.helper.code_executor.code_executor import CodeExecutionError
+from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
+from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
+from models.workflow import WorkflowType
+
+
+class TestTemplateTransformNode:
+ """Comprehensive test suite for TemplateTransformNode."""
+
+ @pytest.fixture
+ def mock_graph_runtime_state(self):
+ """Create a mock GraphRuntimeState with variable pool."""
+ mock_state = MagicMock(spec=GraphRuntimeState)
+ mock_variable_pool = MagicMock()
+ mock_state.variable_pool = mock_variable_pool
+ return mock_state
+
+ @pytest.fixture
+ def mock_graph(self):
+ """Create a mock Graph."""
+ return MagicMock(spec=Graph)
+
+ @pytest.fixture
+ def graph_init_params(self):
+ """Create a mock GraphInitParams."""
+ return GraphInitParams(
+ tenant_id="test_tenant",
+ app_id="test_app",
+ workflow_type=WorkflowType.WORKFLOW,
+ workflow_id="test_workflow",
+ graph_config={},
+ user_id="test_user",
+ user_from="test",
+ invoke_from="test",
+ call_depth=0,
+ )
+
+ @pytest.fixture
+ def basic_node_data(self):
+ """Create basic node data for testing."""
+ return {
+ "title": "Template Transform",
+ "desc": "Transform data using template",
+ "variables": [
+ {"variable": "name", "value_selector": ["sys", "user_name"]},
+ {"variable": "age", "value_selector": ["sys", "user_age"]},
+ ],
+ "template": "Hello {{ name }}, you are {{ age }} years old!",
+ }
+
+ def test_node_initialization(self, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params):
+ """Test that TemplateTransformNode initializes correctly."""
+ node = TemplateTransformNode(
+ id="test_node",
+ config=basic_node_data,
+ graph_init_params=graph_init_params,
+ graph=mock_graph,
+ graph_runtime_state=mock_graph_runtime_state,
+ )
+
+ assert node.node_type == NodeType.TEMPLATE_TRANSFORM
+ assert node._node_data.title == "Template Transform"
+ assert len(node._node_data.variables) == 2
+ assert node._node_data.template == "Hello {{ name }}, you are {{ age }} years old!"
+
+ def test_get_title(self, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params):
+ """Test _get_title method."""
+ node = TemplateTransformNode(
+ id="test_node",
+ config=basic_node_data,
+ graph_init_params=graph_init_params,
+ graph=mock_graph,
+ graph_runtime_state=mock_graph_runtime_state,
+ )
+
+ assert node._get_title() == "Template Transform"
+
+ def test_get_description(self, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params):
+ """Test _get_description method."""
+ node = TemplateTransformNode(
+ id="test_node",
+ config=basic_node_data,
+ graph_init_params=graph_init_params,
+ graph=mock_graph,
+ graph_runtime_state=mock_graph_runtime_state,
+ )
+
+ assert node._get_description() == "Transform data using template"
+
+ def test_get_error_strategy(self, mock_graph, mock_graph_runtime_state, graph_init_params):
+ """Test _get_error_strategy method."""
+ node_data = {
+ "title": "Test",
+ "variables": [],
+ "template": "test",
+ "error_strategy": "fail-branch",
+ }
+
+ node = TemplateTransformNode(
+ id="test_node",
+ config=node_data,
+ graph_init_params=graph_init_params,
+ graph=mock_graph,
+ graph_runtime_state=mock_graph_runtime_state,
+ )
+
+ assert node._get_error_strategy() == ErrorStrategy.FAIL_BRANCH
+
+ def test_get_default_config(self):
+ """Test get_default_config class method."""
+ config = TemplateTransformNode.get_default_config()
+
+ assert config["type"] == "template-transform"
+ assert "config" in config
+ assert "variables" in config["config"]
+ assert "template" in config["config"]
+ assert config["config"]["template"] == "{{ arg1 }}"
+
+ def test_version(self):
+ """Test version class method."""
+ assert TemplateTransformNode.version() == "1"
+
+ @patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
+ def test_run_simple_template(
+ self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
+ ):
+ """Test _run with simple template transformation."""
+ # Setup mock variable pool
+ mock_name_value = MagicMock()
+ mock_name_value.to_object.return_value = "Alice"
+ mock_age_value = MagicMock()
+ mock_age_value.to_object.return_value = 30
+
+ variable_map = {
+ ("sys", "user_name"): mock_name_value,
+ ("sys", "user_age"): mock_age_value,
+ }
+ mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector))
+
+ # Setup mock executor
+ mock_execute.return_value = {"result": "Hello Alice, you are 30 years old!"}
+
+ node = TemplateTransformNode(
+ id="test_node",
+ config=basic_node_data,
+ graph_init_params=graph_init_params,
+ graph=mock_graph,
+ graph_runtime_state=mock_graph_runtime_state,
+ )
+
+ result = node._run()
+
+ assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+ assert result.outputs["output"] == "Hello Alice, you are 30 years old!"
+ assert result.inputs["name"] == "Alice"
+ assert result.inputs["age"] == 30
+
+ @patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
+ def test_run_with_none_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
+ """Test _run with None variable values."""
+ node_data = {
+ "title": "Test",
+ "variables": [{"variable": "value", "value_selector": ["sys", "missing"]}],
+ "template": "Value: {{ value }}",
+ }
+
+ mock_graph_runtime_state.variable_pool.get.return_value = None
+ mock_execute.return_value = {"result": "Value: "}
+
+ node = TemplateTransformNode(
+ id="test_node",
+ config=node_data,
+ graph_init_params=graph_init_params,
+ graph=mock_graph,
+ graph_runtime_state=mock_graph_runtime_state,
+ )
+
+ result = node._run()
+
+ assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+ assert result.inputs["value"] is None
+
+ @patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
+ def test_run_with_code_execution_error(
+ self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
+ ):
+ """Test _run when code execution fails."""
+ mock_graph_runtime_state.variable_pool.get.return_value = MagicMock()
+ mock_execute.side_effect = CodeExecutionError("Template syntax error")
+
+ node = TemplateTransformNode(
+ id="test_node",
+ config=basic_node_data,
+ graph_init_params=graph_init_params,
+ graph=mock_graph,
+ graph_runtime_state=mock_graph_runtime_state,
+ )
+
+ result = node._run()
+
+ assert result.status == WorkflowNodeExecutionStatus.FAILED
+ assert "Template syntax error" in result.error
+
+ @patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
+ @patch("core.workflow.nodes.template_transform.template_transform_node.MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH", 10)
+ def test_run_output_length_exceeds_limit(
+ self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
+ ):
+ """Test _run when output exceeds maximum length."""
+ mock_graph_runtime_state.variable_pool.get.return_value = MagicMock()
+ mock_execute.return_value = {"result": "This is a very long output that exceeds the limit"}
+
+ node = TemplateTransformNode(
+ id="test_node",
+ config=basic_node_data,
+ graph_init_params=graph_init_params,
+ graph=mock_graph,
+ graph_runtime_state=mock_graph_runtime_state,
+ )
+
+ result = node._run()
+
+ assert result.status == WorkflowNodeExecutionStatus.FAILED
+ assert "Output length exceeds" in result.error
+
+ @patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
+ def test_run_with_complex_jinja2_template(
+ self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params
+ ):
+ """Test _run with complex Jinja2 template including loops and conditions."""
+ node_data = {
+ "title": "Complex Template",
+ "variables": [
+ {"variable": "items", "value_selector": ["sys", "items"]},
+ {"variable": "show_total", "value_selector": ["sys", "show_total"]},
+ ],
+ "template": (
+ "{% for item in items %}{{ item }}{% if not loop.last %}, {% endif %}{% endfor %}"
+ "{% if show_total %} (Total: {{ items|length }}){% endif %}"
+ ),
+ }
+
+ mock_items = MagicMock()
+ mock_items.to_object.return_value = ["apple", "banana", "orange"]
+ mock_show_total = MagicMock()
+ mock_show_total.to_object.return_value = True
+
+ variable_map = {
+ ("sys", "items"): mock_items,
+ ("sys", "show_total"): mock_show_total,
+ }
+ mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector))
+ mock_execute.return_value = {"result": "apple, banana, orange (Total: 3)"}
+
+ node = TemplateTransformNode(
+ id="test_node",
+ config=node_data,
+ graph_init_params=graph_init_params,
+ graph=mock_graph,
+ graph_runtime_state=mock_graph_runtime_state,
+ )
+
+ result = node._run()
+
+ assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+ assert result.outputs["output"] == "apple, banana, orange (Total: 3)"
+
+ def test_extract_variable_selector_to_variable_mapping(self):
+ """Test _extract_variable_selector_to_variable_mapping class method."""
+ node_data = {
+ "title": "Test",
+ "variables": [
+ {"variable": "var1", "value_selector": ["sys", "input1"]},
+ {"variable": "var2", "value_selector": ["sys", "input2"]},
+ ],
+ "template": "{{ var1 }} {{ var2 }}",
+ }
+
+ mapping = TemplateTransformNode._extract_variable_selector_to_variable_mapping(
+ graph_config={}, node_id="node_123", node_data=node_data
+ )
+
+ assert "node_123.var1" in mapping
+ assert "node_123.var2" in mapping
+ assert mapping["node_123.var1"] == ["sys", "input1"]
+ assert mapping["node_123.var2"] == ["sys", "input2"]
+
+ @patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
+ def test_run_with_empty_variables(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
+ """Test _run with no variables (static template)."""
+ node_data = {
+ "title": "Static Template",
+ "variables": [],
+ "template": "This is a static message.",
+ }
+
+ mock_execute.return_value = {"result": "This is a static message."}
+
+ node = TemplateTransformNode(
+ id="test_node",
+ config=node_data,
+ graph_init_params=graph_init_params,
+ graph=mock_graph,
+ graph_runtime_state=mock_graph_runtime_state,
+ )
+
+ result = node._run()
+
+ assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+ assert result.outputs["output"] == "This is a static message."
+ assert result.inputs == {}
+
+ @patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
+ def test_run_with_numeric_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
+ """Test _run with numeric variable values."""
+ node_data = {
+ "title": "Numeric Template",
+ "variables": [
+ {"variable": "price", "value_selector": ["sys", "price"]},
+ {"variable": "quantity", "value_selector": ["sys", "quantity"]},
+ ],
+ "template": "Total: ${{ price * quantity }}",
+ }
+
+ mock_price = MagicMock()
+ mock_price.to_object.return_value = 10.5
+ mock_quantity = MagicMock()
+ mock_quantity.to_object.return_value = 3
+
+ variable_map = {
+ ("sys", "price"): mock_price,
+ ("sys", "quantity"): mock_quantity,
+ }
+ mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector))
+ mock_execute.return_value = {"result": "Total: $31.5"}
+
+ node = TemplateTransformNode(
+ id="test_node",
+ config=node_data,
+ graph_init_params=graph_init_params,
+ graph=mock_graph,
+ graph_runtime_state=mock_graph_runtime_state,
+ )
+
+ result = node._run()
+
+ assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+ assert result.outputs["output"] == "Total: $31.5"
+
+ @patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
+ def test_run_with_dict_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
+ """Test _run with dictionary variable values."""
+ node_data = {
+ "title": "Dict Template",
+ "variables": [{"variable": "user", "value_selector": ["sys", "user_data"]}],
+ "template": "Name: {{ user.name }}, Email: {{ user.email }}",
+ }
+
+ mock_user = MagicMock()
+ mock_user.to_object.return_value = {"name": "John Doe", "email": "john@example.com"}
+
+ mock_graph_runtime_state.variable_pool.get.return_value = mock_user
+ mock_execute.return_value = {"result": "Name: John Doe, Email: john@example.com"}
+
+ node = TemplateTransformNode(
+ id="test_node",
+ config=node_data,
+ graph_init_params=graph_init_params,
+ graph=mock_graph,
+ graph_runtime_state=mock_graph_runtime_state,
+ )
+
+ result = node._run()
+
+ assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+ assert "John Doe" in result.outputs["output"]
+ assert "john@example.com" in result.outputs["output"]
+
+ @patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
+ def test_run_with_list_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
+ """Test _run with list variable values."""
+ node_data = {
+ "title": "List Template",
+ "variables": [{"variable": "tags", "value_selector": ["sys", "tags"]}],
+ "template": "Tags: {% for tag in tags %}#{{ tag }} {% endfor %}",
+ }
+
+ mock_tags = MagicMock()
+ mock_tags.to_object.return_value = ["python", "ai", "workflow"]
+
+ mock_graph_runtime_state.variable_pool.get.return_value = mock_tags
+ mock_execute.return_value = {"result": "Tags: #python #ai #workflow "}
+
+ node = TemplateTransformNode(
+ id="test_node",
+ config=node_data,
+ graph_init_params=graph_init_params,
+ graph=mock_graph,
+ graph_runtime_state=mock_graph_runtime_state,
+ )
+
+ result = node._run()
+
+ assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+ assert "#python" in result.outputs["output"]
+ assert "#ai" in result.outputs["output"]
+ assert "#workflow" in result.outputs["output"]
From a4c57017d5d371507a9b78c41827f3f563ac41d8 Mon Sep 17 00:00:00 2001
From: crazywoola <100913391+crazywoola@users.noreply.github.com>
Date: Wed, 26 Nov 2025 23:30:41 +0800
Subject: [PATCH 011/431] add: badges (#28722)
---
README.md | 6 ++++++
docs/ar-SA/README.md | 6 ++++++
docs/bn-BD/README.md | 6 ++++++
docs/de-DE/README.md | 6 ++++++
docs/es-ES/README.md | 6 ++++++
docs/fr-FR/README.md | 6 ++++++
docs/hi-IN/README.md | 6 ++++++
docs/it-IT/README.md | 6 ++++++
docs/ja-JP/README.md | 6 ++++++
docs/ko-KR/README.md | 6 ++++++
docs/pt-BR/README.md | 6 ++++++
docs/sl-SI/README.md | 6 ++++++
docs/tlh/README.md | 6 ++++++
docs/tr-TR/README.md | 6 ++++++
docs/vi-VN/README.md | 6 ++++++
docs/zh-CN/README.md | 6 ++++++
docs/zh-TW/README.md | 6 ++++++
17 files changed, 102 insertions(+)
diff --git a/README.md b/README.md
index e5cc05fbc0..09ba1f634b 100644
--- a/README.md
+++ b/README.md
@@ -36,6 +36,12 @@
+
+
+
+
+
+
diff --git a/docs/ar-SA/README.md b/docs/ar-SA/README.md
index 30920ed983..99e3e3567e 100644
--- a/docs/ar-SA/README.md
+++ b/docs/ar-SA/README.md
@@ -32,6 +32,12 @@
+
+
+
+
+
+
diff --git a/docs/bn-BD/README.md b/docs/bn-BD/README.md
index 5430364ef9..f3fa68b466 100644
--- a/docs/bn-BD/README.md
+++ b/docs/bn-BD/README.md
@@ -36,6 +36,12 @@
+
+
+
+
+
+
diff --git a/docs/de-DE/README.md b/docs/de-DE/README.md
index 6c49fbdfc3..c71a0bfccf 100644
--- a/docs/de-DE/README.md
+++ b/docs/de-DE/README.md
@@ -36,6 +36,12 @@
+
+
+
+
+
+
diff --git a/docs/es-ES/README.md b/docs/es-ES/README.md
index ae83d416e3..da81b51d6a 100644
--- a/docs/es-ES/README.md
+++ b/docs/es-ES/README.md
@@ -32,6 +32,12 @@
+
+
+
+
+
+
diff --git a/docs/fr-FR/README.md b/docs/fr-FR/README.md
index b7d006a927..03f3221798 100644
--- a/docs/fr-FR/README.md
+++ b/docs/fr-FR/README.md
@@ -32,6 +32,12 @@
+
+
+
+
+
+
diff --git a/docs/hi-IN/README.md b/docs/hi-IN/README.md
index 7c4fc70db0..bedeaa6246 100644
--- a/docs/hi-IN/README.md
+++ b/docs/hi-IN/README.md
@@ -36,6 +36,12 @@
+
+
+
+
+
+
diff --git a/docs/it-IT/README.md b/docs/it-IT/README.md
index 598e87ec25..2e96335d3e 100644
--- a/docs/it-IT/README.md
+++ b/docs/it-IT/README.md
@@ -36,6 +36,12 @@
+
+
+
+
+
+
diff --git a/docs/ja-JP/README.md b/docs/ja-JP/README.md
index f9e700d1df..659ffbda51 100644
--- a/docs/ja-JP/README.md
+++ b/docs/ja-JP/README.md
@@ -32,6 +32,12 @@
+
+
+
+
+
+
diff --git a/docs/ko-KR/README.md b/docs/ko-KR/README.md
index 4e4b82e920..2f6c526ef2 100644
--- a/docs/ko-KR/README.md
+++ b/docs/ko-KR/README.md
@@ -32,6 +32,12 @@
+
+
+
+
+
+
diff --git a/docs/pt-BR/README.md b/docs/pt-BR/README.md
index 444faa0a67..ed29ec0294 100644
--- a/docs/pt-BR/README.md
+++ b/docs/pt-BR/README.md
@@ -36,6 +36,12 @@
+
+
+
+
+
+
diff --git a/docs/sl-SI/README.md b/docs/sl-SI/README.md
index 04dc3b5dff..caef2c303c 100644
--- a/docs/sl-SI/README.md
+++ b/docs/sl-SI/README.md
@@ -33,6 +33,12 @@
+
+
+
+
+
+
diff --git a/docs/tlh/README.md b/docs/tlh/README.md
index b1e3016efd..a25849c443 100644
--- a/docs/tlh/README.md
+++ b/docs/tlh/README.md
@@ -32,6 +32,12 @@
+
+
+
+
+
+
diff --git a/docs/tr-TR/README.md b/docs/tr-TR/README.md
index 965a1704be..6361ca5dd9 100644
--- a/docs/tr-TR/README.md
+++ b/docs/tr-TR/README.md
@@ -32,6 +32,12 @@
+
+
+
+
+
+
diff --git a/docs/vi-VN/README.md b/docs/vi-VN/README.md
index 07329e84cd..3042a98d95 100644
--- a/docs/vi-VN/README.md
+++ b/docs/vi-VN/README.md
@@ -32,6 +32,12 @@
+
+
+
+
+
+
diff --git a/docs/zh-CN/README.md b/docs/zh-CN/README.md
index 888a0d7f12..15bb447ad8 100644
--- a/docs/zh-CN/README.md
+++ b/docs/zh-CN/README.md
@@ -32,6 +32,12 @@
+
+
+
+
+
+
diff --git a/docs/zh-TW/README.md b/docs/zh-TW/README.md
index d8c484a6d4..14b343ba29 100644
--- a/docs/zh-TW/README.md
+++ b/docs/zh-TW/README.md
@@ -36,6 +36,12 @@

+
+ 
+
+ 
+
+
From 4ccc150fd190a9151f0e9d674f18ff5773fb068c Mon Sep 17 00:00:00 2001
From: aka James4u
Date: Wed, 26 Nov 2025 07:33:46 -0800
Subject: [PATCH 012/431] test: add comprehensive unit tests for
ExternalDatasetService (external knowledge API integration) (#28716)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
---
.../services/external_dataset_service.py | 920 ++++++++++++++++++
1 file changed, 920 insertions(+)
create mode 100644 api/tests/unit_tests/services/external_dataset_service.py
diff --git a/api/tests/unit_tests/services/external_dataset_service.py b/api/tests/unit_tests/services/external_dataset_service.py
new file mode 100644
index 0000000000..1647eb3e85
--- /dev/null
+++ b/api/tests/unit_tests/services/external_dataset_service.py
@@ -0,0 +1,920 @@
+"""
+Extensive unit tests for ``ExternalDatasetService``.
+
+This module focuses on the *external dataset service* surface area, which is responsible
+for integrating with **external knowledge APIs** and wiring them into Dify datasets.
+
+The goal of this test suite is twofold:
+
+- Provide **high‑confidence regression coverage** for all public helpers on
+ ``ExternalDatasetService``.
+- Serve as **executable documentation** for how external API integration is expected
+ to behave in different scenarios (happy paths, validation failures, and error codes).
+
+The file intentionally contains **rich comments and generous spacing** in order to make
+each scenario easy to scan during reviews.
+"""
+
+from __future__ import annotations
+
+from types import SimpleNamespace
+from typing import Any, cast
+from unittest.mock import MagicMock, Mock, patch
+
+import httpx
+import pytest
+
+from constants import HIDDEN_VALUE
+from models.dataset import Dataset, ExternalKnowledgeApis, ExternalKnowledgeBindings
+from services.entities.external_knowledge_entities.external_knowledge_entities import (
+ Authorization,
+ AuthorizationConfig,
+ ExternalKnowledgeApiSetting,
+)
+from services.errors.dataset import DatasetNameDuplicateError
+from services.external_knowledge_service import ExternalDatasetService
+
+
+class ExternalDatasetTestDataFactory:
+ """
+ Factory helpers for building *lightweight* mocks for external knowledge tests.
+
+ These helpers are intentionally small and explicit:
+
+ - They avoid pulling in unnecessary fixtures.
+ - They reflect the minimal contract that the service under test cares about.
+ """
+
+ @staticmethod
+ def create_external_api(
+ api_id: str = "api-123",
+ tenant_id: str = "tenant-1",
+ name: str = "Test API",
+ description: str = "Description",
+ settings: dict | None = None,
+ ) -> ExternalKnowledgeApis:
+ """
+ Create a concrete ``ExternalKnowledgeApis`` instance with minimal fields.
+
+ Using the real SQLAlchemy model (instead of a pure Mock) makes it easier to
+ exercise ``settings_dict`` and other convenience properties if needed.
+ """
+
+ instance = ExternalKnowledgeApis(
+ tenant_id=tenant_id,
+ name=name,
+ description=description,
+ settings=None if settings is None else cast(str, pytest.approx), # type: ignore[assignment]
+ )
+
+ # Overwrite generated id for determinism in assertions.
+ instance.id = api_id
+ return instance
+
+ @staticmethod
+ def create_dataset(
+ dataset_id: str = "ds-1",
+ tenant_id: str = "tenant-1",
+ name: str = "External Dataset",
+ provider: str = "external",
+ ) -> Dataset:
+ """
+ Build a small ``Dataset`` instance representing an external dataset.
+ """
+
+ dataset = Dataset(
+ tenant_id=tenant_id,
+ name=name,
+ description="",
+ provider=provider,
+ created_by="user-1",
+ )
+ dataset.id = dataset_id
+ return dataset
+
+ @staticmethod
+ def create_external_binding(
+ tenant_id: str = "tenant-1",
+ dataset_id: str = "ds-1",
+ api_id: str = "api-1",
+ external_knowledge_id: str = "knowledge-1",
+ ) -> ExternalKnowledgeBindings:
+ """
+ Small helper for a binding between dataset and external knowledge API.
+ """
+
+ binding = ExternalKnowledgeBindings(
+ tenant_id=tenant_id,
+ dataset_id=dataset_id,
+ external_knowledge_api_id=api_id,
+ external_knowledge_id=external_knowledge_id,
+ created_by="user-1",
+ )
+ return binding
+
+
+# ---------------------------------------------------------------------------
+# get_external_knowledge_apis
+# ---------------------------------------------------------------------------
+
+
+class TestExternalDatasetServiceGetExternalKnowledgeApis:
+ """
+ Tests for ``ExternalDatasetService.get_external_knowledge_apis``.
+
+ These tests focus on:
+
+ - Basic pagination wiring via ``db.paginate``.
+ - Optional search keyword behaviour.
+ """
+
+ @pytest.fixture
+ def mock_db_paginate(self):
+ """
+ Patch ``db.paginate`` so we do not touch the real database layer.
+ """
+
+ with (
+ patch("services.external_knowledge_service.db.paginate") as mock_paginate,
+ patch("services.external_knowledge_service.select"),
+ ):
+ yield mock_paginate
+
+ def test_get_external_knowledge_apis_basic_pagination(self, mock_db_paginate: MagicMock):
+ """
+ It should return ``items`` and ``total`` coming from the paginate object.
+ """
+
+ # Arrange
+ tenant_id = "tenant-1"
+ page = 1
+ per_page = 20
+
+ mock_items = [Mock(spec=ExternalKnowledgeApis), Mock(spec=ExternalKnowledgeApis)]
+ mock_pagination = SimpleNamespace(items=mock_items, total=42)
+ mock_db_paginate.return_value = mock_pagination
+
+ # Act
+ items, total = ExternalDatasetService.get_external_knowledge_apis(page, per_page, tenant_id)
+
+ # Assert
+ assert items is mock_items
+ assert total == 42
+
+ mock_db_paginate.assert_called_once()
+ call_kwargs = mock_db_paginate.call_args.kwargs
+ assert call_kwargs["page"] == page
+ assert call_kwargs["per_page"] == per_page
+ assert call_kwargs["max_per_page"] == 100
+ assert call_kwargs["error_out"] is False
+
+ def test_get_external_knowledge_apis_with_search_keyword(self, mock_db_paginate: MagicMock):
+ """
+ When a search keyword is provided, the query should be adjusted
+ (we simply assert that paginate is still called and does not explode).
+ """
+
+ # Arrange
+ tenant_id = "tenant-1"
+ page = 2
+ per_page = 10
+ search = "foo"
+
+ mock_pagination = SimpleNamespace(items=[], total=0)
+ mock_db_paginate.return_value = mock_pagination
+
+ # Act
+ items, total = ExternalDatasetService.get_external_knowledge_apis(page, per_page, tenant_id, search=search)
+
+ # Assert
+ assert items == []
+ assert total == 0
+ mock_db_paginate.assert_called_once()
+
+
+# ---------------------------------------------------------------------------
+# validate_api_list
+# ---------------------------------------------------------------------------
+
+
+class TestExternalDatasetServiceValidateApiList:
+ """
+ Lightweight validation tests for ``validate_api_list``.
+ """
+
+ def test_validate_api_list_success(self):
+ """
+ A minimal valid configuration (endpoint + api_key) should pass.
+ """
+
+ config = {"endpoint": "https://example.com", "api_key": "secret"}
+
+ # Act & Assert – no exception expected
+ ExternalDatasetService.validate_api_list(config)
+
+ @pytest.mark.parametrize(
+ ("config", "expected_message"),
+ [
+ ({}, "api list is empty"),
+ ({"api_key": "k"}, "endpoint is required"),
+ ({"endpoint": "https://example.com"}, "api_key is required"),
+ ],
+ )
+ def test_validate_api_list_failures(self, config: dict, expected_message: str):
+ """
+ Invalid configs should raise ``ValueError`` with a clear message.
+ """
+
+ with pytest.raises(ValueError, match=expected_message):
+ ExternalDatasetService.validate_api_list(config)
+
+
+# ---------------------------------------------------------------------------
+# create_external_knowledge_api & get/update/delete
+# ---------------------------------------------------------------------------
+
+
+class TestExternalDatasetServiceCrudExternalKnowledgeApi:
+ """
+ CRUD tests for external knowledge API templates.
+ """
+
+ @pytest.fixture
+ def mock_db_session(self):
+ """
+ Patch ``db.session`` for all CRUD tests in this class.
+ """
+
+ with patch("services.external_knowledge_service.db.session") as mock_session:
+ yield mock_session
+
+ def test_create_external_knowledge_api_success(self, mock_db_session: MagicMock):
+ """
+ ``create_external_knowledge_api`` should persist a new record
+ when settings are present and valid.
+ """
+
+ tenant_id = "tenant-1"
+ user_id = "user-1"
+ args = {
+ "name": "API",
+ "description": "desc",
+ "settings": {"endpoint": "https://api.example.com", "api_key": "secret"},
+ }
+
+ # We do not want to actually call the remote endpoint here, so we patch the validator.
+ with patch.object(ExternalDatasetService, "check_endpoint_and_api_key") as mock_check:
+ result = ExternalDatasetService.create_external_knowledge_api(tenant_id, user_id, args)
+
+ assert isinstance(result, ExternalKnowledgeApis)
+ mock_check.assert_called_once_with(args["settings"])
+ mock_db_session.add.assert_called_once()
+ mock_db_session.commit.assert_called_once()
+
+ def test_create_external_knowledge_api_missing_settings_raises(self, mock_db_session: MagicMock):
+ """
+ Missing ``settings`` should result in a ``ValueError``.
+ """
+
+ tenant_id = "tenant-1"
+ user_id = "user-1"
+ args = {"name": "API", "description": "desc"}
+
+ with pytest.raises(ValueError, match="settings is required"):
+ ExternalDatasetService.create_external_knowledge_api(tenant_id, user_id, args)
+
+ mock_db_session.add.assert_not_called()
+ mock_db_session.commit.assert_not_called()
+
+ def test_get_external_knowledge_api_found(self, mock_db_session: MagicMock):
+ """
+ ``get_external_knowledge_api`` should return the first matching record.
+ """
+
+ api = Mock(spec=ExternalKnowledgeApis)
+ mock_db_session.query.return_value.filter_by.return_value.first.return_value = api
+
+ result = ExternalDatasetService.get_external_knowledge_api("api-id")
+ assert result is api
+
+ def test_get_external_knowledge_api_not_found_raises(self, mock_db_session: MagicMock):
+ """
+ When the record is absent, a ``ValueError`` is raised.
+ """
+
+ mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
+
+ with pytest.raises(ValueError, match="api template not found"):
+ ExternalDatasetService.get_external_knowledge_api("missing-id")
+
+ def test_update_external_knowledge_api_success_with_hidden_api_key(self, mock_db_session: MagicMock):
+ """
+ Updating an API should keep the existing API key when the special hidden
+ value placeholder is sent from the UI.
+ """
+
+ tenant_id = "tenant-1"
+ user_id = "user-1"
+ api_id = "api-1"
+
+ existing_api = Mock(spec=ExternalKnowledgeApis)
+ existing_api.settings_dict = {"api_key": "stored-key"}
+ existing_api.settings = '{"api_key":"stored-key"}'
+ mock_db_session.query.return_value.filter_by.return_value.first.return_value = existing_api
+
+ args = {
+ "name": "New Name",
+ "description": "New Desc",
+ "settings": {"endpoint": "https://api.example.com", "api_key": HIDDEN_VALUE},
+ }
+
+ result = ExternalDatasetService.update_external_knowledge_api(tenant_id, user_id, api_id, args)
+
+ assert result is existing_api
+ # The placeholder should be replaced with stored key.
+ assert args["settings"]["api_key"] == "stored-key"
+ mock_db_session.commit.assert_called_once()
+
+ def test_update_external_knowledge_api_not_found_raises(self, mock_db_session: MagicMock):
+ """
+ Updating a non‑existent API template should raise ``ValueError``.
+ """
+
+ mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
+
+ with pytest.raises(ValueError, match="api template not found"):
+ ExternalDatasetService.update_external_knowledge_api(
+ tenant_id="tenant-1",
+ user_id="user-1",
+ external_knowledge_api_id="missing-id",
+ args={"name": "n", "description": "d", "settings": {}},
+ )
+
+ def test_delete_external_knowledge_api_success(self, mock_db_session: MagicMock):
+ """
+ ``delete_external_knowledge_api`` should delete and commit when found.
+ """
+
+ api = Mock(spec=ExternalKnowledgeApis)
+ mock_db_session.query.return_value.filter_by.return_value.first.return_value = api
+
+ ExternalDatasetService.delete_external_knowledge_api("tenant-1", "api-1")
+
+ mock_db_session.delete.assert_called_once_with(api)
+ mock_db_session.commit.assert_called_once()
+
+ def test_delete_external_knowledge_api_not_found_raises(self, mock_db_session: MagicMock):
+ """
+ Deletion of a missing template should raise ``ValueError``.
+ """
+
+ mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
+
+ with pytest.raises(ValueError, match="api template not found"):
+ ExternalDatasetService.delete_external_knowledge_api("tenant-1", "missing")
+
+
+# ---------------------------------------------------------------------------
+# external_knowledge_api_use_check & binding lookups
+# ---------------------------------------------------------------------------
+
+
+class TestExternalDatasetServiceUsageAndBindings:
+ """
+ Tests for usage checks and dataset binding retrieval.
+ """
+
+ @pytest.fixture
+ def mock_db_session(self):
+ with patch("services.external_knowledge_service.db.session") as mock_session:
+ yield mock_session
+
+ def test_external_knowledge_api_use_check_in_use(self, mock_db_session: MagicMock):
+ """
+ When there are bindings, ``external_knowledge_api_use_check`` returns True and count.
+ """
+
+ mock_db_session.query.return_value.filter_by.return_value.count.return_value = 3
+
+ in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1")
+
+ assert in_use is True
+ assert count == 3
+
+ def test_external_knowledge_api_use_check_not_in_use(self, mock_db_session: MagicMock):
+ """
+ Zero bindings should return ``(False, 0)``.
+ """
+
+ mock_db_session.query.return_value.filter_by.return_value.count.return_value = 0
+
+ in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1")
+
+ assert in_use is False
+ assert count == 0
+
+ def test_get_external_knowledge_binding_with_dataset_id_found(self, mock_db_session: MagicMock):
+ """
+ Binding lookup should return the first record when present.
+ """
+
+ binding = Mock(spec=ExternalKnowledgeBindings)
+ mock_db_session.query.return_value.filter_by.return_value.first.return_value = binding
+
+ result = ExternalDatasetService.get_external_knowledge_binding_with_dataset_id("tenant-1", "ds-1")
+ assert result is binding
+
+ def test_get_external_knowledge_binding_with_dataset_id_not_found_raises(self, mock_db_session: MagicMock):
+ """
+ Missing binding should result in a ``ValueError``.
+ """
+
+ mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
+
+ with pytest.raises(ValueError, match="external knowledge binding not found"):
+ ExternalDatasetService.get_external_knowledge_binding_with_dataset_id("tenant-1", "ds-1")
+
+
+# ---------------------------------------------------------------------------
+# document_create_args_validate
+# ---------------------------------------------------------------------------
+
+
+class TestExternalDatasetServiceDocumentCreateArgsValidate:
+ """
+ Tests for ``document_create_args_validate``.
+ """
+
+ @pytest.fixture
+ def mock_db_session(self):
+ with patch("services.external_knowledge_service.db.session") as mock_session:
+ yield mock_session
+
+ def test_document_create_args_validate_success(self, mock_db_session: MagicMock):
+ """
+ All required custom parameters present – validation should pass.
+ """
+
+ external_api = Mock(spec=ExternalKnowledgeApis)
+ external_api.settings = json_settings = (
+ '[{"document_process_setting":[{"name":"foo","required":true},{"name":"bar","required":false}]}]'
+ )
+ # Raw string; the service itself calls json.loads on it
+ mock_db_session.query.return_value.filter_by.return_value.first.return_value = external_api
+
+ process_parameter = {"foo": "value", "bar": "optional"}
+
+ # Act & Assert – no exception
+ ExternalDatasetService.document_create_args_validate("tenant-1", "api-1", process_parameter)
+
+ assert json_settings in external_api.settings # simple sanity check on our test data
+
+ def test_document_create_args_validate_missing_template_raises(self, mock_db_session: MagicMock):
+ """
+ When the referenced API template is missing, a ``ValueError`` is raised.
+ """
+
+ mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
+
+ with pytest.raises(ValueError, match="api template not found"):
+ ExternalDatasetService.document_create_args_validate("tenant-1", "missing", {})
+
+ def test_document_create_args_validate_missing_required_parameter_raises(self, mock_db_session: MagicMock):
+ """
+ Required document process parameters must be supplied.
+ """
+
+ external_api = Mock(spec=ExternalKnowledgeApis)
+ external_api.settings = (
+ '[{"document_process_setting":[{"name":"foo","required":true},{"name":"bar","required":false}]}]'
+ )
+ mock_db_session.query.return_value.filter_by.return_value.first.return_value = external_api
+
+ process_parameter = {"bar": "present"} # missing "foo"
+
+ with pytest.raises(ValueError, match="foo is required"):
+ ExternalDatasetService.document_create_args_validate("tenant-1", "api-1", process_parameter)
+
+
+# ---------------------------------------------------------------------------
+# process_external_api
+# ---------------------------------------------------------------------------
+
+
+class TestExternalDatasetServiceProcessExternalApi:
+ """
+ Tests focused on the HTTP request assembly and method mapping behaviour.
+ """
+
+ def test_process_external_api_valid_method_post(self):
+ """
+ For a supported HTTP verb we should delegate to the correct ``ssrf_proxy`` function.
+ """
+
+ settings = ExternalKnowledgeApiSetting(
+ url="https://example.com/path",
+ request_method="POST",
+ headers={"X-Test": "1"},
+ params={"foo": "bar"},
+ )
+
+ fake_response = httpx.Response(200)
+
+ with patch("services.external_knowledge_service.ssrf_proxy.post") as mock_post:
+ mock_post.return_value = fake_response
+
+ result = ExternalDatasetService.process_external_api(settings, files=None)
+
+ assert result is fake_response
+ mock_post.assert_called_once()
+ kwargs = mock_post.call_args.kwargs
+ assert kwargs["url"] == settings.url
+ assert kwargs["headers"] == settings.headers
+ assert kwargs["follow_redirects"] is True
+ assert "data" in kwargs
+
+ def test_process_external_api_invalid_method_raises(self):
+ """
+ An unsupported HTTP verb should raise ``InvalidHttpMethodError``.
+ """
+
+ settings = ExternalKnowledgeApiSetting(
+ url="https://example.com",
+ request_method="INVALID",
+ headers=None,
+ params={},
+ )
+
+ from core.workflow.nodes.http_request.exc import InvalidHttpMethodError
+
+ with pytest.raises(InvalidHttpMethodError):
+ ExternalDatasetService.process_external_api(settings, files=None)
+
+
+# ---------------------------------------------------------------------------
+# assembling_headers
+# ---------------------------------------------------------------------------
+
+
+class TestExternalDatasetServiceAssemblingHeaders:
+ """
+ Tests for header assembly based on different authentication flavours.
+ """
+
+ def test_assembling_headers_bearer_token(self):
+ """
+ For bearer auth we expect ``Authorization: Bearer `` by default.
+ """
+
+ auth = Authorization(
+ type="api-key",
+ config=AuthorizationConfig(type="bearer", api_key="secret", header=None),
+ )
+
+ headers = ExternalDatasetService.assembling_headers(auth)
+
+ assert headers["Authorization"] == "Bearer secret"
+
+ def test_assembling_headers_basic_token_with_custom_header(self):
+ """
+ For basic auth we honour the configured header name.
+ """
+
+ auth = Authorization(
+ type="api-key",
+ config=AuthorizationConfig(type="basic", api_key="abc123", header="X-Auth"),
+ )
+
+ headers = ExternalDatasetService.assembling_headers(auth, headers={"Existing": "1"})
+
+ assert headers["Existing"] == "1"
+ assert headers["X-Auth"] == "Basic abc123"
+
+ def test_assembling_headers_custom_type(self):
+ """
+ Custom auth type should inject the raw API key.
+ """
+
+ auth = Authorization(
+ type="api-key",
+ config=AuthorizationConfig(type="custom", api_key="raw-key", header="X-API-KEY"),
+ )
+
+ headers = ExternalDatasetService.assembling_headers(auth, headers=None)
+
+ assert headers["X-API-KEY"] == "raw-key"
+
+ def test_assembling_headers_missing_config_raises(self):
+ """
+ Missing config object should be rejected.
+ """
+
+ auth = Authorization(type="api-key", config=None)
+
+ with pytest.raises(ValueError, match="authorization config is required"):
+ ExternalDatasetService.assembling_headers(auth)
+
+ def test_assembling_headers_missing_api_key_raises(self):
+ """
+ ``api_key`` is required when type is ``api-key``.
+ """
+
+ auth = Authorization(
+ type="api-key",
+ config=AuthorizationConfig(type="bearer", api_key=None, header="Authorization"),
+ )
+
+ with pytest.raises(ValueError, match="api_key is required"):
+ ExternalDatasetService.assembling_headers(auth)
+
+ def test_assembling_headers_no_auth_type_leaves_headers_unchanged(self):
+ """
+ For ``no-auth`` we should not modify the headers mapping.
+ """
+
+ auth = Authorization(type="no-auth", config=None)
+
+ base_headers = {"X": "1"}
+ result = ExternalDatasetService.assembling_headers(auth, headers=base_headers)
+
+ # A copy is returned, original is not mutated.
+ assert result == base_headers
+ assert result is not base_headers
+
+
+# ---------------------------------------------------------------------------
+# get_external_knowledge_api_settings
+# ---------------------------------------------------------------------------
+
+
+class TestExternalDatasetServiceGetExternalKnowledgeApiSettings:
+ """
+ Simple shape test for ``get_external_knowledge_api_settings``.
+ """
+
+ def test_get_external_knowledge_api_settings(self):
+ settings_dict: dict[str, Any] = {
+ "url": "https://example.com/retrieval",
+ "request_method": "post",
+ "headers": {"Content-Type": "application/json"},
+ "params": {"foo": "bar"},
+ }
+
+ result = ExternalDatasetService.get_external_knowledge_api_settings(settings_dict)
+
+ assert isinstance(result, ExternalKnowledgeApiSetting)
+ assert result.url == settings_dict["url"]
+ assert result.request_method == settings_dict["request_method"]
+ assert result.headers == settings_dict["headers"]
+ assert result.params == settings_dict["params"]
+
+
+# ---------------------------------------------------------------------------
+# create_external_dataset
+# ---------------------------------------------------------------------------
+
+
+class TestExternalDatasetServiceCreateExternalDataset:
+ """
+ Tests around creating the external dataset and its binding row.
+ """
+
+ @pytest.fixture
+ def mock_db_session(self):
+ with patch("services.external_knowledge_service.db.session") as mock_session:
+ yield mock_session
+
+ def test_create_external_dataset_success(self, mock_db_session: MagicMock):
+ """
+ A brand new dataset name with valid external knowledge references
+ should create both the dataset and its binding.
+ """
+
+ tenant_id = "tenant-1"
+ user_id = "user-1"
+
+ args = {
+ "name": "My Dataset",
+ "description": "desc",
+ "external_knowledge_api_id": "api-1",
+ "external_knowledge_id": "knowledge-1",
+ "external_retrieval_model": {"top_k": 3},
+ }
+
+ # No existing dataset with same name.
+ mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
+ None, # duplicate‑name check
+ Mock(spec=ExternalKnowledgeApis), # external knowledge api
+ ]
+
+ dataset = ExternalDatasetService.create_external_dataset(tenant_id, user_id, args)
+
+ assert isinstance(dataset, Dataset)
+ assert dataset.provider == "external"
+ assert dataset.retrieval_model == args["external_retrieval_model"]
+
+ assert mock_db_session.add.call_count >= 2 # dataset + binding
+ mock_db_session.flush.assert_called_once()
+ mock_db_session.commit.assert_called_once()
+
+ def test_create_external_dataset_duplicate_name_raises(self, mock_db_session: MagicMock):
+ """
+ When a dataset with the same name already exists,
+ ``DatasetNameDuplicateError`` is raised.
+ """
+
+ existing_dataset = Mock(spec=Dataset)
+ mock_db_session.query.return_value.filter_by.return_value.first.return_value = existing_dataset
+
+ args = {
+ "name": "Existing",
+ "external_knowledge_api_id": "api-1",
+ "external_knowledge_id": "knowledge-1",
+ }
+
+ with pytest.raises(DatasetNameDuplicateError):
+ ExternalDatasetService.create_external_dataset("tenant-1", "user-1", args)
+
+ mock_db_session.add.assert_not_called()
+ mock_db_session.commit.assert_not_called()
+
+ def test_create_external_dataset_missing_api_template_raises(self, mock_db_session: MagicMock):
+ """
+ If the referenced external knowledge API does not exist, a ``ValueError`` is raised.
+ """
+
+ # First call: duplicate name check – not found.
+ mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
+ None,
+ None, # external knowledge api lookup
+ ]
+
+ args = {
+ "name": "Dataset",
+ "external_knowledge_api_id": "missing",
+ "external_knowledge_id": "knowledge-1",
+ }
+
+ with pytest.raises(ValueError, match="api template not found"):
+ ExternalDatasetService.create_external_dataset("tenant-1", "user-1", args)
+
+ def test_create_external_dataset_missing_required_ids_raise(self, mock_db_session: MagicMock):
+ """
+ ``external_knowledge_id`` and ``external_knowledge_api_id`` are mandatory.
+ """
+
+ # duplicate name check
+ mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
+ None,
+ Mock(spec=ExternalKnowledgeApis),
+ ]
+
+ args_missing_knowledge_id = {
+ "name": "Dataset",
+ "external_knowledge_api_id": "api-1",
+ "external_knowledge_id": None,
+ }
+
+ with pytest.raises(ValueError, match="external_knowledge_id is required"):
+ ExternalDatasetService.create_external_dataset("tenant-1", "user-1", args_missing_knowledge_id)
+
+ args_missing_api_id = {
+ "name": "Dataset",
+ "external_knowledge_api_id": None,
+ "external_knowledge_id": "k-1",
+ }
+
+ with pytest.raises(ValueError, match="external_knowledge_api_id is required"):
+ ExternalDatasetService.create_external_dataset("tenant-1", "user-1", args_missing_api_id)
+
+
+# ---------------------------------------------------------------------------
+# fetch_external_knowledge_retrieval
+# ---------------------------------------------------------------------------
+
+
+class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
+ """
+ Tests for ``fetch_external_knowledge_retrieval`` which orchestrates
+ external retrieval requests and normalises the response payload.
+ """
+
+ @pytest.fixture
+ def mock_db_session(self):
+ with patch("services.external_knowledge_service.db.session") as mock_session:
+ yield mock_session
+
+ def test_fetch_external_knowledge_retrieval_success(self, mock_db_session: MagicMock):
+ """
+ With a valid binding and API template, records from the external
+ service should be returned when the HTTP response is 200.
+ """
+
+ tenant_id = "tenant-1"
+ dataset_id = "ds-1"
+ query = "test query"
+ external_retrieval_parameters = {"top_k": 3, "score_threshold_enabled": True, "score_threshold": 0.5}
+
+ binding = ExternalDatasetTestDataFactory.create_external_binding(
+ tenant_id=tenant_id,
+ dataset_id=dataset_id,
+ api_id="api-1",
+ external_knowledge_id="knowledge-1",
+ )
+
+ api = Mock(spec=ExternalKnowledgeApis)
+ api.settings = '{"endpoint":"https://example.com","api_key":"secret"}'
+
+ # First query: binding; second query: api.
+ mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
+ binding,
+ api,
+ ]
+
+ fake_records = [{"content": "doc", "score": 0.9}]
+ fake_response = Mock(spec=httpx.Response)
+ fake_response.status_code = 200
+ fake_response.json.return_value = {"records": fake_records}
+
+ metadata_condition = SimpleNamespace(model_dump=lambda: {"field": "value"})
+
+ with patch.object(ExternalDatasetService, "process_external_api", return_value=fake_response) as mock_process:
+ result = ExternalDatasetService.fetch_external_knowledge_retrieval(
+ tenant_id=tenant_id,
+ dataset_id=dataset_id,
+ query=query,
+ external_retrieval_parameters=external_retrieval_parameters,
+ metadata_condition=metadata_condition,
+ )
+
+ assert result == fake_records
+
+ mock_process.assert_called_once()
+ setting_arg = mock_process.call_args.args[0]
+ assert isinstance(setting_arg, ExternalKnowledgeApiSetting)
+ assert setting_arg.url.endswith("/retrieval")
+
+ def test_fetch_external_knowledge_retrieval_binding_not_found_raises(self, mock_db_session: MagicMock):
+ """
+ Missing binding should raise ``ValueError``.
+ """
+
+ mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
+
+ with pytest.raises(ValueError, match="external knowledge binding not found"):
+ ExternalDatasetService.fetch_external_knowledge_retrieval(
+ tenant_id="tenant-1",
+ dataset_id="missing",
+ query="q",
+ external_retrieval_parameters={},
+ metadata_condition=None,
+ )
+
+ def test_fetch_external_knowledge_retrieval_missing_api_template_raises(self, mock_db_session: MagicMock):
+ """
+ When the API template is missing or has no settings, a ``ValueError`` is raised.
+ """
+
+ binding = ExternalDatasetTestDataFactory.create_external_binding()
+ mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
+ binding,
+ None,
+ ]
+
+ with pytest.raises(ValueError, match="external api template not found"):
+ ExternalDatasetService.fetch_external_knowledge_retrieval(
+ tenant_id="tenant-1",
+ dataset_id="ds-1",
+ query="q",
+ external_retrieval_parameters={},
+ metadata_condition=None,
+ )
+
+ def test_fetch_external_knowledge_retrieval_non_200_status_returns_empty_list(self, mock_db_session: MagicMock):
+ """
+ Non‑200 responses should be treated as an empty result set.
+ """
+
+ binding = ExternalDatasetTestDataFactory.create_external_binding()
+ api = Mock(spec=ExternalKnowledgeApis)
+ api.settings = '{"endpoint":"https://example.com","api_key":"secret"}'
+
+ mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
+ binding,
+ api,
+ ]
+
+ fake_response = Mock(spec=httpx.Response)
+ fake_response.status_code = 500
+ fake_response.json.return_value = {}
+
+ with patch.object(ExternalDatasetService, "process_external_api", return_value=fake_response):
+ result = ExternalDatasetService.fetch_external_knowledge_retrieval(
+ tenant_id="tenant-1",
+ dataset_id="ds-1",
+ query="q",
+ external_retrieval_parameters={},
+ metadata_condition=None,
+ )
+
+ assert result == []
From 38522e5dfa38831d44655faef068a525852f7ea2 Mon Sep 17 00:00:00 2001
From: -LAN-
Date: Thu, 27 Nov 2025 08:39:49 +0800
Subject: [PATCH 013/431] fix: use default_factory for callable defaults in ORM
dataclasses (#28730)
---
api/models/account.py | 24 +++++--
api/models/api_based_extension.py | 4 +-
api/models/dataset.py | 106 +++++++++++++++++++++++++-----
api/models/model.py | 52 +++++++++++----
api/models/oauth.py | 12 +++-
api/models/provider.py | 40 ++++++++---
api/models/source.py | 8 ++-
api/models/task.py | 7 +-
api/models/tools.py | 44 +++++++++----
api/models/trigger.py | 36 +++++++---
api/models/web.py | 8 ++-
api/models/workflow.py | 4 +-
12 files changed, 269 insertions(+), 76 deletions(-)
diff --git a/api/models/account.py b/api/models/account.py
index b1dafed0ed..420e6adc6c 100644
--- a/api/models/account.py
+++ b/api/models/account.py
@@ -88,7 +88,9 @@ class Account(UserMixin, TypeBase):
__tablename__ = "accounts"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="account_pkey"), sa.Index("account_email_idx", "email"))
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
name: Mapped[str] = mapped_column(String(255))
email: Mapped[str] = mapped_column(String(255))
password: Mapped[str | None] = mapped_column(String(255), default=None)
@@ -235,7 +237,9 @@ class Tenant(TypeBase):
__tablename__ = "tenants"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="tenant_pkey"),)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
name: Mapped[str] = mapped_column(String(255))
encrypt_public_key: Mapped[str | None] = mapped_column(LongText, default=None)
plan: Mapped[str] = mapped_column(String(255), server_default=sa.text("'basic'"), default="basic")
@@ -275,7 +279,9 @@ class TenantAccountJoin(TypeBase):
sa.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID)
account_id: Mapped[str] = mapped_column(StringUUID)
current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"), default=False)
@@ -297,7 +303,9 @@ class AccountIntegrate(TypeBase):
sa.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
account_id: Mapped[str] = mapped_column(StringUUID)
provider: Mapped[str] = mapped_column(String(16))
open_id: Mapped[str] = mapped_column(String(255))
@@ -348,7 +356,9 @@ class TenantPluginPermission(TypeBase):
sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
install_permission: Mapped[InstallPermission] = mapped_column(
String(16), nullable=False, server_default="everyone", default=InstallPermission.EVERYONE
@@ -375,7 +385,9 @@ class TenantPluginAutoUpgradeStrategy(TypeBase):
sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin_auto_upgrade_strategy"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
strategy_setting: Mapped[StrategySetting] = mapped_column(
String(16), nullable=False, server_default="fix_only", default=StrategySetting.FIX_ONLY
diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py
index 99d33908f8..b5acab5a75 100644
--- a/api/models/api_based_extension.py
+++ b/api/models/api_based_extension.py
@@ -24,7 +24,9 @@ class APIBasedExtension(TypeBase):
sa.Index("api_based_extension_tenant_idx", "tenant_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
api_endpoint: Mapped[str] = mapped_column(String(255), nullable=False)
diff --git a/api/models/dataset.py b/api/models/dataset.py
index 2ea6d98b5f..e072711b82 100644
--- a/api/models/dataset.py
+++ b/api/models/dataset.py
@@ -920,7 +920,12 @@ class AppDatasetJoin(TypeBase):
)
id: Mapped[str] = mapped_column(
- StringUUID, primary_key=True, nullable=False, default=lambda: str(uuid4()), init=False
+ StringUUID,
+ primary_key=True,
+ nullable=False,
+ insert_default=lambda: str(uuid4()),
+ default_factory=lambda: str(uuid4()),
+ init=False,
)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@@ -941,7 +946,12 @@ class DatasetQuery(TypeBase):
)
id: Mapped[str] = mapped_column(
- StringUUID, primary_key=True, nullable=False, default=lambda: str(uuid4()), init=False
+ StringUUID,
+ primary_key=True,
+ nullable=False,
+ insert_default=lambda: str(uuid4()),
+ default_factory=lambda: str(uuid4()),
+ init=False,
)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
content: Mapped[str] = mapped_column(LongText, nullable=False)
@@ -961,7 +971,13 @@ class DatasetKeywordTable(TypeBase):
sa.Index("dataset_keyword_table_dataset_id_idx", "dataset_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID,
+ primary_key=True,
+ insert_default=lambda: str(uuid4()),
+ default_factory=lambda: str(uuid4()),
+ init=False,
+ )
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False, unique=True)
keyword_table: Mapped[str] = mapped_column(LongText, nullable=False)
data_source_type: Mapped[str] = mapped_column(
@@ -1012,7 +1028,13 @@ class Embedding(TypeBase):
sa.Index("created_at_idx", "created_at"),
)
- id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID,
+ primary_key=True,
+ insert_default=lambda: str(uuid4()),
+ default_factory=lambda: str(uuid4()),
+ init=False,
+ )
model_name: Mapped[str] = mapped_column(
String(255), nullable=False, server_default=sa.text("'text-embedding-ada-002'")
)
@@ -1037,7 +1059,13 @@ class DatasetCollectionBinding(TypeBase):
sa.Index("provider_model_name_idx", "provider_name", "model_name"),
)
- id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID,
+ primary_key=True,
+ insert_default=lambda: str(uuid4()),
+ default_factory=lambda: str(uuid4()),
+ init=False,
+ )
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
type: Mapped[str] = mapped_column(String(40), server_default=sa.text("'dataset'"), nullable=False)
@@ -1073,7 +1101,13 @@ class Whitelist(TypeBase):
sa.PrimaryKeyConstraint("id", name="whitelists_pkey"),
sa.Index("whitelists_tenant_idx", "tenant_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID,
+ primary_key=True,
+ insert_default=lambda: str(uuid4()),
+ default_factory=lambda: str(uuid4()),
+ init=False,
+ )
tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
category: Mapped[str] = mapped_column(String(255), nullable=False)
created_at: Mapped[datetime] = mapped_column(
@@ -1090,7 +1124,13 @@ class DatasetPermission(TypeBase):
sa.Index("idx_dataset_permissions_tenant_id", "tenant_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), primary_key=True, init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID,
+ insert_default=lambda: str(uuid4()),
+ default_factory=lambda: str(uuid4()),
+ primary_key=True,
+ init=False,
+ )
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@@ -1110,7 +1150,13 @@ class ExternalKnowledgeApis(TypeBase):
sa.Index("external_knowledge_apis_name_idx", "name"),
)
- id: Mapped[str] = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID,
+ nullable=False,
+ insert_default=lambda: str(uuid4()),
+ default_factory=lambda: str(uuid4()),
+ init=False,
+ )
name: Mapped[str] = mapped_column(String(255), nullable=False)
description: Mapped[str] = mapped_column(String(255), nullable=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@@ -1167,7 +1213,13 @@ class ExternalKnowledgeBindings(TypeBase):
sa.Index("external_knowledge_bindings_external_knowledge_api_idx", "external_knowledge_api_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID,
+ nullable=False,
+ insert_default=lambda: str(uuid4()),
+ default_factory=lambda: str(uuid4()),
+ init=False,
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
external_knowledge_api_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@@ -1191,7 +1243,9 @@ class DatasetAutoDisableLog(TypeBase):
sa.Index("dataset_auto_disable_log_created_atx", "created_at"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@@ -1209,7 +1263,9 @@ class RateLimitLog(TypeBase):
sa.Index("rate_limit_log_operation_idx", "operation"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
subscription_plan: Mapped[str] = mapped_column(String(255), nullable=False)
operation: Mapped[str] = mapped_column(String(255), nullable=False)
@@ -1226,7 +1282,9 @@ class DatasetMetadata(TypeBase):
sa.Index("dataset_metadata_dataset_idx", "dataset_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(String(255), nullable=False)
@@ -1255,7 +1313,9 @@ class DatasetMetadataBinding(TypeBase):
sa.Index("dataset_metadata_binding_document_idx", "document_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
metadata_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@@ -1270,7 +1330,9 @@ class PipelineBuiltInTemplate(TypeBase):
__tablename__ = "pipeline_built_in_templates"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"),)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
+ )
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
description: Mapped[str] = mapped_column(LongText, nullable=False)
chunk_structure: Mapped[str] = mapped_column(sa.String(255), nullable=False)
@@ -1300,7 +1362,9 @@ class PipelineCustomizedTemplate(TypeBase):
sa.Index("pipeline_customized_template_tenant_idx", "tenant_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
description: Mapped[str] = mapped_column(LongText, nullable=False)
@@ -1335,7 +1399,9 @@ class Pipeline(TypeBase):
__tablename__ = "pipelines"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="pipeline_pkey"),)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
description: Mapped[str] = mapped_column(LongText, nullable=False, default=sa.text("''"))
@@ -1368,7 +1434,9 @@ class DocumentPipelineExecutionLog(TypeBase):
sa.Index("document_pipeline_execution_logs_document_id_idx", "document_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
+ )
pipeline_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
datasource_type: Mapped[str] = mapped_column(sa.String(255), nullable=False)
@@ -1385,7 +1453,9 @@ class PipelineRecommendedPlugin(TypeBase):
__tablename__ = "pipeline_recommended_plugins"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="pipeline_recommended_plugin_pkey"),)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
+ )
plugin_id: Mapped[str] = mapped_column(LongText, nullable=False)
provider_name: Mapped[str] = mapped_column(LongText, nullable=False)
position: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
diff --git a/api/models/model.py b/api/models/model.py
index 33a94628f0..1731ff5699 100644
--- a/api/models/model.py
+++ b/api/models/model.py
@@ -572,7 +572,9 @@ class InstalledApp(TypeBase):
sa.UniqueConstraint("tenant_id", "app_id", name="unique_tenant_app"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_owner_tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@@ -606,7 +608,9 @@ class OAuthProviderApp(TypeBase):
sa.Index("oauth_provider_app_client_id_idx", "client_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
+ )
app_icon: Mapped[str] = mapped_column(String(255), nullable=False)
client_id: Mapped[str] = mapped_column(String(255), nullable=False)
client_secret: Mapped[str] = mapped_column(String(255), nullable=False)
@@ -1311,7 +1315,9 @@ class MessageFeedback(TypeBase):
sa.Index("message_feedback_conversation_idx", "conversation_id", "from_source", "rating"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@@ -1360,7 +1366,9 @@ class MessageFile(TypeBase):
sa.Index("message_file_created_by_idx", "created_by"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(String(255), nullable=False)
transfer_method: Mapped[FileTransferMethod] = mapped_column(String(255), nullable=False)
@@ -1452,7 +1460,9 @@ class AppAnnotationSetting(TypeBase):
sa.Index("app_annotation_settings_app_idx", "app_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
score_threshold: Mapped[float] = mapped_column(Float, nullable=False, server_default=sa.text("0"))
collection_binding_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@@ -1488,7 +1498,9 @@ class OperationLog(TypeBase):
sa.Index("operation_log_account_action_idx", "tenant_id", "account_id", "action"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
action: Mapped[str] = mapped_column(String(255), nullable=False)
@@ -1554,7 +1566,9 @@ class AppMCPServer(TypeBase):
sa.UniqueConstraint("tenant_id", "app_id", name="unique_app_mcp_server_tenant_app_id"),
sa.UniqueConstraint("server_code", name="unique_app_mcp_server_server_code"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
@@ -1764,7 +1778,9 @@ class ApiRequest(TypeBase):
sa.Index("api_request_token_idx", "tenant_id", "api_token_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
api_token_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
path: Mapped[str] = mapped_column(String(255), nullable=False)
@@ -1783,7 +1799,9 @@ class MessageChain(TypeBase):
sa.Index("message_chain_message_id_idx", "message_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(String(255), nullable=False)
input: Mapped[str | None] = mapped_column(LongText, nullable=True)
@@ -1914,7 +1932,9 @@ class DatasetRetrieverResource(TypeBase):
sa.Index("dataset_retriever_resource_message_id_idx", "message_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@@ -1946,7 +1966,9 @@ class Tag(TypeBase):
TAG_TYPE_LIST = ["knowledge", "app"]
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
type: Mapped[str] = mapped_column(String(16), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
@@ -1964,7 +1986,9 @@ class TagBinding(TypeBase):
sa.Index("tag_bind_tag_id_idx", "tag_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
tag_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
target_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
@@ -1981,7 +2005,9 @@ class TraceAppConfig(TypeBase):
sa.Index("trace_app_config_app_id_idx", "app_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
tracing_provider: Mapped[str | None] = mapped_column(String(255), nullable=True)
tracing_config: Mapped[dict | None] = mapped_column(sa.JSON, nullable=True)
diff --git a/api/models/oauth.py b/api/models/oauth.py
index 2fce67c998..1db2552469 100644
--- a/api/models/oauth.py
+++ b/api/models/oauth.py
@@ -17,7 +17,9 @@ class DatasourceOauthParamConfig(TypeBase):
sa.UniqueConstraint("plugin_id", "provider", name="datasource_oauth_config_datasource_id_provider_idx"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
+ )
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
provider: Mapped[str] = mapped_column(sa.String(255), nullable=False)
system_credentials: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False)
@@ -30,7 +32,9 @@ class DatasourceProvider(TypeBase):
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", "name", name="datasource_provider_unique_name"),
sa.Index("datasource_provider_auth_type_provider_idx", "tenant_id", "plugin_id", "provider"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
provider: Mapped[str] = mapped_column(sa.String(128), nullable=False)
@@ -60,7 +64,9 @@ class DatasourceOauthTenantParamConfig(TypeBase):
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", name="datasource_oauth_tenant_config_unique"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider: Mapped[str] = mapped_column(sa.String(255), nullable=False)
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
diff --git a/api/models/provider.py b/api/models/provider.py
index 577e098a2e..2afd8c5329 100644
--- a/api/models/provider.py
+++ b/api/models/provider.py
@@ -58,7 +58,13 @@ class Provider(TypeBase):
),
)
- id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuidv7()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID,
+ primary_key=True,
+ insert_default=lambda: str(uuidv7()),
+ default_factory=lambda: str(uuidv7()),
+ init=False,
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
provider_type: Mapped[str] = mapped_column(
@@ -132,7 +138,9 @@ class ProviderModel(TypeBase):
),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
@@ -173,7 +181,9 @@ class TenantDefaultModel(TypeBase):
sa.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
@@ -193,7 +203,9 @@ class TenantPreferredModelProvider(TypeBase):
sa.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
preferred_provider_type: Mapped[str] = mapped_column(String(40), nullable=False)
@@ -212,7 +224,9 @@ class ProviderOrder(TypeBase):
sa.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@@ -245,7 +259,9 @@ class ProviderModelSetting(TypeBase):
sa.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
@@ -273,7 +289,9 @@ class LoadBalancingModelConfig(TypeBase):
sa.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
@@ -302,7 +320,9 @@ class ProviderCredential(TypeBase):
sa.Index("provider_credential_tenant_provider_idx", "tenant_id", "provider_name"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
credential_name: Mapped[str] = mapped_column(String(255), nullable=False)
@@ -332,7 +352,9 @@ class ProviderModelCredential(TypeBase):
),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
diff --git a/api/models/source.py b/api/models/source.py
index f093048c00..a8addbe342 100644
--- a/api/models/source.py
+++ b/api/models/source.py
@@ -18,7 +18,9 @@ class DataSourceOauthBinding(TypeBase):
adjusted_json_index("source_info_idx", "source_info"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
access_token: Mapped[str] = mapped_column(String(255), nullable=False)
provider: Mapped[str] = mapped_column(String(255), nullable=False)
@@ -44,7 +46,9 @@ class DataSourceApiKeyAuthBinding(TypeBase):
sa.Index("data_source_api_key_auth_binding_provider_idx", "provider"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
category: Mapped[str] = mapped_column(String(255), nullable=False)
provider: Mapped[str] = mapped_column(String(255), nullable=False)
diff --git a/api/models/task.py b/api/models/task.py
index 539945b251..d98d99ca2c 100644
--- a/api/models/task.py
+++ b/api/models/task.py
@@ -24,7 +24,8 @@ class CeleryTask(TypeBase):
result: Mapped[bytes | None] = mapped_column(BinaryData, nullable=True, default=None)
date_done: Mapped[datetime | None] = mapped_column(
DateTime,
- default=naive_utc_now,
+ insert_default=naive_utc_now,
+ default=None,
onupdate=naive_utc_now,
nullable=True,
)
@@ -47,4 +48,6 @@ class CeleryTaskSet(TypeBase):
)
taskset_id: Mapped[str] = mapped_column(String(155), unique=True)
result: Mapped[bytes | None] = mapped_column(BinaryData, nullable=True, default=None)
- date_done: Mapped[datetime | None] = mapped_column(DateTime, default=naive_utc_now, nullable=True)
+ date_done: Mapped[datetime | None] = mapped_column(
+ DateTime, insert_default=naive_utc_now, default=None, nullable=True
+ )
diff --git a/api/models/tools.py b/api/models/tools.py
index 0a79f95a70..e4f9bcb582 100644
--- a/api/models/tools.py
+++ b/api/models/tools.py
@@ -30,7 +30,9 @@ class ToolOAuthSystemClient(TypeBase):
sa.UniqueConstraint("plugin_id", "provider", name="tool_oauth_system_client_plugin_id_provider_idx"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
plugin_id: Mapped[str] = mapped_column(String(512), nullable=False)
provider: Mapped[str] = mapped_column(String(255), nullable=False)
# oauth params of the tool provider
@@ -45,7 +47,9 @@ class ToolOAuthTenantClient(TypeBase):
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_tenant_client"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
# tenant id
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
plugin_id: Mapped[str] = mapped_column(String(255), nullable=False)
@@ -71,7 +75,9 @@ class BuiltinToolProvider(TypeBase):
)
# id of the tool provider
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
name: Mapped[str] = mapped_column(
String(256),
nullable=False,
@@ -120,7 +126,9 @@ class ApiToolProvider(TypeBase):
sa.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
# name of the api provider
name: Mapped[str] = mapped_column(
String(255),
@@ -192,7 +200,9 @@ class ToolLabelBinding(TypeBase):
sa.UniqueConstraint("tool_id", "label_name", name="unique_tool_label_bind"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
# tool id
tool_id: Mapped[str] = mapped_column(String(64), nullable=False)
# tool type
@@ -213,7 +223,9 @@ class WorkflowToolProvider(TypeBase):
sa.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
# name of the workflow provider
name: Mapped[str] = mapped_column(String(255), nullable=False)
# label of the workflow provider
@@ -279,7 +291,9 @@ class MCPToolProvider(TypeBase):
sa.UniqueConstraint("tenant_id", "server_identifier", name="unique_mcp_provider_server_identifier"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
# name of the mcp provider
name: Mapped[str] = mapped_column(String(40), nullable=False)
# server identifier of the mcp provider
@@ -360,7 +374,9 @@ class ToolModelInvoke(TypeBase):
__tablename__ = "tool_model_invokes"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="tool_model_invoke_pkey"),)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
# who invoke this tool
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# tenant id
@@ -413,7 +429,9 @@ class ToolConversationVariables(TypeBase):
sa.Index("conversation_id_idx", "conversation_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
# conversation user id
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# tenant id
@@ -450,7 +468,9 @@ class ToolFile(TypeBase):
sa.Index("tool_file_conversation_id_idx", "conversation_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
# conversation user id
user_id: Mapped[str] = mapped_column(StringUUID)
# tenant id
@@ -481,7 +501,9 @@ class DeprecatedPublishedAppTool(TypeBase):
sa.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
# id of the app
app_id: Mapped[str] = mapped_column(StringUUID, ForeignKey("apps.id"), nullable=False)
diff --git a/api/models/trigger.py b/api/models/trigger.py
index 088e797f82..87e2a5ccfc 100644
--- a/api/models/trigger.py
+++ b/api/models/trigger.py
@@ -41,7 +41,9 @@ class TriggerSubscription(TypeBase):
UniqueConstraint("tenant_id", "provider_id", "name", name="unique_trigger_provider"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
name: Mapped[str] = mapped_column(String(255), nullable=False, comment="Subscription instance name")
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@@ -111,7 +113,9 @@ class TriggerOAuthSystemClient(TypeBase):
sa.UniqueConstraint("plugin_id", "provider", name="trigger_oauth_system_client_plugin_id_provider_idx"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
plugin_id: Mapped[str] = mapped_column(String(255), nullable=False)
provider: Mapped[str] = mapped_column(String(255), nullable=False)
# oauth params of the trigger provider
@@ -136,7 +140,9 @@ class TriggerOAuthTenantClient(TypeBase):
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_trigger_oauth_tenant_client"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
# tenant id
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
plugin_id: Mapped[str] = mapped_column(String(255), nullable=False)
@@ -202,7 +208,9 @@ class WorkflowTriggerLog(TypeBase):
sa.Index("workflow_trigger_log_workflow_id_idx", "workflow_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@@ -294,7 +302,9 @@ class WorkflowWebhookTrigger(TypeBase):
sa.UniqueConstraint("webhook_id", name="uniq_webhook_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
+ )
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
node_id: Mapped[str] = mapped_column(String(64), nullable=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@@ -351,7 +361,9 @@ class WorkflowPluginTrigger(TypeBase):
sa.UniqueConstraint("app_id", "node_id", name="uniq_app_node_subscription"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
node_id: Mapped[str] = mapped_column(String(64), nullable=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@@ -395,7 +407,9 @@ class AppTrigger(TypeBase):
sa.Index("app_trigger_tenant_app_idx", "tenant_id", "app_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
node_id: Mapped[str | None] = mapped_column(String(64), nullable=False)
@@ -443,7 +457,13 @@ class WorkflowSchedulePlan(TypeBase):
sa.Index("workflow_schedule_plan_next_idx", "next_run_at"),
)
- id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuidv7()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID,
+ primary_key=True,
+ insert_default=lambda: str(uuidv7()),
+ default_factory=lambda: str(uuidv7()),
+ init=False,
+ )
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
node_id: Mapped[str] = mapped_column(String(64), nullable=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
diff --git a/api/models/web.py b/api/models/web.py
index 4f0bf7c7da..b2832aa163 100644
--- a/api/models/web.py
+++ b/api/models/web.py
@@ -18,7 +18,9 @@ class SavedMessage(TypeBase):
sa.Index("saved_message_message_idx", "app_id", "message_id", "created_by_role", "created_by"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False, server_default=sa.text("'end_user'"))
@@ -42,7 +44,9 @@ class PinnedConversation(TypeBase):
sa.Index("pinned_conversation_conversation_idx", "app_id", "conversation_id", "created_by_role", "created_by"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
conversation_id: Mapped[str] = mapped_column(StringUUID)
created_by_role: Mapped[str] = mapped_column(
diff --git a/api/models/workflow.py b/api/models/workflow.py
index 4efa829692..42ee8a1f2b 100644
--- a/api/models/workflow.py
+++ b/api/models/workflow.py
@@ -1103,7 +1103,9 @@ class WorkflowAppLog(TypeBase):
sa.Index("workflow_app_log_workflow_run_id_idx", "workflow_run_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
tenant_id: Mapped[str] = mapped_column(StringUUID)
app_id: Mapped[str] = mapped_column(StringUUID)
workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
From 64babb35e2c6e75808fab81739b83d2aa6fe8821 Mon Sep 17 00:00:00 2001
From: aka James4u
Date: Wed, 26 Nov 2025 17:55:42 -0800
Subject: [PATCH 014/431] feat: Add comprehensive unit tests for
DatasetCollectionBindingService (dataset collection binding methods) (#28724)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
---
.../services/dataset_collection_binding.py | 932 ++++++++++++++++++
1 file changed, 932 insertions(+)
create mode 100644 api/tests/unit_tests/services/dataset_collection_binding.py
diff --git a/api/tests/unit_tests/services/dataset_collection_binding.py b/api/tests/unit_tests/services/dataset_collection_binding.py
new file mode 100644
index 0000000000..2a939a5c1d
--- /dev/null
+++ b/api/tests/unit_tests/services/dataset_collection_binding.py
@@ -0,0 +1,932 @@
+"""
+Comprehensive unit tests for DatasetCollectionBindingService.
+
+This module contains extensive unit tests for the DatasetCollectionBindingService class,
+which handles dataset collection binding operations for vector database collections.
+
+The DatasetCollectionBindingService provides methods for:
+- Retrieving or creating dataset collection bindings by provider, model, and type
+- Retrieving specific collection bindings by ID and type
+- Managing collection bindings for different collection types (dataset, etc.)
+
+Collection bindings are used to map embedding models (provider + model name) to
+specific vector database collections, allowing datasets to share collections when
+they use the same embedding model configuration.
+
+This test suite ensures:
+- Correct retrieval of existing bindings
+- Proper creation of new bindings when they don't exist
+- Accurate filtering by provider, model, and collection type
+- Proper error handling for missing bindings
+- Database transaction handling (add, commit)
+- Collection name generation using Dataset.gen_collection_name_by_id
+
+================================================================================
+ARCHITECTURE OVERVIEW
+================================================================================
+
+The DatasetCollectionBindingService is a critical component in the Dify platform's
+vector database management system. It serves as an abstraction layer between the
+application logic and the underlying vector database collections.
+
+Key Concepts:
+1. Collection Binding: A mapping between an embedding model configuration
+ (provider + model name) and a vector database collection name. This allows
+ multiple datasets to share the same collection when they use identical
+ embedding models, improving resource efficiency.
+
+2. Collection Type: Different types of collections can exist (e.g., "dataset",
+ "custom_type"). This allows for separation of collections based on their
+ intended use case or data structure.
+
+3. Provider and Model: The combination of provider_name (e.g., "openai",
+ "cohere", "huggingface") and model_name (e.g., "text-embedding-ada-002")
+ uniquely identifies an embedding model configuration.
+
+4. Collection Name Generation: When a new binding is created, a unique collection
+ name is generated using Dataset.gen_collection_name_by_id() with a UUID.
+ This ensures each binding has a unique collection identifier.
+
+================================================================================
+TESTING STRATEGY
+================================================================================
+
+This test suite follows a comprehensive testing strategy that covers:
+
+1. Happy Path Scenarios:
+ - Successful retrieval of existing bindings
+ - Successful creation of new bindings
+ - Proper handling of default parameters
+
+2. Edge Cases:
+ - Different collection types
+ - Various provider/model combinations
+ - Default vs explicit parameter usage
+
+3. Error Handling:
+ - Missing bindings (for get_by_id_and_type)
+ - Database query failures
+ - Invalid parameter combinations
+
+4. Database Interaction:
+ - Query construction and execution
+ - Transaction management (add, commit)
+ - Query chaining (where, order_by, first)
+
+5. Mocking Strategy:
+ - Database session mocking
+ - Query builder chain mocking
+ - UUID generation mocking
+ - Collection name generation mocking
+
+================================================================================
+"""
+
+"""
+Import statements for the test module.
+
+This section imports all necessary dependencies for testing the
+DatasetCollectionBindingService, including:
+- unittest.mock for creating mock objects
+- pytest for test framework functionality
+- uuid for UUID generation (used in collection name generation)
+- Models and services from the application codebase
+"""
+
+from unittest.mock import Mock, patch
+
+import pytest
+
+from models.dataset import Dataset, DatasetCollectionBinding
+from services.dataset_service import DatasetCollectionBindingService
+
+# ============================================================================
+# Test Data Factory
+# ============================================================================
+# The Test Data Factory pattern is used here to centralize the creation of
+# test objects and mock instances. This approach provides several benefits:
+#
+# 1. Consistency: All test objects are created using the same factory methods,
+# ensuring consistent structure across all tests.
+#
+# 2. Maintainability: If the structure of DatasetCollectionBinding or Dataset
+# changes, we only need to update the factory methods rather than every
+# individual test.
+#
+# 3. Reusability: Factory methods can be reused across multiple test classes,
+# reducing code duplication.
+#
+# 4. Readability: Tests become more readable when they use descriptive factory
+# method calls instead of complex object construction logic.
+#
+# ============================================================================
+
+
+class DatasetCollectionBindingTestDataFactory:
+ """
+ Factory class for creating test data and mock objects for dataset collection binding tests.
+
+ This factory provides static methods to create mock objects for:
+ - DatasetCollectionBinding instances
+ - Database query results
+ - Collection name generation results
+
+ The factory methods help maintain consistency across tests and reduce
+ code duplication when setting up test scenarios.
+ """
+
+ @staticmethod
+ def create_collection_binding_mock(
+ binding_id: str = "binding-123",
+ provider_name: str = "openai",
+ model_name: str = "text-embedding-ada-002",
+ collection_name: str = "collection-abc",
+ collection_type: str = "dataset",
+ created_at=None,
+ **kwargs,
+ ) -> Mock:
+ """
+ Create a mock DatasetCollectionBinding with specified attributes.
+
+ Args:
+ binding_id: Unique identifier for the binding
+ provider_name: Name of the embedding model provider (e.g., "openai", "cohere")
+ model_name: Name of the embedding model (e.g., "text-embedding-ada-002")
+ collection_name: Name of the vector database collection
+ collection_type: Type of collection (default: "dataset")
+ created_at: Optional datetime for creation timestamp
+ **kwargs: Additional attributes to set on the mock
+
+ Returns:
+ Mock object configured as a DatasetCollectionBinding instance
+ """
+ binding = Mock(spec=DatasetCollectionBinding)
+ binding.id = binding_id
+ binding.provider_name = provider_name
+ binding.model_name = model_name
+ binding.collection_name = collection_name
+ binding.type = collection_type
+ binding.created_at = created_at
+ for key, value in kwargs.items():
+ setattr(binding, key, value)
+ return binding
+
+ @staticmethod
+ def create_dataset_mock(
+ dataset_id: str = "dataset-123",
+ **kwargs,
+ ) -> Mock:
+ """
+ Create a mock Dataset for testing collection name generation.
+
+ Args:
+ dataset_id: Unique identifier for the dataset
+ **kwargs: Additional attributes to set on the mock
+
+ Returns:
+ Mock object configured as a Dataset instance
+ """
+ dataset = Mock(spec=Dataset)
+ dataset.id = dataset_id
+ for key, value in kwargs.items():
+ setattr(dataset, key, value)
+ return dataset
+
+
+# ============================================================================
+# Tests for get_dataset_collection_binding
+# ============================================================================
+
+
+class TestDatasetCollectionBindingServiceGetBinding:
+ """
+ Comprehensive unit tests for DatasetCollectionBindingService.get_dataset_collection_binding method.
+
+ This test class covers the main collection binding retrieval/creation functionality,
+ including various provider/model combinations, collection types, and edge cases.
+
+ The get_dataset_collection_binding method:
+ 1. Queries for existing binding by provider_name, model_name, and collection_type
+ 2. Orders results by created_at (ascending) and takes the first match
+ 3. If no binding exists, creates a new one with:
+ - The provided provider_name and model_name
+ - A generated collection_name using Dataset.gen_collection_name_by_id
+ - The provided collection_type
+ 4. Adds the new binding to the database session and commits
+ 5. Returns the binding (either existing or newly created)
+
+ Test scenarios include:
+ - Retrieving existing bindings
+ - Creating new bindings when none exist
+ - Different collection types
+ - Database transaction handling
+ - Collection name generation
+ """
+
+ @pytest.fixture
+ def mock_db_session(self):
+ """
+ Mock database session for testing database operations.
+
+ Provides a mocked database session that can be used to verify:
+ - Query construction and execution
+ - Add operations for new bindings
+ - Commit operations for transaction completion
+
+ The mock is configured to return a query builder that supports
+ chaining operations like .where(), .order_by(), and .first().
+ """
+ with patch("services.dataset_service.db.session") as mock_db:
+ yield mock_db
+
+ def test_get_dataset_collection_binding_existing_binding_success(self, mock_db_session):
+ """
+ Test successful retrieval of an existing collection binding.
+
+ Verifies that when a binding already exists in the database for the given
+ provider, model, and collection type, the method returns the existing binding
+ without creating a new one.
+
+ This test ensures:
+ - The query is constructed correctly with all three filters
+ - Results are ordered by created_at
+ - The first matching binding is returned
+ - No new binding is created (db.session.add is not called)
+ - No commit is performed (db.session.commit is not called)
+ """
+ # Arrange
+ provider_name = "openai"
+ model_name = "text-embedding-ada-002"
+ collection_type = "dataset"
+
+ existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
+ binding_id="binding-123",
+ provider_name=provider_name,
+ model_name=model_name,
+ collection_type=collection_type,
+ )
+
+ # Mock the query chain: query().where().order_by().first()
+ mock_query = Mock()
+ mock_where = Mock()
+ mock_order_by = Mock()
+ mock_query.where.return_value = mock_where
+ mock_where.order_by.return_value = mock_order_by
+ mock_order_by.first.return_value = existing_binding
+ mock_db_session.query.return_value = mock_query
+
+ # Act
+ result = DatasetCollectionBindingService.get_dataset_collection_binding(
+ provider_name=provider_name, model_name=model_name, collection_type=collection_type
+ )
+
+ # Assert
+ assert result == existing_binding
+ assert result.id == "binding-123"
+ assert result.provider_name == provider_name
+ assert result.model_name == model_name
+ assert result.type == collection_type
+
+ # Verify query was constructed correctly
+ # The query should be constructed with DatasetCollectionBinding as the model
+ mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
+
+ # Verify the where clause was applied to filter by provider, model, and type
+ mock_query.where.assert_called_once()
+
+ # Verify the results were ordered by created_at (ascending)
+ # This ensures we get the oldest binding if multiple exist
+ mock_where.order_by.assert_called_once()
+
+ # Verify no new binding was created
+ # Since an existing binding was found, we should not create a new one
+ mock_db_session.add.assert_not_called()
+
+ # Verify no commit was performed
+ # Since no new binding was created, no database transaction is needed
+ mock_db_session.commit.assert_not_called()
+
+ def test_get_dataset_collection_binding_create_new_binding_success(self, mock_db_session):
+ """
+ Test successful creation of a new collection binding when none exists.
+
+ Verifies that when no binding exists in the database for the given
+ provider, model, and collection type, the method creates a new binding
+ with a generated collection name and commits it to the database.
+
+ This test ensures:
+ - The query returns None (no existing binding)
+ - A new DatasetCollectionBinding is created with correct attributes
+ - Dataset.gen_collection_name_by_id is called to generate collection name
+ - The new binding is added to the database session
+ - The transaction is committed
+ - The newly created binding is returned
+ """
+ # Arrange
+ provider_name = "cohere"
+ model_name = "embed-english-v3.0"
+ collection_type = "dataset"
+ generated_collection_name = "collection-generated-xyz"
+
+ # Mock the query chain to return None (no existing binding)
+ mock_query = Mock()
+ mock_where = Mock()
+ mock_order_by = Mock()
+ mock_query.where.return_value = mock_where
+ mock_where.order_by.return_value = mock_order_by
+ mock_order_by.first.return_value = None # No existing binding
+ mock_db_session.query.return_value = mock_query
+
+ # Mock Dataset.gen_collection_name_by_id to return a generated name
+ with patch("services.dataset_service.Dataset.gen_collection_name_by_id") as mock_gen_name:
+ mock_gen_name.return_value = generated_collection_name
+
+ # Mock uuid.uuid4 for the collection name generation
+ mock_uuid = "test-uuid-123"
+ with patch("services.dataset_service.uuid.uuid4", return_value=mock_uuid):
+ # Act
+ result = DatasetCollectionBindingService.get_dataset_collection_binding(
+ provider_name=provider_name, model_name=model_name, collection_type=collection_type
+ )
+
+ # Assert
+ assert result is not None
+ assert result.provider_name == provider_name
+ assert result.model_name == model_name
+ assert result.type == collection_type
+ assert result.collection_name == generated_collection_name
+
+ # Verify Dataset.gen_collection_name_by_id was called with the generated UUID
+ # This method generates a unique collection name based on the UUID
+ # The UUID is converted to string before passing to the method
+ mock_gen_name.assert_called_once_with(str(mock_uuid))
+
+ # Verify new binding was added to the database session
+ # The add method should be called exactly once with the new binding instance
+ mock_db_session.add.assert_called_once()
+
+ # Extract the binding that was added to verify its properties
+ added_binding = mock_db_session.add.call_args[0][0]
+
+ # Verify the added binding is an instance of DatasetCollectionBinding
+ # This ensures we're creating the correct type of object
+ assert isinstance(added_binding, DatasetCollectionBinding)
+
+ # Verify all the binding properties are set correctly
+ # These should match the input parameters to the method
+ assert added_binding.provider_name == provider_name
+ assert added_binding.model_name == model_name
+ assert added_binding.type == collection_type
+
+ # Verify the collection name was set from the generated name
+ # This ensures the binding has a valid collection identifier
+ assert added_binding.collection_name == generated_collection_name
+
+ # Verify the transaction was committed
+ # This ensures the new binding is persisted to the database
+ mock_db_session.commit.assert_called_once()
+
+ def test_get_dataset_collection_binding_different_collection_type(self, mock_db_session):
+ """
+ Test retrieval with a different collection type (not "dataset").
+
+ Verifies that the method correctly filters by collection_type, allowing
+ different types of collections to coexist with the same provider/model
+ combination.
+
+ This test ensures:
+ - Collection type is properly used as a filter in the query
+ - Different collection types can have separate bindings
+ - The correct binding is returned based on type
+ """
+ # Arrange
+ provider_name = "openai"
+ model_name = "text-embedding-ada-002"
+ collection_type = "custom_type"
+
+ existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
+ binding_id="binding-456",
+ provider_name=provider_name,
+ model_name=model_name,
+ collection_type=collection_type,
+ )
+
+ # Mock the query chain
+ mock_query = Mock()
+ mock_where = Mock()
+ mock_order_by = Mock()
+ mock_query.where.return_value = mock_where
+ mock_where.order_by.return_value = mock_order_by
+ mock_order_by.first.return_value = existing_binding
+ mock_db_session.query.return_value = mock_query
+
+ # Act
+ result = DatasetCollectionBindingService.get_dataset_collection_binding(
+ provider_name=provider_name, model_name=model_name, collection_type=collection_type
+ )
+
+ # Assert
+ assert result == existing_binding
+ assert result.type == collection_type
+
+ # Verify query was constructed with the correct type filter
+ mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
+ mock_query.where.assert_called_once()
+
+ def test_get_dataset_collection_binding_default_collection_type(self, mock_db_session):
+ """
+ Test retrieval with default collection type ("dataset").
+
+ Verifies that when collection_type is not provided, it defaults to "dataset"
+ as specified in the method signature.
+
+ This test ensures:
+ - The default value "dataset" is used when type is not specified
+ - The query correctly filters by the default type
+ """
+ # Arrange
+ provider_name = "openai"
+ model_name = "text-embedding-ada-002"
+ # collection_type defaults to "dataset" in method signature
+
+ existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
+ binding_id="binding-789",
+ provider_name=provider_name,
+ model_name=model_name,
+ collection_type="dataset", # Default type
+ )
+
+ # Mock the query chain
+ mock_query = Mock()
+ mock_where = Mock()
+ mock_order_by = Mock()
+ mock_query.where.return_value = mock_where
+ mock_where.order_by.return_value = mock_order_by
+ mock_order_by.first.return_value = existing_binding
+ mock_db_session.query.return_value = mock_query
+
+ # Act - call without specifying collection_type (uses default)
+ result = DatasetCollectionBindingService.get_dataset_collection_binding(
+ provider_name=provider_name, model_name=model_name
+ )
+
+ # Assert
+ assert result == existing_binding
+ assert result.type == "dataset"
+
+ # Verify query was constructed correctly
+ mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
+
+ def test_get_dataset_collection_binding_different_provider_model_combination(self, mock_db_session):
+ """
+ Test retrieval with different provider/model combinations.
+
+ Verifies that bindings are correctly filtered by both provider_name and
+ model_name, ensuring that different model combinations have separate bindings.
+
+ This test ensures:
+ - Provider and model are both used as filters
+ - Different combinations result in different bindings
+ - The correct binding is returned for each combination
+ """
+ # Arrange
+ provider_name = "huggingface"
+ model_name = "sentence-transformers/all-MiniLM-L6-v2"
+ collection_type = "dataset"
+
+ existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
+ binding_id="binding-hf-123",
+ provider_name=provider_name,
+ model_name=model_name,
+ collection_type=collection_type,
+ )
+
+ # Mock the query chain
+ mock_query = Mock()
+ mock_where = Mock()
+ mock_order_by = Mock()
+ mock_query.where.return_value = mock_where
+ mock_where.order_by.return_value = mock_order_by
+ mock_order_by.first.return_value = existing_binding
+ mock_db_session.query.return_value = mock_query
+
+ # Act
+ result = DatasetCollectionBindingService.get_dataset_collection_binding(
+ provider_name=provider_name, model_name=model_name, collection_type=collection_type
+ )
+
+ # Assert
+ assert result == existing_binding
+ assert result.provider_name == provider_name
+ assert result.model_name == model_name
+
+ # Verify query filters were applied correctly
+ # The query should filter by both provider_name and model_name
+ # This ensures different model combinations have separate bindings
+ mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
+
+ # Verify the where clause was applied with all three filters:
+ # - provider_name filter
+ # - model_name filter
+ # - collection_type filter
+ mock_query.where.assert_called_once()
+
+
+# ============================================================================
+# Tests for get_dataset_collection_binding_by_id_and_type
+# ============================================================================
+# This section contains tests for the get_dataset_collection_binding_by_id_and_type
+# method, which retrieves a specific collection binding by its ID and type.
+#
+# Key differences from get_dataset_collection_binding:
+# 1. This method queries by ID and type, not by provider/model/type
+# 2. This method does NOT create a new binding if one doesn't exist
+# 3. This method raises ValueError if the binding is not found
+# 4. This method is typically used when you already know the binding ID
+#
+# Use cases:
+# - Retrieving a binding that was previously created
+# - Validating that a binding exists before using it
+# - Accessing binding metadata when you have the ID
+#
+# ============================================================================
+
+
+class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
+ """
+ Comprehensive unit tests for DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type method.
+
+ This test class covers collection binding retrieval by ID and type,
+ including success scenarios and error handling for missing bindings.
+
+ The get_dataset_collection_binding_by_id_and_type method:
+ 1. Queries for a binding by collection_binding_id and collection_type
+ 2. Orders results by created_at (ascending) and takes the first match
+ 3. If no binding exists, raises ValueError("Dataset collection binding not found")
+ 4. Returns the found binding
+
+ Unlike get_dataset_collection_binding, this method does NOT create a new
+ binding if one doesn't exist - it only retrieves existing bindings.
+
+ Test scenarios include:
+ - Successful retrieval of existing bindings
+ - Error handling for missing bindings
+ - Different collection types
+ - Default collection type behavior
+ """
+
+ @pytest.fixture
+ def mock_db_session(self):
+ """
+ Mock database session for testing database operations.
+
+ Provides a mocked database session that can be used to verify:
+ - Query construction with ID and type filters
+ - Ordering by created_at
+ - First result retrieval
+
+ The mock is configured to return a query builder that supports
+ chaining operations like .where(), .order_by(), and .first().
+ """
+ with patch("services.dataset_service.db.session") as mock_db:
+ yield mock_db
+
+ def test_get_dataset_collection_binding_by_id_and_type_success(self, mock_db_session):
+ """
+ Test successful retrieval of a collection binding by ID and type.
+
+ Verifies that when a binding exists in the database with the given
+ ID and collection type, the method returns the binding.
+
+ This test ensures:
+ - The query is constructed correctly with ID and type filters
+ - Results are ordered by created_at
+ - The first matching binding is returned
+ - No error is raised
+ """
+ # Arrange
+ collection_binding_id = "binding-123"
+ collection_type = "dataset"
+
+ existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
+ binding_id=collection_binding_id,
+ provider_name="openai",
+ model_name="text-embedding-ada-002",
+ collection_type=collection_type,
+ )
+
+ # Mock the query chain: query().where().order_by().first()
+ mock_query = Mock()
+ mock_where = Mock()
+ mock_order_by = Mock()
+ mock_query.where.return_value = mock_where
+ mock_where.order_by.return_value = mock_order_by
+ mock_order_by.first.return_value = existing_binding
+ mock_db_session.query.return_value = mock_query
+
+ # Act
+ result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
+ collection_binding_id=collection_binding_id, collection_type=collection_type
+ )
+
+ # Assert
+ assert result == existing_binding
+ assert result.id == collection_binding_id
+ assert result.type == collection_type
+
+ # Verify query was constructed correctly
+ mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
+ mock_query.where.assert_called_once()
+ mock_where.order_by.assert_called_once()
+
+ def test_get_dataset_collection_binding_by_id_and_type_not_found_error(self, mock_db_session):
+ """
+ Test error handling when binding is not found.
+
+ Verifies that when no binding exists in the database with the given
+ ID and collection type, the method raises a ValueError with the
+ message "Dataset collection binding not found".
+
+ This test ensures:
+ - The query returns None (no existing binding)
+ - ValueError is raised with the correct message
+ - No binding is returned
+ """
+ # Arrange
+ collection_binding_id = "non-existent-binding"
+ collection_type = "dataset"
+
+ # Mock the query chain to return None (no existing binding)
+ mock_query = Mock()
+ mock_where = Mock()
+ mock_order_by = Mock()
+ mock_query.where.return_value = mock_where
+ mock_where.order_by.return_value = mock_order_by
+ mock_order_by.first.return_value = None # No existing binding
+ mock_db_session.query.return_value = mock_query
+
+ # Act & Assert
+ with pytest.raises(ValueError, match="Dataset collection binding not found"):
+ DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
+ collection_binding_id=collection_binding_id, collection_type=collection_type
+ )
+
+ # Verify query was attempted
+ mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
+ mock_query.where.assert_called_once()
+
+ def test_get_dataset_collection_binding_by_id_and_type_different_collection_type(self, mock_db_session):
+ """
+ Test retrieval with a different collection type.
+
+ Verifies that the method correctly filters by collection_type, ensuring
+ that bindings with the same ID but different types are treated as
+ separate entities.
+
+ This test ensures:
+ - Collection type is properly used as a filter in the query
+ - Different collection types can have separate bindings with same ID
+ - The correct binding is returned based on type
+ """
+ # Arrange
+ collection_binding_id = "binding-456"
+ collection_type = "custom_type"
+
+ existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
+ binding_id=collection_binding_id,
+ provider_name="cohere",
+ model_name="embed-english-v3.0",
+ collection_type=collection_type,
+ )
+
+ # Mock the query chain
+ mock_query = Mock()
+ mock_where = Mock()
+ mock_order_by = Mock()
+ mock_query.where.return_value = mock_where
+ mock_where.order_by.return_value = mock_order_by
+ mock_order_by.first.return_value = existing_binding
+ mock_db_session.query.return_value = mock_query
+
+ # Act
+ result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
+ collection_binding_id=collection_binding_id, collection_type=collection_type
+ )
+
+ # Assert
+ assert result == existing_binding
+ assert result.id == collection_binding_id
+ assert result.type == collection_type
+
+ # Verify query was constructed with the correct type filter
+ mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
+ mock_query.where.assert_called_once()
+
+ def test_get_dataset_collection_binding_by_id_and_type_default_collection_type(self, mock_db_session):
+ """
+ Test retrieval with default collection type ("dataset").
+
+ Verifies that when collection_type is not provided, it defaults to "dataset"
+ as specified in the method signature.
+
+ This test ensures:
+ - The default value "dataset" is used when type is not specified
+ - The query correctly filters by the default type
+ - The correct binding is returned
+ """
+ # Arrange
+ collection_binding_id = "binding-789"
+ # collection_type defaults to "dataset" in method signature
+
+ existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
+ binding_id=collection_binding_id,
+ provider_name="openai",
+ model_name="text-embedding-ada-002",
+ collection_type="dataset", # Default type
+ )
+
+ # Mock the query chain
+ mock_query = Mock()
+ mock_where = Mock()
+ mock_order_by = Mock()
+ mock_query.where.return_value = mock_where
+ mock_where.order_by.return_value = mock_order_by
+ mock_order_by.first.return_value = existing_binding
+ mock_db_session.query.return_value = mock_query
+
+ # Act - call without specifying collection_type (uses default)
+ result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
+ collection_binding_id=collection_binding_id
+ )
+
+ # Assert
+ assert result == existing_binding
+ assert result.id == collection_binding_id
+ assert result.type == "dataset"
+
+ # Verify query was constructed correctly
+ mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
+ mock_query.where.assert_called_once()
+
+ def test_get_dataset_collection_binding_by_id_and_type_wrong_type_error(self, mock_db_session):
+ """
+ Test error handling when binding exists but with wrong collection type.
+
+ Verifies that when a binding exists with the given ID but a different
+ collection type, the method raises a ValueError because the binding
+ doesn't match both the ID and type criteria.
+
+ This test ensures:
+ - The query correctly filters by both ID and type
+ - Bindings with matching ID but different type are not returned
+ - ValueError is raised when no matching binding is found
+ """
+ # Arrange
+ collection_binding_id = "binding-123"
+ collection_type = "dataset"
+
+ # Mock the query chain to return None (binding exists but with different type)
+ mock_query = Mock()
+ mock_where = Mock()
+ mock_order_by = Mock()
+ mock_query.where.return_value = mock_where
+ mock_where.order_by.return_value = mock_order_by
+ mock_order_by.first.return_value = None # No matching binding
+ mock_db_session.query.return_value = mock_query
+
+ # Act & Assert
+ with pytest.raises(ValueError, match="Dataset collection binding not found"):
+ DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
+ collection_binding_id=collection_binding_id, collection_type=collection_type
+ )
+
+ # Verify query was attempted with both ID and type filters
+ # The query should filter by both collection_binding_id and collection_type
+ # This ensures we only get bindings that match both criteria
+ mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
+
+ # Verify the where clause was applied with both filters:
+ # - collection_binding_id filter (exact match)
+ # - collection_type filter (exact match)
+ mock_query.where.assert_called_once()
+
+ # Note: The order_by and first() calls are also part of the query chain,
+ # but we don't need to verify them separately since they're part of the
+ # standard query pattern used by both methods in this service.
+
+
+# ============================================================================
+# Additional Test Scenarios and Edge Cases
+# ============================================================================
+# The following section could contain additional test scenarios if needed:
+#
+# Potential additional tests:
+# 1. Test with multiple existing bindings (verify ordering by created_at)
+# 2. Test with very long provider/model names (boundary testing)
+# 3. Test with special characters in provider/model names
+# 4. Test concurrent binding creation (thread safety)
+# 5. Test database rollback scenarios
+# 6. Test with None values for optional parameters
+# 7. Test with empty strings for required parameters
+# 8. Test collection name generation uniqueness
+# 9. Test with different UUID formats
+# 10. Test query performance with large datasets
+#
+# These scenarios are not currently implemented but could be added if needed
+# based on real-world usage patterns or discovered edge cases.
+#
+# ============================================================================
+
+
+# ============================================================================
+# Integration Notes and Best Practices
+# ============================================================================
+#
+# When using DatasetCollectionBindingService in production code, consider:
+#
+# 1. Error Handling:
+# - Always handle ValueError exceptions when calling
+# get_dataset_collection_binding_by_id_and_type
+# - Check return values from get_dataset_collection_binding to ensure
+# bindings were created successfully
+#
+# 2. Performance Considerations:
+# - The service queries the database on every call, so consider caching
+# bindings if they're accessed frequently
+# - Collection bindings are typically long-lived, so caching is safe
+#
+# 3. Transaction Management:
+# - New bindings are automatically committed to the database
+# - If you need to rollback, ensure you're within a transaction context
+#
+# 4. Collection Type Usage:
+# - Use "dataset" for standard dataset collections
+# - Use custom types only when you need to separate collections by purpose
+# - Be consistent with collection type naming across your application
+#
+# 5. Provider and Model Naming:
+# - Use consistent provider names (e.g., "openai", not "OpenAI" or "OPENAI")
+# - Use exact model names as provided by the model provider
+# - These names are case-sensitive and must match exactly
+#
+# ============================================================================
+
+
+# ============================================================================
+# Database Schema Reference
+# ============================================================================
+#
+# The DatasetCollectionBinding model has the following structure:
+#
+# - id: StringUUID (primary key, auto-generated)
+# - provider_name: String(255) (required, e.g., "openai", "cohere")
+# - model_name: String(255) (required, e.g., "text-embedding-ada-002")
+# - type: String(40) (required, default: "dataset")
+# - collection_name: String(64) (required, unique collection identifier)
+# - created_at: DateTime (auto-generated timestamp)
+#
+# Indexes:
+# - Primary key on id
+# - Composite index on (provider_name, model_name) for efficient lookups
+#
+# Relationships:
+# - One binding can be referenced by multiple datasets
+# - Datasets reference bindings via collection_binding_id
+#
+# ============================================================================
+
+
+# ============================================================================
+# Mocking Strategy Documentation
+# ============================================================================
+#
+# This test suite uses extensive mocking to isolate the unit under test.
+# Here's how the mocking strategy works:
+#
+# 1. Database Session Mocking:
+# - db.session is patched to prevent actual database access
+# - Query chains are mocked to return predictable results
+# - Add and commit operations are tracked for verification
+#
+# 2. Query Chain Mocking:
+# - query() returns a mock query object
+# - where() returns a mock where object
+# - order_by() returns a mock order_by object
+# - first() returns the final result (binding or None)
+#
+# 3. UUID Generation Mocking:
+# - uuid.uuid4() is mocked to return predictable UUIDs
+# - This ensures collection names are generated consistently in tests
+#
+# 4. Collection Name Generation Mocking:
+# - Dataset.gen_collection_name_by_id() is mocked
+# - This allows us to verify the method is called correctly
+# - We can control the generated collection name for testing
+#
+# Benefits of this approach:
+# - Tests run quickly (no database I/O)
+# - Tests are deterministic (no random UUIDs)
+# - Tests are isolated (no side effects)
+# - Tests are maintainable (clear mock setup)
+#
+# ============================================================================
From 0fdb4e7c12330216fbcbf674815c795f3a97d9e7 Mon Sep 17 00:00:00 2001
From: Gritty_dev <101377478+codomposer@users.noreply.github.com>
Date: Wed, 26 Nov 2025 20:57:52 -0500
Subject: [PATCH 015/431] chore: enhance the test script of conversation
service (#28739)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
---
.../services/test_conversation_service.py | 1412 ++++++++++++++++-
1 file changed, 1339 insertions(+), 73 deletions(-)
diff --git a/api/tests/unit_tests/services/test_conversation_service.py b/api/tests/unit_tests/services/test_conversation_service.py
index 9c1c044f03..81135dbbdf 100644
--- a/api/tests/unit_tests/services/test_conversation_service.py
+++ b/api/tests/unit_tests/services/test_conversation_service.py
@@ -1,17 +1,293 @@
+"""
+Comprehensive unit tests for ConversationService.
+
+This test suite provides complete coverage of conversation management operations in Dify,
+following TDD principles with the Arrange-Act-Assert pattern.
+
+## Test Coverage
+
+### 1. Conversation Pagination (TestConversationServicePagination)
+Tests conversation listing and filtering:
+- Empty include_ids returns empty results
+- Non-empty include_ids filters conversations properly
+- Empty exclude_ids doesn't filter results
+- Non-empty exclude_ids excludes specified conversations
+- Null user handling
+- Sorting and pagination edge cases
+
+### 2. Message Creation (TestConversationServiceMessageCreation)
+Tests message operations within conversations:
+- Message pagination without first_id
+- Message pagination with first_id specified
+- Error handling for non-existent messages
+- Empty result handling for null user/conversation
+- Message ordering (ascending/descending)
+- Has_more flag calculation
+
+### 3. Conversation Summarization (TestConversationServiceSummarization)
+Tests auto-generated conversation names:
+- Successful LLM-based name generation
+- Error handling when conversation has no messages
+- Graceful handling of LLM service failures
+- Manual vs auto-generated naming
+- Name update timestamp tracking
+
+### 4. Message Annotation (TestConversationServiceMessageAnnotation)
+Tests annotation creation and management:
+- Creating annotations from existing messages
+- Creating standalone annotations
+- Updating existing annotations
+- Paginated annotation retrieval
+- Annotation search with keywords
+- Annotation export functionality
+
+### 5. Conversation Export (TestConversationServiceExport)
+Tests data retrieval for export:
+- Successful conversation retrieval
+- Error handling for non-existent conversations
+- Message retrieval
+- Annotation export
+- Batch data export operations
+
+## Testing Approach
+
+- **Mocking Strategy**: All external dependencies (database, LLM, Redis) are mocked
+ for fast, isolated unit tests
+- **Factory Pattern**: ConversationServiceTestDataFactory provides consistent test data
+- **Fixtures**: Mock objects are configured per test method
+- **Assertions**: Each test verifies return values and side effects
+ (database operations, method calls)
+
+## Key Concepts
+
+**Conversation Sources:**
+- console: Created by workspace members
+- api: Created by end users via API
+
+**Message Pagination:**
+- first_id: Paginate from a specific message forward
+- last_id: Paginate from a specific message backward
+- Supports ascending/descending order
+
+**Annotations:**
+- Can be attached to messages or standalone
+- Support full-text search
+- Indexed for semantic retrieval
+"""
+
import uuid
-from unittest.mock import MagicMock, patch
+from datetime import UTC, datetime
+from decimal import Decimal
+from unittest.mock import MagicMock, Mock, create_autospec, patch
+
+import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
+from models import Account
+from models.model import App, Conversation, EndUser, Message, MessageAnnotation
+from services.annotation_service import AppAnnotationService
from services.conversation_service import ConversationService
+from services.errors.conversation import ConversationNotExistsError
+from services.errors.message import FirstMessageNotExistsError, MessageNotExistsError
+from services.message_service import MessageService
-class TestConversationService:
+class ConversationServiceTestDataFactory:
+ """
+ Factory for creating test data and mock objects.
+
+ Provides reusable methods to create consistent mock objects for testing
+ conversation-related operations.
+ """
+
+ @staticmethod
+ def create_account_mock(account_id: str = "account-123", **kwargs) -> Mock:
+ """
+ Create a mock Account object.
+
+ Args:
+ account_id: Unique identifier for the account
+ **kwargs: Additional attributes to set on the mock
+
+ Returns:
+ Mock Account object with specified attributes
+ """
+ account = create_autospec(Account, instance=True)
+ account.id = account_id
+ for key, value in kwargs.items():
+ setattr(account, key, value)
+ return account
+
+ @staticmethod
+ def create_end_user_mock(user_id: str = "user-123", **kwargs) -> Mock:
+ """
+ Create a mock EndUser object.
+
+ Args:
+ user_id: Unique identifier for the end user
+ **kwargs: Additional attributes to set on the mock
+
+ Returns:
+ Mock EndUser object with specified attributes
+ """
+ user = create_autospec(EndUser, instance=True)
+ user.id = user_id
+ for key, value in kwargs.items():
+ setattr(user, key, value)
+ return user
+
+ @staticmethod
+ def create_app_mock(app_id: str = "app-123", tenant_id: str = "tenant-123", **kwargs) -> Mock:
+ """
+ Create a mock App object.
+
+ Args:
+ app_id: Unique identifier for the app
+ tenant_id: Tenant/workspace identifier
+ **kwargs: Additional attributes to set on the mock
+
+ Returns:
+ Mock App object with specified attributes
+ """
+ app = create_autospec(App, instance=True)
+ app.id = app_id
+ app.tenant_id = tenant_id
+ app.name = kwargs.get("name", "Test App")
+ app.mode = kwargs.get("mode", "chat")
+ app.status = kwargs.get("status", "normal")
+ for key, value in kwargs.items():
+ setattr(app, key, value)
+ return app
+
+ @staticmethod
+ def create_conversation_mock(
+ conversation_id: str = "conv-123",
+ app_id: str = "app-123",
+ from_source: str = "console",
+ **kwargs,
+ ) -> Mock:
+ """
+ Create a mock Conversation object.
+
+ Args:
+ conversation_id: Unique identifier for the conversation
+ app_id: Associated app identifier
+ from_source: Source of conversation ('console' or 'api')
+ **kwargs: Additional attributes to set on the mock
+
+ Returns:
+ Mock Conversation object with specified attributes
+ """
+ conversation = create_autospec(Conversation, instance=True)
+ conversation.id = conversation_id
+ conversation.app_id = app_id
+ conversation.from_source = from_source
+ conversation.from_end_user_id = kwargs.get("from_end_user_id")
+ conversation.from_account_id = kwargs.get("from_account_id")
+ conversation.is_deleted = kwargs.get("is_deleted", False)
+ conversation.name = kwargs.get("name", "Test Conversation")
+ conversation.status = kwargs.get("status", "normal")
+ conversation.created_at = kwargs.get("created_at", datetime.now(UTC))
+ conversation.updated_at = kwargs.get("updated_at", datetime.now(UTC))
+ for key, value in kwargs.items():
+ setattr(conversation, key, value)
+ return conversation
+
+ @staticmethod
+ def create_message_mock(
+ message_id: str = "msg-123",
+ conversation_id: str = "conv-123",
+ app_id: str = "app-123",
+ **kwargs,
+ ) -> Mock:
+ """
+ Create a mock Message object.
+
+ Args:
+ message_id: Unique identifier for the message
+ conversation_id: Associated conversation identifier
+ app_id: Associated app identifier
+ **kwargs: Additional attributes to set on the mock
+
+ Returns:
+ Mock Message object with specified attributes including
+ query, answer, tokens, and pricing information
+ """
+ message = create_autospec(Message, instance=True)
+ message.id = message_id
+ message.conversation_id = conversation_id
+ message.app_id = app_id
+ message.query = kwargs.get("query", "Test query")
+ message.answer = kwargs.get("answer", "Test answer")
+ message.from_source = kwargs.get("from_source", "console")
+ message.from_end_user_id = kwargs.get("from_end_user_id")
+ message.from_account_id = kwargs.get("from_account_id")
+ message.created_at = kwargs.get("created_at", datetime.now(UTC))
+ message.message = kwargs.get("message", {})
+ message.message_tokens = kwargs.get("message_tokens", 0)
+ message.answer_tokens = kwargs.get("answer_tokens", 0)
+ message.message_unit_price = kwargs.get("message_unit_price", Decimal(0))
+ message.answer_unit_price = kwargs.get("answer_unit_price", Decimal(0))
+ message.message_price_unit = kwargs.get("message_price_unit", Decimal("0.001"))
+ message.answer_price_unit = kwargs.get("answer_price_unit", Decimal("0.001"))
+ message.currency = kwargs.get("currency", "USD")
+ message.status = kwargs.get("status", "normal")
+ for key, value in kwargs.items():
+ setattr(message, key, value)
+ return message
+
+ @staticmethod
+ def create_annotation_mock(
+ annotation_id: str = "anno-123",
+ app_id: str = "app-123",
+ message_id: str = "msg-123",
+ **kwargs,
+ ) -> Mock:
+ """
+ Create a mock MessageAnnotation object.
+
+ Args:
+ annotation_id: Unique identifier for the annotation
+ app_id: Associated app identifier
+ message_id: Associated message identifier (optional for standalone annotations)
+ **kwargs: Additional attributes to set on the mock
+
+ Returns:
+ Mock MessageAnnotation object with specified attributes including
+ question, content, and hit tracking
+ """
+ annotation = create_autospec(MessageAnnotation, instance=True)
+ annotation.id = annotation_id
+ annotation.app_id = app_id
+ annotation.message_id = message_id
+ annotation.conversation_id = kwargs.get("conversation_id")
+ annotation.question = kwargs.get("question", "Test question")
+ annotation.content = kwargs.get("content", "Test annotation")
+ annotation.account_id = kwargs.get("account_id", "account-123")
+ annotation.hit_count = kwargs.get("hit_count", 0)
+ annotation.created_at = kwargs.get("created_at", datetime.now(UTC))
+ annotation.updated_at = kwargs.get("updated_at", datetime.now(UTC))
+ for key, value in kwargs.items():
+ setattr(annotation, key, value)
+ return annotation
+
+
+class TestConversationServicePagination:
+ """Test conversation pagination operations."""
+
def test_pagination_with_empty_include_ids(self):
- """Test that empty include_ids returns empty result"""
- mock_session = MagicMock()
- mock_app_model = MagicMock(id=str(uuid.uuid4()))
- mock_user = MagicMock(id=str(uuid.uuid4()))
+ """
+ Test that empty include_ids returns empty result.
+ When include_ids is an empty list, the service should short-circuit
+ and return empty results without querying the database.
+ """
+ # Arrange - Set up test data
+ mock_session = MagicMock() # Mock database session
+ mock_app_model = ConversationServiceTestDataFactory.create_app_mock()
+ mock_user = ConversationServiceTestDataFactory.create_account_mock()
+
+ # Act - Call the service method with empty include_ids
result = ConversationService.pagination_by_last_id(
session=mock_session,
app_model=mock_app_model,
@@ -19,25 +295,188 @@ class TestConversationService:
last_id=None,
limit=20,
invoke_from=InvokeFrom.WEB_APP,
- include_ids=[], # Empty include_ids should return empty result
+ include_ids=[], # Empty list should trigger early return
exclude_ids=None,
)
+ # Assert - Verify empty result without database query
+ assert result.data == [] # No conversations returned
+ assert result.has_more is False # No more pages available
+ assert result.limit == 20 # Limit preserved in response
+
+ def test_pagination_with_non_empty_include_ids(self):
+ """
+ Test that non-empty include_ids filters properly.
+
+ When include_ids contains conversation IDs, the query should filter
+ to only return conversations matching those IDs.
+ """
+ # Arrange - Set up test data and mocks
+ mock_session = MagicMock() # Mock database session
+ mock_app_model = ConversationServiceTestDataFactory.create_app_mock()
+ mock_user = ConversationServiceTestDataFactory.create_account_mock()
+
+ # Create 3 mock conversations that would match the filter
+ mock_conversations = [
+ ConversationServiceTestDataFactory.create_conversation_mock(conversation_id=str(uuid.uuid4()))
+ for _ in range(3)
+ ]
+ # Mock the database query results
+ mock_session.scalars.return_value.all.return_value = mock_conversations
+ mock_session.scalar.return_value = 0 # No additional conversations beyond current page
+
+ # Act
+ with patch("services.conversation_service.select") as mock_select:
+ mock_stmt = MagicMock()
+ mock_select.return_value = mock_stmt
+ mock_stmt.where.return_value = mock_stmt
+ mock_stmt.order_by.return_value = mock_stmt
+ mock_stmt.limit.return_value = mock_stmt
+ mock_stmt.subquery.return_value = MagicMock()
+
+ result = ConversationService.pagination_by_last_id(
+ session=mock_session,
+ app_model=mock_app_model,
+ user=mock_user,
+ last_id=None,
+ limit=20,
+ invoke_from=InvokeFrom.WEB_APP,
+ include_ids=["conv1", "conv2"],
+ exclude_ids=None,
+ )
+
+ # Assert
+ assert mock_stmt.where.called
+
+ def test_pagination_with_empty_exclude_ids(self):
+ """
+ Test that empty exclude_ids doesn't filter.
+
+ When exclude_ids is an empty list, the query should not filter out
+ any conversations.
+ """
+ # Arrange
+ mock_session = MagicMock()
+ mock_app_model = ConversationServiceTestDataFactory.create_app_mock()
+ mock_user = ConversationServiceTestDataFactory.create_account_mock()
+ mock_conversations = [
+ ConversationServiceTestDataFactory.create_conversation_mock(conversation_id=str(uuid.uuid4()))
+ for _ in range(5)
+ ]
+ mock_session.scalars.return_value.all.return_value = mock_conversations
+ mock_session.scalar.return_value = 0
+
+ # Act
+ with patch("services.conversation_service.select") as mock_select:
+ mock_stmt = MagicMock()
+ mock_select.return_value = mock_stmt
+ mock_stmt.where.return_value = mock_stmt
+ mock_stmt.order_by.return_value = mock_stmt
+ mock_stmt.limit.return_value = mock_stmt
+ mock_stmt.subquery.return_value = MagicMock()
+
+ result = ConversationService.pagination_by_last_id(
+ session=mock_session,
+ app_model=mock_app_model,
+ user=mock_user,
+ last_id=None,
+ limit=20,
+ invoke_from=InvokeFrom.WEB_APP,
+ include_ids=None,
+ exclude_ids=[],
+ )
+
+ # Assert
+ assert len(result.data) == 5
+
+ def test_pagination_with_non_empty_exclude_ids(self):
+ """
+ Test that non-empty exclude_ids filters properly.
+
+ When exclude_ids contains conversation IDs, the query should filter
+ out conversations matching those IDs.
+ """
+ # Arrange
+ mock_session = MagicMock()
+ mock_app_model = ConversationServiceTestDataFactory.create_app_mock()
+ mock_user = ConversationServiceTestDataFactory.create_account_mock()
+ mock_conversations = [
+ ConversationServiceTestDataFactory.create_conversation_mock(conversation_id=str(uuid.uuid4()))
+ for _ in range(3)
+ ]
+ mock_session.scalars.return_value.all.return_value = mock_conversations
+ mock_session.scalar.return_value = 0
+
+ # Act
+ with patch("services.conversation_service.select") as mock_select:
+ mock_stmt = MagicMock()
+ mock_select.return_value = mock_stmt
+ mock_stmt.where.return_value = mock_stmt
+ mock_stmt.order_by.return_value = mock_stmt
+ mock_stmt.limit.return_value = mock_stmt
+ mock_stmt.subquery.return_value = MagicMock()
+
+ result = ConversationService.pagination_by_last_id(
+ session=mock_session,
+ app_model=mock_app_model,
+ user=mock_user,
+ last_id=None,
+ limit=20,
+ invoke_from=InvokeFrom.WEB_APP,
+ include_ids=None,
+ exclude_ids=["conv1", "conv2"],
+ )
+
+ # Assert
+ assert mock_stmt.where.called
+
+ def test_pagination_returns_empty_when_user_is_none(self):
+ """
+ Test that pagination returns empty result when user is None.
+
+ This ensures proper handling of unauthenticated requests.
+ """
+ # Arrange
+ mock_session = MagicMock()
+ mock_app_model = ConversationServiceTestDataFactory.create_app_mock()
+
+ # Act
+ result = ConversationService.pagination_by_last_id(
+ session=mock_session,
+ app_model=mock_app_model,
+ user=None, # No user provided
+ last_id=None,
+ limit=20,
+ invoke_from=InvokeFrom.WEB_APP,
+ )
+
+ # Assert - should return empty result without querying database
assert result.data == []
assert result.has_more is False
assert result.limit == 20
- def test_pagination_with_non_empty_include_ids(self):
- """Test that non-empty include_ids filters properly"""
- mock_session = MagicMock()
- mock_app_model = MagicMock(id=str(uuid.uuid4()))
- mock_user = MagicMock(id=str(uuid.uuid4()))
+ def test_pagination_with_sorting_descending(self):
+ """
+ Test pagination with descending sort order.
- # Mock the query results
- mock_conversations = [MagicMock(id=str(uuid.uuid4())) for _ in range(3)]
- mock_session.scalars.return_value.all.return_value = mock_conversations
+ Verifies that conversations are sorted by updated_at in descending order (newest first).
+ """
+ # Arrange
+ mock_session = MagicMock()
+ mock_app_model = ConversationServiceTestDataFactory.create_app_mock()
+ mock_user = ConversationServiceTestDataFactory.create_account_mock()
+
+ # Create conversations with different timestamps
+ conversations = [
+ ConversationServiceTestDataFactory.create_conversation_mock(
+ conversation_id=f"conv-{i}", updated_at=datetime(2024, 1, i + 1, tzinfo=UTC)
+ )
+ for i in range(3)
+ ]
+ mock_session.scalars.return_value.all.return_value = conversations
mock_session.scalar.return_value = 0
+ # Act
with patch("services.conversation_service.select") as mock_select:
mock_stmt = MagicMock()
mock_select.return_value = mock_stmt
@@ -53,75 +492,902 @@ class TestConversationService:
last_id=None,
limit=20,
invoke_from=InvokeFrom.WEB_APP,
- include_ids=["conv1", "conv2"], # Non-empty include_ids
- exclude_ids=None,
+ sort_by="-updated_at", # Descending sort
)
- # Verify the where clause was called with id.in_
- assert mock_stmt.where.called
+ # Assert
+ assert len(result.data) == 3
+ mock_stmt.order_by.assert_called()
- def test_pagination_with_empty_exclude_ids(self):
- """Test that empty exclude_ids doesn't filter"""
- mock_session = MagicMock()
- mock_app_model = MagicMock(id=str(uuid.uuid4()))
- mock_user = MagicMock(id=str(uuid.uuid4()))
- # Mock the query results
- mock_conversations = [MagicMock(id=str(uuid.uuid4())) for _ in range(5)]
- mock_session.scalars.return_value.all.return_value = mock_conversations
- mock_session.scalar.return_value = 0
+class TestConversationServiceMessageCreation:
+ """
+ Test message creation and pagination.
- with patch("services.conversation_service.select") as mock_select:
- mock_stmt = MagicMock()
- mock_select.return_value = mock_stmt
- mock_stmt.where.return_value = mock_stmt
- mock_stmt.order_by.return_value = mock_stmt
- mock_stmt.limit.return_value = mock_stmt
- mock_stmt.subquery.return_value = MagicMock()
+ Tests MessageService operations for creating and retrieving messages
+ within conversations.
+ """
- result = ConversationService.pagination_by_last_id(
- session=mock_session,
- app_model=mock_app_model,
- user=mock_user,
- last_id=None,
- limit=20,
- invoke_from=InvokeFrom.WEB_APP,
- include_ids=None,
- exclude_ids=[], # Empty exclude_ids should not filter
+ @patch("services.message_service.db.session")
+ @patch("services.message_service.ConversationService.get_conversation")
+ def test_pagination_by_first_id_without_first_id(self, mock_get_conversation, mock_db_session):
+ """
+ Test message pagination without specifying first_id.
+
+ When first_id is None, the service should return the most recent messages
+ up to the specified limit.
+ """
+ # Arrange
+ app_model = ConversationServiceTestDataFactory.create_app_mock()
+ user = ConversationServiceTestDataFactory.create_account_mock()
+ conversation = ConversationServiceTestDataFactory.create_conversation_mock()
+
+ # Create 3 test messages in the conversation
+ messages = [
+ ConversationServiceTestDataFactory.create_message_mock(
+ message_id=f"msg-{i}", conversation_id=conversation.id
+ )
+ for i in range(3)
+ ]
+
+ # Mock the conversation lookup to return our test conversation
+ mock_get_conversation.return_value = conversation
+
+ # Set up the database query mock chain
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query # WHERE clause returns self for chaining
+ mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining
+ mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
+ mock_query.all.return_value = messages # Final .all() returns the messages
+
+ # Act - Call the pagination method without first_id
+ result = MessageService.pagination_by_first_id(
+ app_model=app_model,
+ user=user,
+ conversation_id=conversation.id,
+ first_id=None, # No starting point specified
+ limit=10,
+ )
+
+ # Assert - Verify the results
+ assert len(result.data) == 3 # All 3 messages returned
+ assert result.has_more is False # No more messages available (3 < limit of 10)
+ # Verify conversation was looked up with correct parameters
+ mock_get_conversation.assert_called_once_with(app_model=app_model, user=user, conversation_id=conversation.id)
+
+ @patch("services.message_service.db.session")
+ @patch("services.message_service.ConversationService.get_conversation")
+ def test_pagination_by_first_id_with_first_id(self, mock_get_conversation, mock_db_session):
+ """
+ Test message pagination with first_id specified.
+
+ When first_id is provided, the service should return messages starting
+ from the specified message up to the limit.
+ """
+ # Arrange
+ app_model = ConversationServiceTestDataFactory.create_app_mock()
+ user = ConversationServiceTestDataFactory.create_account_mock()
+ conversation = ConversationServiceTestDataFactory.create_conversation_mock()
+ first_message = ConversationServiceTestDataFactory.create_message_mock(
+ message_id="msg-first", conversation_id=conversation.id
+ )
+ messages = [
+ ConversationServiceTestDataFactory.create_message_mock(
+ message_id=f"msg-{i}", conversation_id=conversation.id
+ )
+ for i in range(2)
+ ]
+
+ # Mock the conversation lookup to return our test conversation
+ mock_get_conversation.return_value = conversation
+
+ # Set up the database query mock chain
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query # WHERE clause returns self for chaining
+ mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining
+ mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
+ mock_query.first.return_value = first_message # First message returned
+ mock_query.all.return_value = messages # Remaining messages returned
+
+ # Act - Call the pagination method with first_id
+ result = MessageService.pagination_by_first_id(
+ app_model=app_model,
+ user=user,
+ conversation_id=conversation.id,
+ first_id="msg-first",
+ limit=10,
+ )
+
+ # Assert - Verify the results
+ assert len(result.data) == 2 # Only 2 messages returned after first_id
+ assert result.has_more is False # No more messages available (2 < limit of 10)
+
+ @patch("services.message_service.db.session")
+ @patch("services.message_service.ConversationService.get_conversation")
+ def test_pagination_by_first_id_raises_error_when_first_message_not_found(
+ self, mock_get_conversation, mock_db_session
+ ):
+ """
+ Test that FirstMessageNotExistsError is raised when first_id doesn't exist.
+
+ When the specified first_id does not exist in the conversation,
+ the service should raise an error.
+ """
+ # Arrange
+ app_model = ConversationServiceTestDataFactory.create_app_mock()
+ user = ConversationServiceTestDataFactory.create_account_mock()
+ conversation = ConversationServiceTestDataFactory.create_conversation_mock()
+
+ # Mock the conversation lookup to return our test conversation
+ mock_get_conversation.return_value = conversation
+
+ # Set up the database query mock chain
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query # WHERE clause returns self for chaining
+ mock_query.first.return_value = None # No message found for first_id
+
+ # Act & Assert
+ with pytest.raises(FirstMessageNotExistsError):
+ MessageService.pagination_by_first_id(
+ app_model=app_model,
+ user=user,
+ conversation_id=conversation.id,
+ first_id="non-existent-msg",
+ limit=10,
)
- # Result should contain the mocked conversations
- assert len(result.data) == 5
+ def test_pagination_returns_empty_when_no_user(self):
+ """
+ Test that pagination returns empty result when user is None.
- def test_pagination_with_non_empty_exclude_ids(self):
- """Test that non-empty exclude_ids filters properly"""
- mock_session = MagicMock()
- mock_app_model = MagicMock(id=str(uuid.uuid4()))
- mock_user = MagicMock(id=str(uuid.uuid4()))
+ This ensures proper handling of unauthenticated requests.
+ """
+ # Arrange
+ app_model = ConversationServiceTestDataFactory.create_app_mock()
- # Mock the query results
- mock_conversations = [MagicMock(id=str(uuid.uuid4())) for _ in range(3)]
- mock_session.scalars.return_value.all.return_value = mock_conversations
- mock_session.scalar.return_value = 0
+ # Act
+ result = MessageService.pagination_by_first_id(
+ app_model=app_model,
+ user=None,
+ conversation_id="conv-123",
+ first_id=None,
+ limit=10,
+ )
- with patch("services.conversation_service.select") as mock_select:
- mock_stmt = MagicMock()
- mock_select.return_value = mock_stmt
- mock_stmt.where.return_value = mock_stmt
- mock_stmt.order_by.return_value = mock_stmt
- mock_stmt.limit.return_value = mock_stmt
- mock_stmt.subquery.return_value = MagicMock()
+ # Assert
+ assert result.data == []
+ assert result.has_more is False
- result = ConversationService.pagination_by_last_id(
- session=mock_session,
- app_model=mock_app_model,
- user=mock_user,
- last_id=None,
- limit=20,
- invoke_from=InvokeFrom.WEB_APP,
- include_ids=None,
- exclude_ids=["conv1", "conv2"], # Non-empty exclude_ids
+ def test_pagination_returns_empty_when_no_conversation_id(self):
+ """
+ Test that pagination returns empty result when conversation_id is None.
+
+ This ensures proper handling of invalid requests.
+ """
+ # Arrange
+ app_model = ConversationServiceTestDataFactory.create_app_mock()
+ user = ConversationServiceTestDataFactory.create_account_mock()
+
+ # Act
+ result = MessageService.pagination_by_first_id(
+ app_model=app_model,
+ user=user,
+ conversation_id="",
+ first_id=None,
+ limit=10,
+ )
+
+ # Assert
+ assert result.data == []
+ assert result.has_more is False
+
+ @patch("services.message_service.db.session")
+ @patch("services.message_service.ConversationService.get_conversation")
+ def test_pagination_with_has_more_flag(self, mock_get_conversation, mock_db_session):
+ """
+ Test that has_more flag is correctly set when there are more messages.
+
+ The service fetches limit+1 messages to determine if more exist.
+ """
+ # Arrange
+ app_model = ConversationServiceTestDataFactory.create_app_mock()
+ user = ConversationServiceTestDataFactory.create_account_mock()
+ conversation = ConversationServiceTestDataFactory.create_conversation_mock()
+
+ # Create limit+1 messages to trigger has_more
+ limit = 5
+ messages = [
+ ConversationServiceTestDataFactory.create_message_mock(
+ message_id=f"msg-{i}", conversation_id=conversation.id
)
+ for i in range(limit + 1) # One extra message
+ ]
- # Verify the where clause was called for exclusion
- assert mock_stmt.where.called
+ # Mock the conversation lookup to return our test conversation
+ mock_get_conversation.return_value = conversation
+
+ # Set up the database query mock chain
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query # WHERE clause returns self for chaining
+ mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining
+ mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
+ mock_query.all.return_value = messages # Final .all() returns the messages
+
+ # Act
+ result = MessageService.pagination_by_first_id(
+ app_model=app_model,
+ user=user,
+ conversation_id=conversation.id,
+ first_id=None,
+ limit=limit,
+ )
+
+ # Assert
+ assert len(result.data) == limit # Extra message should be removed
+ assert result.has_more is True # Flag should be set
+
+ @patch("services.message_service.db.session")
+ @patch("services.message_service.ConversationService.get_conversation")
+ def test_pagination_with_ascending_order(self, mock_get_conversation, mock_db_session):
+ """
+ Test message pagination with ascending order.
+
+ Messages should be returned in chronological order (oldest first).
+ """
+ # Arrange
+ app_model = ConversationServiceTestDataFactory.create_app_mock()
+ user = ConversationServiceTestDataFactory.create_account_mock()
+ conversation = ConversationServiceTestDataFactory.create_conversation_mock()
+
+ # Create messages with different timestamps
+ messages = [
+ ConversationServiceTestDataFactory.create_message_mock(
+ message_id=f"msg-{i}", conversation_id=conversation.id, created_at=datetime(2024, 1, i + 1, tzinfo=UTC)
+ )
+ for i in range(3)
+ ]
+
+ # Mock the conversation lookup to return our test conversation
+ mock_get_conversation.return_value = conversation
+
+ # Set up the database query mock chain
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query # WHERE clause returns self for chaining
+ mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining
+ mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
+ mock_query.all.return_value = messages # Final .all() returns the messages
+
+ # Act
+ result = MessageService.pagination_by_first_id(
+ app_model=app_model,
+ user=user,
+ conversation_id=conversation.id,
+ first_id=None,
+ limit=10,
+ order="asc", # Ascending order
+ )
+
+ # Assert
+ assert len(result.data) == 3
+ # Messages should be in ascending order after reversal
+
+
+class TestConversationServiceSummarization:
+ """
+ Test conversation summarization (auto-generated names).
+
+ Tests the auto_generate_name functionality that creates conversation
+ titles based on the first message.
+ """
+
+ @patch("services.conversation_service.LLMGenerator.generate_conversation_name")
+ @patch("services.conversation_service.db.session")
+ def test_auto_generate_name_success(self, mock_db_session, mock_llm_generator):
+ """
+ Test successful auto-generation of conversation name.
+
+ The service uses an LLM to generate a descriptive name based on
+ the first message in the conversation.
+ """
+ # Arrange
+ app_model = ConversationServiceTestDataFactory.create_app_mock()
+ conversation = ConversationServiceTestDataFactory.create_conversation_mock()
+
+ # Create the first message that will be used to generate the name
+ first_message = ConversationServiceTestDataFactory.create_message_mock(
+ conversation_id=conversation.id, query="What is machine learning?"
+ )
+ # Expected name from LLM
+ generated_name = "Machine Learning Discussion"
+
+ # Set up database query mock to return the first message
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query # Filter by app_id and conversation_id
+ mock_query.order_by.return_value = mock_query # Order by created_at ascending
+ mock_query.first.return_value = first_message # Return the first message
+
+ # Mock the LLM to return our expected name
+ mock_llm_generator.return_value = generated_name
+
+ # Act
+ result = ConversationService.auto_generate_name(app_model, conversation)
+
+ # Assert
+ assert conversation.name == generated_name # Name updated on conversation object
+ # Verify LLM was called with correct parameters
+ mock_llm_generator.assert_called_once_with(
+ app_model.tenant_id, first_message.query, conversation.id, app_model.id
+ )
+ mock_db_session.commit.assert_called_once() # Changes committed to database
+
+ @patch("services.conversation_service.db.session")
+ def test_auto_generate_name_raises_error_when_no_message(self, mock_db_session):
+ """
+ Test that MessageNotExistsError is raised when conversation has no messages.
+
+ When the conversation has no messages, the service should raise an error.
+ """
+ # Arrange
+ app_model = ConversationServiceTestDataFactory.create_app_mock()
+ conversation = ConversationServiceTestDataFactory.create_conversation_mock()
+
+ # Set up database query mock to return no messages
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query # Filter by app_id and conversation_id
+ mock_query.order_by.return_value = mock_query # Order by created_at ascending
+ mock_query.first.return_value = None # No messages found
+
+ # Act & Assert
+ with pytest.raises(MessageNotExistsError):
+ ConversationService.auto_generate_name(app_model, conversation)
+
+ @patch("services.conversation_service.LLMGenerator.generate_conversation_name")
+ @patch("services.conversation_service.db.session")
+ def test_auto_generate_name_handles_llm_failure_gracefully(self, mock_db_session, mock_llm_generator):
+ """
+ Test that LLM generation failures are suppressed and don't crash.
+
+ When the LLM fails to generate a name, the service should not crash
+ and should return the original conversation name.
+ """
+ # Arrange
+ app_model = ConversationServiceTestDataFactory.create_app_mock()
+ conversation = ConversationServiceTestDataFactory.create_conversation_mock()
+ first_message = ConversationServiceTestDataFactory.create_message_mock(conversation_id=conversation.id)
+ original_name = conversation.name
+
+ # Set up database query mock to return the first message
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query # Filter by app_id and conversation_id
+ mock_query.order_by.return_value = mock_query # Order by created_at ascending
+ mock_query.first.return_value = first_message # Return the first message
+
+ # Mock the LLM to raise an exception
+ mock_llm_generator.side_effect = Exception("LLM service unavailable")
+
+ # Act
+ result = ConversationService.auto_generate_name(app_model, conversation)
+
+ # Assert
+ assert conversation.name == original_name # Name remains unchanged
+ mock_db_session.commit.assert_called_once() # Changes committed to database
+
+ @patch("services.conversation_service.db.session")
+ @patch("services.conversation_service.ConversationService.get_conversation")
+ @patch("services.conversation_service.ConversationService.auto_generate_name")
+ def test_rename_with_auto_generate(self, mock_auto_generate, mock_get_conversation, mock_db_session):
+ """
+ Test renaming conversation with auto-generation enabled.
+
+ When auto_generate is True, the service should call the auto_generate_name
+ method to generate a new name for the conversation.
+ """
+ # Arrange
+ app_model = ConversationServiceTestDataFactory.create_app_mock()
+ user = ConversationServiceTestDataFactory.create_account_mock()
+ conversation = ConversationServiceTestDataFactory.create_conversation_mock()
+ conversation.name = "Auto-generated Name"
+
+ # Mock the conversation lookup to return our test conversation
+ mock_get_conversation.return_value = conversation
+
+ # Mock the auto_generate_name method to return the conversation
+ mock_auto_generate.return_value = conversation
+
+ # Act
+ result = ConversationService.rename(
+ app_model=app_model,
+ conversation_id=conversation.id,
+ user=user,
+ name="",
+ auto_generate=True,
+ )
+
+ # Assert
+ mock_auto_generate.assert_called_once_with(app_model, conversation)
+ assert result == conversation
+
+ @patch("services.conversation_service.db.session")
+ @patch("services.conversation_service.ConversationService.get_conversation")
+ @patch("services.conversation_service.naive_utc_now")
+ def test_rename_with_manual_name(self, mock_naive_utc_now, mock_get_conversation, mock_db_session):
+ """
+ Test renaming conversation with manual name.
+
+ When auto_generate is False, the service should update the conversation
+ name with the provided manual name.
+ """
+ # Arrange
+ app_model = ConversationServiceTestDataFactory.create_app_mock()
+ user = ConversationServiceTestDataFactory.create_account_mock()
+ conversation = ConversationServiceTestDataFactory.create_conversation_mock()
+ new_name = "My Custom Conversation Name"
+ mock_time = datetime(2024, 1, 1, 12, 0, 0)
+
+ # Mock the conversation lookup to return our test conversation
+ mock_get_conversation.return_value = conversation
+
+ # Mock the current time to return our mock time
+ mock_naive_utc_now.return_value = mock_time
+
+ # Act
+ result = ConversationService.rename(
+ app_model=app_model,
+ conversation_id=conversation.id,
+ user=user,
+ name=new_name,
+ auto_generate=False,
+ )
+
+ # Assert
+ assert conversation.name == new_name
+ assert conversation.updated_at == mock_time
+ mock_db_session.commit.assert_called_once()
+
+
+class TestConversationServiceMessageAnnotation:
+ """
+ Test message annotation operations.
+
+ Tests AppAnnotationService operations for creating and managing
+ message annotations.
+ """
+
+ @patch("services.annotation_service.db.session")
+ @patch("services.annotation_service.current_account_with_tenant")
+ def test_create_annotation_from_message(self, mock_current_account, mock_db_session):
+ """
+ Test creating annotation from existing message.
+
+ Annotations can be attached to messages to provide curated responses
+ that override the AI-generated answers.
+ """
+ # Arrange
+ app_id = "app-123"
+ message_id = "msg-123"
+ account = ConversationServiceTestDataFactory.create_account_mock()
+ tenant_id = "tenant-123"
+ app = ConversationServiceTestDataFactory.create_app_mock(app_id=app_id, tenant_id=tenant_id)
+
+ # Create a message that doesn't have an annotation yet
+ message = ConversationServiceTestDataFactory.create_message_mock(
+ message_id=message_id, app_id=app_id, query="What is AI?"
+ )
+ message.annotation = None # No existing annotation
+
+ # Mock the authentication context to return current user and tenant
+ mock_current_account.return_value = (account, tenant_id)
+
+ # Set up database query mock
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query
+ # First call returns app, second returns message, third returns None (no annotation setting)
+ mock_query.first.side_effect = [app, message, None]
+
+ # Annotation data to create
+ args = {"message_id": message_id, "answer": "AI is artificial intelligence"}
+
+ # Act
+ with patch("services.annotation_service.add_annotation_to_index_task"):
+ result = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id)
+
+ # Assert
+ mock_db_session.add.assert_called_once() # Annotation added to session
+ mock_db_session.commit.assert_called_once() # Changes committed
+
+ @patch("services.annotation_service.db.session")
+ @patch("services.annotation_service.current_account_with_tenant")
+ def test_create_annotation_without_message(self, mock_current_account, mock_db_session):
+ """
+ Test creating standalone annotation without message.
+
+ Annotations can be created without a message reference for bulk imports
+ or manual annotation creation.
+ """
+ # Arrange
+ app_id = "app-123"
+ account = ConversationServiceTestDataFactory.create_account_mock()
+ tenant_id = "tenant-123"
+ app = ConversationServiceTestDataFactory.create_app_mock(app_id=app_id, tenant_id=tenant_id)
+
+ # Mock the authentication context to return current user and tenant
+ mock_current_account.return_value = (account, tenant_id)
+
+ # Set up database query mock
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query
+ # First call returns app, second returns None (no message)
+ mock_query.first.side_effect = [app, None]
+
+ # Annotation data to create
+ args = {
+ "question": "What is natural language processing?",
+ "answer": "NLP is a field of AI focused on language understanding",
+ }
+
+ # Act
+ with patch("services.annotation_service.add_annotation_to_index_task"):
+ result = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id)
+
+ # Assert
+ mock_db_session.add.assert_called_once() # Annotation added to session
+ mock_db_session.commit.assert_called_once() # Changes committed
+
+ @patch("services.annotation_service.db.session")
+ @patch("services.annotation_service.current_account_with_tenant")
+ def test_update_existing_annotation(self, mock_current_account, mock_db_session):
+ """
+ Test updating an existing annotation.
+
+ When a message already has an annotation, calling the service again
+ should update the existing annotation rather than creating a new one.
+ """
+ # Arrange
+ app_id = "app-123"
+ message_id = "msg-123"
+ account = ConversationServiceTestDataFactory.create_account_mock()
+ tenant_id = "tenant-123"
+ app = ConversationServiceTestDataFactory.create_app_mock(app_id=app_id, tenant_id=tenant_id)
+ message = ConversationServiceTestDataFactory.create_message_mock(message_id=message_id, app_id=app_id)
+
+ # Create an existing annotation with old content
+ existing_annotation = ConversationServiceTestDataFactory.create_annotation_mock(
+ app_id=app_id, message_id=message_id, content="Old annotation"
+ )
+ message.annotation = existing_annotation # Message already has annotation
+
+ # Mock the authentication context to return current user and tenant
+ mock_current_account.return_value = (account, tenant_id)
+
+ # Set up database query mock
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query
+ # First call returns app, second returns message, third returns None (no annotation setting)
+ mock_query.first.side_effect = [app, message, None]
+
+ # New content to update the annotation with
+ args = {"message_id": message_id, "answer": "Updated annotation content"}
+
+ # Act
+ with patch("services.annotation_service.add_annotation_to_index_task"):
+ result = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id)
+
+ # Assert
+ assert existing_annotation.content == "Updated annotation content" # Content updated
+ mock_db_session.add.assert_called_once() # Annotation re-added to session
+ mock_db_session.commit.assert_called_once() # Changes committed
+
+ @patch("services.annotation_service.db.paginate")
+ @patch("services.annotation_service.db.session")
+ @patch("services.annotation_service.current_account_with_tenant")
+ def test_get_annotation_list(self, mock_current_account, mock_db_session, mock_db_paginate):
+ """
+ Test retrieving paginated annotation list.
+
+ Annotations can be retrieved in a paginated list for display in the UI.
+ """
+ """Test retrieving paginated annotation list."""
+ # Arrange
+ app_id = "app-123"
+ account = ConversationServiceTestDataFactory.create_account_mock()
+ tenant_id = "tenant-123"
+ app = ConversationServiceTestDataFactory.create_app_mock(app_id=app_id, tenant_id=tenant_id)
+ annotations = [
+ ConversationServiceTestDataFactory.create_annotation_mock(annotation_id=f"anno-{i}", app_id=app_id)
+ for i in range(5)
+ ]
+
+ mock_current_account.return_value = (account, tenant_id)
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query
+ mock_query.first.return_value = app
+
+ mock_paginate = MagicMock()
+ mock_paginate.items = annotations
+ mock_paginate.total = 5
+ mock_db_paginate.return_value = mock_paginate
+
+ # Act
+ result_items, result_total = AppAnnotationService.get_annotation_list_by_app_id(
+ app_id=app_id, page=1, limit=10, keyword=""
+ )
+
+ # Assert
+ assert len(result_items) == 5
+ assert result_total == 5
+
+ @patch("services.annotation_service.db.paginate")
+ @patch("services.annotation_service.db.session")
+ @patch("services.annotation_service.current_account_with_tenant")
+ def test_get_annotation_list_with_keyword_search(self, mock_current_account, mock_db_session, mock_db_paginate):
+ """
+ Test retrieving annotations with keyword filtering.
+
+ Annotations can be searched by question or content using case-insensitive matching.
+ """
+ # Arrange
+ app_id = "app-123"
+ account = ConversationServiceTestDataFactory.create_account_mock()
+ tenant_id = "tenant-123"
+ app = ConversationServiceTestDataFactory.create_app_mock(app_id=app_id, tenant_id=tenant_id)
+
+ # Create annotations with searchable content
+ annotations = [
+ ConversationServiceTestDataFactory.create_annotation_mock(
+ annotation_id="anno-1",
+ app_id=app_id,
+ question="What is machine learning?",
+ content="ML is a subset of AI",
+ ),
+ ConversationServiceTestDataFactory.create_annotation_mock(
+ annotation_id="anno-2",
+ app_id=app_id,
+ question="What is deep learning?",
+ content="Deep learning uses neural networks",
+ ),
+ ]
+
+ mock_current_account.return_value = (account, tenant_id)
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query
+ mock_query.first.return_value = app
+
+ mock_paginate = MagicMock()
+ mock_paginate.items = [annotations[0]] # Only first annotation matches
+ mock_paginate.total = 1
+ mock_db_paginate.return_value = mock_paginate
+
+ # Act
+ result_items, result_total = AppAnnotationService.get_annotation_list_by_app_id(
+ app_id=app_id,
+ page=1,
+ limit=10,
+ keyword="machine", # Search keyword
+ )
+
+ # Assert
+ assert len(result_items) == 1
+ assert result_total == 1
+
+ @patch("services.annotation_service.db.session")
+ @patch("services.annotation_service.current_account_with_tenant")
+ def test_insert_annotation_directly(self, mock_current_account, mock_db_session):
+ """
+ Test direct annotation insertion without message reference.
+
+ This is used for bulk imports or manual annotation creation.
+ """
+ # Arrange
+ app_id = "app-123"
+ account = ConversationServiceTestDataFactory.create_account_mock()
+ tenant_id = "tenant-123"
+ app = ConversationServiceTestDataFactory.create_app_mock(app_id=app_id, tenant_id=tenant_id)
+
+ mock_current_account.return_value = (account, tenant_id)
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query
+ mock_query.first.side_effect = [app, None]
+
+ args = {
+ "question": "What is natural language processing?",
+ "answer": "NLP is a field of AI focused on language understanding",
+ }
+
+ # Act
+ with patch("services.annotation_service.add_annotation_to_index_task"):
+ result = AppAnnotationService.insert_app_annotation_directly(args, app_id)
+
+ # Assert
+ mock_db_session.add.assert_called_once()
+ mock_db_session.commit.assert_called_once()
+
+
+class TestConversationServiceExport:
+ """
+ Test conversation export/retrieval operations.
+
+ Tests retrieving conversation data for export purposes.
+ """
+
+ @patch("services.conversation_service.db.session")
+ def test_get_conversation_success(self, mock_db_session):
+ """Test successful retrieval of conversation."""
+ # Arrange
+ app_model = ConversationServiceTestDataFactory.create_app_mock()
+ user = ConversationServiceTestDataFactory.create_account_mock()
+ conversation = ConversationServiceTestDataFactory.create_conversation_mock(
+ app_id=app_model.id, from_account_id=user.id, from_source="console"
+ )
+
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query
+ mock_query.first.return_value = conversation
+
+ # Act
+ result = ConversationService.get_conversation(app_model=app_model, conversation_id=conversation.id, user=user)
+
+ # Assert
+ assert result == conversation
+
+ @patch("services.conversation_service.db.session")
+ def test_get_conversation_not_found(self, mock_db_session):
+ """Test ConversationNotExistsError when conversation doesn't exist."""
+ # Arrange
+ app_model = ConversationServiceTestDataFactory.create_app_mock()
+ user = ConversationServiceTestDataFactory.create_account_mock()
+
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query
+ mock_query.first.return_value = None
+
+ # Act & Assert
+ with pytest.raises(ConversationNotExistsError):
+ ConversationService.get_conversation(app_model=app_model, conversation_id="non-existent", user=user)
+
+ @patch("services.annotation_service.db.session")
+ @patch("services.annotation_service.current_account_with_tenant")
+ def test_export_annotation_list(self, mock_current_account, mock_db_session):
+ """Test exporting all annotations for an app."""
+ # Arrange
+ app_id = "app-123"
+ account = ConversationServiceTestDataFactory.create_account_mock()
+ tenant_id = "tenant-123"
+ app = ConversationServiceTestDataFactory.create_app_mock(app_id=app_id, tenant_id=tenant_id)
+ annotations = [
+ ConversationServiceTestDataFactory.create_annotation_mock(annotation_id=f"anno-{i}", app_id=app_id)
+ for i in range(10)
+ ]
+
+ mock_current_account.return_value = (account, tenant_id)
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query
+ mock_query.order_by.return_value = mock_query
+ mock_query.first.return_value = app
+ mock_query.all.return_value = annotations
+
+ # Act
+ result = AppAnnotationService.export_annotation_list_by_app_id(app_id)
+
+ # Assert
+ assert len(result) == 10
+ assert result == annotations
+
+ @patch("services.message_service.db.session")
+ def test_get_message_success(self, mock_db_session):
+ """Test successful retrieval of a message."""
+ # Arrange
+ app_model = ConversationServiceTestDataFactory.create_app_mock()
+ user = ConversationServiceTestDataFactory.create_account_mock()
+ message = ConversationServiceTestDataFactory.create_message_mock(
+ app_id=app_model.id, from_account_id=user.id, from_source="console"
+ )
+
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query
+ mock_query.first.return_value = message
+
+ # Act
+ result = MessageService.get_message(app_model=app_model, user=user, message_id=message.id)
+
+ # Assert
+ assert result == message
+
+ @patch("services.message_service.db.session")
+ def test_get_message_not_found(self, mock_db_session):
+ """Test MessageNotExistsError when message doesn't exist."""
+ # Arrange
+ app_model = ConversationServiceTestDataFactory.create_app_mock()
+ user = ConversationServiceTestDataFactory.create_account_mock()
+
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query
+ mock_query.first.return_value = None
+
+ # Act & Assert
+ with pytest.raises(MessageNotExistsError):
+ MessageService.get_message(app_model=app_model, user=user, message_id="non-existent")
+
+ @patch("services.conversation_service.db.session")
+ def test_get_conversation_for_end_user(self, mock_db_session):
+ """
+ Test retrieving conversation created by end user via API.
+
+ End users (API) and accounts (console) have different access patterns.
+ """
+ # Arrange
+ app_model = ConversationServiceTestDataFactory.create_app_mock()
+ end_user = ConversationServiceTestDataFactory.create_end_user_mock()
+
+ # Conversation created by end user via API
+ conversation = ConversationServiceTestDataFactory.create_conversation_mock(
+ app_id=app_model.id,
+ from_end_user_id=end_user.id,
+ from_source="api", # API source for end users
+ )
+
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query
+ mock_query.first.return_value = conversation
+
+ # Act
+ result = ConversationService.get_conversation(
+ app_model=app_model, conversation_id=conversation.id, user=end_user
+ )
+
+ # Assert
+ assert result == conversation
+ # Verify query filters for API source
+ mock_query.where.assert_called()
+
+ @patch("services.conversation_service.delete_conversation_related_data") # Mock Celery task
+ @patch("services.conversation_service.db.session") # Mock database session
+ def test_delete_conversation(self, mock_db_session, mock_delete_task):
+ """
+ Test conversation deletion with async cleanup.
+
+ Deletion is a two-step process:
+ 1. Immediately delete the conversation record from database
+ 2. Trigger async background task to clean up related data
+ (messages, annotations, vector embeddings, file uploads)
+ """
+ # Arrange - Set up test data
+ app_model = ConversationServiceTestDataFactory.create_app_mock()
+ user = ConversationServiceTestDataFactory.create_account_mock()
+ conversation_id = "conv-to-delete"
+
+ # Set up database query mock
+ mock_query = MagicMock()
+ mock_db_session.query.return_value = mock_query
+ mock_query.where.return_value = mock_query # Filter by conversation_id
+
+ # Act - Delete the conversation
+ ConversationService.delete(app_model=app_model, conversation_id=conversation_id, user=user)
+
+ # Assert - Verify two-step deletion process
+ # Step 1: Immediate database deletion
+ mock_query.delete.assert_called_once() # DELETE query executed
+ mock_db_session.commit.assert_called_once() # Transaction committed
+
+ # Step 2: Async cleanup task triggered
+ # The Celery task will handle cleanup of messages, annotations, etc.
+ mock_delete_task.delay.assert_called_once_with(conversation_id)
From 766e16b26f5974d689269c14eab7dc8a0976ece8 Mon Sep 17 00:00:00 2001
From: Satoshi Dev <162055292+0xsatoshi99@users.noreply.github.com>
Date: Wed, 26 Nov 2025 18:36:37 -0800
Subject: [PATCH 016/431] add unit tests for code node (#28717)
---
.../core/workflow/nodes/code/__init__.py | 0
.../workflow/nodes/code/code_node_spec.py | 488 ++++++++++++++++++
.../core/workflow/nodes/code/entities_spec.py | 353 +++++++++++++
3 files changed, 841 insertions(+)
create mode 100644 api/tests/unit_tests/core/workflow/nodes/code/__init__.py
create mode 100644 api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py
create mode 100644 api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py
diff --git a/api/tests/unit_tests/core/workflow/nodes/code/__init__.py b/api/tests/unit_tests/core/workflow/nodes/code/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py
new file mode 100644
index 0000000000..f62c714820
--- /dev/null
+++ b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py
@@ -0,0 +1,488 @@
+from core.helper.code_executor.code_executor import CodeLanguage
+from core.variables.types import SegmentType
+from core.workflow.nodes.code.code_node import CodeNode
+from core.workflow.nodes.code.entities import CodeNodeData
+from core.workflow.nodes.code.exc import (
+ CodeNodeError,
+ DepthLimitError,
+ OutputValidationError,
+)
+
+
+class TestCodeNodeExceptions:
+ """Test suite for code node exceptions."""
+
+ def test_code_node_error_is_value_error(self):
+ """Test CodeNodeError inherits from ValueError."""
+ error = CodeNodeError("test error")
+
+ assert isinstance(error, ValueError)
+ assert str(error) == "test error"
+
+ def test_output_validation_error_is_code_node_error(self):
+ """Test OutputValidationError inherits from CodeNodeError."""
+ error = OutputValidationError("validation failed")
+
+ assert isinstance(error, CodeNodeError)
+ assert isinstance(error, ValueError)
+ assert str(error) == "validation failed"
+
+ def test_depth_limit_error_is_code_node_error(self):
+ """Test DepthLimitError inherits from CodeNodeError."""
+ error = DepthLimitError("depth exceeded")
+
+ assert isinstance(error, CodeNodeError)
+ assert isinstance(error, ValueError)
+ assert str(error) == "depth exceeded"
+
+ def test_code_node_error_with_empty_message(self):
+ """Test CodeNodeError with empty message."""
+ error = CodeNodeError("")
+
+ assert str(error) == ""
+
+ def test_output_validation_error_with_field_info(self):
+ """Test OutputValidationError with field information."""
+ error = OutputValidationError("Output 'result' is not a valid type")
+
+ assert "result" in str(error)
+ assert "not a valid type" in str(error)
+
+ def test_depth_limit_error_with_limit_info(self):
+ """Test DepthLimitError with limit information."""
+ error = DepthLimitError("Depth limit 5 reached, object too deep")
+
+ assert "5" in str(error)
+ assert "too deep" in str(error)
+
+
+class TestCodeNodeClassMethods:
+ """Test suite for CodeNode class methods."""
+
+ def test_code_node_version(self):
+ """Test CodeNode version method."""
+ version = CodeNode.version()
+
+ assert version == "1"
+
+ def test_get_default_config_python3(self):
+ """Test get_default_config for Python3."""
+ config = CodeNode.get_default_config(filters={"code_language": CodeLanguage.PYTHON3})
+
+ assert config is not None
+ assert isinstance(config, dict)
+
+ def test_get_default_config_javascript(self):
+ """Test get_default_config for JavaScript."""
+ config = CodeNode.get_default_config(filters={"code_language": CodeLanguage.JAVASCRIPT})
+
+ assert config is not None
+ assert isinstance(config, dict)
+
+ def test_get_default_config_no_filters(self):
+ """Test get_default_config with no filters defaults to Python3."""
+ config = CodeNode.get_default_config()
+
+ assert config is not None
+ assert isinstance(config, dict)
+
+ def test_get_default_config_empty_filters(self):
+ """Test get_default_config with empty filters."""
+ config = CodeNode.get_default_config(filters={})
+
+ assert config is not None
+
+
+class TestCodeNodeCheckMethods:
+ """Test suite for CodeNode check methods."""
+
+ def test_check_string_none_value(self):
+ """Test _check_string with None value."""
+ node = CodeNode.__new__(CodeNode)
+ result = node._check_string(None, "test_var")
+
+ assert result is None
+
+ def test_check_string_removes_null_bytes(self):
+ """Test _check_string removes null bytes."""
+ node = CodeNode.__new__(CodeNode)
+ result = node._check_string("hello\x00world", "test_var")
+
+ assert result == "helloworld"
+ assert "\x00" not in result
+
+ def test_check_string_valid_string(self):
+ """Test _check_string with valid string."""
+ node = CodeNode.__new__(CodeNode)
+ result = node._check_string("valid string", "test_var")
+
+ assert result == "valid string"
+
+ def test_check_string_empty_string(self):
+ """Test _check_string with empty string."""
+ node = CodeNode.__new__(CodeNode)
+ result = node._check_string("", "test_var")
+
+ assert result == ""
+
+ def test_check_string_with_unicode(self):
+ """Test _check_string with unicode characters."""
+ node = CodeNode.__new__(CodeNode)
+ result = node._check_string("你好世界🌍", "test_var")
+
+ assert result == "你好世界🌍"
+
+ def test_check_boolean_none_value(self):
+ """Test _check_boolean with None value."""
+ node = CodeNode.__new__(CodeNode)
+ result = node._check_boolean(None, "test_var")
+
+ assert result is None
+
+ def test_check_boolean_true_value(self):
+ """Test _check_boolean with True value."""
+ node = CodeNode.__new__(CodeNode)
+ result = node._check_boolean(True, "test_var")
+
+ assert result is True
+
+ def test_check_boolean_false_value(self):
+ """Test _check_boolean with False value."""
+ node = CodeNode.__new__(CodeNode)
+ result = node._check_boolean(False, "test_var")
+
+ assert result is False
+
+ def test_check_number_none_value(self):
+ """Test _check_number with None value."""
+ node = CodeNode.__new__(CodeNode)
+ result = node._check_number(None, "test_var")
+
+ assert result is None
+
+ def test_check_number_integer_value(self):
+ """Test _check_number with integer value."""
+ node = CodeNode.__new__(CodeNode)
+ result = node._check_number(42, "test_var")
+
+ assert result == 42
+
+ def test_check_number_float_value(self):
+ """Test _check_number with float value."""
+ node = CodeNode.__new__(CodeNode)
+ result = node._check_number(3.14, "test_var")
+
+ assert result == 3.14
+
+ def test_check_number_zero(self):
+ """Test _check_number with zero."""
+ node = CodeNode.__new__(CodeNode)
+ result = node._check_number(0, "test_var")
+
+ assert result == 0
+
+ def test_check_number_negative(self):
+ """Test _check_number with negative number."""
+ node = CodeNode.__new__(CodeNode)
+ result = node._check_number(-100, "test_var")
+
+ assert result == -100
+
+ def test_check_number_negative_float(self):
+ """Test _check_number with negative float."""
+ node = CodeNode.__new__(CodeNode)
+ result = node._check_number(-3.14159, "test_var")
+
+ assert result == -3.14159
+
+
+class TestCodeNodeConvertBooleanToInt:
+ """Test suite for _convert_boolean_to_int static method."""
+
+ def test_convert_none_returns_none(self):
+ """Test converting None returns None."""
+ result = CodeNode._convert_boolean_to_int(None)
+
+ assert result is None
+
+ def test_convert_true_returns_one(self):
+ """Test converting True returns 1."""
+ result = CodeNode._convert_boolean_to_int(True)
+
+ assert result == 1
+ assert isinstance(result, int)
+
+ def test_convert_false_returns_zero(self):
+ """Test converting False returns 0."""
+ result = CodeNode._convert_boolean_to_int(False)
+
+ assert result == 0
+ assert isinstance(result, int)
+
+ def test_convert_integer_returns_same(self):
+ """Test converting integer returns same value."""
+ result = CodeNode._convert_boolean_to_int(42)
+
+ assert result == 42
+
+ def test_convert_float_returns_same(self):
+ """Test converting float returns same value."""
+ result = CodeNode._convert_boolean_to_int(3.14)
+
+ assert result == 3.14
+
+ def test_convert_zero_returns_zero(self):
+ """Test converting zero returns zero."""
+ result = CodeNode._convert_boolean_to_int(0)
+
+ assert result == 0
+
+ def test_convert_negative_returns_same(self):
+ """Test converting negative number returns same value."""
+ result = CodeNode._convert_boolean_to_int(-100)
+
+ assert result == -100
+
+
+class TestCodeNodeExtractVariableSelector:
+ """Test suite for _extract_variable_selector_to_variable_mapping."""
+
+ def test_extract_empty_variables(self):
+ """Test extraction with no variables."""
+ node_data = {
+ "title": "Test",
+ "variables": [],
+ "code_language": "python3",
+ "code": "def main(): return {}",
+ "outputs": {},
+ }
+
+ result = CodeNode._extract_variable_selector_to_variable_mapping(
+ graph_config={},
+ node_id="node_1",
+ node_data=node_data,
+ )
+
+ assert result == {}
+
+ def test_extract_single_variable(self):
+ """Test extraction with single variable."""
+ node_data = {
+ "title": "Test",
+ "variables": [
+ {"variable": "input_text", "value_selector": ["start", "text"]},
+ ],
+ "code_language": "python3",
+ "code": "def main(): return {}",
+ "outputs": {},
+ }
+
+ result = CodeNode._extract_variable_selector_to_variable_mapping(
+ graph_config={},
+ node_id="node_1",
+ node_data=node_data,
+ )
+
+ assert "node_1.input_text" in result
+ assert result["node_1.input_text"] == ["start", "text"]
+
+ def test_extract_multiple_variables(self):
+ """Test extraction with multiple variables."""
+ node_data = {
+ "title": "Test",
+ "variables": [
+ {"variable": "var1", "value_selector": ["node_a", "output1"]},
+ {"variable": "var2", "value_selector": ["node_b", "output2"]},
+ {"variable": "var3", "value_selector": ["node_c", "output3"]},
+ ],
+ "code_language": "python3",
+ "code": "def main(): return {}",
+ "outputs": {},
+ }
+
+ result = CodeNode._extract_variable_selector_to_variable_mapping(
+ graph_config={},
+ node_id="code_node",
+ node_data=node_data,
+ )
+
+ assert len(result) == 3
+ assert "code_node.var1" in result
+ assert "code_node.var2" in result
+ assert "code_node.var3" in result
+
+ def test_extract_with_nested_selector(self):
+ """Test extraction with nested value selector."""
+ node_data = {
+ "title": "Test",
+ "variables": [
+ {"variable": "deep_var", "value_selector": ["node", "obj", "nested", "value"]},
+ ],
+ "code_language": "python3",
+ "code": "def main(): return {}",
+ "outputs": {},
+ }
+
+ result = CodeNode._extract_variable_selector_to_variable_mapping(
+ graph_config={},
+ node_id="node_x",
+ node_data=node_data,
+ )
+
+ assert result["node_x.deep_var"] == ["node", "obj", "nested", "value"]
+
+
+class TestCodeNodeDataValidation:
+ """Test suite for CodeNodeData validation scenarios."""
+
+ def test_valid_python3_code_node_data(self):
+ """Test valid Python3 CodeNodeData."""
+ data = CodeNodeData(
+ title="Python Code",
+ variables=[],
+ code_language=CodeLanguage.PYTHON3,
+ code="def main(): return {'result': 1}",
+ outputs={"result": CodeNodeData.Output(type=SegmentType.NUMBER)},
+ )
+
+ assert data.code_language == CodeLanguage.PYTHON3
+
+ def test_valid_javascript_code_node_data(self):
+ """Test valid JavaScript CodeNodeData."""
+ data = CodeNodeData(
+ title="JS Code",
+ variables=[],
+ code_language=CodeLanguage.JAVASCRIPT,
+ code="function main() { return { result: 1 }; }",
+ outputs={"result": CodeNodeData.Output(type=SegmentType.NUMBER)},
+ )
+
+ assert data.code_language == CodeLanguage.JAVASCRIPT
+
+ def test_code_node_data_with_all_output_types(self):
+ """Test CodeNodeData with all valid output types."""
+ data = CodeNodeData(
+ title="All Types",
+ variables=[],
+ code_language=CodeLanguage.PYTHON3,
+ code="def main(): return {}",
+ outputs={
+ "str_out": CodeNodeData.Output(type=SegmentType.STRING),
+ "num_out": CodeNodeData.Output(type=SegmentType.NUMBER),
+ "bool_out": CodeNodeData.Output(type=SegmentType.BOOLEAN),
+ "obj_out": CodeNodeData.Output(type=SegmentType.OBJECT),
+ "arr_str": CodeNodeData.Output(type=SegmentType.ARRAY_STRING),
+ "arr_num": CodeNodeData.Output(type=SegmentType.ARRAY_NUMBER),
+ "arr_bool": CodeNodeData.Output(type=SegmentType.ARRAY_BOOLEAN),
+ "arr_obj": CodeNodeData.Output(type=SegmentType.ARRAY_OBJECT),
+ },
+ )
+
+ assert len(data.outputs) == 8
+
+ def test_code_node_data_complex_nested_output(self):
+ """Test CodeNodeData with complex nested output structure."""
+ data = CodeNodeData(
+ title="Complex Output",
+ variables=[],
+ code_language=CodeLanguage.PYTHON3,
+ code="def main(): return {}",
+ outputs={
+ "response": CodeNodeData.Output(
+ type=SegmentType.OBJECT,
+ children={
+ "data": CodeNodeData.Output(
+ type=SegmentType.OBJECT,
+ children={
+ "items": CodeNodeData.Output(type=SegmentType.ARRAY_STRING),
+ "count": CodeNodeData.Output(type=SegmentType.NUMBER),
+ },
+ ),
+ "status": CodeNodeData.Output(type=SegmentType.STRING),
+ "success": CodeNodeData.Output(type=SegmentType.BOOLEAN),
+ },
+ ),
+ },
+ )
+
+ assert data.outputs["response"].type == SegmentType.OBJECT
+ assert data.outputs["response"].children is not None
+ assert "data" in data.outputs["response"].children
+ assert data.outputs["response"].children["data"].children is not None
+
+
+class TestCodeNodeInitialization:
+ """Test suite for CodeNode initialization methods."""
+
+ def test_init_node_data_python3(self):
+ """Test init_node_data with Python3 configuration."""
+ node = CodeNode.__new__(CodeNode)
+ data = {
+ "title": "Test Node",
+ "variables": [],
+ "code_language": "python3",
+ "code": "def main(): return {'x': 1}",
+ "outputs": {"x": {"type": "number"}},
+ }
+
+ node.init_node_data(data)
+
+ assert node._node_data.title == "Test Node"
+ assert node._node_data.code_language == CodeLanguage.PYTHON3
+
+ def test_init_node_data_javascript(self):
+ """Test init_node_data with JavaScript configuration."""
+ node = CodeNode.__new__(CodeNode)
+ data = {
+ "title": "JS Node",
+ "variables": [],
+ "code_language": "javascript",
+ "code": "function main() { return { x: 1 }; }",
+ "outputs": {"x": {"type": "number"}},
+ }
+
+ node.init_node_data(data)
+
+ assert node._node_data.code_language == CodeLanguage.JAVASCRIPT
+
+ def test_get_title(self):
+ """Test _get_title method."""
+ node = CodeNode.__new__(CodeNode)
+ node._node_data = CodeNodeData(
+ title="My Code Node",
+ variables=[],
+ code_language=CodeLanguage.PYTHON3,
+ code="",
+ outputs={},
+ )
+
+ assert node._get_title() == "My Code Node"
+
+ def test_get_description_none(self):
+ """Test _get_description returns None when not set."""
+ node = CodeNode.__new__(CodeNode)
+ node._node_data = CodeNodeData(
+ title="Test",
+ variables=[],
+ code_language=CodeLanguage.PYTHON3,
+ code="",
+ outputs={},
+ )
+
+ assert node._get_description() is None
+
+ def test_get_base_node_data(self):
+ """Test get_base_node_data returns node data."""
+ node = CodeNode.__new__(CodeNode)
+ node._node_data = CodeNodeData(
+ title="Base Test",
+ variables=[],
+ code_language=CodeLanguage.PYTHON3,
+ code="",
+ outputs={},
+ )
+
+ result = node.get_base_node_data()
+
+ assert result == node._node_data
+ assert result.title == "Base Test"
diff --git a/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py b/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py
new file mode 100644
index 0000000000..d14a6ea69c
--- /dev/null
+++ b/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py
@@ -0,0 +1,353 @@
+import pytest
+from pydantic import ValidationError
+
+from core.helper.code_executor.code_executor import CodeLanguage
+from core.variables.types import SegmentType
+from core.workflow.nodes.code.entities import CodeNodeData
+
+
+class TestCodeNodeDataOutput:
+ """Test suite for CodeNodeData.Output model."""
+
+ def test_output_with_string_type(self):
+ """Test Output with STRING type."""
+ output = CodeNodeData.Output(type=SegmentType.STRING)
+
+ assert output.type == SegmentType.STRING
+ assert output.children is None
+
+ def test_output_with_number_type(self):
+ """Test Output with NUMBER type."""
+ output = CodeNodeData.Output(type=SegmentType.NUMBER)
+
+ assert output.type == SegmentType.NUMBER
+ assert output.children is None
+
+ def test_output_with_boolean_type(self):
+ """Test Output with BOOLEAN type."""
+ output = CodeNodeData.Output(type=SegmentType.BOOLEAN)
+
+ assert output.type == SegmentType.BOOLEAN
+
+ def test_output_with_object_type(self):
+ """Test Output with OBJECT type."""
+ output = CodeNodeData.Output(type=SegmentType.OBJECT)
+
+ assert output.type == SegmentType.OBJECT
+
+ def test_output_with_array_string_type(self):
+ """Test Output with ARRAY_STRING type."""
+ output = CodeNodeData.Output(type=SegmentType.ARRAY_STRING)
+
+ assert output.type == SegmentType.ARRAY_STRING
+
+ def test_output_with_array_number_type(self):
+ """Test Output with ARRAY_NUMBER type."""
+ output = CodeNodeData.Output(type=SegmentType.ARRAY_NUMBER)
+
+ assert output.type == SegmentType.ARRAY_NUMBER
+
+ def test_output_with_array_object_type(self):
+ """Test Output with ARRAY_OBJECT type."""
+ output = CodeNodeData.Output(type=SegmentType.ARRAY_OBJECT)
+
+ assert output.type == SegmentType.ARRAY_OBJECT
+
+ def test_output_with_array_boolean_type(self):
+ """Test Output with ARRAY_BOOLEAN type."""
+ output = CodeNodeData.Output(type=SegmentType.ARRAY_BOOLEAN)
+
+ assert output.type == SegmentType.ARRAY_BOOLEAN
+
+ def test_output_with_nested_children(self):
+ """Test Output with nested children for OBJECT type."""
+ child_output = CodeNodeData.Output(type=SegmentType.STRING)
+ parent_output = CodeNodeData.Output(
+ type=SegmentType.OBJECT,
+ children={"name": child_output},
+ )
+
+ assert parent_output.type == SegmentType.OBJECT
+ assert parent_output.children is not None
+ assert "name" in parent_output.children
+ assert parent_output.children["name"].type == SegmentType.STRING
+
+ def test_output_with_deeply_nested_children(self):
+ """Test Output with deeply nested children."""
+ inner_child = CodeNodeData.Output(type=SegmentType.NUMBER)
+ middle_child = CodeNodeData.Output(
+ type=SegmentType.OBJECT,
+ children={"value": inner_child},
+ )
+ outer_output = CodeNodeData.Output(
+ type=SegmentType.OBJECT,
+ children={"nested": middle_child},
+ )
+
+ assert outer_output.children is not None
+ assert outer_output.children["nested"].children is not None
+ assert outer_output.children["nested"].children["value"].type == SegmentType.NUMBER
+
+ def test_output_with_multiple_children(self):
+ """Test Output with multiple children."""
+ output = CodeNodeData.Output(
+ type=SegmentType.OBJECT,
+ children={
+ "name": CodeNodeData.Output(type=SegmentType.STRING),
+ "age": CodeNodeData.Output(type=SegmentType.NUMBER),
+ "active": CodeNodeData.Output(type=SegmentType.BOOLEAN),
+ },
+ )
+
+ assert output.children is not None
+ assert len(output.children) == 3
+ assert output.children["name"].type == SegmentType.STRING
+ assert output.children["age"].type == SegmentType.NUMBER
+ assert output.children["active"].type == SegmentType.BOOLEAN
+
+ def test_output_rejects_invalid_type(self):
+ """Test Output rejects invalid segment types."""
+ with pytest.raises(ValidationError):
+ CodeNodeData.Output(type=SegmentType.FILE)
+
+ def test_output_rejects_array_file_type(self):
+ """Test Output rejects ARRAY_FILE type."""
+ with pytest.raises(ValidationError):
+ CodeNodeData.Output(type=SegmentType.ARRAY_FILE)
+
+
+class TestCodeNodeDataDependency:
+ """Test suite for CodeNodeData.Dependency model."""
+
+ def test_dependency_basic(self):
+ """Test Dependency with name and version."""
+ dependency = CodeNodeData.Dependency(name="numpy", version="1.24.0")
+
+ assert dependency.name == "numpy"
+ assert dependency.version == "1.24.0"
+
+ def test_dependency_with_complex_version(self):
+ """Test Dependency with complex version string."""
+ dependency = CodeNodeData.Dependency(name="pandas", version=">=2.0.0,<3.0.0")
+
+ assert dependency.name == "pandas"
+ assert dependency.version == ">=2.0.0,<3.0.0"
+
+ def test_dependency_with_empty_version(self):
+ """Test Dependency with empty version."""
+ dependency = CodeNodeData.Dependency(name="requests", version="")
+
+ assert dependency.name == "requests"
+ assert dependency.version == ""
+
+
+class TestCodeNodeData:
+ """Test suite for CodeNodeData model."""
+
+ def test_code_node_data_python3(self):
+ """Test CodeNodeData with Python3 language."""
+ data = CodeNodeData(
+ title="Test Code Node",
+ variables=[],
+ code_language=CodeLanguage.PYTHON3,
+ code="def main(): return {'result': 42}",
+ outputs={"result": CodeNodeData.Output(type=SegmentType.NUMBER)},
+ )
+
+ assert data.title == "Test Code Node"
+ assert data.code_language == CodeLanguage.PYTHON3
+ assert data.code == "def main(): return {'result': 42}"
+ assert "result" in data.outputs
+ assert data.dependencies is None
+
+ def test_code_node_data_javascript(self):
+ """Test CodeNodeData with JavaScript language."""
+ data = CodeNodeData(
+ title="JS Code Node",
+ variables=[],
+ code_language=CodeLanguage.JAVASCRIPT,
+ code="function main() { return { result: 'hello' }; }",
+ outputs={"result": CodeNodeData.Output(type=SegmentType.STRING)},
+ )
+
+ assert data.code_language == CodeLanguage.JAVASCRIPT
+ assert "result" in data.outputs
+ assert data.outputs["result"].type == SegmentType.STRING
+
+ def test_code_node_data_with_dependencies(self):
+ """Test CodeNodeData with dependencies."""
+ data = CodeNodeData(
+ title="Code with Deps",
+ variables=[],
+ code_language=CodeLanguage.PYTHON3,
+ code="import numpy as np\ndef main(): return {'sum': 10}",
+ outputs={"sum": CodeNodeData.Output(type=SegmentType.NUMBER)},
+ dependencies=[
+ CodeNodeData.Dependency(name="numpy", version="1.24.0"),
+ CodeNodeData.Dependency(name="pandas", version="2.0.0"),
+ ],
+ )
+
+ assert data.dependencies is not None
+ assert len(data.dependencies) == 2
+ assert data.dependencies[0].name == "numpy"
+ assert data.dependencies[1].name == "pandas"
+
+ def test_code_node_data_with_multiple_outputs(self):
+ """Test CodeNodeData with multiple outputs."""
+ data = CodeNodeData(
+ title="Multi Output",
+ variables=[],
+ code_language=CodeLanguage.PYTHON3,
+ code="def main(): return {'name': 'test', 'count': 5, 'items': ['a', 'b']}",
+ outputs={
+ "name": CodeNodeData.Output(type=SegmentType.STRING),
+ "count": CodeNodeData.Output(type=SegmentType.NUMBER),
+ "items": CodeNodeData.Output(type=SegmentType.ARRAY_STRING),
+ },
+ )
+
+ assert len(data.outputs) == 3
+ assert data.outputs["name"].type == SegmentType.STRING
+ assert data.outputs["count"].type == SegmentType.NUMBER
+ assert data.outputs["items"].type == SegmentType.ARRAY_STRING
+
+ def test_code_node_data_with_object_output(self):
+ """Test CodeNodeData with nested object output."""
+ data = CodeNodeData(
+ title="Object Output",
+ variables=[],
+ code_language=CodeLanguage.PYTHON3,
+ code="def main(): return {'user': {'name': 'John', 'age': 30}}",
+ outputs={
+ "user": CodeNodeData.Output(
+ type=SegmentType.OBJECT,
+ children={
+ "name": CodeNodeData.Output(type=SegmentType.STRING),
+ "age": CodeNodeData.Output(type=SegmentType.NUMBER),
+ },
+ ),
+ },
+ )
+
+ assert data.outputs["user"].type == SegmentType.OBJECT
+ assert data.outputs["user"].children is not None
+ assert len(data.outputs["user"].children) == 2
+
+ def test_code_node_data_with_array_object_output(self):
+ """Test CodeNodeData with array of objects output."""
+ data = CodeNodeData(
+ title="Array Object Output",
+ variables=[],
+ code_language=CodeLanguage.PYTHON3,
+ code="def main(): return {'users': [{'name': 'A'}, {'name': 'B'}]}",
+ outputs={
+ "users": CodeNodeData.Output(
+ type=SegmentType.ARRAY_OBJECT,
+ children={
+ "name": CodeNodeData.Output(type=SegmentType.STRING),
+ },
+ ),
+ },
+ )
+
+ assert data.outputs["users"].type == SegmentType.ARRAY_OBJECT
+ assert data.outputs["users"].children is not None
+
+ def test_code_node_data_empty_code(self):
+ """Test CodeNodeData with empty code."""
+ data = CodeNodeData(
+ title="Empty Code",
+ variables=[],
+ code_language=CodeLanguage.PYTHON3,
+ code="",
+ outputs={},
+ )
+
+ assert data.code == ""
+ assert len(data.outputs) == 0
+
+ def test_code_node_data_multiline_code(self):
+ """Test CodeNodeData with multiline code."""
+ multiline_code = """
+def main():
+ result = 0
+ for i in range(10):
+ result += i
+ return {'sum': result}
+"""
+ data = CodeNodeData(
+ title="Multiline Code",
+ variables=[],
+ code_language=CodeLanguage.PYTHON3,
+ code=multiline_code,
+ outputs={"sum": CodeNodeData.Output(type=SegmentType.NUMBER)},
+ )
+
+ assert "for i in range(10)" in data.code
+ assert "result += i" in data.code
+
+ def test_code_node_data_with_special_characters_in_code(self):
+ """Test CodeNodeData with special characters in code."""
+ code_with_special = "def main(): return {'msg': 'Hello\\nWorld\\t!'}"
+ data = CodeNodeData(
+ title="Special Chars",
+ variables=[],
+ code_language=CodeLanguage.PYTHON3,
+ code=code_with_special,
+ outputs={"msg": CodeNodeData.Output(type=SegmentType.STRING)},
+ )
+
+ assert "\\n" in data.code
+ assert "\\t" in data.code
+
+ def test_code_node_data_with_unicode_in_code(self):
+ """Test CodeNodeData with unicode characters in code."""
+ unicode_code = "def main(): return {'greeting': '你好世界'}"
+ data = CodeNodeData(
+ title="Unicode Code",
+ variables=[],
+ code_language=CodeLanguage.PYTHON3,
+ code=unicode_code,
+ outputs={"greeting": CodeNodeData.Output(type=SegmentType.STRING)},
+ )
+
+ assert "你好世界" in data.code
+
+ def test_code_node_data_empty_dependencies_list(self):
+ """Test CodeNodeData with empty dependencies list."""
+ data = CodeNodeData(
+ title="No Deps",
+ variables=[],
+ code_language=CodeLanguage.PYTHON3,
+ code="def main(): return {}",
+ outputs={},
+ dependencies=[],
+ )
+
+ assert data.dependencies is not None
+ assert len(data.dependencies) == 0
+
+ def test_code_node_data_with_boolean_array_output(self):
+ """Test CodeNodeData with boolean array output."""
+ data = CodeNodeData(
+ title="Boolean Array",
+ variables=[],
+ code_language=CodeLanguage.PYTHON3,
+ code="def main(): return {'flags': [True, False, True]}",
+ outputs={"flags": CodeNodeData.Output(type=SegmentType.ARRAY_BOOLEAN)},
+ )
+
+ assert data.outputs["flags"].type == SegmentType.ARRAY_BOOLEAN
+
+ def test_code_node_data_with_number_array_output(self):
+ """Test CodeNodeData with number array output."""
+ data = CodeNodeData(
+ title="Number Array",
+ variables=[],
+ code_language=CodeLanguage.PYTHON3,
+ code="def main(): return {'values': [1, 2, 3, 4, 5]}",
+ outputs={"values": CodeNodeData.Output(type=SegmentType.ARRAY_NUMBER)},
+ )
+
+ assert data.outputs["values"].type == SegmentType.ARRAY_NUMBER
From 5815950092b93cecc69b89f0c84f23e5a9604cc6 Mon Sep 17 00:00:00 2001
From: Satoshi Dev <162055292+0xsatoshi99@users.noreply.github.com>
Date: Wed, 26 Nov 2025 18:36:47 -0800
Subject: [PATCH 017/431] add unit tests for iteration node (#28719)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
---
.../core/workflow/nodes/iteration/__init__.py | 0
.../workflow/nodes/iteration/entities_spec.py | 339 +++++++++++++++
.../nodes/iteration/iteration_node_spec.py | 390 ++++++++++++++++++
3 files changed, 729 insertions(+)
create mode 100644 api/tests/unit_tests/core/workflow/nodes/iteration/__init__.py
create mode 100644 api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py
create mode 100644 api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py
diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/__init__.py b/api/tests/unit_tests/core/workflow/nodes/iteration/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py b/api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py
new file mode 100644
index 0000000000..d669cc7465
--- /dev/null
+++ b/api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py
@@ -0,0 +1,339 @@
+from core.workflow.nodes.iteration.entities import (
+ ErrorHandleMode,
+ IterationNodeData,
+ IterationStartNodeData,
+ IterationState,
+)
+
+
+class TestErrorHandleMode:
+ """Test suite for ErrorHandleMode enum."""
+
+ def test_terminated_value(self):
+ """Test TERMINATED enum value."""
+ assert ErrorHandleMode.TERMINATED == "terminated"
+ assert ErrorHandleMode.TERMINATED.value == "terminated"
+
+ def test_continue_on_error_value(self):
+ """Test CONTINUE_ON_ERROR enum value."""
+ assert ErrorHandleMode.CONTINUE_ON_ERROR == "continue-on-error"
+ assert ErrorHandleMode.CONTINUE_ON_ERROR.value == "continue-on-error"
+
+ def test_remove_abnormal_output_value(self):
+ """Test REMOVE_ABNORMAL_OUTPUT enum value."""
+ assert ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT == "remove-abnormal-output"
+ assert ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT.value == "remove-abnormal-output"
+
+ def test_error_handle_mode_is_str_enum(self):
+ """Test ErrorHandleMode is a string enum."""
+ assert isinstance(ErrorHandleMode.TERMINATED, str)
+ assert isinstance(ErrorHandleMode.CONTINUE_ON_ERROR, str)
+ assert isinstance(ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT, str)
+
+ def test_error_handle_mode_comparison(self):
+ """Test ErrorHandleMode can be compared with strings."""
+ assert ErrorHandleMode.TERMINATED == "terminated"
+ assert ErrorHandleMode.CONTINUE_ON_ERROR == "continue-on-error"
+
+ def test_all_error_handle_modes(self):
+ """Test all ErrorHandleMode values are accessible."""
+ modes = list(ErrorHandleMode)
+
+ assert len(modes) == 3
+ assert ErrorHandleMode.TERMINATED in modes
+ assert ErrorHandleMode.CONTINUE_ON_ERROR in modes
+ assert ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT in modes
+
+
+class TestIterationNodeData:
+ """Test suite for IterationNodeData model."""
+
+ def test_iteration_node_data_basic(self):
+ """Test IterationNodeData with basic configuration."""
+ data = IterationNodeData(
+ title="Test Iteration",
+ iterator_selector=["node1", "output"],
+ output_selector=["iteration", "result"],
+ )
+
+ assert data.title == "Test Iteration"
+ assert data.iterator_selector == ["node1", "output"]
+ assert data.output_selector == ["iteration", "result"]
+
+ def test_iteration_node_data_default_values(self):
+ """Test IterationNodeData default values."""
+ data = IterationNodeData(
+ title="Default Test",
+ iterator_selector=["start", "items"],
+ output_selector=["iter", "out"],
+ )
+
+ assert data.parent_loop_id is None
+ assert data.is_parallel is False
+ assert data.parallel_nums == 10
+ assert data.error_handle_mode == ErrorHandleMode.TERMINATED
+ assert data.flatten_output is True
+
+ def test_iteration_node_data_parallel_mode(self):
+ """Test IterationNodeData with parallel mode enabled."""
+ data = IterationNodeData(
+ title="Parallel Iteration",
+ iterator_selector=["node", "list"],
+ output_selector=["iter", "output"],
+ is_parallel=True,
+ parallel_nums=5,
+ )
+
+ assert data.is_parallel is True
+ assert data.parallel_nums == 5
+
+ def test_iteration_node_data_custom_parallel_nums(self):
+ """Test IterationNodeData with custom parallel numbers."""
+ data = IterationNodeData(
+ title="Custom Parallel",
+ iterator_selector=["a", "b"],
+ output_selector=["c", "d"],
+ parallel_nums=20,
+ )
+
+ assert data.parallel_nums == 20
+
+ def test_iteration_node_data_continue_on_error(self):
+ """Test IterationNodeData with continue on error mode."""
+ data = IterationNodeData(
+ title="Continue Error",
+ iterator_selector=["x", "y"],
+ output_selector=["z", "w"],
+ error_handle_mode=ErrorHandleMode.CONTINUE_ON_ERROR,
+ )
+
+ assert data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR
+
+ def test_iteration_node_data_remove_abnormal_output(self):
+ """Test IterationNodeData with remove abnormal output mode."""
+ data = IterationNodeData(
+ title="Remove Abnormal",
+ iterator_selector=["input", "array"],
+ output_selector=["output", "result"],
+ error_handle_mode=ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT,
+ )
+
+ assert data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT
+
+ def test_iteration_node_data_flatten_output_disabled(self):
+ """Test IterationNodeData with flatten output disabled."""
+ data = IterationNodeData(
+ title="No Flatten",
+ iterator_selector=["a"],
+ output_selector=["b"],
+ flatten_output=False,
+ )
+
+ assert data.flatten_output is False
+
+ def test_iteration_node_data_with_parent_loop_id(self):
+ """Test IterationNodeData with parent loop ID."""
+ data = IterationNodeData(
+ title="Nested Loop",
+ iterator_selector=["parent", "items"],
+ output_selector=["child", "output"],
+ parent_loop_id="parent_loop_123",
+ )
+
+ assert data.parent_loop_id == "parent_loop_123"
+
+ def test_iteration_node_data_complex_selectors(self):
+ """Test IterationNodeData with complex selectors."""
+ data = IterationNodeData(
+ title="Complex Selectors",
+ iterator_selector=["node1", "output", "data", "items"],
+ output_selector=["iteration", "result", "value"],
+ )
+
+ assert len(data.iterator_selector) == 4
+ assert len(data.output_selector) == 3
+
+ def test_iteration_node_data_all_options(self):
+ """Test IterationNodeData with all options configured."""
+ data = IterationNodeData(
+ title="Full Config",
+ iterator_selector=["start", "list"],
+ output_selector=["end", "result"],
+ parent_loop_id="outer_loop",
+ is_parallel=True,
+ parallel_nums=15,
+ error_handle_mode=ErrorHandleMode.CONTINUE_ON_ERROR,
+ flatten_output=False,
+ )
+
+ assert data.title == "Full Config"
+ assert data.parent_loop_id == "outer_loop"
+ assert data.is_parallel is True
+ assert data.parallel_nums == 15
+ assert data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR
+ assert data.flatten_output is False
+
+
+class TestIterationStartNodeData:
+ """Test suite for IterationStartNodeData model."""
+
+ def test_iteration_start_node_data_basic(self):
+ """Test IterationStartNodeData basic creation."""
+ data = IterationStartNodeData(title="Iteration Start")
+
+ assert data.title == "Iteration Start"
+
+ def test_iteration_start_node_data_with_description(self):
+ """Test IterationStartNodeData with description."""
+ data = IterationStartNodeData(
+ title="Start Node",
+ desc="This is the start of iteration",
+ )
+
+ assert data.title == "Start Node"
+ assert data.desc == "This is the start of iteration"
+
+
+class TestIterationState:
+ """Test suite for IterationState model."""
+
+ def test_iteration_state_default_values(self):
+ """Test IterationState default values."""
+ state = IterationState()
+
+ assert state.outputs == []
+ assert state.current_output is None
+
+ def test_iteration_state_with_outputs(self):
+ """Test IterationState with outputs."""
+ state = IterationState(outputs=["result1", "result2", "result3"])
+
+ assert len(state.outputs) == 3
+ assert state.outputs[0] == "result1"
+ assert state.outputs[2] == "result3"
+
+ def test_iteration_state_with_current_output(self):
+ """Test IterationState with current output."""
+ state = IterationState(current_output="current_value")
+
+ assert state.current_output == "current_value"
+
+ def test_iteration_state_get_last_output_with_outputs(self):
+ """Test get_last_output with outputs present."""
+ state = IterationState(outputs=["first", "second", "last"])
+
+ result = state.get_last_output()
+
+ assert result == "last"
+
+ def test_iteration_state_get_last_output_empty(self):
+ """Test get_last_output with empty outputs."""
+ state = IterationState(outputs=[])
+
+ result = state.get_last_output()
+
+ assert result is None
+
+ def test_iteration_state_get_last_output_single(self):
+ """Test get_last_output with single output."""
+ state = IterationState(outputs=["only_one"])
+
+ result = state.get_last_output()
+
+ assert result == "only_one"
+
+ def test_iteration_state_get_current_output(self):
+ """Test get_current_output method."""
+ state = IterationState(current_output={"key": "value"})
+
+ result = state.get_current_output()
+
+ assert result == {"key": "value"}
+
+ def test_iteration_state_get_current_output_none(self):
+ """Test get_current_output when None."""
+ state = IterationState()
+
+ result = state.get_current_output()
+
+ assert result is None
+
+ def test_iteration_state_with_complex_outputs(self):
+ """Test IterationState with complex output types."""
+ state = IterationState(
+ outputs=[
+ {"id": 1, "name": "first"},
+ {"id": 2, "name": "second"},
+ [1, 2, 3],
+ "string_output",
+ ]
+ )
+
+ assert len(state.outputs) == 4
+ assert state.outputs[0] == {"id": 1, "name": "first"}
+ assert state.outputs[2] == [1, 2, 3]
+
+ def test_iteration_state_with_none_outputs(self):
+ """Test IterationState with None values in outputs."""
+ state = IterationState(outputs=["value1", None, "value3"])
+
+ assert len(state.outputs) == 3
+ assert state.outputs[1] is None
+
+ def test_iteration_state_get_last_output_with_none(self):
+ """Test get_last_output when last output is None."""
+ state = IterationState(outputs=["first", None])
+
+ result = state.get_last_output()
+
+ assert result is None
+
+ def test_iteration_state_metadata_class(self):
+ """Test IterationState.MetaData class."""
+ metadata = IterationState.MetaData(iterator_length=10)
+
+ assert metadata.iterator_length == 10
+
+ def test_iteration_state_metadata_different_lengths(self):
+ """Test IterationState.MetaData with different lengths."""
+ metadata1 = IterationState.MetaData(iterator_length=0)
+ metadata2 = IterationState.MetaData(iterator_length=100)
+ metadata3 = IterationState.MetaData(iterator_length=1000000)
+
+ assert metadata1.iterator_length == 0
+ assert metadata2.iterator_length == 100
+ assert metadata3.iterator_length == 1000000
+
+ def test_iteration_state_outputs_modification(self):
+ """Test modifying IterationState outputs."""
+ state = IterationState(outputs=[])
+
+ state.outputs.append("new_output")
+ state.outputs.append("another_output")
+
+ assert len(state.outputs) == 2
+ assert state.get_last_output() == "another_output"
+
+ def test_iteration_state_current_output_update(self):
+ """Test updating current_output."""
+ state = IterationState()
+
+ state.current_output = "first_value"
+ assert state.get_current_output() == "first_value"
+
+ state.current_output = "updated_value"
+ assert state.get_current_output() == "updated_value"
+
+ def test_iteration_state_with_numeric_outputs(self):
+ """Test IterationState with numeric outputs."""
+ state = IterationState(outputs=[1, 2, 3, 4, 5])
+
+ assert state.get_last_output() == 5
+ assert len(state.outputs) == 5
+
+ def test_iteration_state_with_boolean_outputs(self):
+ """Test IterationState with boolean outputs."""
+ state = IterationState(outputs=[True, False, True])
+
+ assert state.get_last_output() is True
+ assert state.outputs[1] is False
diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py
new file mode 100644
index 0000000000..51af4367f7
--- /dev/null
+++ b/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py
@@ -0,0 +1,390 @@
+from core.workflow.enums import NodeType
+from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
+from core.workflow.nodes.iteration.exc import (
+ InvalidIteratorValueError,
+ IterationGraphNotFoundError,
+ IterationIndexNotFoundError,
+ IterationNodeError,
+ IteratorVariableNotFoundError,
+ StartNodeIdNotFoundError,
+)
+from core.workflow.nodes.iteration.iteration_node import IterationNode
+
+
+class TestIterationNodeExceptions:
+ """Test suite for iteration node exceptions."""
+
+ def test_iteration_node_error_is_value_error(self):
+ """Test IterationNodeError inherits from ValueError."""
+ error = IterationNodeError("test error")
+
+ assert isinstance(error, ValueError)
+ assert str(error) == "test error"
+
+ def test_iterator_variable_not_found_error(self):
+ """Test IteratorVariableNotFoundError."""
+ error = IteratorVariableNotFoundError("Iterator variable not found")
+
+ assert isinstance(error, IterationNodeError)
+ assert isinstance(error, ValueError)
+ assert "Iterator variable not found" in str(error)
+
+ def test_invalid_iterator_value_error(self):
+ """Test InvalidIteratorValueError."""
+ error = InvalidIteratorValueError("Invalid iterator value")
+
+ assert isinstance(error, IterationNodeError)
+ assert "Invalid iterator value" in str(error)
+
+ def test_start_node_id_not_found_error(self):
+ """Test StartNodeIdNotFoundError."""
+ error = StartNodeIdNotFoundError("Start node ID not found")
+
+ assert isinstance(error, IterationNodeError)
+ assert "Start node ID not found" in str(error)
+
+ def test_iteration_graph_not_found_error(self):
+ """Test IterationGraphNotFoundError."""
+ error = IterationGraphNotFoundError("Iteration graph not found")
+
+ assert isinstance(error, IterationNodeError)
+ assert "Iteration graph not found" in str(error)
+
+ def test_iteration_index_not_found_error(self):
+ """Test IterationIndexNotFoundError."""
+ error = IterationIndexNotFoundError("Iteration index not found")
+
+ assert isinstance(error, IterationNodeError)
+ assert "Iteration index not found" in str(error)
+
+ def test_exception_with_empty_message(self):
+ """Test exception with empty message."""
+ error = IterationNodeError("")
+
+ assert str(error) == ""
+
+ def test_exception_with_detailed_message(self):
+ """Test exception with detailed message."""
+ error = IteratorVariableNotFoundError("Variable 'items' not found in node 'start_node'")
+
+ assert "items" in str(error)
+ assert "start_node" in str(error)
+
+ def test_all_exceptions_inherit_from_base(self):
+ """Test all exceptions inherit from IterationNodeError."""
+ exceptions = [
+ IteratorVariableNotFoundError("test"),
+ InvalidIteratorValueError("test"),
+ StartNodeIdNotFoundError("test"),
+ IterationGraphNotFoundError("test"),
+ IterationIndexNotFoundError("test"),
+ ]
+
+ for exc in exceptions:
+ assert isinstance(exc, IterationNodeError)
+ assert isinstance(exc, ValueError)
+
+
+class TestIterationNodeClassAttributes:
+ """Test suite for IterationNode class attributes."""
+
+ def test_node_type(self):
+ """Test IterationNode node_type attribute."""
+ assert IterationNode.node_type == NodeType.ITERATION
+
+ def test_version(self):
+ """Test IterationNode version method."""
+ version = IterationNode.version()
+
+ assert version == "1"
+
+
+class TestIterationNodeDefaultConfig:
+ """Test suite for IterationNode get_default_config."""
+
+ def test_get_default_config_returns_dict(self):
+ """Test get_default_config returns a dictionary."""
+ config = IterationNode.get_default_config()
+
+ assert isinstance(config, dict)
+
+ def test_get_default_config_type(self):
+ """Test get_default_config includes type."""
+ config = IterationNode.get_default_config()
+
+ assert config.get("type") == "iteration"
+
+ def test_get_default_config_has_config_section(self):
+ """Test get_default_config has config section."""
+ config = IterationNode.get_default_config()
+
+ assert "config" in config
+ assert isinstance(config["config"], dict)
+
+ def test_get_default_config_is_parallel_default(self):
+ """Test get_default_config is_parallel default value."""
+ config = IterationNode.get_default_config()
+
+ assert config["config"]["is_parallel"] is False
+
+ def test_get_default_config_parallel_nums_default(self):
+ """Test get_default_config parallel_nums default value."""
+ config = IterationNode.get_default_config()
+
+ assert config["config"]["parallel_nums"] == 10
+
+ def test_get_default_config_error_handle_mode_default(self):
+ """Test get_default_config error_handle_mode default value."""
+ config = IterationNode.get_default_config()
+
+ assert config["config"]["error_handle_mode"] == ErrorHandleMode.TERMINATED
+
+ def test_get_default_config_flatten_output_default(self):
+ """Test get_default_config flatten_output default value."""
+ config = IterationNode.get_default_config()
+
+ assert config["config"]["flatten_output"] is True
+
+ def test_get_default_config_with_none_filters(self):
+ """Test get_default_config with None filters."""
+ config = IterationNode.get_default_config(filters=None)
+
+ assert config is not None
+ assert "type" in config
+
+ def test_get_default_config_with_empty_filters(self):
+ """Test get_default_config with empty filters."""
+ config = IterationNode.get_default_config(filters={})
+
+ assert config is not None
+
+
+class TestIterationNodeInitialization:
+ """Test suite for IterationNode initialization."""
+
+ def test_init_node_data_basic(self):
+ """Test init_node_data with basic configuration."""
+ node = IterationNode.__new__(IterationNode)
+ data = {
+ "title": "Test Iteration",
+ "iterator_selector": ["start", "items"],
+ "output_selector": ["iteration", "result"],
+ }
+
+ node.init_node_data(data)
+
+ assert node._node_data.title == "Test Iteration"
+ assert node._node_data.iterator_selector == ["start", "items"]
+
+ def test_init_node_data_with_parallel(self):
+ """Test init_node_data with parallel configuration."""
+ node = IterationNode.__new__(IterationNode)
+ data = {
+ "title": "Parallel Iteration",
+ "iterator_selector": ["node", "list"],
+ "output_selector": ["out", "result"],
+ "is_parallel": True,
+ "parallel_nums": 5,
+ }
+
+ node.init_node_data(data)
+
+ assert node._node_data.is_parallel is True
+ assert node._node_data.parallel_nums == 5
+
+ def test_init_node_data_with_error_handle_mode(self):
+ """Test init_node_data with error handle mode."""
+ node = IterationNode.__new__(IterationNode)
+ data = {
+ "title": "Error Handle Test",
+ "iterator_selector": ["a", "b"],
+ "output_selector": ["c", "d"],
+ "error_handle_mode": "continue-on-error",
+ }
+
+ node.init_node_data(data)
+
+ assert node._node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR
+
+ def test_get_title(self):
+ """Test _get_title method."""
+ node = IterationNode.__new__(IterationNode)
+ node._node_data = IterationNodeData(
+ title="My Iteration",
+ iterator_selector=["x"],
+ output_selector=["y"],
+ )
+
+ assert node._get_title() == "My Iteration"
+
+ def test_get_description_none(self):
+ """Test _get_description returns None when not set."""
+ node = IterationNode.__new__(IterationNode)
+ node._node_data = IterationNodeData(
+ title="Test",
+ iterator_selector=["a"],
+ output_selector=["b"],
+ )
+
+ assert node._get_description() is None
+
+ def test_get_description_with_value(self):
+ """Test _get_description with value."""
+ node = IterationNode.__new__(IterationNode)
+ node._node_data = IterationNodeData(
+ title="Test",
+ desc="This is a description",
+ iterator_selector=["a"],
+ output_selector=["b"],
+ )
+
+ assert node._get_description() == "This is a description"
+
+ def test_get_base_node_data(self):
+ """Test get_base_node_data returns node data."""
+ node = IterationNode.__new__(IterationNode)
+ node._node_data = IterationNodeData(
+ title="Base Test",
+ iterator_selector=["x"],
+ output_selector=["y"],
+ )
+
+ result = node.get_base_node_data()
+
+ assert result == node._node_data
+
+
+class TestIterationNodeDataValidation:
+ """Test suite for IterationNodeData validation scenarios."""
+
+ def test_valid_iteration_node_data(self):
+ """Test valid IterationNodeData creation."""
+ data = IterationNodeData(
+ title="Valid Iteration",
+ iterator_selector=["start", "items"],
+ output_selector=["end", "result"],
+ )
+
+ assert data.title == "Valid Iteration"
+
+ def test_iteration_node_data_with_all_error_modes(self):
+ """Test IterationNodeData with all error handle modes."""
+ modes = [
+ ErrorHandleMode.TERMINATED,
+ ErrorHandleMode.CONTINUE_ON_ERROR,
+ ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT,
+ ]
+
+ for mode in modes:
+ data = IterationNodeData(
+ title=f"Test {mode}",
+ iterator_selector=["a"],
+ output_selector=["b"],
+ error_handle_mode=mode,
+ )
+ assert data.error_handle_mode == mode
+
+ def test_iteration_node_data_parallel_configuration(self):
+ """Test IterationNodeData parallel configuration combinations."""
+ configs = [
+ (False, 10),
+ (True, 1),
+ (True, 5),
+ (True, 20),
+ (True, 100),
+ ]
+
+ for is_parallel, parallel_nums in configs:
+ data = IterationNodeData(
+ title="Parallel Test",
+ iterator_selector=["x"],
+ output_selector=["y"],
+ is_parallel=is_parallel,
+ parallel_nums=parallel_nums,
+ )
+ assert data.is_parallel == is_parallel
+ assert data.parallel_nums == parallel_nums
+
+ def test_iteration_node_data_flatten_output_options(self):
+ """Test IterationNodeData flatten_output options."""
+ data_flatten = IterationNodeData(
+ title="Flatten True",
+ iterator_selector=["a"],
+ output_selector=["b"],
+ flatten_output=True,
+ )
+
+ data_no_flatten = IterationNodeData(
+ title="Flatten False",
+ iterator_selector=["a"],
+ output_selector=["b"],
+ flatten_output=False,
+ )
+
+ assert data_flatten.flatten_output is True
+ assert data_no_flatten.flatten_output is False
+
+ def test_iteration_node_data_complex_selectors(self):
+ """Test IterationNodeData with complex selectors."""
+ data = IterationNodeData(
+ title="Complex",
+ iterator_selector=["node1", "output", "data", "items", "list"],
+ output_selector=["iteration", "result", "value", "final"],
+ )
+
+ assert len(data.iterator_selector) == 5
+ assert len(data.output_selector) == 4
+
+ def test_iteration_node_data_single_element_selectors(self):
+ """Test IterationNodeData with single element selectors."""
+ data = IterationNodeData(
+ title="Single",
+ iterator_selector=["items"],
+ output_selector=["result"],
+ )
+
+ assert len(data.iterator_selector) == 1
+ assert len(data.output_selector) == 1
+
+
+class TestIterationNodeErrorStrategies:
+ """Test suite for IterationNode error strategies."""
+
+ def test_get_error_strategy_default(self):
+ """Test _get_error_strategy with default value."""
+ node = IterationNode.__new__(IterationNode)
+ node._node_data = IterationNodeData(
+ title="Test",
+ iterator_selector=["a"],
+ output_selector=["b"],
+ )
+
+ result = node._get_error_strategy()
+
+ assert result is None or result == node._node_data.error_strategy
+
+ def test_get_retry_config(self):
+ """Test _get_retry_config method."""
+ node = IterationNode.__new__(IterationNode)
+ node._node_data = IterationNodeData(
+ title="Test",
+ iterator_selector=["a"],
+ output_selector=["b"],
+ )
+
+ result = node._get_retry_config()
+
+ assert result is not None
+
+ def test_get_default_value_dict(self):
+ """Test _get_default_value_dict method."""
+ node = IterationNode.__new__(IterationNode)
+ node._node_data = IterationNodeData(
+ title="Test",
+ iterator_selector=["a"],
+ output_selector=["b"],
+ )
+
+ result = node._get_default_value_dict()
+
+ assert isinstance(result, dict)
From 01afa5616652e3cdf41029b6a4e95f0742c504d1 Mon Sep 17 00:00:00 2001
From: Gritty_dev <101377478+codomposer@users.noreply.github.com>
Date: Wed, 26 Nov 2025 21:37:24 -0500
Subject: [PATCH 018/431] chore: enhance the test script of current billing
service (#28747)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
---
.../services/test_billing_service.py | 1065 ++++++++++++++++-
1 file changed, 1064 insertions(+), 1 deletion(-)
diff --git a/api/tests/unit_tests/services/test_billing_service.py b/api/tests/unit_tests/services/test_billing_service.py
index dc13143417..915aee3fa7 100644
--- a/api/tests/unit_tests/services/test_billing_service.py
+++ b/api/tests/unit_tests/services/test_billing_service.py
@@ -1,3 +1,18 @@
+"""Comprehensive unit tests for BillingService.
+
+This test module covers all aspects of the billing service including:
+- HTTP request handling with retry logic
+- Subscription tier management and billing information retrieval
+- Usage calculation and credit management (positive/negative deltas)
+- Rate limit enforcement for compliance downloads and education features
+- Account management and permission checks
+- Cache management for billing data
+- Partner integration features
+
+All tests use mocking to avoid external dependencies and ensure fast, reliable execution.
+Tests follow the Arrange-Act-Assert pattern for clarity.
+"""
+
import json
from unittest.mock import MagicMock, patch
@@ -5,11 +20,20 @@ import httpx
import pytest
from werkzeug.exceptions import InternalServerError
+from enums.cloud_plan import CloudPlan
+from models import Account, TenantAccountJoin, TenantAccountRole
from services.billing_service import BillingService
class TestBillingServiceSendRequest:
- """Unit tests for BillingService._send_request method."""
+ """Unit tests for BillingService._send_request method.
+
+ Tests cover:
+ - Successful GET/PUT/POST/DELETE requests
+ - Error handling for various HTTP status codes
+ - Retry logic on network failures
+ - Request header and parameter validation
+ """
@pytest.fixture
def mock_httpx_request(self):
@@ -234,3 +258,1042 @@ class TestBillingServiceSendRequest:
# Should retry multiple times (wait=2, stop_before_delay=10 means ~5 attempts)
assert mock_httpx_request.call_count > 1
+
+
+class TestBillingServiceSubscriptionInfo:
+ """Unit tests for subscription tier and billing info retrieval.
+
+ Tests cover:
+ - Billing information retrieval
+ - Knowledge base rate limits with default and custom values
+ - Payment link generation for subscriptions and model providers
+ - Invoice retrieval
+ """
+
+ @pytest.fixture
+ def mock_send_request(self):
+ """Mock _send_request method."""
+ with patch.object(BillingService, "_send_request") as mock:
+ yield mock
+
+ def test_get_info_success(self, mock_send_request):
+ """Test successful retrieval of billing information."""
+ # Arrange
+ tenant_id = "tenant-123"
+ expected_response = {
+ "subscription_plan": "professional",
+ "billing_cycle": "monthly",
+ "status": "active",
+ }
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.get_info(tenant_id)
+
+ # Assert
+ assert result == expected_response
+ mock_send_request.assert_called_once_with("GET", "/subscription/info", params={"tenant_id": tenant_id})
+
+ def test_get_knowledge_rate_limit_with_defaults(self, mock_send_request):
+ """Test knowledge rate limit retrieval with default values."""
+ # Arrange
+ tenant_id = "tenant-456"
+ mock_send_request.return_value = {}
+
+ # Act
+ result = BillingService.get_knowledge_rate_limit(tenant_id)
+
+ # Assert
+ assert result["limit"] == 10 # Default limit
+ assert result["subscription_plan"] == CloudPlan.SANDBOX # Default plan
+ mock_send_request.assert_called_once_with(
+ "GET", "/subscription/knowledge-rate-limit", params={"tenant_id": tenant_id}
+ )
+
+ def test_get_knowledge_rate_limit_with_custom_values(self, mock_send_request):
+ """Test knowledge rate limit retrieval with custom values."""
+ # Arrange
+ tenant_id = "tenant-789"
+ mock_send_request.return_value = {"limit": 100, "subscription_plan": CloudPlan.PROFESSIONAL}
+
+ # Act
+ result = BillingService.get_knowledge_rate_limit(tenant_id)
+
+ # Assert
+ assert result["limit"] == 100
+ assert result["subscription_plan"] == CloudPlan.PROFESSIONAL
+
+ def test_get_subscription_payment_link(self, mock_send_request):
+ """Test subscription payment link generation."""
+ # Arrange
+ plan = "professional"
+ interval = "monthly"
+ email = "user@example.com"
+ tenant_id = "tenant-123"
+ expected_response = {"payment_link": "https://payment.example.com/checkout"}
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.get_subscription(plan, interval, email, tenant_id)
+
+ # Assert
+ assert result == expected_response
+ mock_send_request.assert_called_once_with(
+ "GET",
+ "/subscription/payment-link",
+ params={"plan": plan, "interval": interval, "prefilled_email": email, "tenant_id": tenant_id},
+ )
+
+ def test_get_model_provider_payment_link(self, mock_send_request):
+ """Test model provider payment link generation."""
+ # Arrange
+ provider_name = "openai"
+ tenant_id = "tenant-123"
+ account_id = "account-456"
+ email = "user@example.com"
+ expected_response = {"payment_link": "https://payment.example.com/provider"}
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.get_model_provider_payment_link(provider_name, tenant_id, account_id, email)
+
+ # Assert
+ assert result == expected_response
+ mock_send_request.assert_called_once_with(
+ "GET",
+ "/model-provider/payment-link",
+ params={
+ "provider_name": provider_name,
+ "tenant_id": tenant_id,
+ "account_id": account_id,
+ "prefilled_email": email,
+ },
+ )
+
+ def test_get_invoices(self, mock_send_request):
+ """Test invoice retrieval."""
+ # Arrange
+ email = "user@example.com"
+ tenant_id = "tenant-123"
+ expected_response = {"invoices": [{"id": "inv-1", "amount": 100}]}
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.get_invoices(email, tenant_id)
+
+ # Assert
+ assert result == expected_response
+ mock_send_request.assert_called_once_with(
+ "GET", "/invoices", params={"prefilled_email": email, "tenant_id": tenant_id}
+ )
+
+
+class TestBillingServiceUsageCalculation:
+ """Unit tests for usage calculation and credit management.
+
+ Tests cover:
+ - Feature plan usage information retrieval
+ - Credit addition (positive delta)
+ - Credit consumption (negative delta)
+ - Usage refunds
+ - Specific feature usage queries
+ """
+
+ @pytest.fixture
+ def mock_send_request(self):
+ """Mock _send_request method."""
+ with patch.object(BillingService, "_send_request") as mock:
+ yield mock
+
+ def test_get_tenant_feature_plan_usage_info(self, mock_send_request):
+ """Test retrieval of tenant feature plan usage information."""
+ # Arrange
+ tenant_id = "tenant-123"
+ expected_response = {"features": {"trigger": {"used": 50, "limit": 100}, "workflow": {"used": 20, "limit": 50}}}
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.get_tenant_feature_plan_usage_info(tenant_id)
+
+ # Assert
+ assert result == expected_response
+ mock_send_request.assert_called_once_with("GET", "/tenant-feature-usage/info", params={"tenant_id": tenant_id})
+
+ def test_update_tenant_feature_plan_usage_positive_delta(self, mock_send_request):
+ """Test updating tenant feature usage with positive delta (adding credits)."""
+ # Arrange
+ tenant_id = "tenant-123"
+ feature_key = "trigger"
+ delta = 10
+ expected_response = {"result": "success", "history_id": "hist-uuid-123"}
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.update_tenant_feature_plan_usage(tenant_id, feature_key, delta)
+
+ # Assert
+ assert result == expected_response
+ assert result["result"] == "success"
+ assert "history_id" in result
+ mock_send_request.assert_called_once_with(
+ "POST",
+ "/tenant-feature-usage/usage",
+ params={"tenant_id": tenant_id, "feature_key": feature_key, "delta": delta},
+ )
+
+ def test_update_tenant_feature_plan_usage_negative_delta(self, mock_send_request):
+ """Test updating tenant feature usage with negative delta (consuming credits)."""
+ # Arrange
+ tenant_id = "tenant-456"
+ feature_key = "workflow"
+ delta = -5
+ expected_response = {"result": "success", "history_id": "hist-uuid-456"}
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.update_tenant_feature_plan_usage(tenant_id, feature_key, delta)
+
+ # Assert
+ assert result == expected_response
+ mock_send_request.assert_called_once_with(
+ "POST",
+ "/tenant-feature-usage/usage",
+ params={"tenant_id": tenant_id, "feature_key": feature_key, "delta": delta},
+ )
+
+ def test_refund_tenant_feature_plan_usage(self, mock_send_request):
+ """Test refunding a previous usage charge."""
+ # Arrange
+ history_id = "hist-uuid-789"
+ expected_response = {"result": "success", "history_id": history_id}
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.refund_tenant_feature_plan_usage(history_id)
+
+ # Assert
+ assert result == expected_response
+ assert result["result"] == "success"
+ mock_send_request.assert_called_once_with(
+ "POST", "/tenant-feature-usage/refund", params={"quota_usage_history_id": history_id}
+ )
+
+ def test_get_tenant_feature_plan_usage(self, mock_send_request):
+ """Test getting specific feature usage for a tenant."""
+ # Arrange
+ tenant_id = "tenant-123"
+ feature_key = "trigger"
+ expected_response = {"used": 75, "limit": 100, "remaining": 25}
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.get_tenant_feature_plan_usage(tenant_id, feature_key)
+
+ # Assert
+ assert result == expected_response
+ mock_send_request.assert_called_once_with(
+ "GET", "/billing/tenant_feature_plan/usage", params={"tenant_id": tenant_id, "feature_key": feature_key}
+ )
+
+
+class TestBillingServiceRateLimitEnforcement:
+ """Unit tests for rate limit enforcement mechanisms.
+
+ Tests cover:
+ - Compliance download rate limiting (4 requests per 60 seconds)
+ - Education verification rate limiting (10 requests per 60 seconds)
+ - Education activation rate limiting (10 requests per 60 seconds)
+ - Rate limit increment after successful operations
+ - Proper exception raising when limits are exceeded
+ """
+
+ @pytest.fixture
+ def mock_send_request(self):
+ """Mock _send_request method."""
+ with patch.object(BillingService, "_send_request") as mock:
+ yield mock
+
+ def test_compliance_download_rate_limiter_not_limited(self, mock_send_request):
+ """Test compliance download when rate limit is not exceeded."""
+ # Arrange
+ doc_name = "compliance_report.pdf"
+ account_id = "account-123"
+ tenant_id = "tenant-456"
+ ip = "192.168.1.1"
+ device_info = "Mozilla/5.0"
+ expected_response = {"download_link": "https://example.com/download"}
+
+ # Mock the rate limiter to return False (not limited)
+ with (
+ patch.object(
+ BillingService.compliance_download_rate_limiter, "is_rate_limited", return_value=False
+ ) as mock_is_limited,
+ patch.object(BillingService.compliance_download_rate_limiter, "increment_rate_limit") as mock_increment,
+ ):
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.get_compliance_download_link(doc_name, account_id, tenant_id, ip, device_info)
+
+ # Assert
+ assert result == expected_response
+ mock_is_limited.assert_called_once_with(f"{account_id}:{tenant_id}")
+ mock_send_request.assert_called_once_with(
+ "POST",
+ "/compliance/download",
+ json={
+ "doc_name": doc_name,
+ "account_id": account_id,
+ "tenant_id": tenant_id,
+ "ip_address": ip,
+ "device_info": device_info,
+ },
+ )
+ # Verify rate limit was incremented after successful download
+ mock_increment.assert_called_once_with(f"{account_id}:{tenant_id}")
+
+ def test_compliance_download_rate_limiter_exceeded(self, mock_send_request):
+ """Test compliance download when rate limit is exceeded."""
+ # Arrange
+ doc_name = "compliance_report.pdf"
+ account_id = "account-123"
+ tenant_id = "tenant-456"
+ ip = "192.168.1.1"
+ device_info = "Mozilla/5.0"
+
+ # Import the error class to properly catch it
+ from controllers.console.error import ComplianceRateLimitError
+
+ # Mock the rate limiter to return True (rate limited)
+ with patch.object(
+ BillingService.compliance_download_rate_limiter, "is_rate_limited", return_value=True
+ ) as mock_is_limited:
+ # Act & Assert
+ with pytest.raises(ComplianceRateLimitError):
+ BillingService.get_compliance_download_link(doc_name, account_id, tenant_id, ip, device_info)
+
+ mock_is_limited.assert_called_once_with(f"{account_id}:{tenant_id}")
+ mock_send_request.assert_not_called()
+
+ def test_education_verify_rate_limit_not_exceeded(self, mock_send_request):
+ """Test education verification when rate limit is not exceeded."""
+ # Arrange
+ account_id = "account-123"
+ account_email = "student@university.edu"
+ expected_response = {"verified": True, "institution": "University"}
+
+ # Mock the rate limiter to return False (not limited)
+ with (
+ patch.object(
+ BillingService.EducationIdentity.verification_rate_limit, "is_rate_limited", return_value=False
+ ) as mock_is_limited,
+ patch.object(
+ BillingService.EducationIdentity.verification_rate_limit, "increment_rate_limit"
+ ) as mock_increment,
+ ):
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.EducationIdentity.verify(account_id, account_email)
+
+ # Assert
+ assert result == expected_response
+ mock_is_limited.assert_called_once_with(account_email)
+ mock_send_request.assert_called_once_with("GET", "/education/verify", params={"account_id": account_id})
+ mock_increment.assert_called_once_with(account_email)
+
+ def test_education_verify_rate_limit_exceeded(self, mock_send_request):
+ """Test education verification when rate limit is exceeded."""
+ # Arrange
+ account_id = "account-123"
+ account_email = "student@university.edu"
+
+ # Import the error class to properly catch it
+ from controllers.console.error import EducationVerifyLimitError
+
+ # Mock the rate limiter to return True (rate limited)
+ with patch.object(
+ BillingService.EducationIdentity.verification_rate_limit, "is_rate_limited", return_value=True
+ ) as mock_is_limited:
+ # Act & Assert
+ with pytest.raises(EducationVerifyLimitError):
+ BillingService.EducationIdentity.verify(account_id, account_email)
+
+ mock_is_limited.assert_called_once_with(account_email)
+ mock_send_request.assert_not_called()
+
+ def test_education_activate_rate_limit_not_exceeded(self, mock_send_request):
+ """Test education activation when rate limit is not exceeded."""
+ # Arrange
+ account = MagicMock(spec=Account)
+ account.id = "account-123"
+ account.email = "student@university.edu"
+ account.current_tenant_id = "tenant-456"
+ token = "verification-token"
+ institution = "MIT"
+ role = "student"
+ expected_response = {"result": "success", "activated": True}
+
+ # Mock the rate limiter to return False (not limited)
+ with (
+ patch.object(
+ BillingService.EducationIdentity.activation_rate_limit, "is_rate_limited", return_value=False
+ ) as mock_is_limited,
+ patch.object(
+ BillingService.EducationIdentity.activation_rate_limit, "increment_rate_limit"
+ ) as mock_increment,
+ ):
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.EducationIdentity.activate(account, token, institution, role)
+
+ # Assert
+ assert result == expected_response
+ mock_is_limited.assert_called_once_with(account.email)
+ mock_send_request.assert_called_once_with(
+ "POST",
+ "/education/",
+ json={"institution": institution, "token": token, "role": role},
+ params={"account_id": account.id, "curr_tenant_id": account.current_tenant_id},
+ )
+ mock_increment.assert_called_once_with(account.email)
+
+ def test_education_activate_rate_limit_exceeded(self, mock_send_request):
+ """Test education activation when rate limit is exceeded."""
+ # Arrange
+ account = MagicMock(spec=Account)
+ account.id = "account-123"
+ account.email = "student@university.edu"
+ account.current_tenant_id = "tenant-456"
+ token = "verification-token"
+ institution = "MIT"
+ role = "student"
+
+ # Import the error class to properly catch it
+ from controllers.console.error import EducationActivateLimitError
+
+ # Mock the rate limiter to return True (rate limited)
+ with patch.object(
+ BillingService.EducationIdentity.activation_rate_limit, "is_rate_limited", return_value=True
+ ) as mock_is_limited:
+ # Act & Assert
+ with pytest.raises(EducationActivateLimitError):
+ BillingService.EducationIdentity.activate(account, token, institution, role)
+
+ mock_is_limited.assert_called_once_with(account.email)
+ mock_send_request.assert_not_called()
+
+
+class TestBillingServiceEducationIdentity:
+ """Unit tests for education identity verification and management.
+
+ Tests cover:
+ - Education verification status checking
+ - Institution autocomplete with pagination
+ - Default parameter handling
+ """
+
+ @pytest.fixture
+ def mock_send_request(self):
+ """Mock _send_request method."""
+ with patch.object(BillingService, "_send_request") as mock:
+ yield mock
+
+ def test_education_status(self, mock_send_request):
+ """Test checking education verification status."""
+ # Arrange
+ account_id = "account-123"
+ expected_response = {"verified": True, "institution": "MIT", "role": "student"}
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.EducationIdentity.status(account_id)
+
+ # Assert
+ assert result == expected_response
+ mock_send_request.assert_called_once_with("GET", "/education/status", params={"account_id": account_id})
+
+ def test_education_autocomplete(self, mock_send_request):
+ """Test education institution autocomplete."""
+ # Arrange
+ keywords = "Massachusetts"
+ page = 0
+ limit = 20
+ expected_response = {
+ "institutions": [
+ {"name": "Massachusetts Institute of Technology", "domain": "mit.edu"},
+ {"name": "University of Massachusetts", "domain": "umass.edu"},
+ ]
+ }
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.EducationIdentity.autocomplete(keywords, page, limit)
+
+ # Assert
+ assert result == expected_response
+ mock_send_request.assert_called_once_with(
+ "GET", "/education/autocomplete", params={"keywords": keywords, "page": page, "limit": limit}
+ )
+
+ def test_education_autocomplete_with_defaults(self, mock_send_request):
+ """Test education institution autocomplete with default parameters."""
+ # Arrange
+ keywords = "Stanford"
+ expected_response = {"institutions": [{"name": "Stanford University", "domain": "stanford.edu"}]}
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.EducationIdentity.autocomplete(keywords)
+
+ # Assert
+ assert result == expected_response
+ mock_send_request.assert_called_once_with(
+ "GET", "/education/autocomplete", params={"keywords": keywords, "page": 0, "limit": 20}
+ )
+
+
+class TestBillingServiceAccountManagement:
+ """Unit tests for account-related billing operations.
+
+ Tests cover:
+ - Account deletion
+ - Email freeze status checking
+ - Account deletion feedback submission
+ - Tenant owner/admin permission validation
+ - Error handling for missing tenant joins
+ """
+
+ @pytest.fixture
+ def mock_send_request(self):
+ """Mock _send_request method."""
+ with patch.object(BillingService, "_send_request") as mock:
+ yield mock
+
+ @pytest.fixture
+ def mock_db_session(self):
+ """Mock database session."""
+ with patch("services.billing_service.db.session") as mock_session:
+ yield mock_session
+
+ def test_delete_account(self, mock_send_request):
+ """Test account deletion."""
+ # Arrange
+ account_id = "account-123"
+ expected_response = {"result": "success", "deleted": True}
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.delete_account(account_id)
+
+ # Assert
+ assert result == expected_response
+ mock_send_request.assert_called_once_with("DELETE", "/account/", params={"account_id": account_id})
+
+ def test_is_email_in_freeze_true(self, mock_send_request):
+ """Test checking if email is frozen (returns True)."""
+ # Arrange
+ email = "frozen@example.com"
+ mock_send_request.return_value = {"data": True}
+
+ # Act
+ result = BillingService.is_email_in_freeze(email)
+
+ # Assert
+ assert result is True
+ mock_send_request.assert_called_once_with("GET", "/account/in-freeze", params={"email": email})
+
+ def test_is_email_in_freeze_false(self, mock_send_request):
+ """Test checking if email is frozen (returns False)."""
+ # Arrange
+ email = "active@example.com"
+ mock_send_request.return_value = {"data": False}
+
+ # Act
+ result = BillingService.is_email_in_freeze(email)
+
+ # Assert
+ assert result is False
+ mock_send_request.assert_called_once_with("GET", "/account/in-freeze", params={"email": email})
+
+ def test_is_email_in_freeze_exception_returns_false(self, mock_send_request):
+ """Test that is_email_in_freeze returns False on exception."""
+ # Arrange
+ email = "error@example.com"
+ mock_send_request.side_effect = Exception("Network error")
+
+ # Act
+ result = BillingService.is_email_in_freeze(email)
+
+ # Assert
+ assert result is False
+
+ def test_update_account_deletion_feedback(self, mock_send_request):
+ """Test updating account deletion feedback."""
+ # Arrange
+ email = "user@example.com"
+ feedback = "Service was too expensive"
+ expected_response = {"result": "success"}
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.update_account_deletion_feedback(email, feedback)
+
+ # Assert
+ assert result == expected_response
+ mock_send_request.assert_called_once_with(
+ "POST", "/account/delete-feedback", json={"email": email, "feedback": feedback}
+ )
+
+ def test_is_tenant_owner_or_admin_owner(self, mock_db_session):
+ """Test tenant owner/admin check for owner role."""
+ # Arrange
+ current_user = MagicMock(spec=Account)
+ current_user.id = "account-123"
+ current_user.current_tenant_id = "tenant-456"
+
+ mock_join = MagicMock(spec=TenantAccountJoin)
+ mock_join.role = TenantAccountRole.OWNER
+
+ mock_query = MagicMock()
+ mock_query.where.return_value.first.return_value = mock_join
+ mock_db_session.query.return_value = mock_query
+
+ # Act - should not raise exception
+ BillingService.is_tenant_owner_or_admin(current_user)
+
+ # Assert
+ mock_db_session.query.assert_called_once()
+
+ def test_is_tenant_owner_or_admin_admin(self, mock_db_session):
+ """Test tenant owner/admin check for admin role."""
+ # Arrange
+ current_user = MagicMock(spec=Account)
+ current_user.id = "account-123"
+ current_user.current_tenant_id = "tenant-456"
+
+ mock_join = MagicMock(spec=TenantAccountJoin)
+ mock_join.role = TenantAccountRole.ADMIN
+
+ mock_query = MagicMock()
+ mock_query.where.return_value.first.return_value = mock_join
+ mock_db_session.query.return_value = mock_query
+
+ # Act - should not raise exception
+ BillingService.is_tenant_owner_or_admin(current_user)
+
+ # Assert
+ mock_db_session.query.assert_called_once()
+
+ def test_is_tenant_owner_or_admin_normal_user_raises_error(self, mock_db_session):
+ """Test tenant owner/admin check raises error for normal user."""
+ # Arrange
+ current_user = MagicMock(spec=Account)
+ current_user.id = "account-123"
+ current_user.current_tenant_id = "tenant-456"
+
+ mock_join = MagicMock(spec=TenantAccountJoin)
+ mock_join.role = TenantAccountRole.NORMAL
+
+ mock_query = MagicMock()
+ mock_query.where.return_value.first.return_value = mock_join
+ mock_db_session.query.return_value = mock_query
+
+ # Act & Assert
+ with pytest.raises(ValueError) as exc_info:
+ BillingService.is_tenant_owner_or_admin(current_user)
+ assert "Only team owner or team admin can perform this action" in str(exc_info.value)
+
+ def test_is_tenant_owner_or_admin_no_join_raises_error(self, mock_db_session):
+ """Test tenant owner/admin check raises error when join not found."""
+ # Arrange
+ current_user = MagicMock(spec=Account)
+ current_user.id = "account-123"
+ current_user.current_tenant_id = "tenant-456"
+
+ mock_query = MagicMock()
+ mock_query.where.return_value.first.return_value = None
+ mock_db_session.query.return_value = mock_query
+
+ # Act & Assert
+ with pytest.raises(ValueError) as exc_info:
+ BillingService.is_tenant_owner_or_admin(current_user)
+ assert "Tenant account join not found" in str(exc_info.value)
+
+
+class TestBillingServiceCacheManagement:
+ """Unit tests for billing cache management.
+
+ Tests cover:
+ - Billing info cache invalidation
+ - Proper Redis key formatting
+ """
+
+ @pytest.fixture
+ def mock_redis_client(self):
+ """Mock Redis client."""
+ with patch("services.billing_service.redis_client") as mock_redis:
+ yield mock_redis
+
+ def test_clean_billing_info_cache(self, mock_redis_client):
+ """Test cleaning billing info cache."""
+ # Arrange
+ tenant_id = "tenant-123"
+ expected_key = f"tenant:{tenant_id}:billing_info"
+
+ # Act
+ BillingService.clean_billing_info_cache(tenant_id)
+
+ # Assert
+ mock_redis_client.delete.assert_called_once_with(expected_key)
+
+
+class TestBillingServicePartnerIntegration:
+ """Unit tests for partner integration features.
+
+ Tests cover:
+ - Partner tenant binding synchronization
+ - Click ID tracking
+ """
+
+ @pytest.fixture
+ def mock_send_request(self):
+ """Mock _send_request method."""
+ with patch.object(BillingService, "_send_request") as mock:
+ yield mock
+
+ def test_sync_partner_tenants_bindings(self, mock_send_request):
+ """Test syncing partner tenant bindings."""
+ # Arrange
+ account_id = "account-123"
+ partner_key = "partner-xyz"
+ click_id = "click-789"
+ expected_response = {"result": "success", "synced": True}
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.sync_partner_tenants_bindings(account_id, partner_key, click_id)
+
+ # Assert
+ assert result == expected_response
+ mock_send_request.assert_called_once_with(
+ "PUT", f"/partners/{partner_key}/tenants", json={"account_id": account_id, "click_id": click_id}
+ )
+
+
+class TestBillingServiceEdgeCases:
+ """Unit tests for edge cases and error scenarios.
+
+ Tests cover:
+ - Empty responses from billing API
+ - Malformed JSON responses
+ - Boundary conditions for rate limits
+ - Multiple subscription tiers
+ - Zero and negative usage deltas
+ """
+
+ @pytest.fixture
+ def mock_send_request(self):
+ """Mock _send_request method."""
+ with patch.object(BillingService, "_send_request") as mock:
+ yield mock
+
+ def test_get_info_empty_response(self, mock_send_request):
+ """Test handling of empty billing info response."""
+ # Arrange
+ tenant_id = "tenant-empty"
+ mock_send_request.return_value = {}
+
+ # Act
+ result = BillingService.get_info(tenant_id)
+
+ # Assert
+ assert result == {}
+ mock_send_request.assert_called_once()
+
+ def test_update_tenant_feature_plan_usage_zero_delta(self, mock_send_request):
+ """Test updating tenant feature usage with zero delta (no change)."""
+ # Arrange
+ tenant_id = "tenant-123"
+ feature_key = "trigger"
+ delta = 0 # No change
+ expected_response = {"result": "success", "history_id": "hist-uuid-zero"}
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.update_tenant_feature_plan_usage(tenant_id, feature_key, delta)
+
+ # Assert
+ assert result == expected_response
+ mock_send_request.assert_called_once_with(
+ "POST",
+ "/tenant-feature-usage/usage",
+ params={"tenant_id": tenant_id, "feature_key": feature_key, "delta": delta},
+ )
+
+ def test_update_tenant_feature_plan_usage_large_negative_delta(self, mock_send_request):
+ """Test updating tenant feature usage with large negative delta."""
+ # Arrange
+ tenant_id = "tenant-456"
+ feature_key = "workflow"
+ delta = -1000 # Large consumption
+ expected_response = {"result": "success", "history_id": "hist-uuid-large"}
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.update_tenant_feature_plan_usage(tenant_id, feature_key, delta)
+
+ # Assert
+ assert result == expected_response
+ mock_send_request.assert_called_once()
+
+ def test_get_knowledge_rate_limit_all_subscription_tiers(self, mock_send_request):
+ """Test knowledge rate limit for all subscription tiers."""
+ # Test SANDBOX tier
+ mock_send_request.return_value = {"limit": 10, "subscription_plan": CloudPlan.SANDBOX}
+ result = BillingService.get_knowledge_rate_limit("tenant-sandbox")
+ assert result["subscription_plan"] == CloudPlan.SANDBOX
+ assert result["limit"] == 10
+
+ # Test PROFESSIONAL tier
+ mock_send_request.return_value = {"limit": 100, "subscription_plan": CloudPlan.PROFESSIONAL}
+ result = BillingService.get_knowledge_rate_limit("tenant-pro")
+ assert result["subscription_plan"] == CloudPlan.PROFESSIONAL
+ assert result["limit"] == 100
+
+ # Test TEAM tier
+ mock_send_request.return_value = {"limit": 500, "subscription_plan": CloudPlan.TEAM}
+ result = BillingService.get_knowledge_rate_limit("tenant-team")
+ assert result["subscription_plan"] == CloudPlan.TEAM
+ assert result["limit"] == 500
+
+ def test_get_subscription_with_empty_optional_params(self, mock_send_request):
+ """Test subscription payment link with empty optional parameters."""
+ # Arrange
+ plan = "professional"
+ interval = "yearly"
+ expected_response = {"payment_link": "https://payment.example.com/checkout"}
+ mock_send_request.return_value = expected_response
+
+ # Act - empty email and tenant_id
+ result = BillingService.get_subscription(plan, interval, "", "")
+
+ # Assert
+ assert result == expected_response
+ mock_send_request.assert_called_once_with(
+ "GET",
+ "/subscription/payment-link",
+ params={"plan": plan, "interval": interval, "prefilled_email": "", "tenant_id": ""},
+ )
+
+ def test_get_invoices_with_empty_params(self, mock_send_request):
+ """Test invoice retrieval with empty parameters."""
+ # Arrange
+ expected_response = {"invoices": []}
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.get_invoices("", "")
+
+ # Assert
+ assert result == expected_response
+ assert result["invoices"] == []
+
+ def test_refund_with_invalid_history_id_format(self, mock_send_request):
+ """Test refund with various history ID formats."""
+ # Arrange - test with different ID formats
+ test_ids = ["hist-123", "uuid-abc-def", "12345", ""]
+
+ for history_id in test_ids:
+ expected_response = {"result": "success", "history_id": history_id}
+ mock_send_request.return_value = expected_response
+
+ # Act
+ result = BillingService.refund_tenant_feature_plan_usage(history_id)
+
+ # Assert
+ assert result["history_id"] == history_id
+
+ def test_is_tenant_owner_or_admin_editor_role_raises_error(self):
+ """Test tenant owner/admin check raises error for editor role."""
+ # Arrange
+ current_user = MagicMock(spec=Account)
+ current_user.id = "account-123"
+ current_user.current_tenant_id = "tenant-456"
+
+ mock_join = MagicMock(spec=TenantAccountJoin)
+ mock_join.role = TenantAccountRole.EDITOR # Editor is not privileged
+
+ with patch("services.billing_service.db.session") as mock_session:
+ mock_query = MagicMock()
+ mock_query.where.return_value.first.return_value = mock_join
+ mock_session.query.return_value = mock_query
+
+ # Act & Assert
+ with pytest.raises(ValueError) as exc_info:
+ BillingService.is_tenant_owner_or_admin(current_user)
+ assert "Only team owner or team admin can perform this action" in str(exc_info.value)
+
+ def test_is_tenant_owner_or_admin_dataset_operator_raises_error(self):
+ """Test tenant owner/admin check raises error for dataset operator role."""
+ # Arrange
+ current_user = MagicMock(spec=Account)
+ current_user.id = "account-123"
+ current_user.current_tenant_id = "tenant-456"
+
+ mock_join = MagicMock(spec=TenantAccountJoin)
+ mock_join.role = TenantAccountRole.DATASET_OPERATOR # Dataset operator is not privileged
+
+ with patch("services.billing_service.db.session") as mock_session:
+ mock_query = MagicMock()
+ mock_query.where.return_value.first.return_value = mock_join
+ mock_session.query.return_value = mock_query
+
+ # Act & Assert
+ with pytest.raises(ValueError) as exc_info:
+ BillingService.is_tenant_owner_or_admin(current_user)
+ assert "Only team owner or team admin can perform this action" in str(exc_info.value)
+
+
+class TestBillingServiceIntegrationScenarios:
+ """Integration-style tests simulating real-world usage scenarios.
+
+ These tests combine multiple service methods to test common workflows:
+ - Complete subscription upgrade flow
+ - Usage tracking and refund workflow
+ - Rate limit boundary testing
+ """
+
+ @pytest.fixture
+ def mock_send_request(self):
+ """Mock _send_request method."""
+ with patch.object(BillingService, "_send_request") as mock:
+ yield mock
+
+ def test_subscription_upgrade_workflow(self, mock_send_request):
+ """Test complete subscription upgrade workflow."""
+ # Arrange
+ tenant_id = "tenant-upgrade"
+
+ # Step 1: Get current billing info
+ mock_send_request.return_value = {
+ "subscription_plan": "sandbox",
+ "billing_cycle": "monthly",
+ "status": "active",
+ }
+ current_info = BillingService.get_info(tenant_id)
+ assert current_info["subscription_plan"] == "sandbox"
+
+ # Step 2: Get payment link for upgrade
+ mock_send_request.return_value = {"payment_link": "https://payment.example.com/upgrade"}
+ payment_link = BillingService.get_subscription("professional", "monthly", "user@example.com", tenant_id)
+ assert "payment_link" in payment_link
+
+ # Step 3: Verify new rate limits after upgrade
+ mock_send_request.return_value = {"limit": 100, "subscription_plan": CloudPlan.PROFESSIONAL}
+ rate_limit = BillingService.get_knowledge_rate_limit(tenant_id)
+ assert rate_limit["subscription_plan"] == CloudPlan.PROFESSIONAL
+ assert rate_limit["limit"] == 100
+
+ def test_usage_tracking_and_refund_workflow(self, mock_send_request):
+ """Test usage tracking with subsequent refund."""
+ # Arrange
+ tenant_id = "tenant-usage"
+ feature_key = "workflow"
+
+ # Step 1: Consume credits
+ mock_send_request.return_value = {"result": "success", "history_id": "hist-consume-123"}
+ consume_result = BillingService.update_tenant_feature_plan_usage(tenant_id, feature_key, -10)
+ history_id = consume_result["history_id"]
+ assert history_id == "hist-consume-123"
+
+ # Step 2: Check current usage
+ mock_send_request.return_value = {"used": 10, "limit": 100, "remaining": 90}
+ usage = BillingService.get_tenant_feature_plan_usage(tenant_id, feature_key)
+ assert usage["used"] == 10
+ assert usage["remaining"] == 90
+
+ # Step 3: Refund the usage
+ mock_send_request.return_value = {"result": "success", "history_id": history_id}
+ refund_result = BillingService.refund_tenant_feature_plan_usage(history_id)
+ assert refund_result["result"] == "success"
+
+ # Step 4: Verify usage after refund
+ mock_send_request.return_value = {"used": 0, "limit": 100, "remaining": 100}
+ updated_usage = BillingService.get_tenant_feature_plan_usage(tenant_id, feature_key)
+ assert updated_usage["used"] == 0
+ assert updated_usage["remaining"] == 100
+
+ def test_compliance_download_multiple_requests_within_limit(self, mock_send_request):
+ """Test multiple compliance downloads within rate limit."""
+ # Arrange
+ account_id = "account-compliance"
+ tenant_id = "tenant-compliance"
+ doc_name = "compliance_report.pdf"
+ ip = "192.168.1.1"
+ device_info = "Mozilla/5.0"
+
+ # Mock rate limiter to allow 3 requests (under limit of 4)
+ with (
+ patch.object(
+ BillingService.compliance_download_rate_limiter, "is_rate_limited", side_effect=[False, False, False]
+ ) as mock_is_limited,
+ patch.object(BillingService.compliance_download_rate_limiter, "increment_rate_limit") as mock_increment,
+ ):
+ mock_send_request.return_value = {"download_link": "https://example.com/download"}
+
+ # Act - Make 3 requests
+ for i in range(3):
+ result = BillingService.get_compliance_download_link(doc_name, account_id, tenant_id, ip, device_info)
+ assert "download_link" in result
+
+ # Assert - All 3 requests succeeded
+ assert mock_is_limited.call_count == 3
+ assert mock_increment.call_count == 3
+
+ def test_education_verification_and_activation_flow(self, mock_send_request):
+ """Test complete education verification and activation flow."""
+ # Arrange
+ account = MagicMock(spec=Account)
+ account.id = "account-edu"
+ account.email = "student@mit.edu"
+ account.current_tenant_id = "tenant-edu"
+
+ # Step 1: Search for institution
+ with (
+ patch.object(
+ BillingService.EducationIdentity.verification_rate_limit, "is_rate_limited", return_value=False
+ ),
+ patch.object(BillingService.EducationIdentity.verification_rate_limit, "increment_rate_limit"),
+ ):
+ mock_send_request.return_value = {
+ "institutions": [{"name": "Massachusetts Institute of Technology", "domain": "mit.edu"}]
+ }
+ institutions = BillingService.EducationIdentity.autocomplete("MIT")
+ assert len(institutions["institutions"]) > 0
+
+ # Step 2: Verify email
+ with (
+ patch.object(
+ BillingService.EducationIdentity.verification_rate_limit, "is_rate_limited", return_value=False
+ ),
+ patch.object(BillingService.EducationIdentity.verification_rate_limit, "increment_rate_limit"),
+ ):
+ mock_send_request.return_value = {"verified": True, "institution": "MIT"}
+ verify_result = BillingService.EducationIdentity.verify(account.id, account.email)
+ assert verify_result["verified"] is True
+
+ # Step 3: Check status
+ mock_send_request.return_value = {"verified": True, "institution": "MIT", "role": "student"}
+ status = BillingService.EducationIdentity.status(account.id)
+ assert status["verified"] is True
+
+ # Step 4: Activate education benefits
+ with (
+ patch.object(BillingService.EducationIdentity.activation_rate_limit, "is_rate_limited", return_value=False),
+ patch.object(BillingService.EducationIdentity.activation_rate_limit, "increment_rate_limit"),
+ ):
+ mock_send_request.return_value = {"result": "success", "activated": True}
+ activate_result = BillingService.EducationIdentity.activate(account, "token-123", "MIT", "student")
+ assert activate_result["activated"] is True
From 2551f6f27967f663357c89f33f0f005a27913be1 Mon Sep 17 00:00:00 2001
From: jiangbo721
Date: Thu, 27 Nov 2025 10:51:48 +0800
Subject: [PATCH 019/431] =?UTF-8?q?feat:=20add=20APP=5FDEFAULT=5FACTIVE=5F?=
=?UTF-8?q?REQUESTS=20as=20the=20default=20value=20for=20APP=5FAC=E2=80=A6?=
=?UTF-8?q?=20(#26930)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
api/.env.example | 1 +
api/configs/feature/__init__.py | 4 ++++
api/services/app_generate_service.py | 2 +-
api/services/rag_pipeline/pipeline_generate_service.py | 9 +++++----
api/tests/integration_tests/.env.example | 1 +
.../services/test_app_generate_service.py | 1 +
docker/.env.example | 2 ++
docker/docker-compose.yaml | 1 +
8 files changed, 16 insertions(+), 5 deletions(-)
diff --git a/api/.env.example b/api/.env.example
index fbf0b12f40..50607f5b35 100644
--- a/api/.env.example
+++ b/api/.env.example
@@ -540,6 +540,7 @@ WORKFLOW_LOG_CLEANUP_BATCH_SIZE=100
# App configuration
APP_MAX_EXECUTION_TIME=1200
+APP_DEFAULT_ACTIVE_REQUESTS=0
APP_MAX_ACTIVE_REQUESTS=0
# Celery beat configuration
diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py
index 7cce3847b4..9c0c48c955 100644
--- a/api/configs/feature/__init__.py
+++ b/api/configs/feature/__init__.py
@@ -73,6 +73,10 @@ class AppExecutionConfig(BaseSettings):
description="Maximum allowed execution time for the application in seconds",
default=1200,
)
+ APP_DEFAULT_ACTIVE_REQUESTS: NonNegativeInt = Field(
+ description="Default number of concurrent active requests per app (0 for unlimited)",
+ default=0,
+ )
APP_MAX_ACTIVE_REQUESTS: NonNegativeInt = Field(
description="Maximum number of concurrent active requests per app (0 for unlimited)",
default=0,
diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py
index bb1ea742d0..dc85929b98 100644
--- a/api/services/app_generate_service.py
+++ b/api/services/app_generate_service.py
@@ -135,7 +135,7 @@ class AppGenerateService:
Returns:
The maximum number of active requests allowed
"""
- app_limit = app.max_active_requests or 0
+ app_limit = app.max_active_requests or dify_config.APP_DEFAULT_ACTIVE_REQUESTS
config_limit = dify_config.APP_MAX_ACTIVE_REQUESTS
# Filter out infinite (0) values and return the minimum, or 0 if both are infinite
diff --git a/api/services/rag_pipeline/pipeline_generate_service.py b/api/services/rag_pipeline/pipeline_generate_service.py
index e6cee64df6..f397b28283 100644
--- a/api/services/rag_pipeline/pipeline_generate_service.py
+++ b/api/services/rag_pipeline/pipeline_generate_service.py
@@ -53,10 +53,11 @@ class PipelineGenerateService:
@staticmethod
def _get_max_active_requests(app_model: App) -> int:
- max_active_requests = app_model.max_active_requests
- if max_active_requests is None:
- max_active_requests = int(dify_config.APP_MAX_ACTIVE_REQUESTS)
- return max_active_requests
+ app_limit = app_model.max_active_requests or dify_config.APP_DEFAULT_ACTIVE_REQUESTS
+ config_limit = dify_config.APP_MAX_ACTIVE_REQUESTS
+ # Filter out infinite (0) values and return the minimum, or 0 if both are infinite
+ limits = [limit for limit in [app_limit, config_limit] if limit > 0]
+ return min(limits) if limits else 0
@classmethod
def generate_single_iteration(
diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example
index 46d13079db..e508ceef66 100644
--- a/api/tests/integration_tests/.env.example
+++ b/api/tests/integration_tests/.env.example
@@ -175,6 +175,7 @@ MAX_VARIABLE_SIZE=204800
# App configuration
APP_MAX_EXECUTION_TIME=1200
+APP_DEFAULT_ACTIVE_REQUESTS=0
APP_MAX_ACTIVE_REQUESTS=0
# Celery beat configuration
diff --git a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py
index 0f9ed94017..476f58585d 100644
--- a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py
+++ b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py
@@ -82,6 +82,7 @@ class TestAppGenerateService:
# Setup dify_config mock returns
mock_dify_config.BILLING_ENABLED = False
mock_dify_config.APP_MAX_ACTIVE_REQUESTS = 100
+ mock_dify_config.APP_DEFAULT_ACTIVE_REQUESTS = 100
mock_dify_config.APP_DAILY_RATE_LIMIT = 1000
mock_global_dify_config.BILLING_ENABLED = False
diff --git a/docker/.env.example b/docker/.env.example
index 0bfdc6b495..c9981baaba 100644
--- a/docker/.env.example
+++ b/docker/.env.example
@@ -133,6 +133,8 @@ ACCESS_TOKEN_EXPIRE_MINUTES=60
# Refresh token expiration time in days
REFRESH_TOKEN_EXPIRE_DAYS=30
+# The default number of active requests for the application, where 0 means unlimited, should be a non-negative integer.
+APP_DEFAULT_ACTIVE_REQUESTS=0
# The maximum number of active requests for the application, where 0 means unlimited, should be a non-negative integer.
APP_MAX_ACTIVE_REQUESTS=0
APP_MAX_EXECUTION_TIME=1200
diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml
index 0302612045..17f33bbf72 100644
--- a/docker/docker-compose.yaml
+++ b/docker/docker-compose.yaml
@@ -34,6 +34,7 @@ x-shared-env: &shared-api-worker-env
FILES_ACCESS_TIMEOUT: ${FILES_ACCESS_TIMEOUT:-300}
ACCESS_TOKEN_EXPIRE_MINUTES: ${ACCESS_TOKEN_EXPIRE_MINUTES:-60}
REFRESH_TOKEN_EXPIRE_DAYS: ${REFRESH_TOKEN_EXPIRE_DAYS:-30}
+ APP_DEFAULT_ACTIVE_REQUESTS: ${APP_DEFAULT_ACTIVE_REQUESTS:-0}
APP_MAX_ACTIVE_REQUESTS: ${APP_MAX_ACTIVE_REQUESTS:-0}
APP_MAX_EXECUTION_TIME: ${APP_MAX_EXECUTION_TIME:-1200}
DIFY_BIND_ADDRESS: ${DIFY_BIND_ADDRESS:-0.0.0.0}
From 2f6b3f1c5fc54121765d2201d8dd6bf0c89a5cc3 Mon Sep 17 00:00:00 2001
From: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com>
Date: Thu, 27 Nov 2025 10:54:00 +0800
Subject: [PATCH 020/431] hotfix: fix _extract_filename for rfc 5987 (#26230)
Signed-off-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com>
---
api/factories/file_factory.py | 43 ++++++-
.../unit_tests/factories/test_file_factory.py | 119 +++++++++++++++++-
2 files changed, 156 insertions(+), 6 deletions(-)
diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py
index 2316e45179..737a79f2b0 100644
--- a/api/factories/file_factory.py
+++ b/api/factories/file_factory.py
@@ -1,5 +1,6 @@
import mimetypes
import os
+import re
import urllib.parse
import uuid
from collections.abc import Callable, Mapping, Sequence
@@ -268,15 +269,47 @@ def _build_from_remote_url(
def _extract_filename(url_path: str, content_disposition: str | None) -> str | None:
- filename = None
+ filename: str | None = None
# Try to extract from Content-Disposition header first
if content_disposition:
- _, params = parse_options_header(content_disposition)
- # RFC 5987 https://datatracker.ietf.org/doc/html/rfc5987: filename* takes precedence over filename
- filename = params.get("filename*") or params.get("filename")
+ # Manually extract filename* parameter since parse_options_header doesn't support it
+ filename_star_match = re.search(r"filename\*=([^;]+)", content_disposition)
+ if filename_star_match:
+ raw_star = filename_star_match.group(1).strip()
+ # Remove trailing quotes if present
+ raw_star = raw_star.removesuffix('"')
+ # format: charset'lang'value
+ try:
+ parts = raw_star.split("'", 2)
+ charset = (parts[0] or "utf-8").lower() if len(parts) >= 1 else "utf-8"
+ value = parts[2] if len(parts) == 3 else parts[-1]
+ filename = urllib.parse.unquote(value, encoding=charset, errors="replace")
+ except Exception:
+ # Fallback: try to extract value after the last single quote
+ if "''" in raw_star:
+ filename = urllib.parse.unquote(raw_star.split("''")[-1])
+ else:
+ filename = urllib.parse.unquote(raw_star)
+
+ if not filename:
+ # Fallback to regular filename parameter
+ _, params = parse_options_header(content_disposition)
+ raw = params.get("filename")
+ if raw:
+ # Strip surrounding quotes and percent-decode if present
+ if len(raw) >= 2 and raw[0] == raw[-1] == '"':
+ raw = raw[1:-1]
+ filename = urllib.parse.unquote(raw)
# Fallback to URL path if no filename from header
if not filename:
- filename = os.path.basename(url_path)
+ candidate = os.path.basename(url_path)
+ filename = urllib.parse.unquote(candidate) if candidate else None
+ # Defense-in-depth: ensure basename only
+ if filename:
+ filename = os.path.basename(filename)
+ # Return None if filename is empty or only whitespace
+ if not filename or not filename.strip():
+ filename = None
return filename or None
diff --git a/api/tests/unit_tests/factories/test_file_factory.py b/api/tests/unit_tests/factories/test_file_factory.py
index 777fe5a6e7..e5f45044fa 100644
--- a/api/tests/unit_tests/factories/test_file_factory.py
+++ b/api/tests/unit_tests/factories/test_file_factory.py
@@ -2,7 +2,7 @@ import re
import pytest
-from factories.file_factory import _get_remote_file_info
+from factories.file_factory import _extract_filename, _get_remote_file_info
class _FakeResponse:
@@ -113,3 +113,120 @@ class TestGetRemoteFileInfo:
# Should generate a random hex filename with .bin extension
assert re.match(r"^[0-9a-f]{32}\.bin$", filename) is not None
assert mime_type == "application/octet-stream"
+
+
+class TestExtractFilename:
+ """Tests for _extract_filename function focusing on RFC5987 parsing and security."""
+
+ def test_no_content_disposition_uses_url_basename(self):
+ """Test that URL basename is used when no Content-Disposition header."""
+ result = _extract_filename("http://example.com/path/file.txt", None)
+ assert result == "file.txt"
+
+ def test_no_content_disposition_with_percent_encoded_url(self):
+ """Test that percent-encoded URL basename is decoded."""
+ result = _extract_filename("http://example.com/path/file%20name.txt", None)
+ assert result == "file name.txt"
+
+ def test_no_content_disposition_empty_url_path(self):
+ """Test that empty URL path returns None."""
+ result = _extract_filename("http://example.com/", None)
+ assert result is None
+
+ def test_simple_filename_header(self):
+ """Test basic filename extraction from Content-Disposition."""
+ result = _extract_filename("http://example.com/", 'attachment; filename="test.txt"')
+ assert result == "test.txt"
+
+ def test_quoted_filename_with_spaces(self):
+ """Test filename with spaces in quotes."""
+ result = _extract_filename("http://example.com/", 'attachment; filename="my file.txt"')
+ assert result == "my file.txt"
+
+ def test_unquoted_filename(self):
+ """Test unquoted filename."""
+ result = _extract_filename("http://example.com/", "attachment; filename=test.txt")
+ assert result == "test.txt"
+
+ def test_percent_encoded_filename(self):
+ """Test percent-encoded filename."""
+ result = _extract_filename("http://example.com/", 'attachment; filename="file%20name.txt"')
+ assert result == "file name.txt"
+
+ def test_rfc5987_filename_star_utf8(self):
+ """Test RFC5987 filename* with UTF-8 encoding."""
+ result = _extract_filename("http://example.com/", "attachment; filename*=UTF-8''file%20name.txt")
+ assert result == "file name.txt"
+
+ def test_rfc5987_filename_star_chinese(self):
+ """Test RFC5987 filename* with Chinese characters."""
+ result = _extract_filename(
+ "http://example.com/", "attachment; filename*=UTF-8''%E6%B5%8B%E8%AF%95%E6%96%87%E4%BB%B6.txt"
+ )
+ assert result == "测试文件.txt"
+
+ def test_rfc5987_filename_star_with_language(self):
+ """Test RFC5987 filename* with language tag."""
+ result = _extract_filename("http://example.com/", "attachment; filename*=UTF-8'en'file%20name.txt")
+ assert result == "file name.txt"
+
+ def test_rfc5987_filename_star_fallback_charset(self):
+ """Test RFC5987 filename* with fallback charset."""
+ result = _extract_filename("http://example.com/", "attachment; filename*=''file%20name.txt")
+ assert result == "file name.txt"
+
+ def test_rfc5987_filename_star_malformed_fallback(self):
+ """Test RFC5987 filename* with malformed format falls back to simple unquote."""
+ result = _extract_filename("http://example.com/", "attachment; filename*=malformed%20filename.txt")
+ assert result == "malformed filename.txt"
+
+ def test_filename_star_takes_precedence_over_filename(self):
+ """Test that filename* takes precedence over filename."""
+ test_string = 'attachment; filename="old.txt"; filename*=UTF-8\'\'new.txt"'
+ result = _extract_filename("http://example.com/", test_string)
+ assert result == "new.txt"
+
+ def test_path_injection_protection(self):
+ """Test that path injection attempts are blocked by os.path.basename."""
+ result = _extract_filename("http://example.com/", 'attachment; filename="../../../etc/passwd"')
+ assert result == "passwd"
+
+ def test_path_injection_protection_rfc5987(self):
+ """Test that path injection attempts in RFC5987 are blocked."""
+ result = _extract_filename("http://example.com/", "attachment; filename*=UTF-8''..%2F..%2F..%2Fetc%2Fpasswd")
+ assert result == "passwd"
+
+ def test_empty_filename_returns_none(self):
+ """Test that empty filename returns None."""
+ result = _extract_filename("http://example.com/", 'attachment; filename=""')
+ assert result is None
+
+ def test_whitespace_only_filename_returns_none(self):
+ """Test that whitespace-only filename returns None."""
+ result = _extract_filename("http://example.com/", 'attachment; filename=" "')
+ assert result is None
+
+ def test_complex_rfc5987_encoding(self):
+ """Test complex RFC5987 encoding with special characters."""
+ result = _extract_filename(
+ "http://example.com/",
+ "attachment; filename*=UTF-8''%E4%B8%AD%E6%96%87%E6%96%87%E4%BB%B6%20%28%E5%89%AF%E6%9C%AC%29.pdf",
+ )
+ assert result == "中文文件 (副本).pdf"
+
+ def test_iso8859_1_encoding(self):
+ """Test ISO-8859-1 encoding in RFC5987."""
+ result = _extract_filename("http://example.com/", "attachment; filename*=ISO-8859-1''file%20name.txt")
+ assert result == "file name.txt"
+
+ def test_encoding_error_fallback(self):
+ """Test that encoding errors fall back to safe ASCII filename."""
+ result = _extract_filename("http://example.com/", "attachment; filename*=INVALID-CHARSET''file%20name.txt")
+ assert result == "file name.txt"
+
+ def test_mixed_quotes_and_encoding(self):
+ """Test filename with mixed quotes and percent encoding."""
+ result = _extract_filename(
+ "http://example.com/", 'attachment; filename="file%20with%20quotes%20%26%20encoding.txt"'
+ )
+ assert result == "file with quotes & encoding.txt"
From 09a8046b10809d583825f3fed400ea47c1705f65 Mon Sep 17 00:00:00 2001
From: Will
Date: Thu, 27 Nov 2025 10:56:21 +0800
Subject: [PATCH 021/431] fix: querying webhook trigger issue (#28753)
---
api/controllers/console/app/workflow_trigger.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/api/controllers/console/app/workflow_trigger.py b/api/controllers/console/app/workflow_trigger.py
index b3e5c9619f..5d16e4f979 100644
--- a/api/controllers/console/app/workflow_trigger.py
+++ b/api/controllers/console/app/workflow_trigger.py
@@ -43,7 +43,7 @@ console_ns.schema_model(
class WebhookTriggerApi(Resource):
"""Webhook Trigger API"""
- @console_ns.expect(console_ns.models[Parser.__name__], validate=True)
+ @console_ns.expect(console_ns.models[Parser.__name__])
@setup_required
@login_required
@account_initialization_required
From b786e101e52a4f763c4818f4f7637b191a611c09 Mon Sep 17 00:00:00 2001
From: Will
Date: Thu, 27 Nov 2025 10:58:35 +0800
Subject: [PATCH 022/431] fix: querying and setting the system default model
(#28743)
---
api/controllers/console/workspace/models.py | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py
index 8e402b4bae..c820a8d1f2 100644
--- a/api/controllers/console/workspace/models.py
+++ b/api/controllers/console/workspace/models.py
@@ -1,5 +1,5 @@
import logging
-from typing import Any
+from typing import Any, cast
from flask import request
from flask_restx import Resource
@@ -26,7 +26,7 @@ class ParserGetDefault(BaseModel):
class ParserPostDefault(BaseModel):
class Inner(BaseModel):
model_type: ModelType
- model: str
+ model: str | None = None
provider: str | None = None
model_settings: list[Inner]
@@ -150,7 +150,7 @@ console_ns.schema_model(
@console_ns.route("/workspaces/current/default-model")
class DefaultModelApi(Resource):
- @console_ns.expect(console_ns.models[ParserGetDefault.__name__], validate=True)
+ @console_ns.expect(console_ns.models[ParserGetDefault.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -186,7 +186,7 @@ class DefaultModelApi(Resource):
tenant_id=tenant_id,
model_type=model_setting.model_type,
provider=model_setting.provider,
- model=model_setting.model,
+ model=cast(str, model_setting.model),
)
except Exception as ex:
logger.exception(
From 7efa0df1fd119037386b5627652e02e621f0e1d1 Mon Sep 17 00:00:00 2001
From: aka James4u
Date: Wed, 26 Nov 2025 18:59:17 -0800
Subject: [PATCH 023/431] Add comprehensive API/controller tests for dataset
endpoints (list, create, update, delete, documents, segments, hit testing,
external datasets) (#28750)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
---
.../unit_tests/services/controller_api.py | 1082 +++++++++++++++++
1 file changed, 1082 insertions(+)
create mode 100644 api/tests/unit_tests/services/controller_api.py
diff --git a/api/tests/unit_tests/services/controller_api.py b/api/tests/unit_tests/services/controller_api.py
new file mode 100644
index 0000000000..762d7b9090
--- /dev/null
+++ b/api/tests/unit_tests/services/controller_api.py
@@ -0,0 +1,1082 @@
+"""
+Comprehensive API/Controller tests for Dataset endpoints.
+
+This module contains extensive integration tests for the dataset-related
+controller endpoints, testing the HTTP API layer that exposes dataset
+functionality through REST endpoints.
+
+The controller endpoints provide HTTP access to:
+- Dataset CRUD operations (list, create, update, delete)
+- Document management operations
+- Segment management operations
+- Hit testing (retrieval testing) operations
+- External dataset and knowledge API operations
+
+These tests verify that:
+- HTTP requests are properly routed to service methods
+- Request validation works correctly
+- Response formatting is correct
+- Authentication and authorization are enforced
+- Error handling returns appropriate HTTP status codes
+- Request/response serialization works properly
+
+================================================================================
+ARCHITECTURE OVERVIEW
+================================================================================
+
+The controller layer in Dify uses Flask-RESTX to provide RESTful API endpoints.
+Controllers act as a thin layer between HTTP requests and service methods,
+handling:
+
+1. Request Parsing: Extracting and validating parameters from HTTP requests
+2. Authentication: Verifying user identity and permissions
+3. Authorization: Checking if user has permission to perform operations
+4. Service Invocation: Calling appropriate service methods
+5. Response Formatting: Serializing service results to HTTP responses
+6. Error Handling: Converting exceptions to appropriate HTTP status codes
+
+Key Components:
+- Flask-RESTX Resources: Define endpoint classes with HTTP methods
+- Decorators: Handle authentication, authorization, and setup requirements
+- Request Parsers: Validate and extract request parameters
+- Response Models: Define response structure for Swagger documentation
+- Error Handlers: Convert exceptions to HTTP error responses
+
+================================================================================
+TESTING STRATEGY
+================================================================================
+
+This test suite follows a comprehensive testing strategy that covers:
+
+1. HTTP Request/Response Testing:
+ - GET, POST, PATCH, DELETE methods
+ - Query parameters and request body validation
+ - Response status codes and body structure
+ - Headers and content types
+
+2. Authentication and Authorization:
+ - Login required checks
+ - Account initialization checks
+ - Permission validation
+ - Role-based access control
+
+3. Request Validation:
+ - Required parameter validation
+ - Parameter type validation
+ - Parameter range validation
+ - Custom validation rules
+
+4. Error Handling:
+ - 400 Bad Request (validation errors)
+ - 401 Unauthorized (authentication errors)
+ - 403 Forbidden (authorization errors)
+ - 404 Not Found (resource not found)
+ - 500 Internal Server Error (unexpected errors)
+
+5. Service Integration:
+ - Service method invocation
+ - Service method parameter passing
+ - Service method return value handling
+ - Service exception handling
+
+================================================================================
+"""
+
+from unittest.mock import Mock, patch
+from uuid import uuid4
+
+import pytest
+from flask import Flask
+from flask_restx import Api
+
+from controllers.console.datasets.datasets import DatasetApi, DatasetListApi
+from controllers.console.datasets.external import (
+ ExternalApiTemplateListApi,
+)
+from controllers.console.datasets.hit_testing import HitTestingApi
+from models.dataset import Dataset, DatasetPermissionEnum
+
+# ============================================================================
+# Test Data Factory
+# ============================================================================
+# The Test Data Factory pattern is used here to centralize the creation of
+# test objects and mock instances. This approach provides several benefits:
+#
+# 1. Consistency: All test objects are created using the same factory methods,
+# ensuring consistent structure across all tests.
+#
+# 2. Maintainability: If the structure of models or services changes, we only
+# need to update the factory methods rather than every individual test.
+#
+# 3. Reusability: Factory methods can be reused across multiple test classes,
+# reducing code duplication.
+#
+# 4. Readability: Tests become more readable when they use descriptive factory
+# method calls instead of complex object construction logic.
+#
+# ============================================================================
+
+
+class ControllerApiTestDataFactory:
+ """
+ Factory class for creating test data and mock objects for controller API tests.
+
+ This factory provides static methods to create mock objects for:
+ - Flask application and test client setup
+ - Dataset instances and related models
+ - User and authentication context
+ - HTTP request/response objects
+ - Service method return values
+
+ The factory methods help maintain consistency across tests and reduce
+ code duplication when setting up test scenarios.
+ """
+
+ @staticmethod
+ def create_flask_app():
+ """
+ Create a Flask test application for API testing.
+
+ Returns:
+ Flask application instance configured for testing
+ """
+ app = Flask(__name__)
+ app.config["TESTING"] = True
+ app.config["SECRET_KEY"] = "test-secret-key"
+ return app
+
+ @staticmethod
+ def create_api_instance(app):
+ """
+ Create a Flask-RESTX API instance.
+
+ Args:
+ app: Flask application instance
+
+ Returns:
+ Api instance configured for the application
+ """
+ api = Api(app, doc="/docs/")
+ return api
+
+ @staticmethod
+ def create_test_client(app, api, resource_class, route):
+ """
+ Create a Flask test client with a resource registered.
+
+ Args:
+ app: Flask application instance
+ api: Flask-RESTX API instance
+ resource_class: Resource class to register
+ route: URL route for the resource
+
+ Returns:
+ Flask test client instance
+ """
+ api.add_resource(resource_class, route)
+ return app.test_client()
+
+ @staticmethod
+ def create_dataset_mock(
+ dataset_id: str = "dataset-123",
+ name: str = "Test Dataset",
+ tenant_id: str = "tenant-123",
+ permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME,
+ **kwargs,
+ ) -> Mock:
+ """
+ Create a mock Dataset instance.
+
+ Args:
+ dataset_id: Unique identifier for the dataset
+ name: Name of the dataset
+ tenant_id: Tenant identifier
+ permission: Dataset permission level
+ **kwargs: Additional attributes to set on the mock
+
+ Returns:
+ Mock object configured as a Dataset instance
+ """
+ dataset = Mock(spec=Dataset)
+ dataset.id = dataset_id
+ dataset.name = name
+ dataset.tenant_id = tenant_id
+ dataset.permission = permission
+ dataset.to_dict.return_value = {
+ "id": dataset_id,
+ "name": name,
+ "tenant_id": tenant_id,
+ "permission": permission.value,
+ }
+ for key, value in kwargs.items():
+ setattr(dataset, key, value)
+ return dataset
+
+ @staticmethod
+ def create_user_mock(
+ user_id: str = "user-123",
+ tenant_id: str = "tenant-123",
+ is_dataset_editor: bool = True,
+ **kwargs,
+ ) -> Mock:
+ """
+ Create a mock user/account instance.
+
+ Args:
+ user_id: Unique identifier for the user
+ tenant_id: Tenant identifier
+ is_dataset_editor: Whether user has dataset editor permissions
+ **kwargs: Additional attributes to set on the mock
+
+ Returns:
+ Mock object configured as a user/account instance
+ """
+ user = Mock()
+ user.id = user_id
+ user.current_tenant_id = tenant_id
+ user.is_dataset_editor = is_dataset_editor
+ user.has_edit_permission = True
+ user.is_dataset_operator = False
+ for key, value in kwargs.items():
+ setattr(user, key, value)
+ return user
+
+ @staticmethod
+ def create_paginated_response(items, total, page=1, per_page=20):
+ """
+ Create a mock paginated response.
+
+ Args:
+ items: List of items in the current page
+ total: Total number of items
+ page: Current page number
+ per_page: Items per page
+
+ Returns:
+ Mock paginated response object
+ """
+ response = Mock()
+ response.items = items
+ response.total = total
+ response.page = page
+ response.per_page = per_page
+ response.pages = (total + per_page - 1) // per_page
+ return response
+
+
+# ============================================================================
+# Tests for Dataset List Endpoint (GET /datasets)
+# ============================================================================
+
+
+class TestDatasetListApi:
+ """
+ Comprehensive API tests for DatasetListApi (GET /datasets endpoint).
+
+ This test class covers the dataset listing functionality through the
+ HTTP API, including pagination, search, filtering, and permissions.
+
+ The GET /datasets endpoint:
+ 1. Requires authentication and account initialization
+ 2. Supports pagination (page, limit parameters)
+ 3. Supports search by keyword
+ 4. Supports filtering by tag IDs
+ 5. Supports including all datasets (for admins)
+ 6. Returns paginated list of datasets
+
+ Test scenarios include:
+ - Successful dataset listing with pagination
+ - Search functionality
+ - Tag filtering
+ - Permission-based filtering
+ - Error handling (authentication, authorization)
+ """
+
+ @pytest.fixture
+ def app(self):
+ """
+ Create Flask test application.
+
+ Provides a Flask application instance configured for testing.
+ """
+ return ControllerApiTestDataFactory.create_flask_app()
+
+ @pytest.fixture
+ def api(self, app):
+ """
+ Create Flask-RESTX API instance.
+
+ Provides an API instance for registering resources.
+ """
+ return ControllerApiTestDataFactory.create_api_instance(app)
+
+ @pytest.fixture
+ def client(self, app, api):
+ """
+ Create test client with DatasetListApi registered.
+
+ Provides a Flask test client that can make HTTP requests to
+ the dataset list endpoint.
+ """
+ return ControllerApiTestDataFactory.create_test_client(app, api, DatasetListApi, "/datasets")
+
+ @pytest.fixture
+ def mock_current_user(self):
+ """
+ Mock current user and tenant context.
+
+ Provides mocked current_account_with_tenant function that returns
+ a user and tenant ID for testing authentication.
+ """
+ with patch("controllers.console.datasets.datasets.current_account_with_tenant") as mock_get_user:
+ mock_user = ControllerApiTestDataFactory.create_user_mock()
+ mock_tenant_id = "tenant-123"
+ mock_get_user.return_value = (mock_user, mock_tenant_id)
+ yield mock_get_user
+
+ def test_get_datasets_success(self, client, mock_current_user):
+ """
+ Test successful retrieval of dataset list.
+
+ Verifies that when authentication passes, the endpoint returns
+ a paginated list of datasets.
+
+ This test ensures:
+ - Authentication is checked
+ - Service method is called with correct parameters
+ - Response has correct structure
+ - Status code is 200
+ """
+ # Arrange
+ datasets = [
+ ControllerApiTestDataFactory.create_dataset_mock(dataset_id=f"dataset-{i}", name=f"Dataset {i}")
+ for i in range(3)
+ ]
+
+ paginated_response = ControllerApiTestDataFactory.create_paginated_response(
+ items=datasets, total=3, page=1, per_page=20
+ )
+
+ with patch("controllers.console.datasets.datasets.DatasetService.get_datasets") as mock_get_datasets:
+ mock_get_datasets.return_value = (datasets, 3)
+
+ # Act
+ response = client.get("/datasets?page=1&limit=20")
+
+ # Assert
+ assert response.status_code == 200
+ data = response.get_json()
+ assert "data" in data
+ assert len(data["data"]) == 3
+ assert data["total"] == 3
+ assert data["page"] == 1
+ assert data["limit"] == 20
+
+ # Verify service was called
+ mock_get_datasets.assert_called_once()
+
+ def test_get_datasets_with_search(self, client, mock_current_user):
+ """
+ Test dataset listing with search keyword.
+
+ Verifies that search functionality works correctly through the API.
+
+ This test ensures:
+ - Search keyword is passed to service method
+ - Filtered results are returned
+ - Response structure is correct
+ """
+ # Arrange
+ search_keyword = "test"
+ datasets = [ControllerApiTestDataFactory.create_dataset_mock(dataset_id="dataset-1", name="Test Dataset")]
+
+ with patch("controllers.console.datasets.datasets.DatasetService.get_datasets") as mock_get_datasets:
+ mock_get_datasets.return_value = (datasets, 1)
+
+ # Act
+ response = client.get(f"/datasets?keyword={search_keyword}")
+
+ # Assert
+ assert response.status_code == 200
+ data = response.get_json()
+ assert len(data["data"]) == 1
+
+ # Verify search keyword was passed
+ call_args = mock_get_datasets.call_args
+ assert call_args[1]["search"] == search_keyword
+
+ def test_get_datasets_with_pagination(self, client, mock_current_user):
+ """
+ Test dataset listing with pagination parameters.
+
+ Verifies that pagination works correctly through the API.
+
+ This test ensures:
+ - Page and limit parameters are passed correctly
+ - Pagination metadata is included in response
+ - Correct datasets are returned for the page
+ """
+ # Arrange
+ datasets = [
+ ControllerApiTestDataFactory.create_dataset_mock(dataset_id=f"dataset-{i}", name=f"Dataset {i}")
+ for i in range(5)
+ ]
+
+ with patch("controllers.console.datasets.datasets.DatasetService.get_datasets") as mock_get_datasets:
+ mock_get_datasets.return_value = (datasets[:3], 5) # First page with 3 items
+
+ # Act
+ response = client.get("/datasets?page=1&limit=3")
+
+ # Assert
+ assert response.status_code == 200
+ data = response.get_json()
+ assert len(data["data"]) == 3
+ assert data["page"] == 1
+ assert data["limit"] == 3
+
+ # Verify pagination parameters were passed
+ call_args = mock_get_datasets.call_args
+ assert call_args[0][0] == 1 # page
+ assert call_args[0][1] == 3 # per_page
+
+
+# ============================================================================
+# Tests for Dataset Detail Endpoint (GET /datasets/{id})
+# ============================================================================
+
+
+class TestDatasetApiGet:
+ """
+ Comprehensive API tests for DatasetApi GET method (GET /datasets/{id} endpoint).
+
+ This test class covers the single dataset retrieval functionality through
+ the HTTP API.
+
+ The GET /datasets/{id} endpoint:
+ 1. Requires authentication and account initialization
+ 2. Validates dataset exists
+ 3. Checks user permissions
+ 4. Returns dataset details
+
+ Test scenarios include:
+ - Successful dataset retrieval
+ - Dataset not found (404)
+ - Permission denied (403)
+ - Authentication required
+ """
+
+ @pytest.fixture
+ def app(self):
+ """Create Flask test application."""
+ return ControllerApiTestDataFactory.create_flask_app()
+
+ @pytest.fixture
+ def api(self, app):
+ """Create Flask-RESTX API instance."""
+ return ControllerApiTestDataFactory.create_api_instance(app)
+
+ @pytest.fixture
+ def client(self, app, api):
+ """Create test client with DatasetApi registered."""
+ return ControllerApiTestDataFactory.create_test_client(app, api, DatasetApi, "/datasets/")
+
+ @pytest.fixture
+ def mock_current_user(self):
+ """Mock current user and tenant context."""
+ with patch("controllers.console.datasets.datasets.current_account_with_tenant") as mock_get_user:
+ mock_user = ControllerApiTestDataFactory.create_user_mock()
+ mock_tenant_id = "tenant-123"
+ mock_get_user.return_value = (mock_user, mock_tenant_id)
+ yield mock_get_user
+
+ def test_get_dataset_success(self, client, mock_current_user):
+ """
+ Test successful retrieval of a single dataset.
+
+ Verifies that when authentication and permissions pass, the endpoint
+ returns dataset details.
+
+ This test ensures:
+ - Authentication is checked
+ - Dataset existence is validated
+ - Permissions are checked
+ - Dataset details are returned
+ - Status code is 200
+ """
+ # Arrange
+ dataset_id = str(uuid4())
+ dataset = ControllerApiTestDataFactory.create_dataset_mock(dataset_id=dataset_id, name="Test Dataset")
+
+ with (
+ patch("controllers.console.datasets.datasets.DatasetService.get_dataset") as mock_get_dataset,
+ patch("controllers.console.datasets.datasets.DatasetService.check_dataset_permission") as mock_check_perm,
+ ):
+ mock_get_dataset.return_value = dataset
+ mock_check_perm.return_value = None # No exception = permission granted
+
+ # Act
+ response = client.get(f"/datasets/{dataset_id}")
+
+ # Assert
+ assert response.status_code == 200
+ data = response.get_json()
+ assert data["id"] == dataset_id
+ assert data["name"] == "Test Dataset"
+
+ # Verify service methods were called
+ mock_get_dataset.assert_called_once_with(dataset_id)
+ mock_check_perm.assert_called_once()
+
+ def test_get_dataset_not_found(self, client, mock_current_user):
+ """
+ Test error handling when dataset is not found.
+
+ Verifies that when dataset doesn't exist, a 404 error is returned.
+
+ This test ensures:
+ - 404 status code is returned
+ - Error message is appropriate
+ - Service method is called
+ """
+ # Arrange
+ dataset_id = str(uuid4())
+
+ with (
+ patch("controllers.console.datasets.datasets.DatasetService.get_dataset") as mock_get_dataset,
+ patch("controllers.console.datasets.datasets.DatasetService.check_dataset_permission") as mock_check_perm,
+ ):
+ mock_get_dataset.return_value = None # Dataset not found
+
+ # Act
+ response = client.get(f"/datasets/{dataset_id}")
+
+ # Assert
+ assert response.status_code == 404
+
+ # Verify service was called
+ mock_get_dataset.assert_called_once()
+
+
+# ============================================================================
+# Tests for Dataset Create Endpoint (POST /datasets)
+# ============================================================================
+
+
+class TestDatasetApiCreate:
+ """
+ Comprehensive API tests for DatasetApi POST method (POST /datasets endpoint).
+
+ This test class covers the dataset creation functionality through the HTTP API.
+
+ The POST /datasets endpoint:
+ 1. Requires authentication and account initialization
+ 2. Validates request body
+ 3. Creates dataset via service
+ 4. Returns created dataset
+
+ Test scenarios include:
+ - Successful dataset creation
+ - Request validation errors
+ - Duplicate name errors
+ - Authentication required
+ """
+
+ @pytest.fixture
+ def app(self):
+ """Create Flask test application."""
+ return ControllerApiTestDataFactory.create_flask_app()
+
+ @pytest.fixture
+ def api(self, app):
+ """Create Flask-RESTX API instance."""
+ return ControllerApiTestDataFactory.create_api_instance(app)
+
+ @pytest.fixture
+ def client(self, app, api):
+ """Create test client with DatasetApi registered."""
+ return ControllerApiTestDataFactory.create_test_client(app, api, DatasetApi, "/datasets")
+
+ @pytest.fixture
+ def mock_current_user(self):
+ """Mock current user and tenant context."""
+ with patch("controllers.console.datasets.datasets.current_account_with_tenant") as mock_get_user:
+ mock_user = ControllerApiTestDataFactory.create_user_mock()
+ mock_tenant_id = "tenant-123"
+ mock_get_user.return_value = (mock_user, mock_tenant_id)
+ yield mock_get_user
+
+ def test_create_dataset_success(self, client, mock_current_user):
+ """
+ Test successful creation of a dataset.
+
+ Verifies that when all validation passes, a new dataset is created
+ and returned.
+
+ This test ensures:
+ - Request body is validated
+ - Service method is called with correct parameters
+ - Created dataset is returned
+ - Status code is 201
+ """
+ # Arrange
+ dataset_id = str(uuid4())
+ dataset = ControllerApiTestDataFactory.create_dataset_mock(dataset_id=dataset_id, name="New Dataset")
+
+ request_data = {
+ "name": "New Dataset",
+ "description": "Test description",
+ "permission": "only_me",
+ }
+
+ with patch("controllers.console.datasets.datasets.DatasetService.create_empty_dataset") as mock_create:
+ mock_create.return_value = dataset
+
+ # Act
+ response = client.post(
+ "/datasets",
+ json=request_data,
+ content_type="application/json",
+ )
+
+ # Assert
+ assert response.status_code == 201
+ data = response.get_json()
+ assert data["id"] == dataset_id
+ assert data["name"] == "New Dataset"
+
+ # Verify service was called
+ mock_create.assert_called_once()
+
+
+# ============================================================================
+# Tests for Hit Testing Endpoint (POST /datasets/{id}/hit-testing)
+# ============================================================================
+
+
+class TestHitTestingApi:
+ """
+ Comprehensive API tests for HitTestingApi (POST /datasets/{id}/hit-testing endpoint).
+
+ This test class covers the hit testing (retrieval testing) functionality
+ through the HTTP API.
+
+ The POST /datasets/{id}/hit-testing endpoint:
+ 1. Requires authentication and account initialization
+ 2. Validates dataset exists and user has permission
+ 3. Validates query parameters
+ 4. Performs retrieval testing
+ 5. Returns test results
+
+ Test scenarios include:
+ - Successful hit testing
+ - Query validation errors
+ - Dataset not found
+ - Permission denied
+ """
+
+ @pytest.fixture
+ def app(self):
+ """Create Flask test application."""
+ return ControllerApiTestDataFactory.create_flask_app()
+
+ @pytest.fixture
+ def api(self, app):
+ """Create Flask-RESTX API instance."""
+ return ControllerApiTestDataFactory.create_api_instance(app)
+
+ @pytest.fixture
+ def client(self, app, api):
+ """Create test client with HitTestingApi registered."""
+ return ControllerApiTestDataFactory.create_test_client(
+ app, api, HitTestingApi, "/datasets//hit-testing"
+ )
+
+ @pytest.fixture
+ def mock_current_user(self):
+ """Mock current user and tenant context."""
+ with patch("controllers.console.datasets.hit_testing.current_account_with_tenant") as mock_get_user:
+ mock_user = ControllerApiTestDataFactory.create_user_mock()
+ mock_tenant_id = "tenant-123"
+ mock_get_user.return_value = (mock_user, mock_tenant_id)
+ yield mock_get_user
+
+ def test_hit_testing_success(self, client, mock_current_user):
+ """
+ Test successful hit testing operation.
+
+ Verifies that when all validation passes, hit testing is performed
+ and results are returned.
+
+ This test ensures:
+ - Dataset validation passes
+ - Query validation passes
+ - Hit testing service is called
+ - Results are returned
+ - Status code is 200
+ """
+ # Arrange
+ dataset_id = str(uuid4())
+ dataset = ControllerApiTestDataFactory.create_dataset_mock(dataset_id=dataset_id)
+
+ request_data = {
+ "query": "test query",
+ "top_k": 10,
+ }
+
+ expected_result = {
+ "query": {"content": "test query"},
+ "records": [
+ {"content": "Result 1", "score": 0.95},
+ {"content": "Result 2", "score": 0.85},
+ ],
+ }
+
+ with (
+ patch(
+ "controllers.console.datasets.hit_testing.HitTestingApi.get_and_validate_dataset"
+ ) as mock_get_dataset,
+ patch("controllers.console.datasets.hit_testing.HitTestingApi.parse_args") as mock_parse_args,
+ patch("controllers.console.datasets.hit_testing.HitTestingApi.hit_testing_args_check") as mock_check_args,
+ patch("controllers.console.datasets.hit_testing.HitTestingApi.perform_hit_testing") as mock_perform,
+ ):
+ mock_get_dataset.return_value = dataset
+ mock_parse_args.return_value = request_data
+ mock_check_args.return_value = None # No validation error
+ mock_perform.return_value = expected_result
+
+ # Act
+ response = client.post(
+ f"/datasets/{dataset_id}/hit-testing",
+ json=request_data,
+ content_type="application/json",
+ )
+
+ # Assert
+ assert response.status_code == 200
+ data = response.get_json()
+ assert "query" in data
+ assert "records" in data
+ assert len(data["records"]) == 2
+
+ # Verify methods were called
+ mock_get_dataset.assert_called_once()
+ mock_parse_args.assert_called_once()
+ mock_check_args.assert_called_once()
+ mock_perform.assert_called_once()
+
+
+# ============================================================================
+# Tests for External Dataset Endpoints
+# ============================================================================
+
+
+class TestExternalDatasetApi:
+ """
+ Comprehensive API tests for External Dataset endpoints.
+
+ This test class covers the external knowledge API and external dataset
+ management functionality through the HTTP API.
+
+ Endpoints covered:
+ - GET /datasets/external-knowledge-api - List external knowledge APIs
+ - POST /datasets/external-knowledge-api - Create external knowledge API
+ - GET /datasets/external-knowledge-api/{id} - Get external knowledge API
+ - PATCH /datasets/external-knowledge-api/{id} - Update external knowledge API
+ - DELETE /datasets/external-knowledge-api/{id} - Delete external knowledge API
+ - POST /datasets/external - Create external dataset
+
+ Test scenarios include:
+ - Successful CRUD operations
+ - Request validation
+ - Authentication and authorization
+ - Error handling
+ """
+
+ @pytest.fixture
+ def app(self):
+ """Create Flask test application."""
+ return ControllerApiTestDataFactory.create_flask_app()
+
+ @pytest.fixture
+ def api(self, app):
+ """Create Flask-RESTX API instance."""
+ return ControllerApiTestDataFactory.create_api_instance(app)
+
+ @pytest.fixture
+ def client_list(self, app, api):
+ """Create test client for external knowledge API list endpoint."""
+ return ControllerApiTestDataFactory.create_test_client(
+ app, api, ExternalApiTemplateListApi, "/datasets/external-knowledge-api"
+ )
+
+ @pytest.fixture
+ def mock_current_user(self):
+ """Mock current user and tenant context."""
+ with patch("controllers.console.datasets.external.current_account_with_tenant") as mock_get_user:
+ mock_user = ControllerApiTestDataFactory.create_user_mock(is_dataset_editor=True)
+ mock_tenant_id = "tenant-123"
+ mock_get_user.return_value = (mock_user, mock_tenant_id)
+ yield mock_get_user
+
+ def test_get_external_knowledge_apis_success(self, client_list, mock_current_user):
+ """
+ Test successful retrieval of external knowledge API list.
+
+ Verifies that the endpoint returns a paginated list of external
+ knowledge APIs.
+
+ This test ensures:
+ - Authentication is checked
+ - Service method is called
+ - Paginated response is returned
+ - Status code is 200
+ """
+ # Arrange
+ apis = [{"id": f"api-{i}", "name": f"API {i}", "endpoint": f"https://api{i}.com"} for i in range(3)]
+
+ with patch(
+ "controllers.console.datasets.external.ExternalDatasetService.get_external_knowledge_apis"
+ ) as mock_get_apis:
+ mock_get_apis.return_value = (apis, 3)
+
+ # Act
+ response = client_list.get("/datasets/external-knowledge-api?page=1&limit=20")
+
+ # Assert
+ assert response.status_code == 200
+ data = response.get_json()
+ assert "data" in data
+ assert len(data["data"]) == 3
+ assert data["total"] == 3
+
+ # Verify service was called
+ mock_get_apis.assert_called_once()
+
+
+# ============================================================================
+# Additional Documentation and Notes
+# ============================================================================
+#
+# This test suite covers the core API endpoints for dataset operations.
+# Additional test scenarios that could be added:
+#
+# 1. Document Endpoints:
+# - POST /datasets/{id}/documents - Upload/create documents
+# - GET /datasets/{id}/documents - List documents
+# - GET /datasets/{id}/documents/{doc_id} - Get document details
+# - PATCH /datasets/{id}/documents/{doc_id} - Update document
+# - DELETE /datasets/{id}/documents/{doc_id} - Delete document
+# - POST /datasets/{id}/documents/batch - Batch operations
+#
+# 2. Segment Endpoints:
+# - GET /datasets/{id}/segments - List segments
+# - GET /datasets/{id}/segments/{segment_id} - Get segment details
+# - PATCH /datasets/{id}/segments/{segment_id} - Update segment
+# - DELETE /datasets/{id}/segments/{segment_id} - Delete segment
+#
+# 3. Dataset Update/Delete Endpoints:
+# - PATCH /datasets/{id} - Update dataset
+# - DELETE /datasets/{id} - Delete dataset
+#
+# 4. Advanced Scenarios:
+# - File upload handling
+# - Large payload handling
+# - Concurrent request handling
+# - Rate limiting
+# - CORS headers
+#
+# These scenarios are not currently implemented but could be added if needed
+# based on real-world usage patterns or discovered edge cases.
+#
+# ============================================================================
+
+
+# ============================================================================
+# API Testing Best Practices
+# ============================================================================
+#
+# When writing API tests, consider the following best practices:
+#
+# 1. Test Structure:
+# - Use descriptive test names that explain what is being tested
+# - Follow Arrange-Act-Assert pattern
+# - Keep tests focused on a single scenario
+# - Use fixtures for common setup
+#
+# 2. Mocking Strategy:
+# - Mock external dependencies (database, services, etc.)
+# - Mock authentication and authorization
+# - Use realistic mock data
+# - Verify mock calls to ensure correct integration
+#
+# 3. Assertions:
+# - Verify HTTP status codes
+# - Verify response structure
+# - Verify response data values
+# - Verify service method calls
+# - Verify error messages when appropriate
+#
+# 4. Error Testing:
+# - Test all error paths (400, 401, 403, 404, 500)
+# - Test validation errors
+# - Test authentication failures
+# - Test authorization failures
+# - Test not found scenarios
+#
+# 5. Edge Cases:
+# - Test with empty data
+# - Test with missing required fields
+# - Test with invalid data types
+# - Test with boundary values
+# - Test with special characters
+#
+# ============================================================================
+
+
+# ============================================================================
+# Flask-RESTX Resource Testing Patterns
+# ============================================================================
+#
+# Flask-RESTX resources are tested using Flask's test client. The typical
+# pattern involves:
+#
+# 1. Creating a Flask test application
+# 2. Creating a Flask-RESTX API instance
+# 3. Registering the resource with a route
+# 4. Creating a test client
+# 5. Making HTTP requests through the test client
+# 6. Asserting on the response
+#
+# Example pattern:
+#
+# app = Flask(__name__)
+# app.config["TESTING"] = True
+# api = Api(app)
+# api.add_resource(MyResource, "/my-endpoint")
+# client = app.test_client()
+# response = client.get("/my-endpoint")
+# assert response.status_code == 200
+#
+# Decorators on resources (like @login_required) need to be mocked or
+# bypassed in tests. This is typically done by mocking the decorator
+# functions or the authentication functions they call.
+#
+# ============================================================================
+
+
+# ============================================================================
+# Request/Response Validation
+# ============================================================================
+#
+# API endpoints use Flask-RESTX request parsers to validate incoming requests.
+# These parsers:
+#
+# 1. Extract parameters from query strings, form data, or JSON body
+# 2. Validate parameter types (string, integer, float, boolean, etc.)
+# 3. Validate parameter ranges and constraints
+# 4. Provide default values when parameters are missing
+# 5. Raise BadRequest exceptions when validation fails
+#
+# Response formatting is handled by Flask-RESTX's marshal_with decorator
+# or marshal function, which:
+#
+# 1. Formats response data according to defined models
+# 2. Handles nested objects and lists
+# 3. Filters out fields not in the model
+# 4. Provides consistent response structure
+#
+# Tests should verify:
+# - Request validation works correctly
+# - Invalid requests return 400 Bad Request
+# - Response structure matches the defined model
+# - Response data values are correct
+#
+# ============================================================================
+
+
+# ============================================================================
+# Authentication and Authorization Testing
+# ============================================================================
+#
+# Most API endpoints require authentication and authorization. Testing these
+# aspects involves:
+#
+# 1. Authentication Testing:
+# - Test that unauthenticated requests are rejected (401)
+# - Test that authenticated requests are accepted
+# - Mock the authentication decorators/functions
+# - Verify user context is passed correctly
+#
+# 2. Authorization Testing:
+# - Test that unauthorized requests are rejected (403)
+# - Test that authorized requests are accepted
+# - Test different user roles and permissions
+# - Verify permission checks are performed
+#
+# 3. Common Patterns:
+# - Mock current_account_with_tenant() to return test user
+# - Mock permission check functions
+# - Test with different user roles (admin, editor, operator, etc.)
+# - Test with different permission levels (only_me, all_team, etc.)
+#
+# ============================================================================
+
+
+# ============================================================================
+# Error Handling in API Tests
+# ============================================================================
+#
+# API endpoints should handle errors gracefully and return appropriate HTTP
+# status codes. Testing error handling involves:
+#
+# 1. Service Exception Mapping:
+# - ValueError -> 400 Bad Request
+# - NotFound -> 404 Not Found
+# - Forbidden -> 403 Forbidden
+# - Unauthorized -> 401 Unauthorized
+# - Internal errors -> 500 Internal Server Error
+#
+# 2. Validation Error Testing:
+# - Test missing required parameters
+# - Test invalid parameter types
+# - Test parameter range violations
+# - Test custom validation rules
+#
+# 3. Error Response Structure:
+# - Verify error status code
+# - Verify error message is included
+# - Verify error structure is consistent
+# - Verify error details are helpful
+#
+# ============================================================================
+
+
+# ============================================================================
+# Performance and Scalability Considerations
+# ============================================================================
+#
+# While unit tests focus on correctness, API tests should also consider:
+#
+# 1. Response Time:
+# - Tests should complete quickly
+# - Avoid actual database or network calls
+# - Use mocks for slow operations
+#
+# 2. Resource Usage:
+# - Tests should not consume excessive memory
+# - Tests should clean up after themselves
+# - Use fixtures for resource management
+#
+# 3. Test Isolation:
+# - Tests should not depend on each other
+# - Tests should not share state
+# - Each test should be independently runnable
+#
+# 4. Maintainability:
+# - Tests should be easy to understand
+# - Tests should be easy to modify
+# - Use descriptive names and comments
+# - Follow consistent patterns
+#
+# ============================================================================
From 4ca4493084795eb065e03421e0ca8a67e832213a Mon Sep 17 00:00:00 2001
From: aka James4u
Date: Wed, 26 Nov 2025 19:00:10 -0800
Subject: [PATCH 024/431] Add comprehensive unit tests for MetadataService
(dataset metadata CRUD operations and filtering) (#28748)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
---
.../unit_tests/services/dataset_metadata.py | 1068 +++++++++++++++++
1 file changed, 1068 insertions(+)
create mode 100644 api/tests/unit_tests/services/dataset_metadata.py
diff --git a/api/tests/unit_tests/services/dataset_metadata.py b/api/tests/unit_tests/services/dataset_metadata.py
new file mode 100644
index 0000000000..5ba18d8dc0
--- /dev/null
+++ b/api/tests/unit_tests/services/dataset_metadata.py
@@ -0,0 +1,1068 @@
+"""
+Comprehensive unit tests for MetadataService.
+
+This module contains extensive unit tests for the MetadataService class,
+which handles dataset metadata CRUD operations and filtering/querying functionality.
+
+The MetadataService provides methods for:
+- Creating, reading, updating, and deleting metadata fields
+- Managing built-in metadata fields
+- Updating document metadata values
+- Metadata filtering and querying operations
+- Lock management for concurrent metadata operations
+
+Metadata in Dify allows users to add custom fields to datasets and documents,
+enabling rich filtering and search capabilities. Metadata can be of various
+types (string, number, date, boolean, etc.) and can be used to categorize
+and filter documents within a dataset.
+
+This test suite ensures:
+- Correct creation of metadata fields with validation
+- Proper updating of metadata names and values
+- Accurate deletion of metadata fields
+- Built-in field management (enable/disable)
+- Document metadata updates (partial and full)
+- Lock management for concurrent operations
+- Metadata querying and filtering functionality
+
+================================================================================
+ARCHITECTURE OVERVIEW
+================================================================================
+
+The MetadataService is a critical component in the Dify platform's metadata
+management system. It serves as the primary interface for all metadata-related
+operations, including field definitions and document-level metadata values.
+
+Key Concepts:
+1. DatasetMetadata: Defines a metadata field for a dataset. Each metadata
+ field has a name, type, and is associated with a specific dataset.
+
+2. DatasetMetadataBinding: Links metadata fields to documents. This allows
+ tracking which documents have which metadata fields assigned.
+
+3. Document Metadata: The actual metadata values stored on documents. This
+ is stored as a JSON object in the document's doc_metadata field.
+
+4. Built-in Fields: System-defined metadata fields that are automatically
+ available when enabled (document_name, uploader, upload_date, etc.).
+
+5. Lock Management: Redis-based locking to prevent concurrent metadata
+ operations that could cause data corruption.
+
+================================================================================
+TESTING STRATEGY
+================================================================================
+
+This test suite follows a comprehensive testing strategy that covers:
+
+1. CRUD Operations:
+ - Creating metadata fields with validation
+ - Reading/retrieving metadata fields
+ - Updating metadata field names
+ - Deleting metadata fields
+
+2. Built-in Field Management:
+ - Enabling built-in fields
+ - Disabling built-in fields
+ - Getting built-in field definitions
+
+3. Document Metadata Operations:
+ - Updating document metadata (partial and full)
+ - Managing metadata bindings
+ - Handling built-in field updates
+
+4. Lock Management:
+ - Acquiring locks for dataset operations
+ - Acquiring locks for document operations
+ - Handling lock conflicts
+
+5. Error Handling:
+ - Validation errors (name length, duplicates)
+ - Not found errors
+ - Lock conflict errors
+
+================================================================================
+"""
+
+from unittest.mock import Mock, patch
+
+import pytest
+
+from core.rag.index_processor.constant.built_in_field import BuiltInField
+from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding
+from services.entities.knowledge_entities.knowledge_entities import (
+ MetadataArgs,
+ MetadataValue,
+)
+from services.metadata_service import MetadataService
+
+# ============================================================================
+# Test Data Factory
+# ============================================================================
+# The Test Data Factory pattern is used here to centralize the creation of
+# test objects and mock instances. This approach provides several benefits:
+#
+# 1. Consistency: All test objects are created using the same factory methods,
+# ensuring consistent structure across all tests.
+#
+# 2. Maintainability: If the structure of models changes, we only need to
+# update the factory methods rather than every individual test.
+#
+# 3. Reusability: Factory methods can be reused across multiple test classes,
+# reducing code duplication.
+#
+# 4. Readability: Tests become more readable when they use descriptive factory
+# method calls instead of complex object construction logic.
+#
+# ============================================================================
+
+
+class MetadataTestDataFactory:
+ """
+ Factory class for creating test data and mock objects for metadata service tests.
+
+ This factory provides static methods to create mock objects for:
+ - DatasetMetadata instances
+ - DatasetMetadataBinding instances
+ - Dataset instances
+ - Document instances
+ - MetadataArgs and MetadataOperationData entities
+ - User and tenant context
+
+ The factory methods help maintain consistency across tests and reduce
+ code duplication when setting up test scenarios.
+ """
+
+ @staticmethod
+ def create_metadata_mock(
+ metadata_id: str = "metadata-123",
+ dataset_id: str = "dataset-123",
+ tenant_id: str = "tenant-123",
+ name: str = "category",
+ metadata_type: str = "string",
+ created_by: str = "user-123",
+ **kwargs,
+ ) -> Mock:
+ """
+ Create a mock DatasetMetadata with specified attributes.
+
+ Args:
+ metadata_id: Unique identifier for the metadata field
+ dataset_id: ID of the dataset this metadata belongs to
+ tenant_id: Tenant identifier
+ name: Name of the metadata field
+ metadata_type: Type of metadata (string, number, date, etc.)
+ created_by: ID of the user who created the metadata
+ **kwargs: Additional attributes to set on the mock
+
+ Returns:
+ Mock object configured as a DatasetMetadata instance
+ """
+ metadata = Mock(spec=DatasetMetadata)
+ metadata.id = metadata_id
+ metadata.dataset_id = dataset_id
+ metadata.tenant_id = tenant_id
+ metadata.name = name
+ metadata.type = metadata_type
+ metadata.created_by = created_by
+ metadata.updated_by = None
+ metadata.updated_at = None
+ for key, value in kwargs.items():
+ setattr(metadata, key, value)
+ return metadata
+
+ @staticmethod
+ def create_metadata_binding_mock(
+ binding_id: str = "binding-123",
+ dataset_id: str = "dataset-123",
+ tenant_id: str = "tenant-123",
+ metadata_id: str = "metadata-123",
+ document_id: str = "document-123",
+ created_by: str = "user-123",
+ **kwargs,
+ ) -> Mock:
+ """
+ Create a mock DatasetMetadataBinding with specified attributes.
+
+ Args:
+ binding_id: Unique identifier for the binding
+ dataset_id: ID of the dataset
+ tenant_id: Tenant identifier
+ metadata_id: ID of the metadata field
+ document_id: ID of the document
+ created_by: ID of the user who created the binding
+ **kwargs: Additional attributes to set on the mock
+
+ Returns:
+ Mock object configured as a DatasetMetadataBinding instance
+ """
+ binding = Mock(spec=DatasetMetadataBinding)
+ binding.id = binding_id
+ binding.dataset_id = dataset_id
+ binding.tenant_id = tenant_id
+ binding.metadata_id = metadata_id
+ binding.document_id = document_id
+ binding.created_by = created_by
+ for key, value in kwargs.items():
+ setattr(binding, key, value)
+ return binding
+
+ @staticmethod
+ def create_dataset_mock(
+ dataset_id: str = "dataset-123",
+ tenant_id: str = "tenant-123",
+ built_in_field_enabled: bool = False,
+ doc_metadata: list | None = None,
+ **kwargs,
+ ) -> Mock:
+ """
+ Create a mock Dataset with specified attributes.
+
+ Args:
+ dataset_id: Unique identifier for the dataset
+ tenant_id: Tenant identifier
+ built_in_field_enabled: Whether built-in fields are enabled
+ doc_metadata: List of metadata field definitions
+ **kwargs: Additional attributes to set on the mock
+
+ Returns:
+ Mock object configured as a Dataset instance
+ """
+ dataset = Mock(spec=Dataset)
+ dataset.id = dataset_id
+ dataset.tenant_id = tenant_id
+ dataset.built_in_field_enabled = built_in_field_enabled
+ dataset.doc_metadata = doc_metadata or []
+ for key, value in kwargs.items():
+ setattr(dataset, key, value)
+ return dataset
+
+ @staticmethod
+ def create_document_mock(
+ document_id: str = "document-123",
+ dataset_id: str = "dataset-123",
+ name: str = "Test Document",
+ doc_metadata: dict | None = None,
+ uploader: str = "user-123",
+ data_source_type: str = "upload_file",
+ **kwargs,
+ ) -> Mock:
+ """
+ Create a mock Document with specified attributes.
+
+ Args:
+ document_id: Unique identifier for the document
+ dataset_id: ID of the dataset this document belongs to
+ name: Name of the document
+ doc_metadata: Dictionary of metadata values
+ uploader: ID of the user who uploaded the document
+ data_source_type: Type of data source
+ **kwargs: Additional attributes to set on the mock
+
+ Returns:
+ Mock object configured as a Document instance
+ """
+ document = Mock()
+ document.id = document_id
+ document.dataset_id = dataset_id
+ document.name = name
+ document.doc_metadata = doc_metadata or {}
+ document.uploader = uploader
+ document.data_source_type = data_source_type
+
+ # Mock datetime objects for upload_date and last_update_date
+
+ document.upload_date = Mock()
+ document.upload_date.timestamp.return_value = 1234567890.0
+ document.last_update_date = Mock()
+ document.last_update_date.timestamp.return_value = 1234567890.0
+
+ for key, value in kwargs.items():
+ setattr(document, key, value)
+ return document
+
+ @staticmethod
+ def create_metadata_args_mock(
+ name: str = "category",
+ metadata_type: str = "string",
+ ) -> Mock:
+ """
+ Create a mock MetadataArgs entity.
+
+ Args:
+ name: Name of the metadata field
+ metadata_type: Type of metadata
+
+ Returns:
+ Mock object configured as a MetadataArgs instance
+ """
+ metadata_args = Mock(spec=MetadataArgs)
+ metadata_args.name = name
+ metadata_args.type = metadata_type
+ return metadata_args
+
+ @staticmethod
+ def create_metadata_value_mock(
+ metadata_id: str = "metadata-123",
+ name: str = "category",
+ value: str = "test",
+ ) -> Mock:
+ """
+ Create a mock MetadataValue entity.
+
+ Args:
+ metadata_id: ID of the metadata field
+ name: Name of the metadata field
+ value: Value of the metadata
+
+ Returns:
+ Mock object configured as a MetadataValue instance
+ """
+ metadata_value = Mock(spec=MetadataValue)
+ metadata_value.id = metadata_id
+ metadata_value.name = name
+ metadata_value.value = value
+ return metadata_value
+
+
+# ============================================================================
+# Tests for create_metadata
+# ============================================================================
+
+
+class TestMetadataServiceCreateMetadata:
+ """
+ Comprehensive unit tests for MetadataService.create_metadata method.
+
+ This test class covers the metadata field creation functionality,
+ including validation, duplicate checking, and database operations.
+
+ The create_metadata method:
+ 1. Validates metadata name length (max 255 characters)
+ 2. Checks for duplicate metadata names within the dataset
+ 3. Checks for conflicts with built-in field names
+ 4. Creates a new DatasetMetadata instance
+ 5. Adds it to the database session and commits
+ 6. Returns the created metadata
+
+ Test scenarios include:
+ - Successful creation with valid data
+ - Name length validation
+ - Duplicate name detection
+ - Built-in field name conflicts
+ - Database transaction handling
+ """
+
+ @pytest.fixture
+ def mock_db_session(self):
+ """
+ Mock database session for testing database operations.
+
+ Provides a mocked database session that can be used to verify:
+ - Query construction and execution
+ - Add operations for new metadata
+ - Commit operations for transaction completion
+ """
+ with patch("services.metadata_service.db.session") as mock_db:
+ yield mock_db
+
+ @pytest.fixture
+ def mock_current_user(self):
+ """
+ Mock current user and tenant context.
+
+ Provides mocked current_account_with_tenant function that returns
+ a user and tenant ID for testing authentication and authorization.
+ """
+ with patch("services.metadata_service.current_account_with_tenant") as mock_get_user:
+ mock_user = Mock()
+ mock_user.id = "user-123"
+ mock_tenant_id = "tenant-123"
+ mock_get_user.return_value = (mock_user, mock_tenant_id)
+ yield mock_get_user
+
+ def test_create_metadata_success(self, mock_db_session, mock_current_user):
+ """
+ Test successful creation of a metadata field.
+
+ Verifies that when all validation passes, a new metadata field
+ is created and persisted to the database.
+
+ This test ensures:
+ - Metadata name validation passes
+ - No duplicate name exists
+ - No built-in field conflict
+ - New metadata is added to database
+ - Transaction is committed
+ - Created metadata is returned
+ """
+ # Arrange
+ dataset_id = "dataset-123"
+ metadata_args = MetadataTestDataFactory.create_metadata_args_mock(name="category", metadata_type="string")
+
+ # Mock query to return None (no existing metadata with same name)
+ mock_query = Mock()
+ mock_query.filter_by.return_value = mock_query
+ mock_query.first.return_value = None
+ mock_db_session.query.return_value = mock_query
+
+ # Mock BuiltInField enum iteration
+ with patch("services.metadata_service.BuiltInField") as mock_builtin:
+ mock_builtin.__iter__ = Mock(return_value=iter([]))
+
+ # Act
+ result = MetadataService.create_metadata(dataset_id, metadata_args)
+
+ # Assert
+ assert result is not None
+ assert isinstance(result, DatasetMetadata)
+
+ # Verify query was made to check for duplicates
+ mock_db_session.query.assert_called()
+ mock_query.filter_by.assert_called()
+
+ # Verify metadata was added and committed
+ mock_db_session.add.assert_called_once()
+ mock_db_session.commit.assert_called_once()
+
+ def test_create_metadata_name_too_long_error(self, mock_db_session, mock_current_user):
+ """
+ Test error handling when metadata name exceeds 255 characters.
+
+ Verifies that when a metadata name is longer than 255 characters,
+ a ValueError is raised with an appropriate message.
+
+ This test ensures:
+ - Name length validation is enforced
+ - Error message is clear and descriptive
+ - No database operations are performed
+ """
+ # Arrange
+ dataset_id = "dataset-123"
+ long_name = "a" * 256 # 256 characters (exceeds limit)
+ metadata_args = MetadataTestDataFactory.create_metadata_args_mock(name=long_name, metadata_type="string")
+
+ # Act & Assert
+ with pytest.raises(ValueError, match="Metadata name cannot exceed 255 characters"):
+ MetadataService.create_metadata(dataset_id, metadata_args)
+
+ # Verify no database operations were performed
+ mock_db_session.add.assert_not_called()
+ mock_db_session.commit.assert_not_called()
+
+ def test_create_metadata_duplicate_name_error(self, mock_db_session, mock_current_user):
+ """
+ Test error handling when metadata name already exists.
+
+ Verifies that when a metadata field with the same name already exists
+ in the dataset, a ValueError is raised.
+
+ This test ensures:
+ - Duplicate name detection works correctly
+ - Error message is clear
+ - No new metadata is created
+ """
+ # Arrange
+ dataset_id = "dataset-123"
+ metadata_args = MetadataTestDataFactory.create_metadata_args_mock(name="category", metadata_type="string")
+
+ # Mock existing metadata with same name
+ existing_metadata = MetadataTestDataFactory.create_metadata_mock(name="category")
+ mock_query = Mock()
+ mock_query.filter_by.return_value = mock_query
+ mock_query.first.return_value = existing_metadata
+ mock_db_session.query.return_value = mock_query
+
+ # Act & Assert
+ with pytest.raises(ValueError, match="Metadata name already exists"):
+ MetadataService.create_metadata(dataset_id, metadata_args)
+
+ # Verify no new metadata was added
+ mock_db_session.add.assert_not_called()
+ mock_db_session.commit.assert_not_called()
+
+ def test_create_metadata_builtin_field_conflict_error(self, mock_db_session, mock_current_user):
+ """
+ Test error handling when metadata name conflicts with built-in field.
+
+ Verifies that when a metadata name matches a built-in field name,
+ a ValueError is raised.
+
+ This test ensures:
+ - Built-in field name conflicts are detected
+ - Error message is clear
+ - No new metadata is created
+ """
+ # Arrange
+ dataset_id = "dataset-123"
+ metadata_args = MetadataTestDataFactory.create_metadata_args_mock(
+ name=BuiltInField.document_name, metadata_type="string"
+ )
+
+ # Mock query to return None (no duplicate in database)
+ mock_query = Mock()
+ mock_query.filter_by.return_value = mock_query
+ mock_query.first.return_value = None
+ mock_db_session.query.return_value = mock_query
+
+ # Mock BuiltInField to include the conflicting name
+ with patch("services.metadata_service.BuiltInField") as mock_builtin:
+ mock_field = Mock()
+ mock_field.value = BuiltInField.document_name
+ mock_builtin.__iter__ = Mock(return_value=iter([mock_field]))
+
+ # Act & Assert
+ with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields"):
+ MetadataService.create_metadata(dataset_id, metadata_args)
+
+ # Verify no new metadata was added
+ mock_db_session.add.assert_not_called()
+ mock_db_session.commit.assert_not_called()
+
+
+# ============================================================================
+# Tests for update_metadata_name
+# ============================================================================
+
+
+class TestMetadataServiceUpdateMetadataName:
+ """
+ Comprehensive unit tests for MetadataService.update_metadata_name method.
+
+ This test class covers the metadata field name update functionality,
+ including validation, duplicate checking, and document metadata updates.
+
+ The update_metadata_name method:
+ 1. Validates new name length (max 255 characters)
+ 2. Checks for duplicate names
+ 3. Checks for built-in field conflicts
+ 4. Acquires a lock for the dataset
+ 5. Updates the metadata name
+ 6. Updates all related document metadata
+ 7. Releases the lock
+ 8. Returns the updated metadata
+
+ Test scenarios include:
+ - Successful name update
+ - Name length validation
+ - Duplicate name detection
+ - Built-in field conflicts
+ - Lock management
+ - Document metadata updates
+ """
+
+ @pytest.fixture
+ def mock_db_session(self):
+ """Mock database session for testing."""
+ with patch("services.metadata_service.db.session") as mock_db:
+ yield mock_db
+
+ @pytest.fixture
+ def mock_current_user(self):
+ """Mock current user and tenant context."""
+ with patch("services.metadata_service.current_account_with_tenant") as mock_get_user:
+ mock_user = Mock()
+ mock_user.id = "user-123"
+ mock_tenant_id = "tenant-123"
+ mock_get_user.return_value = (mock_user, mock_tenant_id)
+ yield mock_get_user
+
+ @pytest.fixture
+ def mock_redis_client(self):
+ """Mock Redis client for lock management."""
+ with patch("services.metadata_service.redis_client") as mock_redis:
+ mock_redis.get.return_value = None # No existing lock
+ mock_redis.set.return_value = True
+ mock_redis.delete.return_value = True
+ yield mock_redis
+
+ def test_update_metadata_name_success(self, mock_db_session, mock_current_user, mock_redis_client):
+ """
+ Test successful update of metadata field name.
+
+ Verifies that when all validation passes, the metadata name is
+ updated and all related document metadata is updated accordingly.
+
+ This test ensures:
+ - Name validation passes
+ - Lock is acquired and released
+ - Metadata name is updated
+ - Related document metadata is updated
+ - Transaction is committed
+ """
+ # Arrange
+ dataset_id = "dataset-123"
+ metadata_id = "metadata-123"
+ new_name = "updated_category"
+
+ existing_metadata = MetadataTestDataFactory.create_metadata_mock(metadata_id=metadata_id, name="category")
+
+ # Mock query for duplicate check (no duplicate)
+ mock_query = Mock()
+ mock_query.filter_by.return_value = mock_query
+ mock_query.first.return_value = None
+ mock_db_session.query.return_value = mock_query
+
+ # Mock metadata retrieval
+ def query_side_effect(model):
+ if model == DatasetMetadata:
+ mock_meta_query = Mock()
+ mock_meta_query.filter_by.return_value = mock_meta_query
+ mock_meta_query.first.return_value = existing_metadata
+ return mock_meta_query
+ return mock_query
+
+ mock_db_session.query.side_effect = query_side_effect
+
+ # Mock no metadata bindings (no documents to update)
+ mock_binding_query = Mock()
+ mock_binding_query.filter_by.return_value = mock_binding_query
+ mock_binding_query.all.return_value = []
+
+ # Mock BuiltInField enum
+ with patch("services.metadata_service.BuiltInField") as mock_builtin:
+ mock_builtin.__iter__ = Mock(return_value=iter([]))
+
+ # Act
+ result = MetadataService.update_metadata_name(dataset_id, metadata_id, new_name)
+
+ # Assert
+ assert result is not None
+ assert result.name == new_name
+
+ # Verify lock was acquired and released
+ mock_redis_client.get.assert_called()
+ mock_redis_client.set.assert_called()
+ mock_redis_client.delete.assert_called()
+
+ # Verify metadata was updated and committed
+ mock_db_session.commit.assert_called()
+
+ def test_update_metadata_name_not_found_error(self, mock_db_session, mock_current_user, mock_redis_client):
+ """
+ Test error handling when metadata is not found.
+
+ Verifies that when the metadata ID doesn't exist, a ValueError
+ is raised with an appropriate message.
+
+ This test ensures:
+ - Not found error is handled correctly
+ - Lock is properly released even on error
+ - No updates are committed
+ """
+ # Arrange
+ dataset_id = "dataset-123"
+ metadata_id = "non-existent-metadata"
+ new_name = "updated_category"
+
+ # Mock query for duplicate check (no duplicate)
+ mock_query = Mock()
+ mock_query.filter_by.return_value = mock_query
+ mock_query.first.return_value = None
+ mock_db_session.query.return_value = mock_query
+
+ # Mock metadata retrieval to return None
+ def query_side_effect(model):
+ if model == DatasetMetadata:
+ mock_meta_query = Mock()
+ mock_meta_query.filter_by.return_value = mock_meta_query
+ mock_meta_query.first.return_value = None # Not found
+ return mock_meta_query
+ return mock_query
+
+ mock_db_session.query.side_effect = query_side_effect
+
+ # Mock BuiltInField enum
+ with patch("services.metadata_service.BuiltInField") as mock_builtin:
+ mock_builtin.__iter__ = Mock(return_value=iter([]))
+
+ # Act & Assert
+ with pytest.raises(ValueError, match="Metadata not found"):
+ MetadataService.update_metadata_name(dataset_id, metadata_id, new_name)
+
+ # Verify lock was released
+ mock_redis_client.delete.assert_called()
+
+
+# ============================================================================
+# Tests for delete_metadata
+# ============================================================================
+
+
+class TestMetadataServiceDeleteMetadata:
+ """
+ Comprehensive unit tests for MetadataService.delete_metadata method.
+
+ This test class covers the metadata field deletion functionality,
+ including document metadata cleanup and lock management.
+
+ The delete_metadata method:
+ 1. Acquires a lock for the dataset
+ 2. Retrieves the metadata to delete
+ 3. Deletes the metadata from the database
+ 4. Removes metadata from all related documents
+ 5. Releases the lock
+ 6. Returns the deleted metadata
+
+ Test scenarios include:
+ - Successful deletion
+ - Not found error handling
+ - Document metadata cleanup
+ - Lock management
+ """
+
+ @pytest.fixture
+ def mock_db_session(self):
+ """Mock database session for testing."""
+ with patch("services.metadata_service.db.session") as mock_db:
+ yield mock_db
+
+ @pytest.fixture
+ def mock_redis_client(self):
+ """Mock Redis client for lock management."""
+ with patch("services.metadata_service.redis_client") as mock_redis:
+ mock_redis.get.return_value = None
+ mock_redis.set.return_value = True
+ mock_redis.delete.return_value = True
+ yield mock_redis
+
+ def test_delete_metadata_success(self, mock_db_session, mock_redis_client):
+ """
+ Test successful deletion of a metadata field.
+
+ Verifies that when the metadata exists, it is deleted and all
+ related document metadata is cleaned up.
+
+ This test ensures:
+ - Lock is acquired and released
+ - Metadata is deleted from database
+ - Related document metadata is removed
+ - Transaction is committed
+ """
+ # Arrange
+ dataset_id = "dataset-123"
+ metadata_id = "metadata-123"
+
+ existing_metadata = MetadataTestDataFactory.create_metadata_mock(metadata_id=metadata_id, name="category")
+
+ # Mock metadata retrieval
+ mock_query = Mock()
+ mock_query.filter_by.return_value = mock_query
+ mock_query.first.return_value = existing_metadata
+ mock_db_session.query.return_value = mock_query
+
+ # Mock no metadata bindings (no documents to update)
+ mock_binding_query = Mock()
+ mock_binding_query.filter_by.return_value = mock_binding_query
+ mock_binding_query.all.return_value = []
+
+ # Act
+ result = MetadataService.delete_metadata(dataset_id, metadata_id)
+
+ # Assert
+ assert result == existing_metadata
+
+ # Verify lock was acquired and released
+ mock_redis_client.get.assert_called()
+ mock_redis_client.set.assert_called()
+ mock_redis_client.delete.assert_called()
+
+ # Verify metadata was deleted and committed
+ mock_db_session.delete.assert_called_once_with(existing_metadata)
+ mock_db_session.commit.assert_called()
+
+ def test_delete_metadata_not_found_error(self, mock_db_session, mock_redis_client):
+ """
+ Test error handling when metadata is not found.
+
+ Verifies that when the metadata ID doesn't exist, a ValueError
+ is raised and the lock is properly released.
+
+ This test ensures:
+ - Not found error is handled correctly
+ - Lock is released even on error
+ - No deletion is performed
+ """
+ # Arrange
+ dataset_id = "dataset-123"
+ metadata_id = "non-existent-metadata"
+
+ # Mock metadata retrieval to return None
+ mock_query = Mock()
+ mock_query.filter_by.return_value = mock_query
+ mock_query.first.return_value = None
+ mock_db_session.query.return_value = mock_query
+
+ # Act & Assert
+ with pytest.raises(ValueError, match="Metadata not found"):
+ MetadataService.delete_metadata(dataset_id, metadata_id)
+
+ # Verify lock was released
+ mock_redis_client.delete.assert_called()
+
+ # Verify no deletion was performed
+ mock_db_session.delete.assert_not_called()
+
+
+# ============================================================================
+# Tests for get_built_in_fields
+# ============================================================================
+
+
+class TestMetadataServiceGetBuiltInFields:
+ """
+ Comprehensive unit tests for MetadataService.get_built_in_fields method.
+
+ This test class covers the built-in field retrieval functionality.
+
+ The get_built_in_fields method:
+ 1. Returns a list of built-in field definitions
+ 2. Each definition includes name and type
+
+ Test scenarios include:
+ - Successful retrieval of built-in fields
+ - Correct field definitions
+ """
+
+ def test_get_built_in_fields_success(self):
+ """
+ Test successful retrieval of built-in fields.
+
+ Verifies that the method returns the correct list of built-in
+ field definitions with proper structure.
+
+ This test ensures:
+ - All built-in fields are returned
+ - Each field has name and type
+ - Field definitions are correct
+ """
+ # Act
+ result = MetadataService.get_built_in_fields()
+
+ # Assert
+ assert isinstance(result, list)
+ assert len(result) > 0
+
+ # Verify each field has required properties
+ for field in result:
+ assert "name" in field
+ assert "type" in field
+ assert isinstance(field["name"], str)
+ assert isinstance(field["type"], str)
+
+ # Verify specific built-in fields are present
+ field_names = [field["name"] for field in result]
+ assert BuiltInField.document_name in field_names
+ assert BuiltInField.uploader in field_names
+
+
+# ============================================================================
+# Tests for knowledge_base_metadata_lock_check
+# ============================================================================
+
+
+class TestMetadataServiceLockCheck:
+ """
+ Comprehensive unit tests for MetadataService.knowledge_base_metadata_lock_check method.
+
+ This test class covers the lock management functionality for preventing
+ concurrent metadata operations.
+
+ The knowledge_base_metadata_lock_check method:
+ 1. Checks if a lock exists for the dataset or document
+ 2. Raises ValueError if lock exists (operation in progress)
+ 3. Sets a lock with expiration time (3600 seconds)
+ 4. Supports both dataset-level and document-level locks
+
+ Test scenarios include:
+ - Successful lock acquisition
+ - Lock conflict detection
+ - Dataset-level locks
+ - Document-level locks
+ """
+
+ @pytest.fixture
+ def mock_redis_client(self):
+ """Mock Redis client for lock management."""
+ with patch("services.metadata_service.redis_client") as mock_redis:
+ yield mock_redis
+
+ def test_lock_check_dataset_success(self, mock_redis_client):
+ """
+ Test successful lock acquisition for dataset operations.
+
+ Verifies that when no lock exists, a new lock is acquired
+ for the dataset.
+
+ This test ensures:
+ - Lock check passes when no lock exists
+ - Lock is set with correct key and expiration
+ - No error is raised
+ """
+ # Arrange
+ dataset_id = "dataset-123"
+ mock_redis_client.get.return_value = None # No existing lock
+
+ # Act (should not raise)
+ MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
+
+ # Assert
+ mock_redis_client.get.assert_called_once_with(f"dataset_metadata_lock_{dataset_id}")
+ mock_redis_client.set.assert_called_once_with(f"dataset_metadata_lock_{dataset_id}", 1, ex=3600)
+
+ def test_lock_check_dataset_conflict_error(self, mock_redis_client):
+ """
+ Test error handling when dataset lock already exists.
+
+ Verifies that when a lock exists for the dataset, a ValueError
+ is raised with an appropriate message.
+
+ This test ensures:
+ - Lock conflict is detected
+ - Error message is clear
+ - No new lock is set
+ """
+ # Arrange
+ dataset_id = "dataset-123"
+ mock_redis_client.get.return_value = "1" # Lock exists
+
+ # Act & Assert
+ with pytest.raises(ValueError, match="Another knowledge base metadata operation is running"):
+ MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
+
+ # Verify lock was checked but not set
+ mock_redis_client.get.assert_called_once()
+ mock_redis_client.set.assert_not_called()
+
+ def test_lock_check_document_success(self, mock_redis_client):
+ """
+ Test successful lock acquisition for document operations.
+
+ Verifies that when no lock exists, a new lock is acquired
+ for the document.
+
+ This test ensures:
+ - Lock check passes when no lock exists
+ - Lock is set with correct key and expiration
+ - No error is raised
+ """
+ # Arrange
+ document_id = "document-123"
+ mock_redis_client.get.return_value = None # No existing lock
+
+ # Act (should not raise)
+ MetadataService.knowledge_base_metadata_lock_check(None, document_id)
+
+ # Assert
+ mock_redis_client.get.assert_called_once_with(f"document_metadata_lock_{document_id}")
+ mock_redis_client.set.assert_called_once_with(f"document_metadata_lock_{document_id}", 1, ex=3600)
+
+
+# ============================================================================
+# Tests for get_dataset_metadatas
+# ============================================================================
+
+
+class TestMetadataServiceGetDatasetMetadatas:
+ """
+ Comprehensive unit tests for MetadataService.get_dataset_metadatas method.
+
+ This test class covers the metadata retrieval functionality for datasets.
+
+ The get_dataset_metadatas method:
+ 1. Retrieves all metadata fields for a dataset
+ 2. Excludes built-in fields from the list
+ 3. Includes usage count for each metadata field
+ 4. Returns built-in field enabled status
+
+ Test scenarios include:
+ - Successful retrieval with metadata fields
+ - Empty metadata list
+ - Built-in field filtering
+ - Usage count calculation
+ """
+
+ @pytest.fixture
+ def mock_db_session(self):
+ """Mock database session for testing."""
+ with patch("services.metadata_service.db.session") as mock_db:
+ yield mock_db
+
+ def test_get_dataset_metadatas_success(self, mock_db_session):
+ """
+ Test successful retrieval of dataset metadata fields.
+
+ Verifies that all metadata fields are returned with correct
+ structure and usage counts.
+
+ This test ensures:
+ - All metadata fields are included
+ - Built-in fields are excluded
+ - Usage counts are calculated correctly
+ - Built-in field status is included
+ """
+ # Arrange
+ dataset = MetadataTestDataFactory.create_dataset_mock(
+ dataset_id="dataset-123",
+ built_in_field_enabled=True,
+ doc_metadata=[
+ {"id": "metadata-1", "name": "category", "type": "string"},
+ {"id": "metadata-2", "name": "priority", "type": "number"},
+ {"id": "built-in", "name": "document_name", "type": "string"},
+ ],
+ )
+
+ # Mock usage count queries
+ mock_query = Mock()
+ mock_query.filter_by.return_value = mock_query
+ mock_query.count.return_value = 5 # 5 documents use this metadata
+ mock_db_session.query.return_value = mock_query
+
+ # Act
+ result = MetadataService.get_dataset_metadatas(dataset)
+
+ # Assert
+ assert "doc_metadata" in result
+ assert "built_in_field_enabled" in result
+ assert result["built_in_field_enabled"] is True
+
+ # Verify built-in fields are excluded
+ metadata_ids = [meta["id"] for meta in result["doc_metadata"]]
+ assert "built-in" not in metadata_ids
+
+ # Verify all custom metadata fields are included
+ assert len(result["doc_metadata"]) == 2
+
+ # Verify usage counts are included
+ for meta in result["doc_metadata"]:
+ assert "count" in meta
+ assert meta["count"] == 5
+
+
+# ============================================================================
+# Additional Documentation and Notes
+# ============================================================================
+#
+# This test suite covers the core metadata CRUD operations and basic
+# filtering functionality. Additional test scenarios that could be added:
+#
+# 1. enable_built_in_field / disable_built_in_field:
+# - Testing built-in field enablement
+# - Testing built-in field disablement
+# - Testing document metadata updates when enabling/disabling
+#
+# 2. update_documents_metadata:
+# - Testing partial updates
+# - Testing full updates
+# - Testing metadata binding creation
+# - Testing built-in field updates
+#
+# 3. Metadata Filtering and Querying:
+# - Testing metadata-based document filtering
+# - Testing complex metadata queries
+# - Testing metadata value retrieval
+#
+# These scenarios are not currently implemented but could be added if needed
+# based on real-world usage patterns or discovered edge cases.
+#
+# ============================================================================
From 8d8800e632a417d21ebaa06e784e66022596a4fc Mon Sep 17 00:00:00 2001
From: majinghe <42570491+majinghe@users.noreply.github.com>
Date: Thu, 27 Nov 2025 11:01:14 +0800
Subject: [PATCH 025/431] upgrade docker compose milvus version to 2.6.0 to fix
installation error (#26618)
Co-authored-by: crazywoola <427733928@qq.com>
---
docker/docker-compose-template.yaml | 2 +-
docker/docker-compose.yaml | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml
index 975c92693a..703a60ef67 100644
--- a/docker/docker-compose-template.yaml
+++ b/docker/docker-compose-template.yaml
@@ -676,7 +676,7 @@ services:
milvus-standalone:
container_name: milvus-standalone
- image: milvusdb/milvus:v2.5.15
+ image: milvusdb/milvus:v2.6.3
profiles:
- milvus
command: ["milvus", "run", "standalone"]
diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml
index 17f33bbf72..de2e3943fe 100644
--- a/docker/docker-compose.yaml
+++ b/docker/docker-compose.yaml
@@ -1311,7 +1311,7 @@ services:
milvus-standalone:
container_name: milvus-standalone
- image: milvusdb/milvus:v2.5.15
+ image: milvusdb/milvus:v2.6.3
profiles:
- milvus
command: ["milvus", "run", "standalone"]
From f9b4c3134441f4c2547ad4613d2fb1800e7e1ab8 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?=
Date: Thu, 27 Nov 2025 11:22:49 +0800
Subject: [PATCH 026/431] fix: MCP tool time configuration not work (#28740)
---
web/app/components/tools/mcp/modal.tsx | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/web/app/components/tools/mcp/modal.tsx b/web/app/components/tools/mcp/modal.tsx
index 68f97703bf..836fc5e0aa 100644
--- a/web/app/components/tools/mcp/modal.tsx
+++ b/web/app/components/tools/mcp/modal.tsx
@@ -99,8 +99,8 @@ const MCPModal = ({
const [appIcon, setAppIcon] = useState(() => getIcon(data))
const [showAppIconPicker, setShowAppIconPicker] = useState(false)
const [serverIdentifier, setServerIdentifier] = React.useState(data?.server_identifier || '')
- const [timeout, setMcpTimeout] = React.useState(data?.timeout || 30)
- const [sseReadTimeout, setSseReadTimeout] = React.useState(data?.sse_read_timeout || 300)
+ const [timeout, setMcpTimeout] = React.useState(data?.configuration?.timeout || 30)
+ const [sseReadTimeout, setSseReadTimeout] = React.useState(data?.configuration?.sse_read_timeout || 300)
const [headers, setHeaders] = React.useState(
Object.entries(data?.masked_headers || {}).map(([key, value]) => ({ id: uuid(), key, value })),
)
@@ -118,8 +118,8 @@ const MCPModal = ({
setUrl(data.server_url || '')
setName(data.name || '')
setServerIdentifier(data.server_identifier || '')
- setMcpTimeout(data.timeout || 30)
- setSseReadTimeout(data.sse_read_timeout || 300)
+ setMcpTimeout(data.configuration?.timeout || 30)
+ setSseReadTimeout(data.configuration?.sse_read_timeout || 300)
setHeaders(Object.entries(data.masked_headers || {}).map(([key, value]) => ({ id: uuid(), key, value })))
setAppIcon(getIcon(data))
setIsDynamicRegistration(data.is_dynamic_registration)
From 6deabfdad38f4f7ed4ff9d2f945e2a8385316ea6 Mon Sep 17 00:00:00 2001
From: -LAN-
Date: Thu, 27 Nov 2025 11:23:20 +0800
Subject: [PATCH 027/431] Use naive_utc_now in graph engine tests (#28735)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
---
.../event_management/test_event_handlers.py | 5 ++---
.../graph_engine/orchestration/test_dispatcher.py | 10 +++++-----
2 files changed, 7 insertions(+), 8 deletions(-)
diff --git a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py
index 2b8f04979d..5d17b7a243 100644
--- a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py
+++ b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py
@@ -2,8 +2,6 @@
from __future__ import annotations
-from datetime import datetime
-
from core.workflow.enums import NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
@@ -16,6 +14,7 @@ from core.workflow.graph_events import NodeRunRetryEvent, NodeRunStartedEvent
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import RetryConfig
from core.workflow.runtime import GraphRuntimeState, VariablePool
+from libs.datetime_utils import naive_utc_now
class _StubEdgeProcessor:
@@ -75,7 +74,7 @@ def test_retry_does_not_emit_additional_start_event() -> None:
execution_id = "exec-1"
node_type = NodeType.CODE
- start_time = datetime.utcnow()
+ start_time = naive_utc_now()
start_event = NodeRunStartedEvent(
id=execution_id,
diff --git a/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py b/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py
index e6d4508fdf..c1fc4acd73 100644
--- a/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py
+++ b/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py
@@ -3,7 +3,6 @@
from __future__ import annotations
import queue
-from datetime import datetime
from unittest import mock
from core.workflow.entities.pause_reason import SchedulingPause
@@ -18,6 +17,7 @@ from core.workflow.graph_events import (
NodeRunSucceededEvent,
)
from core.workflow.node_events import NodeRunResult
+from libs.datetime_utils import naive_utc_now
def test_dispatcher_should_consume_remains_events_after_pause():
@@ -109,7 +109,7 @@ def _make_started_event() -> NodeRunStartedEvent:
node_id="node-1",
node_type=NodeType.CODE,
node_title="Test Node",
- start_at=datetime.utcnow(),
+ start_at=naive_utc_now(),
)
@@ -119,7 +119,7 @@ def _make_succeeded_event() -> NodeRunSucceededEvent:
node_id="node-1",
node_type=NodeType.CODE,
node_title="Test Node",
- start_at=datetime.utcnow(),
+ start_at=naive_utc_now(),
node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED),
)
@@ -153,7 +153,7 @@ def test_dispatcher_drain_event_queue():
node_id="node-1",
node_type=NodeType.CODE,
node_title="Code",
- start_at=datetime.utcnow(),
+ start_at=naive_utc_now(),
),
NodeRunPauseRequestedEvent(
id="pause-event",
@@ -165,7 +165,7 @@ def test_dispatcher_drain_event_queue():
id="success-event",
node_id="node-1",
node_type=NodeType.CODE,
- start_at=datetime.utcnow(),
+ start_at=naive_utc_now(),
node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED),
),
]
From 0309545ff15d2a79087a5875d99c036c301ccc74 Mon Sep 17 00:00:00 2001
From: Gritty_dev <101377478+codomposer@users.noreply.github.com>
Date: Wed, 26 Nov 2025 22:23:55 -0500
Subject: [PATCH 028/431] Feat/test script of workflow service (#28726)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
---
.../services/test_workflow_service.py | 1114 +++++++++++++++++
1 file changed, 1114 insertions(+)
create mode 100644 api/tests/unit_tests/services/test_workflow_service.py
diff --git a/api/tests/unit_tests/services/test_workflow_service.py b/api/tests/unit_tests/services/test_workflow_service.py
new file mode 100644
index 0000000000..ae5b194afb
--- /dev/null
+++ b/api/tests/unit_tests/services/test_workflow_service.py
@@ -0,0 +1,1114 @@
+"""
+Unit tests for WorkflowService.
+
+This test suite covers:
+- Workflow creation from template
+- Workflow validation (graph and features structure)
+- Draft/publish transitions
+- Version management
+- Execution triggering
+"""
+
+import json
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from core.workflow.enums import NodeType
+from libs.datetime_utils import naive_utc_now
+from models.model import App, AppMode
+from models.workflow import Workflow, WorkflowType
+from services.errors.app import IsDraftWorkflowError, TriggerNodeLimitExceededError, WorkflowHashNotEqualError
+from services.errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError
+from services.workflow_service import WorkflowService
+
+
+class TestWorkflowAssociatedDataFactory:
+ """
+ Factory class for creating test data and mock objects for workflow service tests.
+
+ This factory provides reusable methods to create mock objects for:
+ - App models with configurable attributes
+ - Workflow models with graph and feature configurations
+ - Account models for user authentication
+ - Valid workflow graph structures for testing
+
+ All factory methods return MagicMock objects that simulate database models
+ without requiring actual database connections.
+ """
+
+ @staticmethod
+ def create_app_mock(
+ app_id: str = "app-123",
+ tenant_id: str = "tenant-456",
+ mode: str = AppMode.WORKFLOW.value,
+ workflow_id: str | None = None,
+ **kwargs,
+ ) -> MagicMock:
+ """
+ Create a mock App with specified attributes.
+
+ Args:
+ app_id: Unique identifier for the app
+ tenant_id: Workspace/tenant identifier
+ mode: App mode (workflow, chat, completion, etc.)
+ workflow_id: Optional ID of the published workflow
+ **kwargs: Additional attributes to set on the mock
+
+ Returns:
+ MagicMock object configured as an App model
+ """
+ app = MagicMock(spec=App)
+ app.id = app_id
+ app.tenant_id = tenant_id
+ app.mode = mode
+ app.workflow_id = workflow_id
+ for key, value in kwargs.items():
+ setattr(app, key, value)
+ return app
+
+ @staticmethod
+ def create_workflow_mock(
+ workflow_id: str = "workflow-789",
+ tenant_id: str = "tenant-456",
+ app_id: str = "app-123",
+ version: str = Workflow.VERSION_DRAFT,
+ workflow_type: str = WorkflowType.WORKFLOW.value,
+ graph: dict | None = None,
+ features: dict | None = None,
+ unique_hash: str | None = None,
+ **kwargs,
+ ) -> MagicMock:
+ """
+ Create a mock Workflow with specified attributes.
+
+ Args:
+ workflow_id: Unique identifier for the workflow
+ tenant_id: Workspace/tenant identifier
+ app_id: Associated app identifier
+ version: Workflow version ("draft" or timestamp-based version)
+ workflow_type: Type of workflow (workflow, chat, rag-pipeline)
+ graph: Workflow graph structure containing nodes and edges
+ features: Feature configuration (file upload, text-to-speech, etc.)
+ unique_hash: Hash for optimistic locking during updates
+ **kwargs: Additional attributes to set on the mock
+
+ Returns:
+ MagicMock object configured as a Workflow model with graph/features
+ """
+ workflow = MagicMock(spec=Workflow)
+ workflow.id = workflow_id
+ workflow.tenant_id = tenant_id
+ workflow.app_id = app_id
+ workflow.version = version
+ workflow.type = workflow_type
+
+ # Set up graph and features with defaults if not provided
+ # Graph contains the workflow structure (nodes and their connections)
+ if graph is None:
+ graph = {"nodes": [], "edges": []}
+ # Features contain app-level configurations like file upload settings
+ if features is None:
+ features = {}
+
+ workflow.graph = json.dumps(graph)
+ workflow.features = json.dumps(features)
+ workflow.graph_dict = graph
+ workflow.features_dict = features
+ workflow.unique_hash = unique_hash or "test-hash-123"
+ workflow.environment_variables = []
+ workflow.conversation_variables = []
+ workflow.rag_pipeline_variables = []
+ workflow.created_by = "user-123"
+ workflow.updated_by = None
+ workflow.created_at = naive_utc_now()
+ workflow.updated_at = naive_utc_now()
+
+ # Mock walk_nodes method to iterate through workflow nodes
+ # This is used by the service to traverse and validate workflow structure
+ def walk_nodes_side_effect(specific_node_type=None):
+ nodes = graph.get("nodes", [])
+ # Filter by node type if specified (e.g., only LLM nodes)
+ if specific_node_type:
+ return (
+ (node["id"], node["data"])
+ for node in nodes
+ if node.get("data", {}).get("type") == specific_node_type.value
+ )
+ # Return all nodes if no filter specified
+ return ((node["id"], node["data"]) for node in nodes)
+
+ workflow.walk_nodes = walk_nodes_side_effect
+
+ for key, value in kwargs.items():
+ setattr(workflow, key, value)
+ return workflow
+
+ @staticmethod
+ def create_account_mock(account_id: str = "user-123", **kwargs) -> MagicMock:
+ """Create a mock Account with specified attributes."""
+ account = MagicMock()
+ account.id = account_id
+ for key, value in kwargs.items():
+ setattr(account, key, value)
+ return account
+
+ @staticmethod
+ def create_valid_workflow_graph(include_start: bool = True, include_trigger: bool = False) -> dict:
+ """
+ Create a valid workflow graph structure for testing.
+
+ Args:
+ include_start: Whether to include a START node (for regular workflows)
+ include_trigger: Whether to include trigger nodes (webhook, schedule, etc.)
+
+ Returns:
+ Dictionary containing nodes and edges arrays representing workflow graph
+
+ Note:
+ Start nodes and trigger nodes cannot coexist in the same workflow.
+ This is validated by the workflow service.
+ """
+ nodes = []
+ edges = []
+
+ # Add START node for regular workflows (user-initiated)
+ if include_start:
+ nodes.append(
+ {
+ "id": "start",
+ "data": {
+ "type": NodeType.START.value,
+ "title": "START",
+ "variables": [],
+ },
+ }
+ )
+
+ # Add trigger node for event-driven workflows (webhook, schedule, etc.)
+ if include_trigger:
+ nodes.append(
+ {
+ "id": "trigger-1",
+ "data": {
+ "type": "http-request",
+ "title": "HTTP Request Trigger",
+ },
+ }
+ )
+
+ # Add an LLM node as a sample processing node
+ # This represents an AI model interaction in the workflow
+ nodes.append(
+ {
+ "id": "llm-1",
+ "data": {
+ "type": NodeType.LLM.value,
+ "title": "LLM",
+ "model": {
+ "provider": "openai",
+ "name": "gpt-4",
+ },
+ },
+ }
+ )
+
+ return {"nodes": nodes, "edges": edges}
+
+
+class TestWorkflowService:
+ """
+ Comprehensive unit tests for WorkflowService methods.
+
+ This test suite covers:
+ - Workflow creation from template
+ - Workflow validation (graph and features)
+ - Draft/publish transitions
+ - Version management
+ - Workflow deletion and error handling
+ """
+
+ @pytest.fixture
+ def workflow_service(self):
+ """
+ Create a WorkflowService instance with mocked dependencies.
+
+ This fixture patches the database to avoid real database connections
+ during testing. Each test gets a fresh service instance.
+ """
+ with patch("services.workflow_service.db"):
+ service = WorkflowService()
+ return service
+
+ @pytest.fixture
+ def mock_db_session(self):
+ """
+ Mock database session for testing database operations.
+
+ Provides mock implementations of:
+ - session.add(): Adding new records
+ - session.commit(): Committing transactions
+ - session.query(): Querying database
+ - session.execute(): Executing SQL statements
+ """
+ with patch("services.workflow_service.db") as mock_db:
+ mock_session = MagicMock()
+ mock_db.session = mock_session
+ mock_session.add = MagicMock()
+ mock_session.commit = MagicMock()
+ mock_session.query = MagicMock()
+ mock_session.execute = MagicMock()
+ yield mock_db
+
+ @pytest.fixture
+ def mock_sqlalchemy_session(self):
+ """
+ Mock SQLAlchemy Session for publish_workflow tests.
+
+ This is a separate fixture because publish_workflow uses
+ SQLAlchemy's Session class directly rather than the Flask-SQLAlchemy
+ db.session object.
+ """
+ mock_session = MagicMock()
+ mock_session.add = MagicMock()
+ mock_session.commit = MagicMock()
+ mock_session.scalar = MagicMock()
+ return mock_session
+
+ # ==================== Workflow Existence Tests ====================
+ # These tests verify the service can check if a draft workflow exists
+
+ def test_is_workflow_exist_returns_true(self, workflow_service, mock_db_session):
+ """
+ Test is_workflow_exist returns True when draft workflow exists.
+
+ Verifies that the service correctly identifies when an app has a draft workflow.
+ This is used to determine whether to create or update a workflow.
+ """
+ app = TestWorkflowAssociatedDataFactory.create_app_mock()
+
+ # Mock the database query to return True
+ mock_db_session.session.execute.return_value.scalar_one.return_value = True
+
+ result = workflow_service.is_workflow_exist(app)
+
+ assert result is True
+
+ def test_is_workflow_exist_returns_false(self, workflow_service, mock_db_session):
+ """Test is_workflow_exist returns False when no draft workflow exists."""
+ app = TestWorkflowAssociatedDataFactory.create_app_mock()
+
+ # Mock the database query to return False
+ mock_db_session.session.execute.return_value.scalar_one.return_value = False
+
+ result = workflow_service.is_workflow_exist(app)
+
+ assert result is False
+
+ # ==================== Get Draft Workflow Tests ====================
+ # These tests verify retrieval of draft workflows (version="draft")
+
+ def test_get_draft_workflow_success(self, workflow_service, mock_db_session):
+ """
+ Test get_draft_workflow returns draft workflow successfully.
+
+ Draft workflows are the working copy that users edit before publishing.
+ Each app can have only one draft workflow at a time.
+ """
+ app = TestWorkflowAssociatedDataFactory.create_app_mock()
+ mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock()
+
+ # Mock database query
+ mock_query = MagicMock()
+ mock_db_session.session.query.return_value = mock_query
+ mock_query.where.return_value.first.return_value = mock_workflow
+
+ result = workflow_service.get_draft_workflow(app)
+
+ assert result == mock_workflow
+
+ def test_get_draft_workflow_returns_none(self, workflow_service, mock_db_session):
+ """Test get_draft_workflow returns None when no draft exists."""
+ app = TestWorkflowAssociatedDataFactory.create_app_mock()
+
+ # Mock database query to return None
+ mock_query = MagicMock()
+ mock_db_session.session.query.return_value = mock_query
+ mock_query.where.return_value.first.return_value = None
+
+ result = workflow_service.get_draft_workflow(app)
+
+ assert result is None
+
+ def test_get_draft_workflow_with_workflow_id(self, workflow_service, mock_db_session):
+ """Test get_draft_workflow with workflow_id calls get_published_workflow_by_id."""
+ app = TestWorkflowAssociatedDataFactory.create_app_mock()
+ workflow_id = "workflow-123"
+ mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(version="v1")
+
+ # Mock database query
+ mock_query = MagicMock()
+ mock_db_session.session.query.return_value = mock_query
+ mock_query.where.return_value.first.return_value = mock_workflow
+
+ result = workflow_service.get_draft_workflow(app, workflow_id=workflow_id)
+
+ assert result == mock_workflow
+
+ # ==================== Get Published Workflow Tests ====================
+ # These tests verify retrieval of published workflows (versioned snapshots)
+
+ def test_get_published_workflow_by_id_success(self, workflow_service, mock_db_session):
+ """Test get_published_workflow_by_id returns published workflow."""
+ app = TestWorkflowAssociatedDataFactory.create_app_mock()
+ workflow_id = "workflow-123"
+ mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(workflow_id=workflow_id, version="v1")
+
+ # Mock database query
+ mock_query = MagicMock()
+ mock_db_session.session.query.return_value = mock_query
+ mock_query.where.return_value.first.return_value = mock_workflow
+
+ result = workflow_service.get_published_workflow_by_id(app, workflow_id)
+
+ assert result == mock_workflow
+
+ def test_get_published_workflow_by_id_raises_error_for_draft(self, workflow_service, mock_db_session):
+ """
+ Test get_published_workflow_by_id raises error when workflow is draft.
+
+ This prevents using draft workflows in production contexts where only
+ published, stable versions should be used (e.g., API execution).
+ """
+ app = TestWorkflowAssociatedDataFactory.create_app_mock()
+ workflow_id = "workflow-123"
+ mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(
+ workflow_id=workflow_id, version=Workflow.VERSION_DRAFT
+ )
+
+ # Mock database query
+ mock_query = MagicMock()
+ mock_db_session.session.query.return_value = mock_query
+ mock_query.where.return_value.first.return_value = mock_workflow
+
+ with pytest.raises(IsDraftWorkflowError):
+ workflow_service.get_published_workflow_by_id(app, workflow_id)
+
+ def test_get_published_workflow_by_id_returns_none(self, workflow_service, mock_db_session):
+ """Test get_published_workflow_by_id returns None when workflow not found."""
+ app = TestWorkflowAssociatedDataFactory.create_app_mock()
+ workflow_id = "nonexistent-workflow"
+
+ # Mock database query to return None
+ mock_query = MagicMock()
+ mock_db_session.session.query.return_value = mock_query
+ mock_query.where.return_value.first.return_value = None
+
+ result = workflow_service.get_published_workflow_by_id(app, workflow_id)
+
+ assert result is None
+
+ def test_get_published_workflow_success(self, workflow_service, mock_db_session):
+ """Test get_published_workflow returns published workflow."""
+ workflow_id = "workflow-123"
+ app = TestWorkflowAssociatedDataFactory.create_app_mock(workflow_id=workflow_id)
+ mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(workflow_id=workflow_id, version="v1")
+
+ # Mock database query
+ mock_query = MagicMock()
+ mock_db_session.session.query.return_value = mock_query
+ mock_query.where.return_value.first.return_value = mock_workflow
+
+ result = workflow_service.get_published_workflow(app)
+
+ assert result == mock_workflow
+
+ def test_get_published_workflow_returns_none_when_no_workflow_id(self, workflow_service):
+ """Test get_published_workflow returns None when app has no workflow_id."""
+ app = TestWorkflowAssociatedDataFactory.create_app_mock(workflow_id=None)
+
+ result = workflow_service.get_published_workflow(app)
+
+ assert result is None
+
+ # ==================== Sync Draft Workflow Tests ====================
+ # These tests verify creating and updating draft workflows with validation
+
+ def test_sync_draft_workflow_creates_new_draft(self, workflow_service, mock_db_session):
+ """
+ Test sync_draft_workflow creates new draft workflow when none exists.
+
+ When a user first creates a workflow app, this creates the initial draft.
+ The draft is validated before creation to ensure graph and features are valid.
+ """
+ app = TestWorkflowAssociatedDataFactory.create_app_mock()
+ account = TestWorkflowAssociatedDataFactory.create_account_mock()
+ graph = TestWorkflowAssociatedDataFactory.create_valid_workflow_graph()
+ features = {"file_upload": {"enabled": False}}
+
+ # Mock get_draft_workflow to return None (no existing draft)
+ # This simulates the first time a workflow is created for an app
+ mock_query = MagicMock()
+ mock_db_session.session.query.return_value = mock_query
+ mock_query.where.return_value.first.return_value = None
+
+ with (
+ patch.object(workflow_service, "validate_features_structure"),
+ patch.object(workflow_service, "validate_graph_structure"),
+ patch("services.workflow_service.app_draft_workflow_was_synced"),
+ ):
+ result = workflow_service.sync_draft_workflow(
+ app_model=app,
+ graph=graph,
+ features=features,
+ unique_hash=None,
+ account=account,
+ environment_variables=[],
+ conversation_variables=[],
+ )
+
+ # Verify workflow was added to session
+ mock_db_session.session.add.assert_called_once()
+ mock_db_session.session.commit.assert_called_once()
+
+ def test_sync_draft_workflow_updates_existing_draft(self, workflow_service, mock_db_session):
+ """
+ Test sync_draft_workflow updates existing draft workflow.
+
+ When users edit their workflow, this updates the existing draft.
+ The unique_hash is used for optimistic locking to prevent conflicts.
+ """
+ app = TestWorkflowAssociatedDataFactory.create_app_mock()
+ account = TestWorkflowAssociatedDataFactory.create_account_mock()
+ graph = TestWorkflowAssociatedDataFactory.create_valid_workflow_graph()
+ features = {"file_upload": {"enabled": False}}
+ unique_hash = "test-hash-123"
+
+ # Mock existing draft workflow
+ mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(unique_hash=unique_hash)
+
+ mock_query = MagicMock()
+ mock_db_session.session.query.return_value = mock_query
+ mock_query.where.return_value.first.return_value = mock_workflow
+
+ with (
+ patch.object(workflow_service, "validate_features_structure"),
+ patch.object(workflow_service, "validate_graph_structure"),
+ patch("services.workflow_service.app_draft_workflow_was_synced"),
+ ):
+ result = workflow_service.sync_draft_workflow(
+ app_model=app,
+ graph=graph,
+ features=features,
+ unique_hash=unique_hash,
+ account=account,
+ environment_variables=[],
+ conversation_variables=[],
+ )
+
+ # Verify workflow was updated
+ assert mock_workflow.graph == json.dumps(graph)
+ assert mock_workflow.features == json.dumps(features)
+ assert mock_workflow.updated_by == account.id
+ mock_db_session.session.commit.assert_called_once()
+
+ def test_sync_draft_workflow_raises_hash_not_equal_error(self, workflow_service, mock_db_session):
+ """
+ Test sync_draft_workflow raises error when hash doesn't match.
+
+ This implements optimistic locking: if the workflow was modified by another
+ user/session since it was loaded, the hash won't match and the update fails.
+ This prevents overwriting concurrent changes.
+ """
+ app = TestWorkflowAssociatedDataFactory.create_app_mock()
+ account = TestWorkflowAssociatedDataFactory.create_account_mock()
+ graph = TestWorkflowAssociatedDataFactory.create_valid_workflow_graph()
+ features = {}
+
+ # Mock existing draft workflow with different hash
+ mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(unique_hash="old-hash")
+
+ mock_query = MagicMock()
+ mock_db_session.session.query.return_value = mock_query
+ mock_query.where.return_value.first.return_value = mock_workflow
+
+ with pytest.raises(WorkflowHashNotEqualError):
+ workflow_service.sync_draft_workflow(
+ app_model=app,
+ graph=graph,
+ features=features,
+ unique_hash="new-hash",
+ account=account,
+ environment_variables=[],
+ conversation_variables=[],
+ )
+
+ # ==================== Workflow Validation Tests ====================
+ # These tests verify graph structure and feature configuration validation
+
+ def test_validate_graph_structure_empty_graph(self, workflow_service):
+ """Test validate_graph_structure accepts empty graph."""
+ graph = {"nodes": []}
+
+ # Should not raise any exception
+ workflow_service.validate_graph_structure(graph)
+
+ def test_validate_graph_structure_valid_graph(self, workflow_service):
+ """Test validate_graph_structure accepts valid graph."""
+ graph = TestWorkflowAssociatedDataFactory.create_valid_workflow_graph()
+
+ # Should not raise any exception
+ workflow_service.validate_graph_structure(graph)
+
+ def test_validate_graph_structure_start_and_trigger_coexist_raises_error(self, workflow_service):
+ """
+ Test validate_graph_structure raises error when start and trigger nodes coexist.
+
+ Workflows can be either:
+ - User-initiated (with START node): User provides input to start execution
+ - Event-driven (with trigger nodes): External events trigger execution
+
+ These two patterns cannot be mixed in a single workflow.
+ """
+ # Create a graph with both start and trigger nodes
+ # Use actual trigger node types: trigger-webhook, trigger-schedule, trigger-plugin
+ graph = {
+ "nodes": [
+ {
+ "id": "start",
+ "data": {
+ "type": "start",
+ "title": "START",
+ },
+ },
+ {
+ "id": "trigger-1",
+ "data": {
+ "type": "trigger-webhook",
+ "title": "Webhook Trigger",
+ },
+ },
+ ],
+ "edges": [],
+ }
+
+ with pytest.raises(ValueError, match="Start node and trigger nodes cannot coexist"):
+ workflow_service.validate_graph_structure(graph)
+
+ def test_validate_features_structure_workflow_mode(self, workflow_service):
+ """
+ Test validate_features_structure for workflow mode.
+
+ Different app modes have different feature configurations.
+ This ensures the features match the expected schema for workflow apps.
+ """
+ app = TestWorkflowAssociatedDataFactory.create_app_mock(mode=AppMode.WORKFLOW.value)
+ features = {"file_upload": {"enabled": False}}
+
+ with patch("services.workflow_service.WorkflowAppConfigManager.config_validate") as mock_validate:
+ workflow_service.validate_features_structure(app, features)
+ mock_validate.assert_called_once_with(
+ tenant_id=app.tenant_id, config=features, only_structure_validate=True
+ )
+
+ def test_validate_features_structure_advanced_chat_mode(self, workflow_service):
+ """Test validate_features_structure for advanced chat mode."""
+ app = TestWorkflowAssociatedDataFactory.create_app_mock(mode=AppMode.ADVANCED_CHAT.value)
+ features = {"opening_statement": "Hello"}
+
+ with patch("services.workflow_service.AdvancedChatAppConfigManager.config_validate") as mock_validate:
+ workflow_service.validate_features_structure(app, features)
+ mock_validate.assert_called_once_with(
+ tenant_id=app.tenant_id, config=features, only_structure_validate=True
+ )
+
+ def test_validate_features_structure_invalid_mode_raises_error(self, workflow_service):
+ """Test validate_features_structure raises error for invalid mode."""
+ app = TestWorkflowAssociatedDataFactory.create_app_mock(mode=AppMode.COMPLETION.value)
+ features = {}
+
+ with pytest.raises(ValueError, match="Invalid app mode"):
+ workflow_service.validate_features_structure(app, features)
+
+ # ==================== Publish Workflow Tests ====================
+ # These tests verify creating published versions from draft workflows
+
+ def test_publish_workflow_success(self, workflow_service, mock_sqlalchemy_session):
+ """
+ Test publish_workflow creates new published version.
+
+ Publishing creates a timestamped snapshot of the draft workflow.
+ This allows users to:
+ - Roll back to previous versions
+ - Use stable versions in production
+ - Continue editing draft without affecting published version
+ """
+ app = TestWorkflowAssociatedDataFactory.create_app_mock()
+ account = TestWorkflowAssociatedDataFactory.create_account_mock()
+ graph = TestWorkflowAssociatedDataFactory.create_valid_workflow_graph()
+
+ # Mock draft workflow
+ mock_draft = TestWorkflowAssociatedDataFactory.create_workflow_mock(version=Workflow.VERSION_DRAFT, graph=graph)
+ mock_sqlalchemy_session.scalar.return_value = mock_draft
+
+ with (
+ patch.object(workflow_service, "validate_graph_structure"),
+ patch("services.workflow_service.app_published_workflow_was_updated"),
+ patch("services.workflow_service.dify_config") as mock_config,
+ patch("services.workflow_service.Workflow.new") as mock_workflow_new,
+ ):
+ # Disable billing
+ mock_config.BILLING_ENABLED = False
+
+ # Mock Workflow.new to return a new workflow
+ mock_new_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(version="v1")
+ mock_workflow_new.return_value = mock_new_workflow
+
+ result = workflow_service.publish_workflow(
+ session=mock_sqlalchemy_session,
+ app_model=app,
+ account=account,
+ marked_name="Version 1",
+ marked_comment="Initial release",
+ )
+
+ # Verify workflow was added to session
+ mock_sqlalchemy_session.add.assert_called_once_with(mock_new_workflow)
+ assert result == mock_new_workflow
+
+ def test_publish_workflow_no_draft_raises_error(self, workflow_service, mock_sqlalchemy_session):
+ """
+ Test publish_workflow raises error when no draft exists.
+
+ Cannot publish if there's no draft to publish from.
+ Users must create and save a draft before publishing.
+ """
+ app = TestWorkflowAssociatedDataFactory.create_app_mock()
+ account = TestWorkflowAssociatedDataFactory.create_account_mock()
+
+ # Mock no draft workflow
+ mock_sqlalchemy_session.scalar.return_value = None
+
+ with pytest.raises(ValueError, match="No valid workflow found"):
+ workflow_service.publish_workflow(session=mock_sqlalchemy_session, app_model=app, account=account)
+
+ def test_publish_workflow_trigger_limit_exceeded(self, workflow_service, mock_sqlalchemy_session):
+ """
+ Test publish_workflow raises error when trigger node limit exceeded in SANDBOX plan.
+
+ Free/sandbox tier users have limits on the number of trigger nodes.
+ This prevents resource abuse while allowing users to test the feature.
+ The limit is enforced at publish time, not during draft editing.
+ """
+ app = TestWorkflowAssociatedDataFactory.create_app_mock()
+ account = TestWorkflowAssociatedDataFactory.create_account_mock()
+
+ # Create graph with 3 trigger nodes (exceeds SANDBOX limit of 2)
+ # Trigger nodes enable event-driven automation which consumes resources
+ graph = {
+ "nodes": [
+ {"id": "trigger-1", "data": {"type": "trigger-webhook"}},
+ {"id": "trigger-2", "data": {"type": "trigger-schedule"}},
+ {"id": "trigger-3", "data": {"type": "trigger-plugin"}},
+ ],
+ "edges": [],
+ }
+ mock_draft = TestWorkflowAssociatedDataFactory.create_workflow_mock(version=Workflow.VERSION_DRAFT, graph=graph)
+ mock_sqlalchemy_session.scalar.return_value = mock_draft
+
+ with (
+ patch.object(workflow_service, "validate_graph_structure"),
+ patch("services.workflow_service.dify_config") as mock_config,
+ patch("services.workflow_service.BillingService") as MockBillingService,
+ patch("services.workflow_service.app_published_workflow_was_updated"),
+ ):
+ # Enable billing and set SANDBOX plan
+ mock_config.BILLING_ENABLED = True
+ MockBillingService.get_info.return_value = {"subscription": {"plan": "sandbox"}}
+
+ with pytest.raises(TriggerNodeLimitExceededError):
+ workflow_service.publish_workflow(session=mock_sqlalchemy_session, app_model=app, account=account)
+
+ # ==================== Version Management Tests ====================
+ # These tests verify listing and managing published workflow versions
+
+ def test_get_all_published_workflow_with_pagination(self, workflow_service):
+ """
+ Test get_all_published_workflow returns paginated results.
+
+ Apps can have many published versions over time.
+ Pagination prevents loading all versions at once, improving performance.
+ """
+ app = TestWorkflowAssociatedDataFactory.create_app_mock(workflow_id="workflow-123")
+
+ # Mock workflows
+ mock_workflows = [
+ TestWorkflowAssociatedDataFactory.create_workflow_mock(workflow_id=f"workflow-{i}", version=f"v{i}")
+ for i in range(5)
+ ]
+
+ mock_session = MagicMock()
+ mock_session.scalars.return_value.all.return_value = mock_workflows
+
+ with patch("services.workflow_service.select") as mock_select:
+ mock_stmt = MagicMock()
+ mock_select.return_value = mock_stmt
+ mock_stmt.where.return_value = mock_stmt
+ mock_stmt.order_by.return_value = mock_stmt
+ mock_stmt.limit.return_value = mock_stmt
+ mock_stmt.offset.return_value = mock_stmt
+
+ workflows, has_more = workflow_service.get_all_published_workflow(
+ session=mock_session, app_model=app, page=1, limit=10, user_id=None
+ )
+
+ assert len(workflows) == 5
+ assert has_more is False
+
+ def test_get_all_published_workflow_has_more(self, workflow_service):
+ """
+ Test get_all_published_workflow indicates has_more when results exceed limit.
+
+ The has_more flag tells the UI whether to show a "Load More" button.
+ This is determined by fetching limit+1 records and checking if we got that many.
+ """
+ app = TestWorkflowAssociatedDataFactory.create_app_mock(workflow_id="workflow-123")
+
+ # Mock 11 workflows (limit is 10, so has_more should be True)
+ mock_workflows = [
+ TestWorkflowAssociatedDataFactory.create_workflow_mock(workflow_id=f"workflow-{i}", version=f"v{i}")
+ for i in range(11)
+ ]
+
+ mock_session = MagicMock()
+ mock_session.scalars.return_value.all.return_value = mock_workflows
+
+ with patch("services.workflow_service.select") as mock_select:
+ mock_stmt = MagicMock()
+ mock_select.return_value = mock_stmt
+ mock_stmt.where.return_value = mock_stmt
+ mock_stmt.order_by.return_value = mock_stmt
+ mock_stmt.limit.return_value = mock_stmt
+ mock_stmt.offset.return_value = mock_stmt
+
+ workflows, has_more = workflow_service.get_all_published_workflow(
+ session=mock_session, app_model=app, page=1, limit=10, user_id=None
+ )
+
+ assert len(workflows) == 10
+ assert has_more is True
+
+ def test_get_all_published_workflow_no_workflow_id(self, workflow_service):
+ """Test get_all_published_workflow returns empty when app has no workflow_id."""
+ app = TestWorkflowAssociatedDataFactory.create_app_mock(workflow_id=None)
+ mock_session = MagicMock()
+
+ workflows, has_more = workflow_service.get_all_published_workflow(
+ session=mock_session, app_model=app, page=1, limit=10, user_id=None
+ )
+
+ assert workflows == []
+ assert has_more is False
+
+ # ==================== Update Workflow Tests ====================
+ # These tests verify updating workflow metadata (name, comments, etc.)
+
+ def test_update_workflow_success(self, workflow_service):
+ """
+ Test update_workflow updates workflow attributes.
+
+ Allows updating metadata like marked_name and marked_comment
+ without creating a new version. Only specific fields are allowed
+ to prevent accidental modification of workflow logic.
+ """
+ workflow_id = "workflow-123"
+ tenant_id = "tenant-456"
+ account_id = "user-123"
+ mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(workflow_id=workflow_id)
+
+ mock_session = MagicMock()
+ mock_session.scalar.return_value = mock_workflow
+
+ with patch("services.workflow_service.select") as mock_select:
+ mock_stmt = MagicMock()
+ mock_select.return_value = mock_stmt
+ mock_stmt.where.return_value = mock_stmt
+
+ result = workflow_service.update_workflow(
+ session=mock_session,
+ workflow_id=workflow_id,
+ tenant_id=tenant_id,
+ account_id=account_id,
+ data={"marked_name": "Updated Name", "marked_comment": "Updated Comment"},
+ )
+
+ assert result == mock_workflow
+ assert mock_workflow.marked_name == "Updated Name"
+ assert mock_workflow.marked_comment == "Updated Comment"
+ assert mock_workflow.updated_by == account_id
+
+ def test_update_workflow_not_found(self, workflow_service):
+ """Test update_workflow returns None when workflow not found."""
+ mock_session = MagicMock()
+ mock_session.scalar.return_value = None
+
+ with patch("services.workflow_service.select") as mock_select:
+ mock_stmt = MagicMock()
+ mock_select.return_value = mock_stmt
+ mock_stmt.where.return_value = mock_stmt
+
+ result = workflow_service.update_workflow(
+ session=mock_session,
+ workflow_id="nonexistent",
+ tenant_id="tenant-456",
+ account_id="user-123",
+ data={"marked_name": "Test"},
+ )
+
+ assert result is None
+
+ # ==================== Delete Workflow Tests ====================
+ # These tests verify workflow deletion with safety checks
+
+ def test_delete_workflow_success(self, workflow_service):
+ """
+ Test delete_workflow successfully deletes a published workflow.
+
+ Users can delete old published versions they no longer need.
+ This helps manage storage and keeps the version list clean.
+ """
+ workflow_id = "workflow-123"
+ tenant_id = "tenant-456"
+ mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(workflow_id=workflow_id, version="v1")
+
+ mock_session = MagicMock()
+ # Mock successful deletion scenario:
+ # 1. Workflow exists
+ # 2. No app is currently using it
+ # 3. Not published as a tool
+ mock_session.scalar.side_effect = [mock_workflow, None] # workflow exists, no app using it
+ mock_session.query.return_value.where.return_value.first.return_value = None # no tool provider
+
+ with patch("services.workflow_service.select") as mock_select:
+ mock_stmt = MagicMock()
+ mock_select.return_value = mock_stmt
+ mock_stmt.where.return_value = mock_stmt
+
+ result = workflow_service.delete_workflow(
+ session=mock_session, workflow_id=workflow_id, tenant_id=tenant_id
+ )
+
+ assert result is True
+ mock_session.delete.assert_called_once_with(mock_workflow)
+
+ def test_delete_workflow_draft_raises_error(self, workflow_service):
+ """
+ Test delete_workflow raises error when trying to delete draft.
+
+ Draft workflows cannot be deleted - they're the working copy.
+ Users can only delete published versions to clean up old snapshots.
+ """
+ workflow_id = "workflow-123"
+ tenant_id = "tenant-456"
+ mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(
+ workflow_id=workflow_id, version=Workflow.VERSION_DRAFT
+ )
+
+ mock_session = MagicMock()
+ mock_session.scalar.return_value = mock_workflow
+
+ with patch("services.workflow_service.select") as mock_select:
+ mock_stmt = MagicMock()
+ mock_select.return_value = mock_stmt
+ mock_stmt.where.return_value = mock_stmt
+
+ with pytest.raises(DraftWorkflowDeletionError, match="Cannot delete draft workflow"):
+ workflow_service.delete_workflow(session=mock_session, workflow_id=workflow_id, tenant_id=tenant_id)
+
+ def test_delete_workflow_in_use_by_app_raises_error(self, workflow_service):
+ """
+ Test delete_workflow raises error when workflow is in use by app.
+
+ Cannot delete a workflow version that's currently published/active.
+ This would break the app for users. Must publish a different version first.
+ """
+ workflow_id = "workflow-123"
+ tenant_id = "tenant-456"
+ mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(workflow_id=workflow_id, version="v1")
+ mock_app = TestWorkflowAssociatedDataFactory.create_app_mock(workflow_id=workflow_id)
+
+ mock_session = MagicMock()
+ mock_session.scalar.side_effect = [mock_workflow, mock_app]
+
+ with patch("services.workflow_service.select") as mock_select:
+ mock_stmt = MagicMock()
+ mock_select.return_value = mock_stmt
+ mock_stmt.where.return_value = mock_stmt
+
+ with pytest.raises(WorkflowInUseError, match="currently in use by app"):
+ workflow_service.delete_workflow(session=mock_session, workflow_id=workflow_id, tenant_id=tenant_id)
+
+ def test_delete_workflow_published_as_tool_raises_error(self, workflow_service):
+ """
+ Test delete_workflow raises error when workflow is published as tool.
+
+ Workflows can be published as reusable tools for other workflows.
+ Cannot delete a version that's being used as a tool, as this would
+ break other workflows that depend on it.
+ """
+ workflow_id = "workflow-123"
+ tenant_id = "tenant-456"
+ mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(workflow_id=workflow_id, version="v1")
+ mock_tool_provider = MagicMock()
+
+ mock_session = MagicMock()
+ mock_session.scalar.side_effect = [mock_workflow, None] # workflow exists, no app using it
+ mock_session.query.return_value.where.return_value.first.return_value = mock_tool_provider
+
+ with patch("services.workflow_service.select") as mock_select:
+ mock_stmt = MagicMock()
+ mock_select.return_value = mock_stmt
+ mock_stmt.where.return_value = mock_stmt
+
+ with pytest.raises(WorkflowInUseError, match="published as a tool"):
+ workflow_service.delete_workflow(session=mock_session, workflow_id=workflow_id, tenant_id=tenant_id)
+
+ def test_delete_workflow_not_found_raises_error(self, workflow_service):
+ """Test delete_workflow raises error when workflow not found."""
+ workflow_id = "nonexistent"
+ tenant_id = "tenant-456"
+
+ mock_session = MagicMock()
+ mock_session.scalar.return_value = None
+
+ with patch("services.workflow_service.select") as mock_select:
+ mock_stmt = MagicMock()
+ mock_select.return_value = mock_stmt
+ mock_stmt.where.return_value = mock_stmt
+
+ with pytest.raises(ValueError, match="not found"):
+ workflow_service.delete_workflow(session=mock_session, workflow_id=workflow_id, tenant_id=tenant_id)
+
+ # ==================== Get Default Block Config Tests ====================
+ # These tests verify retrieval of default node configurations
+
+ def test_get_default_block_configs(self, workflow_service):
+ """
+ Test get_default_block_configs returns list of default configs.
+
+ Returns default configurations for all available node types.
+ Used by the UI to populate the node palette and provide sensible defaults
+ when users add new nodes to their workflow.
+ """
+ with patch("services.workflow_service.NODE_TYPE_CLASSES_MAPPING") as mock_mapping:
+ # Mock node class with default config
+ mock_node_class = MagicMock()
+ mock_node_class.get_default_config.return_value = {"type": "llm", "config": {}}
+
+ mock_mapping.values.return_value = [{"latest": mock_node_class}]
+
+ with patch("services.workflow_service.LATEST_VERSION", "latest"):
+ result = workflow_service.get_default_block_configs()
+
+ assert len(result) > 0
+
+ def test_get_default_block_config_for_node_type(self, workflow_service):
+ """
+ Test get_default_block_config returns config for specific node type.
+
+ Returns the default configuration for a specific node type (e.g., LLM, HTTP).
+ This includes default values for all required and optional parameters.
+ """
+ with (
+ patch("services.workflow_service.NODE_TYPE_CLASSES_MAPPING") as mock_mapping,
+ patch("services.workflow_service.LATEST_VERSION", "latest"),
+ ):
+ # Mock node class with default config
+ mock_node_class = MagicMock()
+ mock_config = {"type": "llm", "config": {"provider": "openai"}}
+ mock_node_class.get_default_config.return_value = mock_config
+
+ # Create a mock mapping that includes NodeType.LLM
+ mock_mapping.__contains__.return_value = True
+ mock_mapping.__getitem__.return_value = {"latest": mock_node_class}
+
+ result = workflow_service.get_default_block_config(NodeType.LLM.value)
+
+ assert result == mock_config
+ mock_node_class.get_default_config.assert_called_once()
+
+ def test_get_default_block_config_invalid_node_type(self, workflow_service):
+ """Test get_default_block_config returns empty dict for invalid node type."""
+ with patch("services.workflow_service.NODE_TYPE_CLASSES_MAPPING") as mock_mapping:
+ # Mock mapping to not contain the node type
+ mock_mapping.__contains__.return_value = False
+
+ # Use a valid NodeType but one that's not in the mapping
+ result = workflow_service.get_default_block_config(NodeType.LLM.value)
+
+ assert result == {}
+
+ # ==================== Workflow Conversion Tests ====================
+ # These tests verify converting basic apps to workflow apps
+
+ def test_convert_to_workflow_from_chat_app(self, workflow_service):
+ """
+ Test convert_to_workflow converts chat app to workflow.
+
+ Allows users to migrate from simple chat apps to advanced workflow apps.
+ The conversion creates equivalent workflow nodes from the chat configuration,
+ giving users more control and customization options.
+ """
+ app = TestWorkflowAssociatedDataFactory.create_app_mock(mode=AppMode.CHAT.value)
+ account = TestWorkflowAssociatedDataFactory.create_account_mock()
+ args = {
+ "name": "Converted Workflow",
+ "icon_type": "emoji",
+ "icon": "🤖",
+ "icon_background": "#FFEAD5",
+ }
+
+ with patch("services.workflow_service.WorkflowConverter") as MockConverter:
+ mock_converter = MockConverter.return_value
+ mock_new_app = TestWorkflowAssociatedDataFactory.create_app_mock(mode=AppMode.WORKFLOW.value)
+ mock_converter.convert_to_workflow.return_value = mock_new_app
+
+ result = workflow_service.convert_to_workflow(app, account, args)
+
+ assert result == mock_new_app
+ mock_converter.convert_to_workflow.assert_called_once()
+
+ def test_convert_to_workflow_from_completion_app(self, workflow_service):
+ """
+ Test convert_to_workflow converts completion app to workflow.
+
+ Similar to chat conversion, but for completion-style apps.
+ Completion apps are simpler (single prompt-response), so the
+ conversion creates a basic workflow with fewer nodes.
+ """
+ app = TestWorkflowAssociatedDataFactory.create_app_mock(mode=AppMode.COMPLETION.value)
+ account = TestWorkflowAssociatedDataFactory.create_account_mock()
+ args = {"name": "Converted Workflow"}
+
+ with patch("services.workflow_service.WorkflowConverter") as MockConverter:
+ mock_converter = MockConverter.return_value
+ mock_new_app = TestWorkflowAssociatedDataFactory.create_app_mock(mode=AppMode.WORKFLOW.value)
+ mock_converter.convert_to_workflow.return_value = mock_new_app
+
+ result = workflow_service.convert_to_workflow(app, account, args)
+
+ assert result == mock_new_app
+
+ def test_convert_to_workflow_invalid_mode_raises_error(self, workflow_service):
+ """
+ Test convert_to_workflow raises error for invalid app mode.
+
+ Only chat and completion apps can be converted to workflows.
+ Apps that are already workflows or have other modes cannot be converted.
+ """
+ app = TestWorkflowAssociatedDataFactory.create_app_mock(mode=AppMode.WORKFLOW.value)
+ account = TestWorkflowAssociatedDataFactory.create_account_mock()
+ args = {}
+
+ with pytest.raises(ValueError, match="not supported convert to workflow"):
+ workflow_service.convert_to_workflow(app, account, args)
From 7a7fea40d9eb5f15f18d8fd55f6ef8dc9166e1bf Mon Sep 17 00:00:00 2001
From: Gritty_dev <101377478+codomposer@users.noreply.github.com>
Date: Thu, 27 Nov 2025 01:39:33 -0500
Subject: [PATCH 029/431] feat: complete test script of dataset retrieval
(#28762)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
---
.../unit_tests/core/rag/retrieval/__init__.py | 0
.../rag/retrieval/test_dataset_retrieval.py | 1696 +++++++++++++++++
2 files changed, 1696 insertions(+)
create mode 100644 api/tests/unit_tests/core/rag/retrieval/__init__.py
create mode 100644 api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py
diff --git a/api/tests/unit_tests/core/rag/retrieval/__init__.py b/api/tests/unit_tests/core/rag/retrieval/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py
new file mode 100644
index 0000000000..0163e42992
--- /dev/null
+++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py
@@ -0,0 +1,1696 @@
+"""
+Unit tests for dataset retrieval functionality.
+
+This module provides comprehensive test coverage for the RetrievalService class,
+which is responsible for retrieving relevant documents from datasets using various
+search strategies.
+
+Core Retrieval Mechanisms Tested:
+==================================
+1. **Vector Search (Semantic Search)**
+ - Uses embedding vectors to find semantically similar documents
+ - Supports score thresholds and top-k limiting
+ - Can filter by document IDs and metadata
+
+2. **Keyword Search**
+ - Traditional text-based search using keyword matching
+ - Handles special characters and query escaping
+ - Supports document filtering
+
+3. **Full-Text Search**
+ - BM25-based full-text search for text matching
+ - Used in hybrid search scenarios
+
+4. **Hybrid Search**
+ - Combines vector and full-text search results
+ - Implements deduplication to avoid duplicate chunks
+ - Uses DataPostProcessor for score merging with configurable weights
+
+5. **Score Merging Algorithms**
+ - Deduplication based on doc_id
+ - Retains higher-scoring duplicates
+ - Supports weighted score combination
+
+6. **Metadata Filtering**
+ - Filters documents based on metadata conditions
+ - Supports document ID filtering
+
+Test Architecture:
+==================
+- **Fixtures**: Provide reusable mock objects (datasets, documents, Flask app)
+- **Mocking Strategy**: Mock at the method level (embedding_search, keyword_search, etc.)
+ rather than at the class level to properly simulate the ThreadPoolExecutor behavior
+- **Pattern**: All tests follow Arrange-Act-Assert (AAA) pattern
+- **Isolation**: Each test is independent and doesn't rely on external state
+
+Running Tests:
+==============
+ # Run all tests in this module
+ uv run --project api pytest \
+ api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py -v
+
+ # Run a specific test class
+ uv run --project api pytest \
+ api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py::TestRetrievalService -v
+
+ # Run a specific test
+ uv run --project api pytest \
+ api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py::\
+TestRetrievalService::test_vector_search_basic -v
+
+Notes:
+======
+- The RetrievalService uses ThreadPoolExecutor for concurrent search operations
+- Tests mock the individual search methods to avoid threading complexity
+- All mocked search methods modify the all_documents list in-place
+- Score thresholds and top-k limits are enforced by the search methods
+"""
+
+from unittest.mock import MagicMock, Mock, patch
+from uuid import uuid4
+
+import pytest
+
+from core.rag.datasource.retrieval_service import RetrievalService
+from core.rag.models.document import Document
+from core.rag.retrieval.retrieval_methods import RetrievalMethod
+from models.dataset import Dataset
+
+# ==================== Helper Functions ====================
+
+
+def create_mock_document(
+ content: str,
+ doc_id: str,
+ score: float = 0.8,
+ provider: str = "dify",
+ additional_metadata: dict | None = None,
+) -> Document:
+ """
+ Create a mock Document object for testing.
+
+ This helper function standardizes document creation across tests,
+ ensuring consistent structure and reducing code duplication.
+
+ Args:
+ content: The text content of the document
+ doc_id: Unique identifier for the document chunk
+ score: Relevance score (0.0 to 1.0)
+ provider: Document provider ("dify" or "external")
+ additional_metadata: Optional extra metadata fields
+
+ Returns:
+ Document: A properly structured Document object
+
+ Example:
+ >>> doc = create_mock_document("Python is great", "doc1", score=0.95)
+ >>> assert doc.metadata["score"] == 0.95
+ """
+ metadata = {
+ "doc_id": doc_id,
+ "document_id": str(uuid4()),
+ "dataset_id": str(uuid4()),
+ "score": score,
+ }
+
+ # Merge additional metadata if provided
+ if additional_metadata:
+ metadata.update(additional_metadata)
+
+ return Document(
+ page_content=content,
+ metadata=metadata,
+ provider=provider,
+ )
+
+
+def create_side_effect_for_search(documents: list[Document]):
+ """
+ Create a side effect function for mocking search methods.
+
+ This helper creates a function that simulates how RetrievalService
+ search methods work - they modify the all_documents list in-place
+ rather than returning values directly.
+
+ Args:
+ documents: List of documents to add to all_documents
+
+ Returns:
+ Callable: A side effect function compatible with mock.side_effect
+
+ Example:
+ >>> mock_search.side_effect = create_side_effect_for_search([doc1, doc2])
+
+ Note:
+ The RetrievalService uses ThreadPoolExecutor which submits tasks that
+ modify a shared all_documents list. This pattern simulates that behavior.
+ """
+
+ def side_effect(flask_app, dataset_id, query, top_k, *args, all_documents, exceptions, **kwargs):
+ """
+ Side effect function that mimics search method behavior.
+
+ Args:
+ flask_app: Flask application context (unused in mock)
+ dataset_id: ID of the dataset being searched
+ query: Search query string
+ top_k: Maximum number of results
+ all_documents: Shared list to append results to
+ exceptions: Shared list to append errors to
+ **kwargs: Additional arguments (score_threshold, document_ids_filter, etc.)
+ """
+ all_documents.extend(documents)
+
+ return side_effect
+
+
+def create_side_effect_with_exception(error_message: str):
+ """
+ Create a side effect function that adds an exception to the exceptions list.
+
+ Used for testing error handling in the RetrievalService.
+
+ Args:
+ error_message: The error message to add to exceptions
+
+ Returns:
+ Callable: A side effect function that simulates an error
+
+ Example:
+ >>> mock_search.side_effect = create_side_effect_with_exception("Search failed")
+ """
+
+ def side_effect(flask_app, dataset_id, query, top_k, *args, all_documents, exceptions, **kwargs):
+ """Add error message to exceptions list."""
+ exceptions.append(error_message)
+
+ return side_effect
+
+
+class TestRetrievalService:
+ """
+ Comprehensive test suite for RetrievalService class.
+
+ This test class validates all retrieval methods and their interactions,
+ including edge cases, error handling, and integration scenarios.
+
+ Test Organization:
+ ==================
+ 1. Fixtures (lines ~190-240)
+ - mock_dataset: Standard dataset configuration
+ - sample_documents: Reusable test documents with varying scores
+ - mock_flask_app: Flask application context
+ - mock_thread_pool: Synchronous executor for deterministic testing
+
+ 2. Vector Search Tests (lines ~240-350)
+ - Basic functionality
+ - Document filtering
+ - Empty results
+ - Metadata filtering
+ - Score thresholds
+
+ 3. Keyword Search Tests (lines ~350-450)
+ - Basic keyword matching
+ - Special character handling
+ - Document filtering
+
+ 4. Hybrid Search Tests (lines ~450-640)
+ - Vector + full-text combination
+ - Deduplication logic
+ - Weighted score merging
+
+ 5. Full-Text Search Tests (lines ~640-680)
+ - BM25-based search
+
+ 6. Score Merging Tests (lines ~680-790)
+ - Deduplication algorithms
+ - Score comparison
+ - Provider-specific handling
+
+ 7. Error Handling Tests (lines ~790-920)
+ - Empty queries
+ - Non-existent datasets
+ - Exception propagation
+
+ 8. Additional Tests (lines ~920-1080)
+ - Query escaping
+ - Reranking integration
+ - Top-K limiting
+
+ Mocking Strategy:
+ =================
+ Tests mock at the method level (embedding_search, keyword_search, etc.)
+ rather than the underlying Vector/Keyword classes. This approach:
+ - Avoids complexity of mocking ThreadPoolExecutor behavior
+ - Provides clearer test intent
+ - Makes tests more maintainable
+ - Properly simulates the in-place list modification pattern
+
+ Common Patterns:
+ ================
+ 1. **Arrange**: Set up mocks with side_effect functions
+ 2. **Act**: Call RetrievalService.retrieve() with specific parameters
+ 3. **Assert**: Verify results, mock calls, and side effects
+
+ Example Test Structure:
+ ```python
+ def test_example(self, mock_get_dataset, mock_search, mock_dataset):
+ # Arrange: Set up test data and mocks
+ mock_get_dataset.return_value = mock_dataset
+ mock_search.side_effect = create_side_effect_for_search([doc1, doc2])
+
+ # Act: Execute the method under test
+ results = RetrievalService.retrieve(...)
+
+ # Assert: Verify expectations
+ assert len(results) == 2
+ mock_search.assert_called_once()
+ ```
+ """
+
+ @pytest.fixture
+ def mock_dataset(self) -> Dataset:
+ """
+ Create a mock Dataset object for testing.
+
+ Returns:
+ Dataset: Mock dataset with standard configuration
+ """
+ dataset = Mock(spec=Dataset)
+ dataset.id = str(uuid4())
+ dataset.tenant_id = str(uuid4())
+ dataset.name = "test_dataset"
+ dataset.indexing_technique = "high_quality"
+ dataset.embedding_model = "text-embedding-ada-002"
+ dataset.embedding_model_provider = "openai"
+ dataset.retrieval_model = {
+ "search_method": RetrievalMethod.SEMANTIC_SEARCH,
+ "reranking_enable": False,
+ "top_k": 4,
+ "score_threshold_enabled": False,
+ }
+ return dataset
+
+ @pytest.fixture
+ def sample_documents(self) -> list[Document]:
+ """
+ Create sample documents for testing retrieval results.
+
+ Returns:
+ list[Document]: List of mock documents with varying scores
+ """
+ return [
+ Document(
+ page_content="Python is a high-level programming language.",
+ metadata={
+ "doc_id": "doc1",
+ "document_id": str(uuid4()),
+ "dataset_id": str(uuid4()),
+ "score": 0.95,
+ },
+ provider="dify",
+ ),
+ Document(
+ page_content="JavaScript is widely used for web development.",
+ metadata={
+ "doc_id": "doc2",
+ "document_id": str(uuid4()),
+ "dataset_id": str(uuid4()),
+ "score": 0.85,
+ },
+ provider="dify",
+ ),
+ Document(
+ page_content="Machine learning is a subset of artificial intelligence.",
+ metadata={
+ "doc_id": "doc3",
+ "document_id": str(uuid4()),
+ "dataset_id": str(uuid4()),
+ "score": 0.75,
+ },
+ provider="dify",
+ ),
+ ]
+
+ @pytest.fixture
+ def mock_flask_app(self):
+ """
+ Create a mock Flask application context.
+
+ Returns:
+ Mock: Flask app mock with app_context
+ """
+ app = MagicMock()
+ app.app_context.return_value.__enter__ = Mock()
+ app.app_context.return_value.__exit__ = Mock()
+ return app
+
+ @pytest.fixture(autouse=True)
+ def mock_thread_pool(self):
+ """
+ Mock ThreadPoolExecutor to run tasks synchronously in tests.
+
+ The RetrievalService uses ThreadPoolExecutor to run search operations
+ concurrently (embedding_search, keyword_search, full_text_index_search).
+ In tests, we want synchronous execution for:
+ - Deterministic behavior
+ - Easier debugging
+ - Avoiding race conditions
+ - Simpler assertions
+
+ How it works:
+ -------------
+ 1. Intercepts ThreadPoolExecutor creation
+ 2. Replaces submit() to execute functions immediately (synchronously)
+ 3. Functions modify shared all_documents list in-place
+ 4. Mocks concurrent.futures.wait() since tasks are already done
+
+ Why this approach:
+ ------------------
+ - RetrievalService.retrieve() creates a ThreadPoolExecutor context
+ - It submits search tasks that modify all_documents list
+ - concurrent.futures.wait() waits for all tasks to complete
+ - By executing synchronously, we avoid threading complexity in tests
+
+ Returns:
+ Mock: Mocked ThreadPoolExecutor that executes tasks synchronously
+ """
+ with patch("core.rag.datasource.retrieval_service.ThreadPoolExecutor") as mock_executor:
+ # Store futures to track submitted tasks (for debugging if needed)
+ futures_list = []
+
+ def sync_submit(fn, *args, **kwargs):
+ """
+ Synchronous replacement for ThreadPoolExecutor.submit().
+
+ Instead of scheduling the function for async execution,
+ we execute it immediately in the current thread.
+
+ Args:
+ fn: The function to execute (e.g., embedding_search)
+ *args, **kwargs: Arguments to pass to the function
+
+ Returns:
+ Mock: A mock Future object
+ """
+ future = Mock()
+ try:
+ # Execute immediately - this modifies all_documents in place
+ # The function signature is: fn(flask_app, dataset_id, query,
+ # top_k, all_documents, exceptions, ...)
+ fn(*args, **kwargs)
+ future.result.return_value = None
+ future.exception.return_value = None
+ except Exception as e:
+ # If function raises, store exception in future
+ future.result.return_value = None
+ future.exception.return_value = e
+
+ futures_list.append(future)
+ return future
+
+ # Set up the mock executor instance
+ mock_executor_instance = Mock()
+ mock_executor_instance.submit = sync_submit
+
+ # Configure context manager behavior (__enter__ and __exit__)
+ mock_executor.return_value.__enter__.return_value = mock_executor_instance
+ mock_executor.return_value.__exit__.return_value = None
+
+ # Mock concurrent.futures.wait to do nothing since tasks are already done
+ # In real code, this waits for all futures to complete
+ # In tests, futures complete immediately, so wait is a no-op
+ with patch("core.rag.datasource.retrieval_service.concurrent.futures.wait"):
+ yield mock_executor
+
+ # ==================== Vector Search Tests ====================
+
+ @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
+ def test_vector_search_basic(self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents):
+ """
+ Test basic vector/semantic search functionality.
+
+ This test validates the core vector search flow:
+ 1. Dataset is retrieved from database
+ 2. embedding_search is called via ThreadPoolExecutor
+ 3. Documents are added to shared all_documents list
+ 4. Results are returned to caller
+
+ Verifies:
+ - Vector search is called with correct parameters
+ - Results are returned in expected format
+ - Score threshold is applied correctly
+ - Documents maintain their metadata and scores
+ """
+ # ==================== ARRANGE ====================
+ # Set up the mock dataset that will be "retrieved" from database
+ mock_get_dataset.return_value = mock_dataset
+
+ # Create a side effect function that simulates embedding_search behavior
+ # In the real implementation, embedding_search:
+ # 1. Gets the dataset
+ # 2. Creates a Vector instance
+ # 3. Calls search_by_vector with embeddings
+ # 4. Extends all_documents with results
+ def side_effect_embedding_search(
+ flask_app,
+ dataset_id,
+ query,
+ top_k,
+ score_threshold,
+ reranking_model,
+ all_documents,
+ retrieval_method,
+ exceptions,
+ document_ids_filter=None,
+ ):
+ """Simulate embedding_search adding documents to the shared list."""
+ all_documents.extend(sample_documents)
+
+ mock_embedding_search.side_effect = side_effect_embedding_search
+
+ # Define test parameters
+ query = "What is Python?" # Natural language query
+ top_k = 3 # Maximum number of results to return
+ score_threshold = 0.7 # Minimum relevance score (0.0 to 1.0)
+
+ # ==================== ACT ====================
+ # Call the retrieve method with SEMANTIC_SEARCH strategy
+ # This will:
+ # 1. Check if query is empty (early return if so)
+ # 2. Get the dataset using _get_dataset
+ # 3. Create ThreadPoolExecutor
+ # 4. Submit embedding_search task
+ # 5. Wait for completion
+ # 6. Return all_documents list
+ results = RetrievalService.retrieve(
+ retrieval_method=RetrievalMethod.SEMANTIC_SEARCH,
+ dataset_id=mock_dataset.id,
+ query=query,
+ top_k=top_k,
+ score_threshold=score_threshold,
+ )
+
+ # ==================== ASSERT ====================
+ # Verify we got the expected number of documents
+ assert len(results) == 3, "Should return 3 documents from sample_documents"
+
+ # Verify all results are Document objects (type safety)
+ assert all(isinstance(doc, Document) for doc in results), "All results should be Document instances"
+
+ # Verify documents maintain their scores (highest score first in sample_documents)
+ assert results[0].metadata["score"] == 0.95, "First document should have highest score from sample_documents"
+
+ # Verify embedding_search was called exactly once
+ # This confirms the search method was invoked by ThreadPoolExecutor
+ mock_embedding_search.assert_called_once()
+
+ @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
+ def test_vector_search_with_document_filter(
+ self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents
+ ):
+ """
+ Test vector search with document ID filtering.
+
+ Verifies:
+ - Document ID filter is passed correctly to vector search
+ - Only specified documents are searched
+ """
+ # Arrange
+ mock_get_dataset.return_value = mock_dataset
+ filtered_docs = [sample_documents[0]]
+
+ def side_effect_embedding_search(
+ flask_app,
+ dataset_id,
+ query,
+ top_k,
+ score_threshold,
+ reranking_model,
+ all_documents,
+ retrieval_method,
+ exceptions,
+ document_ids_filter=None,
+ ):
+ all_documents.extend(filtered_docs)
+
+ mock_embedding_search.side_effect = side_effect_embedding_search
+ document_ids_filter = [sample_documents[0].metadata["document_id"]]
+
+ # Act
+ results = RetrievalService.retrieve(
+ retrieval_method=RetrievalMethod.SEMANTIC_SEARCH,
+ dataset_id=mock_dataset.id,
+ query="test query",
+ top_k=5,
+ document_ids_filter=document_ids_filter,
+ )
+
+ # Assert
+ assert len(results) == 1
+ assert results[0].metadata["doc_id"] == "doc1"
+ # Verify document_ids_filter was passed
+ call_kwargs = mock_embedding_search.call_args.kwargs
+ assert call_kwargs["document_ids_filter"] == document_ids_filter
+
+ @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
+ def test_vector_search_empty_results(self, mock_get_dataset, mock_embedding_search, mock_dataset):
+ """
+ Test vector search when no results match the query.
+
+ Verifies:
+ - Empty list is returned when no documents match
+ - No errors are raised
+ """
+ # Arrange
+ mock_get_dataset.return_value = mock_dataset
+ # embedding_search doesn't add anything to all_documents
+ mock_embedding_search.side_effect = lambda *args, **kwargs: None
+
+ # Act
+ results = RetrievalService.retrieve(
+ retrieval_method=RetrievalMethod.SEMANTIC_SEARCH,
+ dataset_id=mock_dataset.id,
+ query="nonexistent query",
+ top_k=5,
+ )
+
+ # Assert
+ assert results == []
+
+ # ==================== Keyword Search Tests ====================
+
+ @patch("core.rag.datasource.retrieval_service.RetrievalService.keyword_search")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
+ def test_keyword_search_basic(self, mock_get_dataset, mock_keyword_search, mock_dataset, sample_documents):
+ """
+ Test basic keyword search functionality.
+
+ Verifies:
+ - Keyword search is invoked correctly
+ - Query is escaped properly for search
+ - Results are returned in expected format
+ """
+ # Arrange
+ mock_get_dataset.return_value = mock_dataset
+
+ def side_effect_keyword_search(
+ flask_app, dataset_id, query, top_k, all_documents, exceptions, document_ids_filter=None
+ ):
+ all_documents.extend(sample_documents)
+
+ mock_keyword_search.side_effect = side_effect_keyword_search
+
+ query = "Python programming"
+ top_k = 3
+
+ # Act
+ results = RetrievalService.retrieve(
+ retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
+ dataset_id=mock_dataset.id,
+ query=query,
+ top_k=top_k,
+ )
+
+ # Assert
+ assert len(results) == 3
+ assert all(isinstance(doc, Document) for doc in results)
+ mock_keyword_search.assert_called_once()
+
+ @patch("core.rag.datasource.retrieval_service.RetrievalService.keyword_search")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
+ def test_keyword_search_with_special_characters(self, mock_get_dataset, mock_keyword_search, mock_dataset):
+ """
+ Test keyword search with special characters in query.
+
+ Verifies:
+ - Special characters are escaped correctly
+ - Search handles quotes and other special chars
+ """
+ # Arrange
+ mock_get_dataset.return_value = mock_dataset
+ mock_keyword_search.side_effect = lambda *args, **kwargs: None
+
+ query = 'Python "programming" language'
+
+ # Act
+ RetrievalService.retrieve(
+ retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
+ dataset_id=mock_dataset.id,
+ query=query,
+ top_k=5,
+ )
+
+ # Assert
+ # Verify that keyword_search was called
+ assert mock_keyword_search.called
+ # The query escaping happens inside keyword_search method
+ call_args = mock_keyword_search.call_args
+ assert call_args is not None
+
+ @patch("core.rag.datasource.retrieval_service.RetrievalService.keyword_search")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
+ def test_keyword_search_with_document_filter(
+ self, mock_get_dataset, mock_keyword_search, mock_dataset, sample_documents
+ ):
+ """
+ Test keyword search with document ID filtering.
+
+ Verifies:
+ - Document filter is applied to keyword search
+ - Only filtered documents are returned
+ """
+ # Arrange
+ mock_get_dataset.return_value = mock_dataset
+ filtered_docs = [sample_documents[1]]
+
+ def side_effect_keyword_search(
+ flask_app, dataset_id, query, top_k, all_documents, exceptions, document_ids_filter=None
+ ):
+ all_documents.extend(filtered_docs)
+
+ mock_keyword_search.side_effect = side_effect_keyword_search
+ document_ids_filter = [sample_documents[1].metadata["document_id"]]
+
+ # Act
+ results = RetrievalService.retrieve(
+ retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
+ dataset_id=mock_dataset.id,
+ query="JavaScript",
+ top_k=5,
+ document_ids_filter=document_ids_filter,
+ )
+
+ # Assert
+ assert len(results) == 1
+ assert results[0].metadata["doc_id"] == "doc2"
+
+ # ==================== Hybrid Search Tests ====================
+
+ @patch("core.rag.datasource.retrieval_service.DataPostProcessor")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService.full_text_index_search")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
+ def test_hybrid_search_basic(
+ self,
+ mock_get_dataset,
+ mock_embedding_search,
+ mock_fulltext_search,
+ mock_data_processor_class,
+ mock_dataset,
+ sample_documents,
+ ):
+ """
+ Test basic hybrid search combining vector and full-text search.
+
+ Verifies:
+ - Both vector and full-text search are executed
+ - Results are merged and deduplicated
+ - DataPostProcessor is invoked for score merging
+ """
+ # Arrange
+ mock_get_dataset.return_value = mock_dataset
+
+ # Vector search returns first 2 docs
+ def side_effect_embedding(
+ flask_app,
+ dataset_id,
+ query,
+ top_k,
+ score_threshold,
+ reranking_model,
+ all_documents,
+ retrieval_method,
+ exceptions,
+ document_ids_filter=None,
+ ):
+ all_documents.extend(sample_documents[:2])
+
+ mock_embedding_search.side_effect = side_effect_embedding
+
+ # Full-text search returns last 2 docs (with overlap)
+ def side_effect_fulltext(
+ flask_app,
+ dataset_id,
+ query,
+ top_k,
+ score_threshold,
+ reranking_model,
+ all_documents,
+ retrieval_method,
+ exceptions,
+ document_ids_filter=None,
+ ):
+ all_documents.extend(sample_documents[1:])
+
+ mock_fulltext_search.side_effect = side_effect_fulltext
+
+ # Mock DataPostProcessor
+ mock_processor_instance = Mock()
+ mock_processor_instance.invoke.return_value = sample_documents
+ mock_data_processor_class.return_value = mock_processor_instance
+
+ # Act
+ results = RetrievalService.retrieve(
+ retrieval_method=RetrievalMethod.HYBRID_SEARCH,
+ dataset_id=mock_dataset.id,
+ query="Python programming",
+ top_k=3,
+ score_threshold=0.5,
+ )
+
+ # Assert
+ assert len(results) == 3
+ mock_embedding_search.assert_called_once()
+ mock_fulltext_search.assert_called_once()
+ mock_processor_instance.invoke.assert_called_once()
+
+ @patch("core.rag.datasource.retrieval_service.DataPostProcessor")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService.full_text_index_search")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
+ def test_hybrid_search_deduplication(
+ self, mock_get_dataset, mock_embedding_search, mock_fulltext_search, mock_data_processor_class, mock_dataset
+ ):
+ """
+ Test that hybrid search properly deduplicates documents.
+
+ Hybrid search combines results from multiple search methods (vector + full-text).
+ This can lead to duplicate documents when the same chunk is found by both methods.
+
+ Scenario:
+ ---------
+ 1. Vector search finds document "duplicate_doc" with score 0.9
+ 2. Full-text search also finds "duplicate_doc" but with score 0.6
+ 3. Both searches find "unique_doc"
+ 4. Deduplication should keep only the higher-scoring version (0.9)
+
+ Why deduplication matters:
+ --------------------------
+ - Prevents showing the same content multiple times to users
+ - Ensures score consistency (keeps best match)
+ - Improves result quality and user experience
+ - Happens BEFORE reranking to avoid processing duplicates
+
+ Verifies:
+ - Duplicate documents (same doc_id) are removed
+ - Higher scoring duplicate is retained
+ - Deduplication happens before post-processing
+ - Final result count is correct
+ """
+ # ==================== ARRANGE ====================
+ mock_get_dataset.return_value = mock_dataset
+
+ # Create test documents with intentional duplication
+ # Same doc_id but different scores to test score comparison logic
+ doc1_high = Document(
+ page_content="Content 1",
+ metadata={
+ "doc_id": "duplicate_doc", # Same doc_id as doc1_low
+ "score": 0.9, # Higher score - should be kept
+ "document_id": str(uuid4()),
+ },
+ provider="dify",
+ )
+ doc1_low = Document(
+ page_content="Content 1",
+ metadata={
+ "doc_id": "duplicate_doc", # Same doc_id as doc1_high
+ "score": 0.6, # Lower score - should be discarded
+ "document_id": str(uuid4()),
+ },
+ provider="dify",
+ )
+ doc2 = Document(
+ page_content="Content 2",
+ metadata={
+ "doc_id": "unique_doc", # Unique doc_id
+ "score": 0.8,
+ "document_id": str(uuid4()),
+ },
+ provider="dify",
+ )
+
+ # Simulate vector search returning high-score duplicate + unique doc
+ def side_effect_embedding(
+ flask_app,
+ dataset_id,
+ query,
+ top_k,
+ score_threshold,
+ reranking_model,
+ all_documents,
+ retrieval_method,
+ exceptions,
+ document_ids_filter=None,
+ ):
+ """Vector search finds 2 documents including high-score duplicate."""
+ all_documents.extend([doc1_high, doc2])
+
+ mock_embedding_search.side_effect = side_effect_embedding
+
+ # Simulate full-text search returning low-score duplicate
+ def side_effect_fulltext(
+ flask_app,
+ dataset_id,
+ query,
+ top_k,
+ score_threshold,
+ reranking_model,
+ all_documents,
+ retrieval_method,
+ exceptions,
+ document_ids_filter=None,
+ ):
+ """Full-text search finds the same document but with lower score."""
+ all_documents.extend([doc1_low])
+
+ mock_fulltext_search.side_effect = side_effect_fulltext
+
+ # Mock DataPostProcessor to return deduplicated results
+ # In real implementation, _deduplicate_documents is called before this
+ mock_processor_instance = Mock()
+ mock_processor_instance.invoke.return_value = [doc1_high, doc2]
+ mock_data_processor_class.return_value = mock_processor_instance
+
+ # ==================== ACT ====================
+ # Execute hybrid search which should:
+ # 1. Run both embedding_search and full_text_index_search
+ # 2. Collect all results in all_documents (3 docs: 2 unique + 1 duplicate)
+ # 3. Call _deduplicate_documents to remove duplicate (keeps higher score)
+ # 4. Pass deduplicated results to DataPostProcessor
+ # 5. Return final results
+ results = RetrievalService.retrieve(
+ retrieval_method=RetrievalMethod.HYBRID_SEARCH,
+ dataset_id=mock_dataset.id,
+ query="test",
+ top_k=5,
+ )
+
+ # ==================== ASSERT ====================
+ # Verify deduplication worked correctly
+ assert len(results) == 2, "Should have 2 unique documents after deduplication (not 3)"
+
+ # Verify the correct documents are present
+ doc_ids = [doc.metadata["doc_id"] for doc in results]
+ assert "duplicate_doc" in doc_ids, "Duplicate doc should be present (higher score version)"
+ assert "unique_doc" in doc_ids, "Unique doc should be present"
+
+ # Implicitly verifies that doc1_low (score 0.6) was discarded
+ # in favor of doc1_high (score 0.9)
+
+ @patch("core.rag.datasource.retrieval_service.DataPostProcessor")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService.full_text_index_search")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
+ def test_hybrid_search_with_weights(
+ self,
+ mock_get_dataset,
+ mock_embedding_search,
+ mock_fulltext_search,
+ mock_data_processor_class,
+ mock_dataset,
+ sample_documents,
+ ):
+ """
+ Test hybrid search with custom weights for score merging.
+
+ Verifies:
+ - Weights are passed to DataPostProcessor
+ - Score merging respects weight configuration
+ """
+ # Arrange
+ mock_get_dataset.return_value = mock_dataset
+
+ def side_effect_embedding(
+ flask_app,
+ dataset_id,
+ query,
+ top_k,
+ score_threshold,
+ reranking_model,
+ all_documents,
+ retrieval_method,
+ exceptions,
+ document_ids_filter=None,
+ ):
+ all_documents.extend(sample_documents[:2])
+
+ mock_embedding_search.side_effect = side_effect_embedding
+
+ def side_effect_fulltext(
+ flask_app,
+ dataset_id,
+ query,
+ top_k,
+ score_threshold,
+ reranking_model,
+ all_documents,
+ retrieval_method,
+ exceptions,
+ document_ids_filter=None,
+ ):
+ all_documents.extend(sample_documents[1:])
+
+ mock_fulltext_search.side_effect = side_effect_fulltext
+
+ mock_processor_instance = Mock()
+ mock_processor_instance.invoke.return_value = sample_documents
+ mock_data_processor_class.return_value = mock_processor_instance
+
+ weights = {
+ "vector_setting": {
+ "vector_weight": 0.7,
+ "embedding_provider_name": "openai",
+ "embedding_model_name": "text-embedding-ada-002",
+ },
+ "keyword_setting": {"keyword_weight": 0.3},
+ }
+
+ # Act
+ results = RetrievalService.retrieve(
+ retrieval_method=RetrievalMethod.HYBRID_SEARCH,
+ dataset_id=mock_dataset.id,
+ query="test query",
+ top_k=3,
+ weights=weights,
+ reranking_mode="weighted_score",
+ )
+
+ # Assert
+ assert len(results) == 3
+ # Verify DataPostProcessor was created with weights
+ mock_data_processor_class.assert_called_once()
+ # Check that weights were passed (may be in args or kwargs)
+ call_args = mock_data_processor_class.call_args
+ if call_args.kwargs:
+ assert call_args.kwargs.get("weights") == weights
+ else:
+ # Weights might be in positional args (position 3)
+ assert len(call_args.args) >= 4
+
+ # ==================== Full-Text Search Tests ====================
+
+ @patch("core.rag.datasource.retrieval_service.RetrievalService.full_text_index_search")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
+ def test_fulltext_search_basic(self, mock_get_dataset, mock_fulltext_search, mock_dataset, sample_documents):
+ """
+ Test basic full-text search functionality.
+
+ Verifies:
+ - Full-text search is invoked correctly
+ - Results are returned in expected format
+ """
+ # Arrange
+ mock_get_dataset.return_value = mock_dataset
+
+ def side_effect_fulltext(
+ flask_app,
+ dataset_id,
+ query,
+ top_k,
+ score_threshold,
+ reranking_model,
+ all_documents,
+ retrieval_method,
+ exceptions,
+ document_ids_filter=None,
+ ):
+ all_documents.extend(sample_documents)
+
+ mock_fulltext_search.side_effect = side_effect_fulltext
+
+ # Act
+ results = RetrievalService.retrieve(
+ retrieval_method=RetrievalMethod.FULL_TEXT_SEARCH,
+ dataset_id=mock_dataset.id,
+ query="programming language",
+ top_k=3,
+ )
+
+ # Assert
+ assert len(results) == 3
+ mock_fulltext_search.assert_called_once()
+
+ # ==================== Score Merging Tests ====================
+
+ def test_deduplicate_documents_basic(self):
+ """
+ Test basic document deduplication logic.
+
+ Verifies:
+ - Documents with same doc_id are deduplicated
+ - First occurrence is kept by default
+ """
+ # Arrange
+ doc1 = Document(
+ page_content="Content 1",
+ metadata={"doc_id": "doc1", "score": 0.8},
+ provider="dify",
+ )
+ doc2 = Document(
+ page_content="Content 2",
+ metadata={"doc_id": "doc2", "score": 0.7},
+ provider="dify",
+ )
+ doc1_duplicate = Document(
+ page_content="Content 1 duplicate",
+ metadata={"doc_id": "doc1", "score": 0.6},
+ provider="dify",
+ )
+
+ documents = [doc1, doc2, doc1_duplicate]
+
+ # Act
+ result = RetrievalService._deduplicate_documents(documents)
+
+ # Assert
+ assert len(result) == 2
+ doc_ids = [doc.metadata["doc_id"] for doc in result]
+ assert doc_ids == ["doc1", "doc2"]
+
+ def test_deduplicate_documents_keeps_higher_score(self):
+ """
+ Test that deduplication keeps document with higher score.
+
+ Verifies:
+ - When duplicates exist, higher scoring version is retained
+ - Score comparison works correctly
+ """
+ # Arrange
+ doc_low = Document(
+ page_content="Content",
+ metadata={"doc_id": "doc1", "score": 0.5},
+ provider="dify",
+ )
+ doc_high = Document(
+ page_content="Content",
+ metadata={"doc_id": "doc1", "score": 0.9},
+ provider="dify",
+ )
+
+ # Low score first
+ documents = [doc_low, doc_high]
+
+ # Act
+ result = RetrievalService._deduplicate_documents(documents)
+
+ # Assert
+ assert len(result) == 1
+ assert result[0].metadata["score"] == 0.9
+
+ def test_deduplicate_documents_empty_list(self):
+ """
+ Test deduplication with empty document list.
+
+ Verifies:
+ - Empty list returns empty list
+ - No errors are raised
+ """
+ # Act
+ result = RetrievalService._deduplicate_documents([])
+
+ # Assert
+ assert result == []
+
+ def test_deduplicate_documents_non_dify_provider(self):
+ """
+ Test deduplication with non-dify provider documents.
+
+ Verifies:
+ - External provider documents use content-based deduplication
+ - Different providers are handled correctly
+ """
+ # Arrange
+ doc1 = Document(
+ page_content="External content",
+ metadata={"score": 0.8},
+ provider="external",
+ )
+ doc2 = Document(
+ page_content="External content",
+ metadata={"score": 0.7},
+ provider="external",
+ )
+
+ documents = [doc1, doc2]
+
+ # Act
+ result = RetrievalService._deduplicate_documents(documents)
+
+ # Assert
+ # External documents without doc_id should use content-based dedup
+ assert len(result) >= 1
+
+ # ==================== Metadata Filtering Tests ====================
+
+ @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
+ def test_vector_search_with_metadata_filter(
+ self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents
+ ):
+ """
+ Test vector search with metadata-based document filtering.
+
+ Verifies:
+ - Metadata filters are applied correctly
+ - Only documents matching metadata criteria are returned
+ """
+ # Arrange
+ mock_get_dataset.return_value = mock_dataset
+
+ # Add metadata to documents
+ filtered_doc = sample_documents[0]
+ filtered_doc.metadata["category"] = "programming"
+
+ def side_effect_embedding(
+ flask_app,
+ dataset_id,
+ query,
+ top_k,
+ score_threshold,
+ reranking_model,
+ all_documents,
+ retrieval_method,
+ exceptions,
+ document_ids_filter=None,
+ ):
+ all_documents.append(filtered_doc)
+
+ mock_embedding_search.side_effect = side_effect_embedding
+
+ # Act
+ results = RetrievalService.retrieve(
+ retrieval_method=RetrievalMethod.SEMANTIC_SEARCH,
+ dataset_id=mock_dataset.id,
+ query="Python",
+ top_k=5,
+ document_ids_filter=[filtered_doc.metadata["document_id"]],
+ )
+
+ # Assert
+ assert len(results) == 1
+ assert results[0].metadata.get("category") == "programming"
+
+ # ==================== Error Handling Tests ====================
+
+ @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
+ def test_retrieve_with_empty_query(self, mock_get_dataset, mock_dataset):
+ """
+ Test retrieval with empty query string.
+
+ Verifies:
+ - Empty query returns empty results
+ - No search operations are performed
+ """
+ # Arrange
+ mock_get_dataset.return_value = mock_dataset
+
+ # Act
+ results = RetrievalService.retrieve(
+ retrieval_method=RetrievalMethod.SEMANTIC_SEARCH,
+ dataset_id=mock_dataset.id,
+ query="",
+ top_k=5,
+ )
+
+ # Assert
+ assert results == []
+
+ @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
+ def test_retrieve_with_nonexistent_dataset(self, mock_get_dataset):
+ """
+ Test retrieval with non-existent dataset ID.
+
+ Verifies:
+ - Non-existent dataset returns empty results
+ - No errors are raised
+ """
+ # Arrange
+ mock_get_dataset.return_value = None
+
+ # Act
+ results = RetrievalService.retrieve(
+ retrieval_method=RetrievalMethod.SEMANTIC_SEARCH,
+ dataset_id="nonexistent_id",
+ query="test query",
+ top_k=5,
+ )
+
+ # Assert
+ assert results == []
+
+ @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
+ def test_retrieve_with_exception_handling(self, mock_get_dataset, mock_embedding_search, mock_dataset):
+ """
+ Test that exceptions during retrieval are properly handled.
+
+ Verifies:
+ - Exceptions are caught and added to exceptions list
+ - ValueError is raised with exception messages
+ """
+ # Arrange
+ mock_get_dataset.return_value = mock_dataset
+
+ # Make embedding_search add an exception to the exceptions list
+ def side_effect_with_exception(
+ flask_app,
+ dataset_id,
+ query,
+ top_k,
+ score_threshold,
+ reranking_model,
+ all_documents,
+ retrieval_method,
+ exceptions,
+ document_ids_filter=None,
+ ):
+ exceptions.append("Search failed")
+
+ mock_embedding_search.side_effect = side_effect_with_exception
+
+ # Act & Assert
+ with pytest.raises(ValueError) as exc_info:
+ RetrievalService.retrieve(
+ retrieval_method=RetrievalMethod.SEMANTIC_SEARCH,
+ dataset_id=mock_dataset.id,
+ query="test query",
+ top_k=5,
+ )
+
+ assert "Search failed" in str(exc_info.value)
+
+ # ==================== Score Threshold Tests ====================
+
+ @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
+ def test_vector_search_with_score_threshold(self, mock_get_dataset, mock_embedding_search, mock_dataset):
+ """
+ Test vector search with score threshold filtering.
+
+ Verifies:
+ - Score threshold is passed to search method
+ - Documents below threshold are filtered out
+ """
+ # Arrange
+ mock_get_dataset.return_value = mock_dataset
+
+ # Only return documents above threshold
+ high_score_doc = Document(
+ page_content="High relevance content",
+ metadata={"doc_id": "doc1", "score": 0.85},
+ provider="dify",
+ )
+
+ def side_effect_embedding(
+ flask_app,
+ dataset_id,
+ query,
+ top_k,
+ score_threshold,
+ reranking_model,
+ all_documents,
+ retrieval_method,
+ exceptions,
+ document_ids_filter=None,
+ ):
+ all_documents.append(high_score_doc)
+
+ mock_embedding_search.side_effect = side_effect_embedding
+
+ score_threshold = 0.8
+
+ # Act
+ results = RetrievalService.retrieve(
+ retrieval_method=RetrievalMethod.SEMANTIC_SEARCH,
+ dataset_id=mock_dataset.id,
+ query="test query",
+ top_k=5,
+ score_threshold=score_threshold,
+ )
+
+ # Assert
+ assert len(results) == 1
+ assert results[0].metadata["score"] >= score_threshold
+
+ # ==================== Top-K Limiting Tests ====================
+
+ @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
+ def test_retrieve_respects_top_k_limit(self, mock_get_dataset, mock_embedding_search, mock_dataset):
+ """
+ Test that retrieval respects top_k parameter.
+
+ Verifies:
+ - Only top_k documents are returned
+ - Limit is applied correctly
+ """
+ # Arrange
+ mock_get_dataset.return_value = mock_dataset
+
+ # Create more documents than top_k
+ many_docs = [
+ Document(
+ page_content=f"Content {i}",
+ metadata={"doc_id": f"doc{i}", "score": 0.9 - i * 0.1},
+ provider="dify",
+ )
+ for i in range(10)
+ ]
+
+ def side_effect_embedding(
+ flask_app,
+ dataset_id,
+ query,
+ top_k,
+ score_threshold,
+ reranking_model,
+ all_documents,
+ retrieval_method,
+ exceptions,
+ document_ids_filter=None,
+ ):
+ # Return only top_k documents
+ all_documents.extend(many_docs[:top_k])
+
+ mock_embedding_search.side_effect = side_effect_embedding
+
+ top_k = 3
+
+ # Act
+ results = RetrievalService.retrieve(
+ retrieval_method=RetrievalMethod.SEMANTIC_SEARCH,
+ dataset_id=mock_dataset.id,
+ query="test query",
+ top_k=top_k,
+ )
+
+ # Assert
+ # Verify top_k was passed to embedding_search
+ assert mock_embedding_search.called
+ call_kwargs = mock_embedding_search.call_args.kwargs
+ assert call_kwargs["top_k"] == top_k
+ # Verify we got the right number of results
+ assert len(results) == top_k
+
+ # ==================== Query Escaping Tests ====================
+
+ def test_escape_query_for_search(self):
+ """
+ Test query escaping for special characters.
+
+ Verifies:
+ - Double quotes are properly escaped
+ - Other characters remain unchanged
+ """
+ # Test cases with expected outputs
+ test_cases = [
+ ("simple query", "simple query"),
+ ('query with "quotes"', 'query with \\"quotes\\"'),
+ ('"quoted phrase"', '\\"quoted phrase\\"'),
+ ("no special chars", "no special chars"),
+ ]
+
+ for input_query, expected_output in test_cases:
+ result = RetrievalService.escape_query_for_search(input_query)
+ assert result == expected_output
+
+ # ==================== Reranking Tests ====================
+
+ @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
+ @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
+ def test_semantic_search_with_reranking(
+ self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents
+ ):
+ """
+ Test semantic search with reranking model.
+
+ Verifies:
+ - Reranking is applied when configured
+ - DataPostProcessor is invoked with correct parameters
+ """
+ # Arrange
+ mock_get_dataset.return_value = mock_dataset
+
+ # Simulate reranking changing order
+ reranked_docs = list(reversed(sample_documents))
+
+ def side_effect_embedding(
+ flask_app,
+ dataset_id,
+ query,
+ top_k,
+ score_threshold,
+ reranking_model,
+ all_documents,
+ retrieval_method,
+ exceptions,
+ document_ids_filter=None,
+ ):
+ # embedding_search handles reranking internally
+ all_documents.extend(reranked_docs)
+
+ mock_embedding_search.side_effect = side_effect_embedding
+
+ reranking_model = {
+ "reranking_provider_name": "cohere",
+ "reranking_model_name": "rerank-english-v2.0",
+ }
+
+ # Act
+ results = RetrievalService.retrieve(
+ retrieval_method=RetrievalMethod.SEMANTIC_SEARCH,
+ dataset_id=mock_dataset.id,
+ query="test query",
+ top_k=3,
+ reranking_model=reranking_model,
+ )
+
+ # Assert
+ # For semantic search with reranking, reranking_model should be passed
+ assert len(results) == 3
+ call_kwargs = mock_embedding_search.call_args.kwargs
+ assert call_kwargs["reranking_model"] == reranking_model
+
+
+class TestRetrievalMethods:
+ """
+ Test suite for RetrievalMethod enum and utility methods.
+
+ The RetrievalMethod enum defines the available search strategies:
+
+ 1. **SEMANTIC_SEARCH**: Vector-based similarity search using embeddings
+ - Best for: Natural language queries, conceptual similarity
+ - Uses: Embedding models (e.g., text-embedding-ada-002)
+ - Example: "What is machine learning?" matches "AI and ML concepts"
+
+ 2. **FULL_TEXT_SEARCH**: BM25-based text matching
+ - Best for: Exact phrase matching, keyword presence
+ - Uses: BM25 algorithm with sparse vectors
+ - Example: "Python programming" matches documents with those exact terms
+
+ 3. **HYBRID_SEARCH**: Combination of semantic + full-text
+ - Best for: Comprehensive search with both conceptual and exact matching
+ - Uses: Both embedding vectors and BM25, with score merging
+ - Example: Finds both semantically similar and keyword-matching documents
+
+ 4. **KEYWORD_SEARCH**: Traditional keyword-based search (economy mode)
+ - Best for: Simple, fast searches without embeddings
+ - Uses: Jieba tokenization and keyword matching
+ - Example: Basic text search without vector database
+
+ Utility Methods:
+ ================
+ - is_support_semantic_search(): Check if method uses embeddings
+ - is_support_fulltext_search(): Check if method uses BM25
+
+ These utilities help determine which search operations to execute
+ in the RetrievalService.retrieve() method.
+ """
+
+ def test_retrieval_method_values(self):
+ """
+ Test that all retrieval method constants are defined correctly.
+
+ This ensures the enum values match the expected string constants
+ used throughout the codebase for configuration and API calls.
+
+ Verifies:
+ - All expected retrieval methods exist
+ - Values are correct strings (not accidentally changed)
+ - String values match database/config expectations
+ """
+ assert RetrievalMethod.SEMANTIC_SEARCH == "semantic_search"
+ assert RetrievalMethod.FULL_TEXT_SEARCH == "full_text_search"
+ assert RetrievalMethod.HYBRID_SEARCH == "hybrid_search"
+ assert RetrievalMethod.KEYWORD_SEARCH == "keyword_search"
+
+ def test_is_support_semantic_search(self):
+ """
+ Test semantic search support detection.
+
+ Verifies:
+ - Semantic search method is detected
+ - Hybrid search method is detected (includes semantic)
+ - Other methods are not detected
+ """
+ assert RetrievalMethod.is_support_semantic_search(RetrievalMethod.SEMANTIC_SEARCH) is True
+ assert RetrievalMethod.is_support_semantic_search(RetrievalMethod.HYBRID_SEARCH) is True
+ assert RetrievalMethod.is_support_semantic_search(RetrievalMethod.FULL_TEXT_SEARCH) is False
+ assert RetrievalMethod.is_support_semantic_search(RetrievalMethod.KEYWORD_SEARCH) is False
+
+ def test_is_support_fulltext_search(self):
+ """
+ Test full-text search support detection.
+
+ Verifies:
+ - Full-text search method is detected
+ - Hybrid search method is detected (includes full-text)
+ - Other methods are not detected
+ """
+ assert RetrievalMethod.is_support_fulltext_search(RetrievalMethod.FULL_TEXT_SEARCH) is True
+ assert RetrievalMethod.is_support_fulltext_search(RetrievalMethod.HYBRID_SEARCH) is True
+ assert RetrievalMethod.is_support_fulltext_search(RetrievalMethod.SEMANTIC_SEARCH) is False
+ assert RetrievalMethod.is_support_fulltext_search(RetrievalMethod.KEYWORD_SEARCH) is False
+
+
+class TestDocumentModel:
+ """
+ Test suite for Document model used in retrieval.
+
+ The Document class is the core data structure for representing text chunks
+ in the retrieval system. It's based on Pydantic BaseModel for validation.
+
+ Document Structure:
+ ===================
+ - **page_content** (str): The actual text content of the document chunk
+ - **metadata** (dict): Additional information about the document
+ - doc_id: Unique identifier for the chunk
+ - document_id: Parent document ID
+ - dataset_id: Dataset this document belongs to
+ - score: Relevance score from search (0.0 to 1.0)
+ - Custom fields: category, tags, timestamps, etc.
+ - **provider** (str): Source of the document ("dify" or "external")
+ - **vector** (list[float] | None): Embedding vector for semantic search
+ - **children** (list[ChildDocument] | None): Sub-chunks for hierarchical docs
+
+ Document Lifecycle:
+ ===================
+ 1. **Creation**: Documents are created when text is indexed
+ - Content is chunked into manageable pieces
+ - Embeddings are generated for semantic search
+ - Metadata is attached for filtering and tracking
+
+ 2. **Storage**: Documents are stored in vector databases
+ - Vector field stores embeddings
+ - Metadata enables filtering
+ - Provider tracks source (internal vs external)
+
+ 3. **Retrieval**: Documents are returned from search operations
+ - Scores are added during search
+ - Multiple documents may be combined (hybrid search)
+ - Deduplication uses doc_id
+
+ 4. **Post-processing**: Documents may be reranked or filtered
+ - Scores can be recalculated
+ - Content may be truncated or formatted
+ - Metadata is used for display
+
+ Why Test the Document Model:
+ ============================
+ - Ensures data structure integrity
+ - Validates Pydantic model behavior
+ - Confirms default values work correctly
+ - Tests equality comparison for deduplication
+ - Verifies metadata handling
+
+ Related Classes:
+ ================
+ - ChildDocument: For hierarchical document structures
+ - RetrievalSegments: Combines Document with database segment info
+ """
+
+ def test_document_creation_basic(self):
+ """
+ Test basic Document object creation.
+
+ Tests the minimal required fields and default values.
+ Only page_content is required; all other fields have defaults.
+
+ Verifies:
+ - Document can be created with minimal fields
+ - Default values are set correctly
+ - Pydantic validation works
+ - No exceptions are raised
+ """
+ doc = Document(page_content="Test content")
+
+ assert doc.page_content == "Test content"
+ assert doc.metadata == {} # Empty dict by default
+ assert doc.provider == "dify" # Default provider
+ assert doc.vector is None # No embedding by default
+ assert doc.children is None # No child documents by default
+
+ def test_document_creation_with_metadata(self):
+ """
+ Test Document creation with metadata.
+
+ Verifies:
+ - Metadata is stored correctly
+ - Metadata can contain various types
+ """
+ metadata = {
+ "doc_id": "test_doc",
+ "score": 0.95,
+ "dataset_id": str(uuid4()),
+ "category": "test",
+ }
+ doc = Document(page_content="Test content", metadata=metadata)
+
+ assert doc.metadata == metadata
+ assert doc.metadata["score"] == 0.95
+
+ def test_document_creation_with_vector(self):
+ """
+ Test Document creation with embedding vector.
+
+ Verifies:
+ - Vector embeddings can be stored
+ - Vector is optional
+ """
+ vector = [0.1, 0.2, 0.3, 0.4, 0.5]
+ doc = Document(page_content="Test content", vector=vector)
+
+ assert doc.vector == vector
+ assert len(doc.vector) == 5
+
+ def test_document_with_external_provider(self):
+ """
+ Test Document with external provider.
+
+ Verifies:
+ - Provider can be set to external
+ - External documents are handled correctly
+ """
+ doc = Document(page_content="External content", provider="external")
+
+ assert doc.provider == "external"
+
+ def test_document_equality(self):
+ """
+ Test Document equality comparison.
+
+ Verifies:
+ - Documents with same content are considered equal
+ - Metadata affects equality
+ """
+ doc1 = Document(page_content="Content", metadata={"id": "1"})
+ doc2 = Document(page_content="Content", metadata={"id": "1"})
+ doc3 = Document(page_content="Different", metadata={"id": "1"})
+
+ assert doc1 == doc2
+ assert doc1 != doc3
From 58f448a926174fa90a2d971432dacb218a990c11 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?=
Date: Thu, 27 Nov 2025 14:40:06 +0800
Subject: [PATCH 030/431] chore: remove outdated model config doc (#28765)
---
.../en_US/customizable_model_scale_out.md | 308 --------
.../docs/en_US/images/index/image-1.png | Bin 235102 -> 0 bytes
.../docs/en_US/images/index/image-2.png | Bin 210087 -> 0 bytes
.../images/index/image-20231210143654461.png | Bin 379070 -> 0 bytes
.../images/index/image-20231210144229650.png | Bin 115258 -> 0 bytes
.../images/index/image-20231210144814617.png | Bin 111420 -> 0 bytes
.../images/index/image-20231210151548521.png | Bin 71354 -> 0 bytes
.../images/index/image-20231210151628992.png | Bin 76990 -> 0 bytes
.../images/index/image-20231210165243632.png | Bin 554357 -> 0 bytes
.../docs/en_US/images/index/image-3.png | Bin 44778 -> 0 bytes
.../docs/en_US/images/index/image.png | Bin 267979 -> 0 bytes
.../model_runtime/docs/en_US/interfaces.md | 701 -----------------
.../docs/en_US/predefined_model_scale_out.md | 176 -----
.../docs/en_US/provider_scale_out.md | 266 -------
api/core/model_runtime/docs/en_US/schema.md | 208 -----
.../zh_Hans/customizable_model_scale_out.md | 304 -------
.../docs/zh_Hans/images/index/image-1.png | Bin 235102 -> 0 bytes
.../docs/zh_Hans/images/index/image-2.png | Bin 210087 -> 0 bytes
.../images/index/image-20231210143654461.png | Bin 394062 -> 0 bytes
.../images/index/image-20231210144229650.png | Bin 115258 -> 0 bytes
.../images/index/image-20231210144814617.png | Bin 111420 -> 0 bytes
.../images/index/image-20231210151548521.png | Bin 71354 -> 0 bytes
.../images/index/image-20231210151628992.png | Bin 76990 -> 0 bytes
.../images/index/image-20231210165243632.png | Bin 554357 -> 0 bytes
.../docs/zh_Hans/images/index/image-3.png | Bin 44778 -> 0 bytes
.../docs/zh_Hans/images/index/image.png | Bin 267979 -> 0 bytes
.../model_runtime/docs/zh_Hans/interfaces.md | 744 ------------------
.../zh_Hans/predefined_model_scale_out.md | 172 ----
.../docs/zh_Hans/provider_scale_out.md | 192 -----
api/core/model_runtime/docs/zh_Hans/schema.md | 209 -----
30 files changed, 3280 deletions(-)
delete mode 100644 api/core/model_runtime/docs/en_US/customizable_model_scale_out.md
delete mode 100644 api/core/model_runtime/docs/en_US/images/index/image-1.png
delete mode 100644 api/core/model_runtime/docs/en_US/images/index/image-2.png
delete mode 100644 api/core/model_runtime/docs/en_US/images/index/image-20231210143654461.png
delete mode 100644 api/core/model_runtime/docs/en_US/images/index/image-20231210144229650.png
delete mode 100644 api/core/model_runtime/docs/en_US/images/index/image-20231210144814617.png
delete mode 100644 api/core/model_runtime/docs/en_US/images/index/image-20231210151548521.png
delete mode 100644 api/core/model_runtime/docs/en_US/images/index/image-20231210151628992.png
delete mode 100644 api/core/model_runtime/docs/en_US/images/index/image-20231210165243632.png
delete mode 100644 api/core/model_runtime/docs/en_US/images/index/image-3.png
delete mode 100644 api/core/model_runtime/docs/en_US/images/index/image.png
delete mode 100644 api/core/model_runtime/docs/en_US/interfaces.md
delete mode 100644 api/core/model_runtime/docs/en_US/predefined_model_scale_out.md
delete mode 100644 api/core/model_runtime/docs/en_US/provider_scale_out.md
delete mode 100644 api/core/model_runtime/docs/en_US/schema.md
delete mode 100644 api/core/model_runtime/docs/zh_Hans/customizable_model_scale_out.md
delete mode 100644 api/core/model_runtime/docs/zh_Hans/images/index/image-1.png
delete mode 100644 api/core/model_runtime/docs/zh_Hans/images/index/image-2.png
delete mode 100644 api/core/model_runtime/docs/zh_Hans/images/index/image-20231210143654461.png
delete mode 100644 api/core/model_runtime/docs/zh_Hans/images/index/image-20231210144229650.png
delete mode 100644 api/core/model_runtime/docs/zh_Hans/images/index/image-20231210144814617.png
delete mode 100644 api/core/model_runtime/docs/zh_Hans/images/index/image-20231210151548521.png
delete mode 100644 api/core/model_runtime/docs/zh_Hans/images/index/image-20231210151628992.png
delete mode 100644 api/core/model_runtime/docs/zh_Hans/images/index/image-20231210165243632.png
delete mode 100644 api/core/model_runtime/docs/zh_Hans/images/index/image-3.png
delete mode 100644 api/core/model_runtime/docs/zh_Hans/images/index/image.png
delete mode 100644 api/core/model_runtime/docs/zh_Hans/interfaces.md
delete mode 100644 api/core/model_runtime/docs/zh_Hans/predefined_model_scale_out.md
delete mode 100644 api/core/model_runtime/docs/zh_Hans/provider_scale_out.md
delete mode 100644 api/core/model_runtime/docs/zh_Hans/schema.md
diff --git a/api/core/model_runtime/docs/en_US/customizable_model_scale_out.md b/api/core/model_runtime/docs/en_US/customizable_model_scale_out.md
deleted file mode 100644
index 245aa4699c..0000000000
--- a/api/core/model_runtime/docs/en_US/customizable_model_scale_out.md
+++ /dev/null
@@ -1,308 +0,0 @@
-## Custom Integration of Pre-defined Models
-
-### Introduction
-
-After completing the vendors integration, the next step is to connect the vendor's models. To illustrate the entire connection process, we will use Xinference as an example to demonstrate a complete vendor integration.
-
-It is important to note that for custom models, each model connection requires a complete vendor credential.
-
-Unlike pre-defined models, a custom vendor integration always includes the following two parameters, which do not need to be defined in the vendor YAML file.
-
-
-
-As mentioned earlier, vendors do not need to implement validate_provider_credential. The runtime will automatically call the corresponding model layer's validate_credentials to validate the credentials based on the model type and name selected by the user.
-
-### Writing the Vendor YAML
-
-First, we need to identify the types of models supported by the vendor we are integrating.
-
-Currently supported model types are as follows:
-
-- `llm` Text Generation Models
-
-- `text_embedding` Text Embedding Models
-
-- `rerank` Rerank Models
-
-- `speech2text` Speech-to-Text
-
-- `tts` Text-to-Speech
-
-- `moderation` Moderation
-
-Xinference supports LLM, Text Embedding, and Rerank. So we will start by writing xinference.yaml.
-
-```yaml
-provider: xinference #Define the vendor identifier
-label: # Vendor display name, supports both en_US (English) and zh_Hans (Simplified Chinese). If zh_Hans is not set, it will use en_US by default.
- en_US: Xorbits Inference
-icon_small: # Small icon, refer to other vendors' icons stored in the _assets directory within the vendor implementation directory; follows the same language policy as the label
- en_US: icon_s_en.svg
-icon_large: # Large icon
- en_US: icon_l_en.svg
-help: # Help information
- title:
- en_US: How to deploy Xinference
- zh_Hans: 如何部署 Xinference
- url:
- en_US: https://github.com/xorbitsai/inference
-supported_model_types: # Supported model types. Xinference supports LLM, Text Embedding, and Rerank
-- llm
-- text-embedding
-- rerank
-configurate_methods: # Since Xinference is a locally deployed vendor with no predefined models, users need to deploy whatever models they need according to Xinference documentation. Thus, it only supports custom models.
-- customizable-model
-provider_credential_schema:
- credential_form_schemas:
-```
-
-Then, we need to determine what credentials are required to define a model in Xinference.
-
-- Since it supports three different types of models, we need to specify the model_type to denote the model type. Here is how we can define it:
-
-```yaml
-provider_credential_schema:
- credential_form_schemas:
- - variable: model_type
- type: select
- label:
- en_US: Model type
- zh_Hans: 模型类型
- required: true
- options:
- - value: text-generation
- label:
- en_US: Language Model
- zh_Hans: 语言模型
- - value: embeddings
- label:
- en_US: Text Embedding
- - value: reranking
- label:
- en_US: Rerank
-```
-
-- Next, each model has its own model_name, so we need to define that here:
-
-```yaml
- - variable: model_name
- type: text-input
- label:
- en_US: Model name
- zh_Hans: 模型名称
- required: true
- placeholder:
- zh_Hans: 填写模型名称
- en_US: Input model name
-```
-
-- Specify the Xinference local deployment address:
-
-```yaml
- - variable: server_url
- label:
- zh_Hans: 服务器 URL
- en_US: Server url
- type: text-input
- required: true
- placeholder:
- zh_Hans: 在此输入 Xinference 的服务器地址,如 https://example.com/xxx
- en_US: Enter the url of your Xinference, for example https://example.com/xxx
-```
-
-- Each model has a unique model_uid, so we also need to define that here:
-
-```yaml
- - variable: model_uid
- label:
- zh_Hans: 模型 UID
- en_US: Model uid
- type: text-input
- required: true
- placeholder:
- zh_Hans: 在此输入您的 Model UID
- en_US: Enter the model uid
-```
-
-Now, we have completed the basic definition of the vendor.
-
-### Writing the Model Code
-
-Next, let's take the `llm` type as an example and write `xinference.llm.llm.py`.
-
-In `llm.py`, create a Xinference LLM class, we name it `XinferenceAILargeLanguageModel` (this can be arbitrary), inheriting from the `__base.large_language_model.LargeLanguageModel` base class, and implement the following methods:
-
-- LLM Invocation
-
-Implement the core method for LLM invocation, supporting both stream and synchronous responses.
-
-```python
-def _invoke(self, model: str, credentials: dict,
- prompt_messages: list[PromptMessage], model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
- stream: bool = True, user: Optional[str] = None) \
- -> Union[LLMResult, Generator]:
- """
- Invoke large language model
-
- :param model: model name
- :param credentials: model credentials
- :param prompt_messages: prompt messages
- :param model_parameters: model parameters
- :param tools: tools for tool usage
- :param stop: stop words
- :param stream: is the response a stream
- :param user: unique user id
- :return: full response or stream response chunk generator result
- """
-```
-
-When implementing, ensure to use two functions to return data separately for synchronous and stream responses. This is important because Python treats functions containing the `yield` keyword as generator functions, mandating them to return `Generator` types. Here’s an example (note that the example uses simplified parameters; in real implementation, use the parameter list as defined above):
-
-```python
-def _invoke(self, stream: bool, **kwargs) \
- -> Union[LLMResult, Generator]:
- if stream:
- return self._handle_stream_response(**kwargs)
- return self._handle_sync_response(**kwargs)
-
-def _handle_stream_response(self, **kwargs) -> Generator:
- for chunk in response:
- yield chunk
-def _handle_sync_response(self, **kwargs) -> LLMResult:
- return LLMResult(**response)
-```
-
-- Pre-compute Input Tokens
-
-If the model does not provide an interface for pre-computing tokens, you can return 0 directly.
-
-```python
-def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],tools: Optional[list[PromptMessageTool]] = None) -> int:
- """
- Get number of tokens for given prompt messages
-
- :param model: model name
- :param credentials: model credentials
- :param prompt_messages: prompt messages
- :param tools: tools for tool usage
- :return: token count
- """
-```
-
-Sometimes, you might not want to return 0 directly. In such cases, you can use `self._get_num_tokens_by_gpt2(text: str)` to get pre-computed tokens and ensure environment variable `PLUGIN_BASED_TOKEN_COUNTING_ENABLED` is set to `true`, This method is provided by the `AIModel` base class, and it uses GPT2's Tokenizer for calculation. However, it should be noted that this is only a substitute and may not be fully accurate.
-
-- Model Credentials Validation
-
-Similar to vendor credentials validation, this method validates individual model credentials.
-
-```python
-def validate_credentials(self, model: str, credentials: dict) -> None:
- """
- Validate model credentials
-
- :param model: model name
- :param credentials: model credentials
- :return: None
- """
-```
-
-- Model Parameter Schema
-
-Unlike custom types, since the YAML file does not define which parameters a model supports, we need to dynamically generate the model parameter schema.
-
-For instance, Xinference supports `max_tokens`, `temperature`, and `top_p` parameters.
-
-However, some vendors may support different parameters for different models. For example, the `OpenLLM` vendor supports `top_k`, but not all models provided by this vendor support `top_k`. Let's say model A supports `top_k` but model B does not. In such cases, we need to dynamically generate the model parameter schema, as illustrated below:
-
-```python
- def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
- """
- used to define customizable model schema
- """
- rules = [
- ParameterRule(
- name='temperature', type=ParameterType.FLOAT,
- use_template='temperature',
- label=I18nObject(
- zh_Hans='温度', en_US='Temperature'
- )
- ),
- ParameterRule(
- name='top_p', type=ParameterType.FLOAT,
- use_template='top_p',
- label=I18nObject(
- zh_Hans='Top P', en_US='Top P'
- )
- ),
- ParameterRule(
- name='max_tokens', type=ParameterType.INT,
- use_template='max_tokens',
- min=1,
- default=512,
- label=I18nObject(
- zh_Hans='最大生成长度', en_US='Max Tokens'
- )
- )
- ]
-
- # if model is A, add top_k to rules
- if model == 'A':
- rules.append(
- ParameterRule(
- name='top_k', type=ParameterType.INT,
- use_template='top_k',
- min=1,
- default=50,
- label=I18nObject(
- zh_Hans='Top K', en_US='Top K'
- )
- )
- )
-
- """
- some NOT IMPORTANT code here
- """
-
- entity = AIModelEntity(
- model=model,
- label=I18nObject(
- en_US=model
- ),
- fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
- model_type=model_type,
- model_properties={
- ModelPropertyKey.MODE: ModelType.LLM,
- },
- parameter_rules=rules
- )
-
- return entity
-```
-
-- Exception Error Mapping
-
-When a model invocation error occurs, it should be mapped to the runtime's specified `InvokeError` type, enabling Dify to handle different errors appropriately.
-
-Runtime Errors:
-
-- `InvokeConnectionError` Connection error during invocation
-- `InvokeServerUnavailableError` Service provider unavailable
-- `InvokeRateLimitError` Rate limit reached
-- `InvokeAuthorizationError` Authorization failure
-- `InvokeBadRequestError` Invalid request parameters
-
-```python
- @property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
- """
- Map model invoke error to unified error
- The key is the error type thrown to the caller
- The value is the error type thrown by the model,
- which needs to be converted into a unified error type for the caller.
-
- :return: Invoke error mapping
- """
-```
-
-For interface method details, see: [Interfaces](./interfaces.md). For specific implementations, refer to: [llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py).
diff --git a/api/core/model_runtime/docs/en_US/images/index/image-1.png b/api/core/model_runtime/docs/en_US/images/index/image-1.png
deleted file mode 100644
index b158d44b29dcc2a8fa6d6d349ef8d7fb9f7d4cdd..0000000000000000000000000000000000000000
GIT binary patch
literal 0
HcmV?d00001
literal 235102
zcmeFZXH-+&);3I0P!Z6ZD&2~7REl&I=>pPwmkyzKLPS6iq&KAn1SwKN?+}WBfD~x~
zlF&OO^p+6vM$dV^an28)*Z<$k7(3Y`ti8utbI*CLd0lfR?_a4aQeI=aMnptJsjT!u
zi-_n7;kf+$3K`*(ByK+wBBJZk4svp@l;z|&UU|6MIyl=95h=Y-(I+?1?xW8(*1CP^
z3J1si!v~S$L|V`7iSrpc6qHE9{;}dHn))e!8OBdvky~;niuHzsm7S*Z6axi!f0IE4
zkE@m}47d`whFzV-UwE#{*bQYFlM!LAX1_>`b|i`CSK*W4ib
zoLT=-=V5~N)SFn^2Uqrj#D;Vhxy97=7p1O&!=6%o<0Fcn@RxLWK*W@Gdzg38?m1NdhC0IE@vwr6q{?o5KzGbM~v>Jb-!_YVl8?v6t-Q%X^Q>D?RwH-6em2?OtOG
z7v4_nqJCwipUldjt04P6WpvW;tx1%vdT|r2DO7bZ
zc?#zwRHe5@LZ=MdDWcv?U0J;04$#Su(ipxG?soef&!sG0?&SxvrPO;tme1T;`t9%7)`TrnoOyG`>di#Ce;iCVki2fDZlRp^N#b;I8-l@i(1FzUUXqnm3$ew{_xa
z*mL)98D@6BXWC8ZUtB&je$K*48$rf>RQdV3gHmQclow)Q(Kgk?ds9>;{DjW-ei+wQ
zXzi|l)cJx*#lZ7N67I5!27=-q7W^cZuPQvyQ%`cAJ~I1q@Jx%5v5+NN`+nmiZV|Ga(M;T{rYcKhMmbJ#OIf3Hio
zqGs=ON3!f*T9mv&?6FY$@{x)pBi~iYcOTZ-fBIMKS++G+^kMN*GVQ^3Q}OS09#w@z
zzQnxYeDsm=#r3=IZlK)c5Y#6fcNZQJi7;e_2U7l6YL;2Jnwzy7@0K;9+#vi+=MIi*
zKr@GCGor{b;li5F&8cW3mVB}86)X2rK*MdfZ%7Lk5j6GPX1Ub7eq%GbRgD=?MhCU8
z+=`-mK~D1RWTGh4@Ll0;VFbDEa_vQ>bic_M`PSgA?K3gywTMlMd9g)14QJ^(c8Tb$
zZ$s9K6=dz-Vy$BgczkqX$@O8@-J492^o8F!$u}b}Zo1A7_463cn`|La=kn9~4>Nf|
zilpjoucLI5Kk0sd-6ilgzK`x8>)j*MOj8r$c#wHA1R`1)wL|t9hRMn}ASY_;#bUA4
zpta68(h3-j^hHFQXC3e9#w*f!#O=11M)}KbXCXv$|9G9|-y)*1Aa-5|GI>M3OH9nc
z7#3i^s_iNj^v*;X9X=Hl+oY89SDs!SdJ{eWxPc1Y
zJ~4LL?2XwjRgl7su&|LEWL9rJ-qDLBU3^<{L*^di`xlGxwD&0cUVMwZ{^{M1XZ;B_
z-;G|=XoYUeQ{CoFI8c)Cpzwd=br+JL@$lW(Jn^DOjxkhthpV~Y4<4nqgtkliIRyTLQV3SGpF>O;{>I9
zq=}?`X|!sK3D=arzqukM#2d}D&tBXsDy~oGKo5w{R9UWI4%^WKo
z@6;MPmP{NMWoZZI9BY~yBh9k%r1eY-+LVA9n5##|o38dk>B4+(QdM~J`94pKsQ*Nd
zN!tqwWePPgk;HsrN?|%;(ibut`YY&ydbGdkM+eDK!6lVy`7sCM=dNIe05QB|Sn&
zLI_*KXmq#cbM1Z&y(4Nw9CGwBBj2TrwUlopyQI|+n|yqitf#`ljWMW)rH8)9K_QJw
zvsWRaB&kfhyjrzTtx#Xz7^Wtw>{|u+g4RT;OT8R6iYrMgsW)&1&r-H($bE>Yh`18L
z%cv!~+#i2}kvk$i49D1DnvWn`pvM+Wmzb9sd4=r-8|)&PXPA4K)E}F;`ps0>6%D-t
z=bq=<$emEOx`Z)JGk4!o6P|G%cO18KwqkMQePm`G)ZftiV)$PBJ^1VJzDG#_?RUlC
z;_qJJ8}4IkFHIk0d>SwrbhbQBzx65nQ{NAfPqm+bBK40ngajVVxdactb?m76USVMQ
z)A^^9gQHl#N4ibfiVIH^-2nH1W6FeK$;qV1WZ&dLDRU`!dL8cO66NCPf}R<5#<_6L
z=)ujZ@6J?idu&fmo7O3!kf`cy9+ZGs4iy{Kod|C#RM(ZRmUr(Xbg8%8bbMb+(2Grq
zOS_GFMtO0&JGupkAw2vxcfWKEQ{h*)(V1H<-Dg8rP&XIuswfD>_@S(JI(>`18dp_3
zDc!R+X4WzNgB@z?b+f1fi7sF=R`SI51j3fhrWGe12hdO*78>Rrrb?yac~=|nJrq4p
zw=O)k6$w*xrQUhI1Dpk&J;f}aZtwbwjN0bAEvF{L1q}lI+cJVAa36ymf|+sRxWRM(
zi-`+rQeKiv#KxrARj8@y?X2yzkfD_Rx?WvDwUxOqyo~uk2qLzA~=xakJW;naAE<
zz5VHS?d{fFbgs6XM%c~HJDsWTyxn3Ng8-EEqKGF!0e-4s0WX5S%|u2d94Cx^FVI{s
z+SMFTM8;G_{i7)SBGevmxW69wNaPWXHC69F*49=$*5Ge3ks0*9BH1@(ZU-x#H;1){
zbzk-U$A!k`{h;z;exosQTXRK
zI1FyD(_^lsT?N~o+VT$grEX+2N>)n@$dueEQT%qZjIOnxYG5`VC(S8tXRHnac%ade
z9%e?fw=3(ODk2hH6J^Apri&b0DPHP_*$p065TWNywM1Nu`XdHKq@ZPEo4P9V*VIwx
zQQXneFICcOKvk2k_m(HQjL}Bwu+p<=(TLz8I{Y8$?
zCa~Ubz4a%_OCfNJmfWDO!kEu7qN3q|uaLU(@~M{ylbW->)Os_YlZkL?|Ao$Qnq5S)*L*3Vr2T`=v45-^H-ip8e<(@>Jjy_`U@JP0fq%hFv;K#~z0y-W&Fx)G-6L8yqud(n@qfDm6Y^F_eSHIo7tTivN
z*-Y8|2x@lPGe1~Y4{*)ET4B)8fxwe9;52zCHC(1jIvbaKzOk{UY_4Rk2dc%+9P%x~
zkryo&!B;i3vROH^eMbt7&M#OiShztSN?0t9wx+a$#sCStJf0gGX^4bAW4%G#;
z_qZLJK}B@z90C_8(qN~My-xY&lE9m$CaSuS)^sa#F3~#(a_5)l>Dg&i7cpEXnKH|`
zgInKqk?q54AmZHFAUXa-UMm8jaJDf}wpCLj;wGG5A-Y7&M0A;OMoc)QiJAYua|Pmu
zM5MnzCm|w=a3H$$?=k9x_g{bUgyYvW|9U6MA^D#XSG00S|L6Sj_g_P~SZwDBZ&%%v
z3_Xd6=FoyUiS^&LjZb2*xGx2N;33-Le
z&CN9{lJND3wFhah-BRFA{NKaxN4C_m?I4raSfPfmK-+0z;OZwkMD?%~`O9dM
zB_EFNmV==!Ju`>0f2c;cwCLcz_p0
zMox~T5IBe9aFtN3Khvjh@ep^>c<KZNxR6()mpi~T
zYx)!BC>UCd<}0IPTusLDEr+A3=wZ-BG!|Ap9P7askk{ZY9RFM+$G|2fB~UKb=+0l@
z#z-hhfJ*l}X25(l=)(U58^PxD#P-K}hn^RtABP-hI?61(FX2$FCTep-e(UwO|8;Wz
zMoNnpB+ME;N7SA>(JP9-LHc{UJ)-=ZA_$96F*fT_lO)bK@wW<9yZ53r?Fwc8<_1jH
z>;=2Uogh+*&{gYkuu|oGjRPR`25dYuzDMGw_y`*PoLnms}pZs-@Mm$J|aDCHeKOi@pK64~Ti0Uuy5n}eQ$Ksn+
z8`yfQ>npy9GXJApwm)LCq)fb2CGpB%+=KX%g1$w9JiWWCiPz7cd|oWcBHW~0@7}#b
zsPL2vQl+cX{nx|$(;ahkLsIn3rZ=+jnL^&)9N~gTPv(NOkx@MSGMtU0dVljo0?aIC
zYAz%0MDuR%;D)(Kxtm1d?~s|;r051*%;ymQP5A$KPck4-A2hS=$1BV9w1fVEcaxNj
zJq2HbMPc-9nZKFD4aUpFQVDmmolRuT&COXPEmI{hFRpb(dDFin(>y%#t@-#gF!ko&
z_<2q=Nzwx~k7)iTd(*2n)|WQ#99hUJDm=YXc+I}}e%kso`mZ1U%FiBDoLafobR$}R
z`h}Q$V>7z-vqb7EnZ`F1Sj5GsrT_7NZbg3t-ouZw)q!}CV2e8~XzMABe8DhVtPh&Y
zQ#88M=kM&&OLFkziH2&elQ-`()lV^+`IeOV_Js?iq;~_>yBRBA@&463KFUg*T)=N>
zTb#ue`ol@N0$Mwku7aEXIPd@X#|XtQaCgzIPV1$w5stTdO3A98UU&HZ@2t3_o}=GO
zQ6m}uS)chSI+m6@>@SSp&T?lY+M3dcP3cj@wp{xAzqo{m#Ir(Hm9Q}PuDUYtC
z5a#!Vcy#mT(hqV09P*qzh`yEpYhP{E
zzq#;1fh=b)k=J90y3-n&mRHs1tRE#~F%SMOdhN|~lB|ivKj~xniq#XS@);p?rRCzr
ze+OY{mwT2B!o3I+w(PZ*lbDv?fBqMb!9tzx;OM>^C%M_S`(+GQH=$?RxFH(_lHSW7
zBZ^&6hqpR0i_VP~f9V6Olf7uM$vDY}qni&z0gX0k>aY@e{xdWk`h;3)Q99tX5OI|=
zYqK1);EP}?fh^|D;oIdLu(YpOlyXDJCWBpMPUW!P&K%v_)741L-pqBg;1xO>R5^Bn
z5)gNS+SW4&HiXWi@N8#YdFap4P%ynGN&w;;xKpXxG(_@eC$r|uc|-&R0!<-3_9@yE
z@#U2H7B}SS#1{T#spZRuVeepOfl`z~n;W1n0|Mt*`IA#T?ujYrxN&W4<62A;!G
zzA1eFUZW5lBTMGWbbiJmGXQgpD9~UDt{Bti6_;YYcyB$%DQt2+
zSp7~aevwlKC!Zd->DPV{*m$;<0U4VZT=B)So-|ZL_XbT2Ci>?~4I9-M%smh2Q6&Sq{7>{5O#9)yD+q)@0*7%M
zBCU0u4_bY`6%2|7CK=m)`f|Rx7|oadqx$?gOBr0_N?$~;<#g51&9-24ebt-S8*gyO
zDi?7FacOk5DMwZ;Uj_;a!Tvos3?`%cjY!&za`T?T*8N6fp!2M4$la*&$Gx8sH+9p$
z{UIXT8LxTah3TxP8p`RVQb}2xW-9I-#8Tyd)F#u+4D9GeS91pU?PO51dgPBu49XB9
zA7EjT0{}^#Eo-9>tA`*IkZT5RGc~DXT0uu2JWHz9l^SNcerF3Kz>tdz>H3+kJk4uh
z&~3i$Mtj2$SA*9JsdZeV)cr;H4~Q`Li*U3wSSw=+JCu7#PYcG(U6NToa7EO|2OVpKFIE#?1Q
z+b2)NmdJUM48(XB%<+-psBw
zZ5qemW!*bYrS<;l-AQq{*`uQ)>``v+G@x$
zbmqfBczBlIzP?I54jmGa87zdVz)5(La;;_uqJ;)OyBgdxoqT?Nm?vnT6Ugb+_*SvO
z1J3H(>;!3+9<&)46!qRJZ7pctScN0X15>H5y;`=>r-IInhJ}ag!e)6H_J^TT>jf|S
z&OMl&c)$`D_M-TpgA|udzmH~Vx*|Z{TUTyEAHP_;Nsfm4n(DOnl$_Fj#hoC#bXSbi
zi2uyDC@(EKZHp9juI|;58D6(eqKvfn7L=Eav|{W|c+E({dP_+YClrNdHR-Z$-_&U@tN=G#usmPV@%Vc++>Mz8iZ(50G-PsrIR>
z2=5ahx{BRog+~y!+y?+l`5q0NAn}2PV81GtYuM9-1>Yh14A(^;;0~3$P@#TU5RQ3H
zrE3x9_^`q;zW0+kv>oY#MM@({D1NNPKHa|Sz}96r!T}|fKHM|WqP#aE1XQ?qPt&Mm
zU9*qbf;1E27{&)hJpQ#QvuW}P-B?`FH!B;k&Z?c;l{v3-1+6_kX*2aLjrJWV
zt6wvV@E9RKQ6`f+X-vDwzyl#+6yw2#Fq`LpO#fslNPJx3ggp3SgEruwaLUgW{?#3{
z-Op(It(`f=K56PD0i?w~XZYLkeHV!FTa6w4D_+
z0gq<~MX-%Xj>`5$b9&=brPW@RKo6UJ1&vQro^uIE#uDCS%HNo3W|8e*(14Svu=~1K
z?G(I0F8WXwvf_7O%Zzn#f;2Wk-SNnQ@|R0>%{hhIwB%a$L|w4qK{74cgVGtFHPaI@
z+|*%ZXbG6n(w%8API5J@21)PQO_y5T9Xo`};8TcS@@FvFZaYk=PU@hBG*Gi25}$+*
z1Dk|IU0|}wgCf)4cY5#Bo0kgBKEVe}mQ_DHNHU`IFm$htQ8|bt6=+#+R+q6ElexgW
z8GF?GM;a&iPl#X68?x(DM9~S$w?=1jt}?%{^rR3-eI*k4nH%uPrGB#d8D&+U#&lhf
zpFlamc=7@wO-bI9F@mi|)?gf!dzxuLwy{&pgvYeGCkis=O<6NS{V!k3yqyD63|fii
zJnc~jOLuJ)lZv?1uGpfJlj3T1eKCe>g#$DW&orh-XfnPQsLj>nV$U?%(g2e@{EHzZ
zdm{2pEA_rj0$?8>CVc-uTM8B1h(+RQ?L&urGKv69m|2=#vv9E&18X9KtQSd2yyBi0R)~w}X2*wV|_)c6Hj)zi9&EPI_
z6@%IzFuTsT9j7E!#n7n&@8UdXqk^s~5OR&x*CFP()dG`)+Tb6W5ls$b^xwmsY<8r4
z+M
zwIt5xFF!r@f!u5dwn7^yZA4CM><@%bvpL$jr#hz6bd8H
z7+ue;(*MvF{+l}=HpxnU{9Nr)S*S8a!{rQESOVUV7b&{k7IE`Kl}nbeHURmf*#Lf%
zGXJ41{p5kDfB({c@d|?S_fyyt>rN
zX~ACP?<{Miipd|R{k$|
z4Pue
z1z5eK%r^kF)_PucY=1E=87oz`eN73K~Ntw!|5i4VELmw*z@0-%Lzz~;FBSPe9A{&qd8`2a$zDq8j5=Khc8
zFhWes{4zK2TUIMIAoPLrjqpL;4L|>8U*{EG06EimC9i_#E4f~7q=G^ho06t*PEJf!
zApa)4E3#vQb=77J>x29WTq+gd1aV;nKuha&@SP!fY|wt}k+7cmeFj?|3DhZ|@EMiJ
zG|j!5wuw~NIt!V|B}IMHq;hJzT;!p-3ZQ*@X(&RQ&EdCXkQw-kY95Ks8Y#M?ykX6r
zL(Sh=1HOkex`1$7*6K1@nE>d)>|qLBzwcJP^Q_)6L0At9yf|YBO4l<3=gy@y{gh#Q
zKzZkVVaW1TW>7Jy)I3-{aV~51#I?FX1?@mD%TXOP%0`Oua5X0m_8+
z58!9if6`|ygf1)i3Wghrn2~D=X@jqX@?tWCt6EW6tl?v3)5tg}>(w6x;k<(Qb2JnW
z6k~9O0#Dr7#nwq1EBtICP;&{0k$$q5-{6ZjQhQc?_vzz84SogxjdfyTsV_y27{r=}
zXFW23J*cqZKrL(Lu#kJ$sv$eTt;cX}?X4)n0*MHyPJsiOdT@KWa9o^J+2WV9edx}j
z%Ji8rrz+U~QI5}1
zB|)~5mlvwpx4&@T41Y$uut0%4;%fRZNy`52=tb0i|B-In@#m?A_bt*pEe8z){9eF!5pC|nnka>_#pu?brv7L0=*r%UVd$jwWa;+1LE}1vklNP}S}1g4;4JXy&~s~lMD~4}xqo-WPU}twf$h11dw$Yz
zt3O2-y*AA=7ES?&1)+X;BDu_^pS*C`#s36L_aOt|Sq|4*A3N?Z_iY65P+*X*qPs`U
z07os5)cxDptLZ!go)$TC^&5#hW((!#9rEaKh6@@+)`=DU!ZmWOb!F87{_13qO@vwH
zXRhGZ-X7+%ox`tuYBP#Vhc}?Hd8gyWgsm$f4&D!{wjK0X10DS1yS-Pq{8^xJc@e$I
zBjC=%3w35{>&p=F`X!UVPu|3>WWU+Vm5V)uTnaoWU@mkw{3ElGwcx$vg_k|4tPW(#
zr8hh^7pEr{vM4QWC`9yQV>dl(rJ1ZKOpAoLI$}{`Rb-VqXXy8-eK2hGzb}&
z=?rd%osCbS*u>naYyu}jSxz8g;oM6OUC{!q%X&sQHOi%L%&_U4;CtFFrUM|{5)v>&
zLP7g)hRj&d&d}x;AZ~GI=(z^->DCWIoQGPtuUr`5in!6Nsfcb3QR8Y|gI)1neqr-k#@g
zTlRpW8rRI-FdN@$boV!l?#cvLU^j0q6f>ro1;rW%6yJvr8RaR@wGA30G=eeFqcsGg
zQlBB@z)_DK^u6Hsfbp!&ad0u7AsJClra7PNa+v+R`)5
z_ph(A@b7Wg5MN+C8P^^@1wWNE3&zhjH_maIK|$*YLQl&TnZEx~ACqOcB;?nbeFxdN
z9Ub~`n~q&1vU7C_C@6n?_`XGxepc+s%6M7Ol2M&h{2g_Kb7Rnme6@}rmXao^d@aYAEp#6zSUpsj!@4{X=22QRy6=$b_l%U|m&|!_Rr!hC4=-atcI#vASl{U{
zYt8ucHF;mT{cjNzZQjM$6^N^sDqmxm#A1arN2uA=&jCGSK?iEyOoDqIxZ|kM-idvg
zi@g^{V0rqXY^BPL2h0TvMyDv0?^y`#WB96zF$M;FcwQcehW*eGoePNI!g=2i+#VI2{H5a5JH0_tgATQ!EogD4RWhw@0&mX>78f>bmPi5be|l_#
zZrZ-shd~Q117^b14u(a+V{gLX=mt2s!;L<7@I{#UiV`D#f=Zj8adBMVOTf+6C#J476*~L*A5!#8+BB1!T(-)ZQv=a`GLzw3
z-kb-+GT^{5g)NwicDpyUsRMVIZr%h5+Pfs`z1i}`eNmQ@s?s5RC9PIoW^UERRhd9f
z+rzT^E9c*iA>J{UrbVB8XJ_+PDXDUq@jS`G&0jNi^}Zl90dC|ua>z^SO6e{`KtSG@
zf3@8;#e8Yy`9TMSQ(Xm8a?wgDh&xC
zuAj1s7RPz-6`Nz50?$_TB|~QU@O>IVLcM!Go(!*&X*!<_xg_?sLNcSr`uh`AdC{(t
zwg12t*KeXUq*J6G&vt&g`wOZ+m{<@L4N6k3V|er9>XPz{#^&RuEuQ*bW;69{kG(-x
zkO0+J<2<5=0knaSyW8`LPwWqABf$7YSl