diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 587c663482..c029e00553 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -1,6 +1,6 @@ import logging import time -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from typing import Any, cast from sqlalchemy import select @@ -25,6 +25,7 @@ from core.moderation.input_moderation import InputModeration from core.variables.variables import VariableUnion from core.workflow.enums import WorkflowType from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel +from core.workflow.graph_engine.layers.base import GraphEngineLayer from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository @@ -61,11 +62,13 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): app: App, workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, + graph_engine_layers: Sequence[GraphEngineLayer] = (), ): super().__init__( queue_manager=queue_manager, variable_loader=variable_loader, app_id=application_generate_entity.app_config.app_id, + graph_engine_layers=graph_engine_layers, ) self.application_generate_entity = application_generate_entity self.conversation = conversation @@ -195,6 +198,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): ) workflow_entry.graph_engine.layer(persistence_layer) + for layer in self._graph_engine_layers: + workflow_entry.graph_engine.layer(layer) generator = workflow_entry.run() diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 3c9bf176b5..eab2256426 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -135,6 +135,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): ) workflow_entry.graph_engine.layer(persistence_layer) + for layer in self._graph_engine_layers: + workflow_entry.graph_engine.layer(layer) generator = workflow_entry.run() diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 5e2bd17f8c..73725e75b5 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -1,5 +1,5 @@ import time -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from typing import Any, cast from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -27,6 +27,7 @@ from core.app.entities.queue_entities import ( ) from core.workflow.entities import GraphInitParams from core.workflow.graph import Graph +from core.workflow.graph_engine.layers.base import GraphEngineLayer from core.workflow.graph_events import ( GraphEngineEvent, GraphRunFailedEvent, @@ -69,10 +70,12 @@ class WorkflowBasedAppRunner: queue_manager: AppQueueManager, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, app_id: str, + graph_engine_layers: Sequence[GraphEngineLayer] = (), ): self._queue_manager = queue_manager self._variable_loader = variable_loader self._app_id = app_id + self._graph_engine_layers = graph_engine_layers def _init_graph( self, diff --git a/api/core/app/layers/pause_state_persist_layer.py b/api/core/app/layers/pause_state_persist_layer.py new file mode 100644 index 0000000000..3dee75c082 --- /dev/null +++ b/api/core/app/layers/pause_state_persist_layer.py @@ -0,0 +1,71 @@ +from sqlalchemy import Engine +from sqlalchemy.orm import sessionmaker + +from core.workflow.graph_engine.layers.base import GraphEngineLayer +from core.workflow.graph_events.base import GraphEngineEvent +from core.workflow.graph_events.graph import GraphRunPausedEvent +from repositories.api_workflow_run_repository import APIWorkflowRunRepository +from repositories.factory import DifyAPIRepositoryFactory + + +class PauseStatePersistenceLayer(GraphEngineLayer): + def __init__(self, session_factory: Engine | sessionmaker, state_owner_user_id: str): + """Create a PauseStatePersistenceLayer. + + The `state_owner_user_id` is used when creating state file for pause. + It generally should id of the creator of workflow. + """ + if isinstance(session_factory, Engine): + session_factory = sessionmaker(session_factory) + self._session_maker = session_factory + self._state_owner_user_id = state_owner_user_id + + def _get_repo(self) -> APIWorkflowRunRepository: + return DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_maker) + + def on_graph_start(self) -> None: + """ + Called when graph execution starts. + + This is called after the engine has been initialized but before any nodes + are executed. Layers can use this to set up resources or log start information. + """ + pass + + def on_event(self, event: GraphEngineEvent) -> None: + """ + Called for every event emitted by the engine. + + This method receives all events generated during graph execution, including: + - Graph lifecycle events (start, success, failure) + - Node execution events (start, success, failure, retry) + - Stream events for response nodes + - Container events (iteration, loop) + + Args: + event: The event emitted by the engine + """ + if not isinstance(event, GraphRunPausedEvent): + return + + assert self.graph_runtime_state is not None + workflow_run_id: str | None = self.graph_runtime_state.system_variable.workflow_execution_id + assert workflow_run_id is not None + repo = self._get_repo() + repo.create_workflow_pause( + workflow_run_id=workflow_run_id, + state_owner_user_id=self._state_owner_user_id, + state=self.graph_runtime_state.dumps(), + ) + + def on_graph_end(self, error: Exception | None) -> None: + """ + Called when graph execution ends. + + This is called after all nodes have been executed or when execution is + aborted. Layers can use this to clean up resources or log final state. + + Args: + error: The exception that caused execution to fail, or None if successful + """ + pass diff --git a/api/core/workflow/entities/__init__.py b/api/core/workflow/entities/__init__.py index 185f0ad620..f4ce9052e0 100644 --- a/api/core/workflow/entities/__init__.py +++ b/api/core/workflow/entities/__init__.py @@ -4,6 +4,7 @@ from .agent import AgentNodeStrategyInit from .graph_init_params import GraphInitParams from .workflow_execution import WorkflowExecution from .workflow_node_execution import WorkflowNodeExecution +from .workflow_pause import WorkflowPauseEntity __all__ = [ "AgentNodeStrategyInit", @@ -12,4 +13,5 @@ __all__ = [ "VariablePool", "WorkflowExecution", "WorkflowNodeExecution", + "WorkflowPauseEntity", ] diff --git a/api/core/workflow/entities/pause_reason.py b/api/core/workflow/entities/pause_reason.py new file mode 100644 index 0000000000..16ad3d639d --- /dev/null +++ b/api/core/workflow/entities/pause_reason.py @@ -0,0 +1,49 @@ +from enum import StrEnum, auto +from typing import Annotated, Any, ClassVar, TypeAlias + +from pydantic import BaseModel, Discriminator, Tag + + +class _PauseReasonType(StrEnum): + HUMAN_INPUT_REQUIRED = auto() + SCHEDULED_PAUSE = auto() + + +class _PauseReasonBase(BaseModel): + TYPE: ClassVar[_PauseReasonType] + + +class HumanInputRequired(_PauseReasonBase): + TYPE = _PauseReasonType.HUMAN_INPUT_REQUIRED + + +class SchedulingPause(_PauseReasonBase): + TYPE = _PauseReasonType.SCHEDULED_PAUSE + + message: str + + +def _get_pause_reason_discriminator(v: Any) -> _PauseReasonType | None: + if isinstance(v, _PauseReasonBase): + return v.TYPE + elif isinstance(v, dict): + reason_type_str = v.get("TYPE") + if reason_type_str is None: + return None + try: + reason_type = _PauseReasonType(reason_type_str) + except ValueError: + return None + return reason_type + else: + # return None if the discriminator value isn't found + return None + + +PauseReason: TypeAlias = Annotated[ + ( + Annotated[HumanInputRequired, Tag(_PauseReasonType.HUMAN_INPUT_REQUIRED)] + | Annotated[SchedulingPause, Tag(_PauseReasonType.SCHEDULED_PAUSE)] + ), + Discriminator(_get_pause_reason_discriminator), +] diff --git a/api/core/workflow/entities/workflow_pause.py b/api/core/workflow/entities/workflow_pause.py new file mode 100644 index 0000000000..2f31c1ff53 --- /dev/null +++ b/api/core/workflow/entities/workflow_pause.py @@ -0,0 +1,61 @@ +""" +Domain entities for workflow pause management. + +This module contains the domain model for workflow pause, which is used +by the core workflow module. These models are independent of the storage mechanism +and don't contain implementation details like tenant_id, app_id, etc. +""" + +from abc import ABC, abstractmethod +from datetime import datetime + + +class WorkflowPauseEntity(ABC): + """ + Abstract base class for workflow pause entities. + + This domain model represents a paused workflow execution state, + without implementation details like tenant_id, app_id, etc. + It provides the interface for managing workflow pause/resume operations + and state persistence through file storage. + + The `WorkflowPauseEntity` is never reused. If a workflow execution pauses multiple times, + it will generate multiple `WorkflowPauseEntity` records. + """ + + @property + @abstractmethod + def id(self) -> str: + """The identifier of current WorkflowPauseEntity""" + pass + + @property + @abstractmethod + def workflow_execution_id(self) -> str: + """The identifier of the workflow execution record the pause associated with. + Correspond to `WorkflowExecution.id`. + """ + + @abstractmethod + def get_state(self) -> bytes: + """ + Retrieve the serialized workflow state from storage. + + This method should load and return the workflow execution state + that was saved when the workflow was paused. The state contains + all necessary information to resume the workflow execution. + + Returns: + bytes: The serialized workflow state containing + execution context, variable values, node states, etc. + + """ + ... + + @property + @abstractmethod + def resumed_at(self) -> datetime | None: + """`resumed_at` return the resumption time of the current pause, or `None` if + the pause is not resumed yet. + """ + pass diff --git a/api/core/workflow/enums.py b/api/core/workflow/enums.py index 83b9281e51..6f95ecc76f 100644 --- a/api/core/workflow/enums.py +++ b/api/core/workflow/enums.py @@ -92,13 +92,111 @@ class WorkflowType(StrEnum): class WorkflowExecutionStatus(StrEnum): + # State diagram for the workflw status: + # (@) means start, (*) means end + # + # ┌------------------>------------------------->------------------->--------------┐ + # | | + # | ┌-----------------------<--------------------┐ | + # ^ | | | + # | | ^ | + # | V | | + # ┌-----------┐ ┌-----------------------┐ ┌-----------┐ V + # | Scheduled |------->| Running |---------------------->| paused | | + # └-----------┘ └-----------------------┘ └-----------┘ | + # | | | | | | | + # | | | | | | | + # ^ | | | V V | + # | | | | | ┌---------┐ | + # (@) | | | └------------------------>| Stopped |<----┘ + # | | | └---------┘ + # | | | | + # | | V V + # | | ┌-----------┐ | + # | | | Succeeded |------------->--------------┤ + # | | └-----------┘ | + # | V V + # | +--------┐ | + # | | Failed |---------------------->----------------┤ + # | └--------┘ | + # V V + # ┌---------------------┐ | + # | Partially Succeeded |---------------------->-----------------┘--------> (*) + # └---------------------┘ + # + # Mermaid diagram: + # + # --- + # title: State diagram for Workflow run state + # --- + # stateDiagram-v2 + # scheduled: Scheduled + # running: Running + # succeeded: Succeeded + # failed: Failed + # partial_succeeded: Partial Succeeded + # paused: Paused + # stopped: Stopped + # + # [*] --> scheduled: + # scheduled --> running: Start Execution + # running --> paused: Human input required + # paused --> running: human input added + # paused --> stopped: User stops execution + # running --> succeeded: Execution finishes without any error + # running --> failed: Execution finishes with errors + # running --> stopped: User stops execution + # running --> partial_succeeded: some execution occurred and handled during execution + # + # scheduled --> stopped: User stops execution + # + # succeeded --> [*] + # failed --> [*] + # partial_succeeded --> [*] + # stopped --> [*] + + # `SCHEDULED` means that the workflow is scheduled to run, but has not + # started running yet. (maybe due to possible worker saturation.) + # + # This enum value is currently unused. + SCHEDULED = "scheduled" + + # `RUNNING` means the workflow is exeuting. RUNNING = "running" + + # `SUCCEEDED` means the execution of workflow succeed without any error. SUCCEEDED = "succeeded" + + # `FAILED` means the execution of workflow failed without some errors. FAILED = "failed" + + # `STOPPED` means the execution of workflow was stopped, either manually + # by the user, or automatically by the Dify application (E.G. the moderation + # mechanism.) STOPPED = "stopped" + + # `PARTIAL_SUCCEEDED` indicates that some errors occurred during the workflow + # execution, but they were successfully handled (e.g., by using an error + # strategy such as "fail branch" or "default value"). PARTIAL_SUCCEEDED = "partial-succeeded" + + # `PAUSED` indicates that the workflow execution is temporarily paused + # (e.g., awaiting human input) and is expected to resume later. PAUSED = "paused" + def is_ended(self) -> bool: + return self in _END_STATE + + +_END_STATE = frozenset( + [ + WorkflowExecutionStatus.SUCCEEDED, + WorkflowExecutionStatus.FAILED, + WorkflowExecutionStatus.PARTIAL_SUCCEEDED, + WorkflowExecutionStatus.STOPPED, + ] +) + class WorkflowNodeExecutionMetadataKey(StrEnum): """ diff --git a/api/core/workflow/graph_engine/command_processing/command_handlers.py b/api/core/workflow/graph_engine/command_processing/command_handlers.py index c26c98c496..e9f109c88c 100644 --- a/api/core/workflow/graph_engine/command_processing/command_handlers.py +++ b/api/core/workflow/graph_engine/command_processing/command_handlers.py @@ -3,6 +3,8 @@ from typing import final from typing_extensions import override +from core.workflow.entities.pause_reason import SchedulingPause + from ..domain.graph_execution import GraphExecution from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand from .command_processor import CommandHandler @@ -25,4 +27,7 @@ class PauseCommandHandler(CommandHandler): def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: assert isinstance(command, PauseCommand) logger.debug("Pausing workflow %s: %s", execution.workflow_id, command.reason) - execution.pause(command.reason) + # Convert string reason to PauseReason if needed + reason = command.reason + pause_reason = SchedulingPause(message=reason) + execution.pause(pause_reason) diff --git a/api/core/workflow/graph_engine/domain/graph_execution.py b/api/core/workflow/graph_engine/domain/graph_execution.py index 6482c927d6..3d587d6691 100644 --- a/api/core/workflow/graph_engine/domain/graph_execution.py +++ b/api/core/workflow/graph_engine/domain/graph_execution.py @@ -8,6 +8,7 @@ from typing import Literal from pydantic import BaseModel, Field +from core.workflow.entities.pause_reason import PauseReason from core.workflow.enums import NodeState from .node_execution import NodeExecution @@ -41,7 +42,7 @@ class GraphExecutionState(BaseModel): completed: bool = Field(default=False) aborted: bool = Field(default=False) paused: bool = Field(default=False) - pause_reason: str | None = Field(default=None) + pause_reason: PauseReason | None = Field(default=None) error: GraphExecutionErrorState | None = Field(default=None) exceptions_count: int = Field(default=0) node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState]) @@ -106,7 +107,7 @@ class GraphExecution: completed: bool = False aborted: bool = False paused: bool = False - pause_reason: str | None = None + pause_reason: PauseReason | None = None error: Exception | None = None node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution]) exceptions_count: int = 0 @@ -130,7 +131,7 @@ class GraphExecution: self.aborted = True self.error = RuntimeError(f"Aborted: {reason}") - def pause(self, reason: str | None = None) -> None: + def pause(self, reason: PauseReason) -> None: """Pause the graph execution without marking it complete.""" if self.completed: raise RuntimeError("Cannot pause execution that has completed") diff --git a/api/core/workflow/graph_engine/entities/commands.py b/api/core/workflow/graph_engine/entities/commands.py index 6070ed8812..0d51b2b716 100644 --- a/api/core/workflow/graph_engine/entities/commands.py +++ b/api/core/workflow/graph_engine/entities/commands.py @@ -36,4 +36,4 @@ class PauseCommand(GraphEngineCommand): """Command to pause a running workflow execution.""" command_type: CommandType = Field(default=CommandType.PAUSE, description="Type of command") - reason: str | None = Field(default=None, description="Optional reason for pause") + reason: str = Field(default="unknown reason", description="reason for pause") diff --git a/api/core/workflow/graph_engine/event_management/event_handlers.py b/api/core/workflow/graph_engine/event_management/event_handlers.py index b054ebd7ad..5b0f56e59d 100644 --- a/api/core/workflow/graph_engine/event_management/event_handlers.py +++ b/api/core/workflow/graph_engine/event_management/event_handlers.py @@ -210,7 +210,7 @@ class EventHandler: def _(self, event: NodeRunPauseRequestedEvent) -> None: """Handle pause requests emitted by nodes.""" - pause_reason = event.reason or "Awaiting human input" + pause_reason = event.reason self._graph_execution.pause(pause_reason) self._state_manager.finish_execution(event.node_id) if event.node_id in self._graph.nodes: diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index dd2ca3f93b..7071a1f33a 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -247,8 +247,11 @@ class GraphEngine: # Handle completion if self._graph_execution.is_paused: + pause_reason = self._graph_execution.pause_reason + assert pause_reason is not None, "pause_reason should not be None when execution is paused." + # Ensure we have a valid PauseReason for the event paused_event = GraphRunPausedEvent( - reason=self._graph_execution.pause_reason, + reason=pause_reason, outputs=self._graph_runtime_state.outputs, ) self._event_manager.notify_layers(paused_event) diff --git a/api/core/workflow/graph_engine/layers/persistence.py b/api/core/workflow/graph_engine/layers/persistence.py index ecd8e12ca5..b70f36ec9e 100644 --- a/api/core/workflow/graph_engine/layers/persistence.py +++ b/api/core/workflow/graph_engine/layers/persistence.py @@ -216,7 +216,6 @@ class WorkflowPersistenceLayer(GraphEngineLayer): def _handle_graph_run_paused(self, event: GraphRunPausedEvent) -> None: execution = self._get_workflow_execution() execution.status = WorkflowExecutionStatus.PAUSED - execution.error_message = event.reason or "Workflow execution paused" execution.outputs = event.outputs self._populate_completion_statistics(execution, update_finished=False) @@ -296,7 +295,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer): domain_execution, event.node_run_result, WorkflowNodeExecutionStatus.PAUSED, - error=event.reason, + error="", update_outputs=False, ) diff --git a/api/core/workflow/graph_events/graph.py b/api/core/workflow/graph_events/graph.py index 0da962aa1c..9faafc3173 100644 --- a/api/core/workflow/graph_events/graph.py +++ b/api/core/workflow/graph_events/graph.py @@ -1,5 +1,6 @@ from pydantic import Field +from core.workflow.entities.pause_reason import PauseReason from core.workflow.graph_events import BaseGraphEvent @@ -44,7 +45,8 @@ class GraphRunAbortedEvent(BaseGraphEvent): class GraphRunPausedEvent(BaseGraphEvent): """Event emitted when a graph run is paused by user command.""" - reason: str | None = Field(default=None, description="reason for pause") + # reason: str | None = Field(default=None, description="reason for pause") + reason: PauseReason = Field(..., description="reason for pause") outputs: dict[str, object] = Field( default_factory=dict, description="Outputs available to the client while the run is paused.", diff --git a/api/core/workflow/graph_events/node.py b/api/core/workflow/graph_events/node.py index b880df60d1..f225798d41 100644 --- a/api/core/workflow/graph_events/node.py +++ b/api/core/workflow/graph_events/node.py @@ -5,6 +5,7 @@ from pydantic import Field from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.workflow.entities import AgentNodeStrategyInit +from core.workflow.entities.pause_reason import PauseReason from .base import GraphNodeEventBase @@ -54,4 +55,4 @@ class NodeRunRetryEvent(NodeRunStartedEvent): class NodeRunPauseRequestedEvent(GraphNodeEventBase): - reason: str | None = Field(default=None, description="Optional pause reason") + reason: PauseReason = Field(..., description="pause reason") diff --git a/api/core/workflow/node_events/node.py b/api/core/workflow/node_events/node.py index 4fd5684436..ebf93f2fc2 100644 --- a/api/core/workflow/node_events/node.py +++ b/api/core/workflow/node_events/node.py @@ -5,6 +5,7 @@ from pydantic import Field from core.model_runtime.entities.llm_entities import LLMUsage from core.rag.entities.citation_metadata import RetrievalSourceMetadata +from core.workflow.entities.pause_reason import PauseReason from core.workflow.node_events import NodeRunResult from .base import NodeEventBase @@ -43,4 +44,4 @@ class StreamCompletedEvent(NodeEventBase): class PauseRequestedEvent(NodeEventBase): - reason: str | None = Field(default=None, description="Optional pause reason") + reason: PauseReason = Field(..., description="pause reason") diff --git a/api/core/workflow/nodes/human_input/human_input_node.py b/api/core/workflow/nodes/human_input/human_input_node.py index e49f9a8c81..2d6d9760af 100644 --- a/api/core/workflow/nodes/human_input/human_input_node.py +++ b/api/core/workflow/nodes/human_input/human_input_node.py @@ -1,6 +1,7 @@ from collections.abc import Mapping from typing import Any +from core.workflow.entities.pause_reason import HumanInputRequired from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult, PauseRequestedEvent from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig @@ -64,7 +65,7 @@ class HumanInputNode(Node): return self._pause_generator() def _pause_generator(self): - yield PauseRequestedEvent(reason=self._node_data.pause_reason) + yield PauseRequestedEvent(reason=HumanInputRequired()) def _is_completion_ready(self) -> bool: """Determine whether all required inputs are satisfied.""" diff --git a/api/core/workflow/runtime/graph_runtime_state_protocol.py b/api/core/workflow/runtime/graph_runtime_state_protocol.py index 40835a936f..5e0878e873 100644 --- a/api/core/workflow/runtime/graph_runtime_state_protocol.py +++ b/api/core/workflow/runtime/graph_runtime_state_protocol.py @@ -3,6 +3,7 @@ from typing import Any, Protocol from core.model_runtime.entities.llm_entities import LLMUsage from core.variables.segments import Segment +from core.workflow.system_variable import SystemVariableReadOnlyView class ReadOnlyVariablePool(Protocol): @@ -30,6 +31,9 @@ class ReadOnlyGraphRuntimeState(Protocol): All methods return defensive copies to ensure immutability. """ + @property + def system_variable(self) -> SystemVariableReadOnlyView: ... + @property def variable_pool(self) -> ReadOnlyVariablePool: """Get read-only access to the variable pool.""" diff --git a/api/core/workflow/runtime/read_only_wrappers.py b/api/core/workflow/runtime/read_only_wrappers.py index 664c365295..8539727fd6 100644 --- a/api/core/workflow/runtime/read_only_wrappers.py +++ b/api/core/workflow/runtime/read_only_wrappers.py @@ -6,6 +6,7 @@ from typing import Any from core.model_runtime.entities.llm_entities import LLMUsage from core.variables.segments import Segment +from core.workflow.system_variable import SystemVariableReadOnlyView from .graph_runtime_state import GraphRuntimeState from .variable_pool import VariablePool @@ -42,6 +43,10 @@ class ReadOnlyGraphRuntimeStateWrapper: self._state = state self._variable_pool_wrapper = ReadOnlyVariablePoolWrapper(state.variable_pool) + @property + def system_variable(self) -> SystemVariableReadOnlyView: + return self._state.variable_pool.system_variables.as_view() + @property def variable_pool(self) -> ReadOnlyVariablePoolWrapper: return self._variable_pool_wrapper diff --git a/api/core/workflow/system_variable.py b/api/core/workflow/system_variable.py index 6716e745cd..29bf19716c 100644 --- a/api/core/workflow/system_variable.py +++ b/api/core/workflow/system_variable.py @@ -1,4 +1,5 @@ from collections.abc import Mapping, Sequence +from types import MappingProxyType from typing import Any from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator @@ -108,3 +109,102 @@ class SystemVariable(BaseModel): if self.invoke_from is not None: d[SystemVariableKey.INVOKE_FROM] = self.invoke_from return d + + def as_view(self) -> "SystemVariableReadOnlyView": + return SystemVariableReadOnlyView(self) + + +class SystemVariableReadOnlyView: + """ + A read-only view of a SystemVariable that implements the ReadOnlySystemVariable protocol. + + This class wraps a SystemVariable instance and provides read-only access to all its fields. + It always reads the latest data from the wrapped instance and prevents any write operations. + """ + + def __init__(self, system_variable: SystemVariable) -> None: + """ + Initialize the read-only view with a SystemVariable instance. + + Args: + system_variable: The SystemVariable instance to wrap + """ + self._system_variable = system_variable + + @property + def user_id(self) -> str | None: + return self._system_variable.user_id + + @property + def app_id(self) -> str | None: + return self._system_variable.app_id + + @property + def workflow_id(self) -> str | None: + return self._system_variable.workflow_id + + @property + def workflow_execution_id(self) -> str | None: + return self._system_variable.workflow_execution_id + + @property + def query(self) -> str | None: + return self._system_variable.query + + @property + def conversation_id(self) -> str | None: + return self._system_variable.conversation_id + + @property + def dialogue_count(self) -> int | None: + return self._system_variable.dialogue_count + + @property + def document_id(self) -> str | None: + return self._system_variable.document_id + + @property + def original_document_id(self) -> str | None: + return self._system_variable.original_document_id + + @property + def dataset_id(self) -> str | None: + return self._system_variable.dataset_id + + @property + def batch(self) -> str | None: + return self._system_variable.batch + + @property + def datasource_type(self) -> str | None: + return self._system_variable.datasource_type + + @property + def invoke_from(self) -> str | None: + return self._system_variable.invoke_from + + @property + def files(self) -> Sequence[File]: + """ + Get a copy of the files from the wrapped SystemVariable. + + Returns: + A defensive copy of the files sequence to prevent modification + """ + return tuple(self._system_variable.files) # Convert to immutable tuple + + @property + def datasource_info(self) -> Mapping[str, Any] | None: + """ + Get a copy of the datasource info from the wrapped SystemVariable. + + Returns: + A view of the datasource info mapping to prevent modification + """ + if self._system_variable.datasource_info is None: + return None + return MappingProxyType(self._system_variable.datasource_info) + + def __repr__(self) -> str: + """Return a string representation of the read-only view.""" + return f"SystemVariableReadOnlyView(system_variable={self._system_variable!r})" diff --git a/api/extensions/ext_storage.py b/api/extensions/ext_storage.py index 2960cde242..a609f13dbc 100644 --- a/api/extensions/ext_storage.py +++ b/api/extensions/ext_storage.py @@ -85,7 +85,7 @@ class Storage: case _: raise ValueError(f"unsupported storage type {storage_type}") - def save(self, filename, data): + def save(self, filename: str, data: bytes): self.storage_runner.save(filename, data) @overload diff --git a/api/extensions/storage/base_storage.py b/api/extensions/storage/base_storage.py index 0393206e54..8ddedb24ae 100644 --- a/api/extensions/storage/base_storage.py +++ b/api/extensions/storage/base_storage.py @@ -8,7 +8,7 @@ class BaseStorage(ABC): """Interface for file storage.""" @abstractmethod - def save(self, filename, data): + def save(self, filename: str, data: bytes): raise NotImplementedError @abstractmethod diff --git a/api/migrations/versions/2025_10_22_1611-03f8dcbc611e_add_workflowpause_model.py b/api/migrations/versions/2025_10_22_1611-03f8dcbc611e_add_workflowpause_model.py new file mode 100644 index 0000000000..1ab4202674 --- /dev/null +++ b/api/migrations/versions/2025_10_22_1611-03f8dcbc611e_add_workflowpause_model.py @@ -0,0 +1,41 @@ +"""add WorkflowPause model + +Revision ID: 03f8dcbc611e +Revises: ae662b25d9bc +Create Date: 2025-10-22 16:11:31.805407 + +""" + +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "03f8dcbc611e" +down_revision = "ae662b25d9bc" +branch_labels = None +depends_on = None + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "workflow_pauses", + sa.Column("workflow_id", models.types.StringUUID(), nullable=False), + sa.Column("workflow_run_id", models.types.StringUUID(), nullable=False), + sa.Column("resumed_at", sa.DateTime(), nullable=True), + sa.Column("state_object_key", sa.String(length=255), nullable=False), + sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuidv7()"), nullable=False), + sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.PrimaryKeyConstraint("id", name=op.f("workflow_pauses_pkey")), + sa.UniqueConstraint("workflow_run_id", name=op.f("workflow_pauses_workflow_run_id_key")), + ) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("workflow_pauses") + # ### end Alembic commands ### diff --git a/api/models/__init__.py b/api/models/__init__.py index 779484283f..1c09b4610d 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -88,6 +88,7 @@ from .workflow import ( WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload, WorkflowNodeExecutionTriggeredFrom, + WorkflowPause, WorkflowRun, WorkflowType, ) @@ -177,6 +178,7 @@ __all__ = [ "WorkflowNodeExecutionModel", "WorkflowNodeExecutionOffload", "WorkflowNodeExecutionTriggeredFrom", + "WorkflowPause", "WorkflowRun", "WorkflowRunTriggeredFrom", "WorkflowToolProvider", diff --git a/api/models/base.py b/api/models/base.py index 76848825fe..3660068035 100644 --- a/api/models/base.py +++ b/api/models/base.py @@ -1,6 +1,12 @@ -from sqlalchemy.orm import DeclarativeBase, MappedAsDataclass +from datetime import datetime +from sqlalchemy import DateTime, func, text +from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column + +from libs.datetime_utils import naive_utc_now +from libs.uuid_utils import uuidv7 from models.engine import metadata +from models.types import StringUUID class Base(DeclarativeBase): @@ -13,3 +19,34 @@ class TypeBase(MappedAsDataclass, DeclarativeBase): """ metadata = metadata + + +class DefaultFieldsMixin: + id: Mapped[str] = mapped_column( + StringUUID, + primary_key=True, + # NOTE: The default and server_default serve as fallback mechanisms. + # The application can generate the `id` before saving to optimize + # the insertion process (especially for interdependent models) + # and reduce database roundtrips. + default=uuidv7, + server_default=text("uuidv7()"), + ) + + created_at: Mapped[datetime] = mapped_column( + DateTime, + nullable=False, + default=naive_utc_now, + server_default=func.current_timestamp(), + ) + + updated_at: Mapped[datetime] = mapped_column( + __name_pos=DateTime, + nullable=False, + default=naive_utc_now, + server_default=func.current_timestamp(), + onupdate=func.current_timestamp(), + ) + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}(id={self.id})>" diff --git a/api/models/workflow.py b/api/models/workflow.py index b898f02612..d312b96b39 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -13,8 +13,11 @@ from core.file.constants import maybe_file_object from core.file.models import File from core.variables import utils as variable_utils from core.variables.variables import FloatVariable, IntegerVariable, StringVariable -from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from core.workflow.enums import NodeType +from core.workflow.constants import ( + CONVERSATION_VARIABLE_NODE_ID, + SYSTEM_VARIABLE_NODE_ID, +) +from core.workflow.enums import NodeType, WorkflowExecutionStatus from extensions.ext_storage import Storage from factories.variable_factory import TypeMismatchError, build_segment_with_type from libs.datetime_utils import naive_utc_now @@ -35,7 +38,7 @@ from factories import variable_factory from libs import helper from .account import Account -from .base import Base +from .base import Base, DefaultFieldsMixin from .engine import db from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType from .types import EnumText, StringUUID @@ -247,7 +250,9 @@ class Workflow(Base): return node_type @staticmethod - def get_enclosing_node_type_and_id(node_config: Mapping[str, Any]) -> tuple[NodeType, str] | None: + def get_enclosing_node_type_and_id( + node_config: Mapping[str, Any], + ) -> tuple[NodeType, str] | None: in_loop = node_config.get("isInLoop", False) in_iteration = node_config.get("isInIteration", False) if in_loop: @@ -306,7 +311,10 @@ class Workflow(Base): if "nodes" not in graph_dict: return [] - start_node = next((node for node in graph_dict["nodes"] if node["data"]["type"] == "start"), None) + start_node = next( + (node for node in graph_dict["nodes"] if node["data"]["type"] == "start"), + None, + ) if not start_node: return [] @@ -359,7 +367,9 @@ class Workflow(Base): return db.session.execute(stmt).scalar_one() @property - def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]: + def environment_variables( + self, + ) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]: # TODO: find some way to init `self._environment_variables` when instance created. if self._environment_variables is None: self._environment_variables = "{}" @@ -376,7 +386,9 @@ class Workflow(Base): ] # decrypt secret variables value - def decrypt_func(var: Variable) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable: + def decrypt_func( + var: Variable, + ) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable: if isinstance(var, SecretVariable): return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)}) elif isinstance(var, (StringVariable, IntegerVariable, FloatVariable)): @@ -537,7 +549,10 @@ class WorkflowRun(Base): version: Mapped[str] = mapped_column(String(255)) graph: Mapped[str | None] = mapped_column(sa.Text) inputs: Mapped[str | None] = mapped_column(sa.Text) - status: Mapped[str] = mapped_column(String(255)) # running, succeeded, failed, stopped, partial-succeeded + status: Mapped[str] = mapped_column( + EnumText(WorkflowExecutionStatus, length=255), + nullable=False, + ) outputs: Mapped[str | None] = mapped_column(sa.Text, default="{}") error: Mapped[str | None] = mapped_column(sa.Text) elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0")) @@ -549,6 +564,15 @@ class WorkflowRun(Base): finished_at: Mapped[datetime | None] = mapped_column(DateTime) exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True) + pause: Mapped[Optional["WorkflowPause"]] = orm.relationship( + "WorkflowPause", + primaryjoin="WorkflowRun.id == foreign(WorkflowPause.workflow_run_id)", + uselist=False, + # require explicit preloading. + lazy="raise", + back_populates="workflow_run", + ) + @property def created_by_account(self): created_by_role = CreatorUserRole(self.created_by_role) @@ -1073,7 +1097,10 @@ class ConversationVariable(Base): DateTime, nullable=False, server_default=func.current_timestamp(), index=True ) updated_at: Mapped[datetime] = mapped_column( - DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + DateTime, + nullable=False, + server_default=func.current_timestamp(), + onupdate=func.current_timestamp(), ) def __init__(self, *, id: str, app_id: str, conversation_id: str, data: str): @@ -1101,10 +1128,6 @@ class ConversationVariable(Base): _EDITABLE_SYSTEM_VARIABLE = frozenset(["query", "files"]) -def _naive_utc_datetime(): - return naive_utc_now() - - class WorkflowDraftVariable(Base): """`WorkflowDraftVariable` record variables and outputs generated during debugging workflow or chatflow. @@ -1138,14 +1161,14 @@ class WorkflowDraftVariable(Base): created_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, - default=_naive_utc_datetime, + default=naive_utc_now, server_default=func.current_timestamp(), ) updated_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, - default=_naive_utc_datetime, + default=naive_utc_now, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), ) @@ -1412,8 +1435,8 @@ class WorkflowDraftVariable(Base): file_id: str | None = None, ) -> "WorkflowDraftVariable": variable = WorkflowDraftVariable() - variable.created_at = _naive_utc_datetime() - variable.updated_at = _naive_utc_datetime() + variable.created_at = naive_utc_now() + variable.updated_at = naive_utc_now() variable.description = description variable.app_id = app_id variable.node_id = node_id @@ -1518,7 +1541,7 @@ class WorkflowDraftVariableFile(Base): created_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, - default=_naive_utc_datetime, + default=naive_utc_now, server_default=func.current_timestamp(), ) @@ -1583,3 +1606,68 @@ class WorkflowDraftVariableFile(Base): def is_system_variable_editable(name: str) -> bool: return name in _EDITABLE_SYSTEM_VARIABLE + + +class WorkflowPause(DefaultFieldsMixin, Base): + """ + WorkflowPause records the paused state and related metadata for a specific workflow run. + + Each `WorkflowRun` can have zero or one associated `WorkflowPause`, depending on its execution status. + If a `WorkflowRun` is in the `PAUSED` state, there must be a corresponding `WorkflowPause` + that has not yet been resumed. + Otherwise, there should be no active (non-resumed) `WorkflowPause` linked to that run. + + This model captures the execution context required to resume workflow processing at a later time. + """ + + __tablename__ = "workflow_pauses" + __table_args__ = ( + # Design Note: + # Instead of adding a `pause_id` field to the `WorkflowRun` model—which would require a migration + # on a potentially large table—we reference `WorkflowRun` from `WorkflowPause` and enforce a unique + # constraint on `workflow_run_id` to guarantee a one-to-one relationship. + UniqueConstraint("workflow_run_id"), + ) + + # `workflow_id` represents the unique identifier of the workflow associated with this pause. + # It corresponds to the `id` field in the `Workflow` model. + # + # Since an application can have multiple versions of a workflow, each with its own unique ID, + # the `app_id` alone is insufficient to determine which workflow version should be loaded + # when resuming a suspended workflow. + workflow_id: Mapped[str] = mapped_column( + StringUUID, + nullable=False, + ) + + # `workflow_run_id` represents the identifier of the execution of workflow, + # correspond to the `id` field of `WorkflowRun`. + workflow_run_id: Mapped[str] = mapped_column( + StringUUID, + nullable=False, + ) + + # `resumed_at` records the timestamp when the suspended workflow was resumed. + # It is set to `NULL` if the workflow has not been resumed. + # + # NOTE: Resuming a suspended WorkflowPause does not delete the record immediately. + # It only set `resumed_at` to a non-null value. + resumed_at: Mapped[datetime | None] = mapped_column( + sa.DateTime, + nullable=True, + ) + + # state_object_key stores the object key referencing the serialized runtime state + # of the `GraphEngine`. This object captures the complete execution context of the + # workflow at the moment it was paused, enabling accurate resumption. + state_object_key: Mapped[str] = mapped_column(String(length=255), nullable=False) + + # Relationship to WorkflowRun + workflow_run: Mapped["WorkflowRun"] = orm.relationship( + foreign_keys=[workflow_run_id], + # require explicit preloading. + lazy="raise", + uselist=False, + primaryjoin="WorkflowPause.workflow_run_id == WorkflowRun.id", + back_populates="pause", + ) diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py index eb6d599224..21fd57cd22 100644 --- a/api/repositories/api_workflow_run_repository.py +++ b/api/repositories/api_workflow_run_repository.py @@ -38,6 +38,7 @@ from collections.abc import Sequence from datetime import datetime from typing import Protocol +from core.workflow.entities.workflow_pause import WorkflowPauseEntity from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import WorkflowRunTriggeredFrom @@ -251,6 +252,116 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): """ ... + def create_workflow_pause( + self, + workflow_run_id: str, + state_owner_user_id: str, + state: str, + ) -> WorkflowPauseEntity: + """ + Create a new workflow pause state. + + Creates a pause state for a workflow run, storing the current execution + state and marking the workflow as paused. This is used when a workflow + needs to be suspended and later resumed. + + Args: + workflow_run_id: Identifier of the workflow run to pause + state_owner_user_id: User ID who owns the pause state for file storage + state: Serialized workflow execution state (JSON string) + + Returns: + WorkflowPauseEntity representing the created pause state + + Raises: + ValueError: If workflow_run_id is invalid or workflow run doesn't exist + RuntimeError: If workflow is already paused or in invalid state + """ + # NOTE: we may get rid of the `state_owner_user_id` in parameter list. + # However, removing it would require an extra for `Workflow` model + # while creating pause. + ... + + def resume_workflow_pause( + self, + workflow_run_id: str, + pause_entity: WorkflowPauseEntity, + ) -> WorkflowPauseEntity: + """ + Resume a paused workflow. + + Marks a paused workflow as resumed, set the `resumed_at` field of WorkflowPauseEntity + and returning the workflow to running status. Returns the pause entity + that was resumed. + + The returned `WorkflowPauseEntity` model has `resumed_at` set. + + NOTE: this method does not delete the correspond `WorkflowPauseEntity` record and associated states. + It's the callers responsibility to clear the correspond state with `delete_workflow_pause`. + + Args: + workflow_run_id: Identifier of the workflow run to resume + pause_entity: The pause entity to resume + + Returns: + WorkflowPauseEntity representing the resumed pause state + + Raises: + ValueError: If workflow_run_id is invalid + RuntimeError: If workflow is not paused or already resumed + """ + ... + + def delete_workflow_pause( + self, + pause_entity: WorkflowPauseEntity, + ) -> None: + """ + Delete a workflow pause state. + + Permanently removes the pause state for a workflow run, including + the stored state file. Used for cleanup operations when a paused + workflow is no longer needed. + + Args: + pause_entity: The pause entity to delete + + Raises: + ValueError: If pause_entity is invalid + RuntimeError: If workflow is not paused + + Note: + This operation is irreversible. The stored workflow state will be + permanently deleted along with the pause record. + """ + ... + + def prune_pauses( + self, + expiration: datetime, + resumption_expiration: datetime, + limit: int | None = None, + ) -> Sequence[str]: + """ + Clean up expired and old pause states. + + Removes pause states that have expired (created before expiration time) + and pause states that were resumed more than resumption_duration ago. + This is used for maintenance and cleanup operations. + + Args: + expiration: Remove pause states created before this time + resumption_expiration: Remove pause states resumed before this time + limit: maximum number of records deleted in one call + + Returns: + a list of ids for pause records that were pruned + + Raises: + ValueError: If parameters are invalid + """ + ... + def get_daily_runs_statistics( self, tenant_id: str, diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index f08eab0b01..0d52c56138 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -20,19 +20,26 @@ Implementation Notes: """ import logging +import uuid from collections.abc import Sequence from datetime import datetime from decimal import Decimal from typing import Any, cast import sqlalchemy as sa -from sqlalchemy import delete, func, select +from sqlalchemy import and_, delete, func, null, or_, select from sqlalchemy.engine import CursorResult -from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.orm import Session, selectinload, sessionmaker +from core.workflow.entities.workflow_pause import WorkflowPauseEntity +from core.workflow.enums import WorkflowExecutionStatus +from extensions.ext_storage import storage +from libs.datetime_utils import naive_utc_now from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.time_parser import get_time_threshold +from libs.uuid_utils import uuidv7 from models.enums import WorkflowRunTriggeredFrom +from models.workflow import WorkflowPause as WorkflowPauseModel from models.workflow import WorkflowRun from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.types import ( @@ -45,6 +52,10 @@ from repositories.types import ( logger = logging.getLogger(__name__) +class _WorkflowRunError(Exception): + pass + + class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): """ SQLAlchemy implementation of APIWorkflowRunRepository. @@ -301,6 +312,281 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): logger.info("Total deleted %s workflow runs for app %s", total_deleted, app_id) return total_deleted + def create_workflow_pause( + self, + workflow_run_id: str, + state_owner_user_id: str, + state: str, + ) -> WorkflowPauseEntity: + """ + Create a new workflow pause state. + + Creates a pause state for a workflow run, storing the current execution + state and marking the workflow as paused. This is used when a workflow + needs to be suspended and later resumed. + + Args: + workflow_run_id: Identifier of the workflow run to pause + state_owner_user_id: User ID who owns the pause state for file storage + state: Serialized workflow execution state (JSON string) + + Returns: + RepositoryWorkflowPauseEntity representing the created pause state + + Raises: + ValueError: If workflow_run_id is invalid or workflow run doesn't exist + RuntimeError: If workflow is already paused or in invalid state + """ + previous_pause_model_query = select(WorkflowPauseModel).where( + WorkflowPauseModel.workflow_run_id == workflow_run_id + ) + with self._session_maker() as session, session.begin(): + # Get the workflow run + workflow_run = session.get(WorkflowRun, workflow_run_id) + if workflow_run is None: + raise ValueError(f"WorkflowRun not found: {workflow_run_id}") + + # Check if workflow is in RUNNING status + if workflow_run.status != WorkflowExecutionStatus.RUNNING: + raise _WorkflowRunError( + f"Only WorkflowRun with RUNNING status can be paused, " + f"workflow_run_id={workflow_run_id}, current_status={workflow_run.status}" + ) + # + previous_pause = session.scalars(previous_pause_model_query).first() + if previous_pause: + self._delete_pause_model(session, previous_pause) + # we need to flush here to ensure that the old one is actually deleted. + session.flush() + + state_obj_key = f"workflow-state-{uuid.uuid4()}.json" + storage.save(state_obj_key, state.encode()) + # Upload the state file + + # Create the pause record + pause_model = WorkflowPauseModel() + pause_model.id = str(uuidv7()) + pause_model.workflow_id = workflow_run.workflow_id + pause_model.workflow_run_id = workflow_run.id + pause_model.state_object_key = state_obj_key + pause_model.created_at = naive_utc_now() + + # Update workflow run status + workflow_run.status = WorkflowExecutionStatus.PAUSED + + # Save everything in a transaction + session.add(pause_model) + session.add(workflow_run) + + logger.info("Created workflow pause %s for workflow run %s", pause_model.id, workflow_run_id) + + return _PrivateWorkflowPauseEntity.from_models(pause_model) + + def get_workflow_pause( + self, + workflow_run_id: str, + ) -> WorkflowPauseEntity | None: + """ + Get an existing workflow pause state. + + Retrieves the pause state for a specific workflow run if it exists. + Used to check if a workflow is paused and to retrieve its saved state. + + Args: + workflow_run_id: Identifier of the workflow run to get pause state for + + Returns: + RepositoryWorkflowPauseEntity if pause state exists, None otherwise + + Raises: + ValueError: If workflow_run_id is invalid + """ + with self._session_maker() as session: + # Query workflow run with pause and state file + stmt = select(WorkflowRun).options(selectinload(WorkflowRun.pause)).where(WorkflowRun.id == workflow_run_id) + workflow_run = session.scalar(stmt) + + if workflow_run is None: + raise ValueError(f"WorkflowRun not found: {workflow_run_id}") + + pause_model = workflow_run.pause + if pause_model is None: + return None + + return _PrivateWorkflowPauseEntity.from_models(pause_model) + + def resume_workflow_pause( + self, + workflow_run_id: str, + pause_entity: WorkflowPauseEntity, + ) -> WorkflowPauseEntity: + """ + Resume a paused workflow. + + Marks a paused workflow as resumed, clearing the pause state and + returning the workflow to running status. Returns the pause entity + that was resumed. + + Args: + workflow_run_id: Identifier of the workflow run to resume + pause_entity: The pause entity to resume + + Returns: + RepositoryWorkflowPauseEntity representing the resumed pause state + + Raises: + ValueError: If workflow_run_id is invalid + RuntimeError: If workflow is not paused or already resumed + """ + with self._session_maker() as session, session.begin(): + # Get the workflow run with pause + stmt = select(WorkflowRun).options(selectinload(WorkflowRun.pause)).where(WorkflowRun.id == workflow_run_id) + workflow_run = session.scalar(stmt) + + if workflow_run is None: + raise ValueError(f"WorkflowRun not found: {workflow_run_id}") + + if workflow_run.status != WorkflowExecutionStatus.PAUSED: + raise _WorkflowRunError( + f"WorkflowRun is not in PAUSED status, workflow_run_id={workflow_run_id}, " + f"current_status={workflow_run.status}" + ) + pause_model = workflow_run.pause + if pause_model is None: + raise _WorkflowRunError(f"No pause state found for workflow run: {workflow_run_id}") + + if pause_model.id != pause_entity.id: + raise _WorkflowRunError( + "different id in WorkflowPause and WorkflowPauseEntity, " + f"WorkflowPause.id={pause_model.id}, " + f"WorkflowPauseEntity.id={pause_entity.id}" + ) + + if pause_model.resumed_at is not None: + raise _WorkflowRunError(f"Cannot resume an already resumed pause, pause_id={pause_model.id}") + + # Mark as resumed + pause_model.resumed_at = naive_utc_now() + workflow_run.pause_id = None # type: ignore + workflow_run.status = WorkflowExecutionStatus.RUNNING + + session.add(pause_model) + session.add(workflow_run) + + logger.info("Resumed workflow pause %s for workflow run %s", pause_model.id, workflow_run_id) + + return _PrivateWorkflowPauseEntity.from_models(pause_model) + + def delete_workflow_pause( + self, + pause_entity: WorkflowPauseEntity, + ) -> None: + """ + Delete a workflow pause state. + + Permanently removes the pause state for a workflow run, including + the stored state file. Used for cleanup operations when a paused + workflow is no longer needed. + + Args: + pause_entity: The pause entity to delete + + Raises: + ValueError: If pause_entity is invalid + _WorkflowRunError: If workflow is not paused + + Note: + This operation is irreversible. The stored workflow state will be + permanently deleted along with the pause record. + """ + with self._session_maker() as session, session.begin(): + # Get the pause model by ID + pause_model = session.get(WorkflowPauseModel, pause_entity.id) + if pause_model is None: + raise _WorkflowRunError(f"WorkflowPause not found: {pause_entity.id}") + self._delete_pause_model(session, pause_model) + + @staticmethod + def _delete_pause_model(session: Session, pause_model: WorkflowPauseModel): + storage.delete(pause_model.state_object_key) + + # Delete the pause record + session.delete(pause_model) + + logger.info("Deleted workflow pause %s for workflow run %s", pause_model.id, pause_model.workflow_run_id) + + def prune_pauses( + self, + expiration: datetime, + resumption_expiration: datetime, + limit: int | None = None, + ) -> Sequence[str]: + """ + Clean up expired and old pause states. + + Removes pause states that have expired (created before expiration time) + and pause states that were resumed more than resumption_duration ago. + This is used for maintenance and cleanup operations. + + Args: + expiration: Remove pause states created before this time + resumption_expiration: Remove pause states resumed before this time + limit: maximum number of records deleted in one call + + Returns: + a list of ids for pause records that were pruned + + Raises: + ValueError: If parameters are invalid + """ + _limit: int = limit or 1000 + pruned_record_ids: list[str] = [] + cond = or_( + WorkflowPauseModel.created_at < expiration, + and_( + WorkflowPauseModel.resumed_at.is_not(null()), + WorkflowPauseModel.resumed_at < resumption_expiration, + ), + ) + # First, collect pause records to delete with their state files + # Expired pauses (created before expiration time) + stmt = select(WorkflowPauseModel).where(cond).limit(_limit) + + with self._session_maker(expire_on_commit=False) as session: + # Old resumed pauses (resumed more than resumption_duration ago) + + # Get all records to delete + pauses_to_delete = session.scalars(stmt).all() + + # Delete state files from storage + for pause in pauses_to_delete: + with self._session_maker(expire_on_commit=False) as session, session.begin(): + # todo: this issues a separate query for each WorkflowPauseModel record. + # consider batching this lookup. + try: + storage.delete(pause.state_object_key) + logger.info( + "Deleted state object for pause, pause_id=%s, object_key=%s", + pause.id, + pause.state_object_key, + ) + except Exception: + logger.exception( + "Failed to delete state file for pause, pause_id=%s, object_key=%s", + pause.id, + pause.state_object_key, + ) + continue + session.delete(pause) + pruned_record_ids.append(pause.id) + logger.info( + "workflow pause records deleted, id=%s, resumed_at=%s", + pause.id, + pause.resumed_at, + ) + + return pruned_record_ids + def get_daily_runs_statistics( self, tenant_id: str, @@ -510,3 +796,69 @@ GROUP BY ) return cast(list[AverageInteractionStats], response_data) + + +class _PrivateWorkflowPauseEntity(WorkflowPauseEntity): + """ + Private implementation of WorkflowPauseEntity for SQLAlchemy repository. + + This implementation is internal to the repository layer and provides + the concrete implementation of the WorkflowPauseEntity interface. + """ + + def __init__( + self, + *, + pause_model: WorkflowPauseModel, + ) -> None: + self._pause_model = pause_model + self._cached_state: bytes | None = None + + @classmethod + def from_models(cls, workflow_pause_model) -> "_PrivateWorkflowPauseEntity": + """ + Create a _PrivateWorkflowPauseEntity from database models. + + Args: + workflow_pause_model: The WorkflowPause database model + upload_file_model: The UploadFile database model + + Returns: + _PrivateWorkflowPauseEntity: The constructed entity + + Raises: + ValueError: If required model attributes are missing + """ + return cls(pause_model=workflow_pause_model) + + @property + def id(self) -> str: + return self._pause_model.id + + @property + def workflow_execution_id(self) -> str: + return self._pause_model.workflow_run_id + + def get_state(self) -> bytes: + """ + Retrieve the serialized workflow state from storage. + + Returns: + Mapping[str, Any]: The workflow state as a dictionary + + Raises: + FileNotFoundError: If the state file cannot be found + IOError: If there are issues reading the state file + _Workflow: If the state cannot be deserialized properly + """ + if self._cached_state is not None: + return self._cached_state + + # Load the state from storage + state_data = storage.load(self._pause_model.state_object_key) + self._cached_state = state_data + return state_data + + @property + def resumed_at(self) -> datetime | None: + return self._pause_model.resumed_at diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index 5c8719b499..b903d8df5f 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -1,6 +1,7 @@ import threading from collections.abc import Sequence +from sqlalchemy import Engine from sqlalchemy.orm import sessionmaker import contexts @@ -14,17 +15,26 @@ from models import ( WorkflowRun, WorkflowRunTriggeredFrom, ) +from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.factory import DifyAPIRepositoryFactory class WorkflowRunService: - def __init__(self): + _session_factory: sessionmaker + _workflow_run_repo: APIWorkflowRunRepository + + def __init__(self, session_factory: Engine | sessionmaker | None = None): """Initialize WorkflowRunService with repository dependencies.""" - session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + if session_factory is None: + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + elif isinstance(session_factory, Engine): + session_factory = sessionmaker(bind=session_factory, expire_on_commit=False) + + self._session_factory = session_factory self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( - session_maker + self._session_factory ) - self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) + self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_factory) def get_paginate_advanced_chat_workflow_runs( self, app_model: App, args: dict, triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING diff --git a/api/tests/test_containers_integration_tests/core/__init__.py b/api/tests/test_containers_integration_tests/core/__init__.py new file mode 100644 index 0000000000..5860ad0399 --- /dev/null +++ b/api/tests/test_containers_integration_tests/core/__init__.py @@ -0,0 +1 @@ +# Core integration tests package diff --git a/api/tests/test_containers_integration_tests/core/app/__init__.py b/api/tests/test_containers_integration_tests/core/app/__init__.py new file mode 100644 index 0000000000..0822a865b7 --- /dev/null +++ b/api/tests/test_containers_integration_tests/core/app/__init__.py @@ -0,0 +1 @@ +# App integration tests package diff --git a/api/tests/test_containers_integration_tests/core/app/layers/__init__.py b/api/tests/test_containers_integration_tests/core/app/layers/__init__.py new file mode 100644 index 0000000000..90e5229b1a --- /dev/null +++ b/api/tests/test_containers_integration_tests/core/app/layers/__init__.py @@ -0,0 +1 @@ +# Layers integration tests package diff --git a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py new file mode 100644 index 0000000000..133e600ca0 --- /dev/null +++ b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py @@ -0,0 +1,520 @@ +"""Comprehensive TestContainers-based integration tests for PauseStatePersistenceLayer class. + +This test suite covers complete integration scenarios including: +- Real database interactions using containerized PostgreSQL +- Real storage operations using test storage backend +- Complete workflow: event -> state serialization -> database save -> storage save +- Testing with actual WorkflowRunService (not mocked) +- Real Workflow and WorkflowRun instances in database +- Database transactions and rollback behavior +- Actual file upload and retrieval through storage +- Workflow status transitions in database +- Error handling with real database constraints +- Multiple pause events in sequence +- Integration with real ReadOnlyGraphRuntimeState implementations + +These tests use TestContainers to spin up real services for integration testing, +providing more reliable and realistic test scenarios than mocks. +""" + +import json +import uuid +from time import time + +import pytest +from sqlalchemy import Engine, delete, select +from sqlalchemy.orm import Session + +from core.app.layers.pause_state_persist_layer import PauseStatePersistenceLayer +from core.model_runtime.entities.llm_entities import LLMUsage +from core.workflow.entities.pause_reason import SchedulingPause +from core.workflow.enums import WorkflowExecutionStatus +from core.workflow.graph_engine.entities.commands import GraphEngineCommand +from core.workflow.graph_events.graph import GraphRunPausedEvent +from core.workflow.runtime.graph_runtime_state import GraphRuntimeState +from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState +from core.workflow.runtime.read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper +from core.workflow.runtime.variable_pool import SystemVariable, VariablePool +from extensions.ext_storage import storage +from libs.datetime_utils import naive_utc_now +from models import Account +from models import WorkflowPause as WorkflowPauseModel +from models.model import UploadFile +from models.workflow import Workflow, WorkflowRun +from services.file_service import FileService +from services.workflow_run_service import WorkflowRunService + + +class _TestCommandChannelImpl: + """Real implementation of CommandChannel for testing.""" + + def __init__(self): + self._commands: list[GraphEngineCommand] = [] + + def fetch_commands(self) -> list[GraphEngineCommand]: + """Fetch pending commands for this GraphEngine instance.""" + return self._commands.copy() + + def send_command(self, command: GraphEngineCommand) -> None: + """Send a command to be processed by this GraphEngine instance.""" + self._commands.append(command) + + +class TestPauseStatePersistenceLayerTestContainers: + """Comprehensive TestContainers-based integration tests for PauseStatePersistenceLayer class.""" + + @pytest.fixture + def engine(self, db_session_with_containers: Session): + """Get database engine from TestContainers session.""" + bind = db_session_with_containers.get_bind() + assert isinstance(bind, Engine) + return bind + + @pytest.fixture + def file_service(self, engine: Engine): + """Create FileService instance with TestContainers engine.""" + return FileService(engine) + + @pytest.fixture + def workflow_run_service(self, engine: Engine, file_service: FileService): + """Create WorkflowRunService instance with TestContainers engine and FileService.""" + return WorkflowRunService(engine) + + @pytest.fixture(autouse=True) + def setup_test_data(self, db_session_with_containers, file_service, workflow_run_service): + """Set up test data for each test method using TestContainers.""" + # Create test tenant and account + from models.account import Tenant, TenantAccountJoin, TenantAccountRole + + tenant = Tenant( + name="Test Tenant", + status="normal", + ) + db_session_with_containers.add(tenant) + db_session_with_containers.commit() + + account = Account( + email="test@example.com", + name="Test User", + interface_language="en-US", + status="active", + ) + db_session_with_containers.add(account) + db_session_with_containers.commit() + + # Create tenant-account join + tenant_join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + db_session_with_containers.add(tenant_join) + db_session_with_containers.commit() + + # Set test data + self.test_tenant_id = tenant.id + self.test_user_id = account.id + self.test_app_id = str(uuid.uuid4()) + self.test_workflow_id = str(uuid.uuid4()) + self.test_workflow_run_id = str(uuid.uuid4()) + + # Create test workflow + self.test_workflow = Workflow( + id=self.test_workflow_id, + tenant_id=self.test_tenant_id, + app_id=self.test_app_id, + type="workflow", + version="draft", + graph='{"nodes": [], "edges": []}', + features='{"file_upload": {"enabled": false}}', + created_by=self.test_user_id, + created_at=naive_utc_now(), + ) + + # Create test workflow run + self.test_workflow_run = WorkflowRun( + id=self.test_workflow_run_id, + tenant_id=self.test_tenant_id, + app_id=self.test_app_id, + workflow_id=self.test_workflow_id, + type="workflow", + triggered_from="debugging", + version="draft", + status=WorkflowExecutionStatus.RUNNING, + created_by=self.test_user_id, + created_by_role="account", + created_at=naive_utc_now(), + ) + + # Store session and service instances + self.session = db_session_with_containers + self.file_service = file_service + self.workflow_run_service = workflow_run_service + + # Save test data to database + self.session.add(self.test_workflow) + self.session.add(self.test_workflow_run) + self.session.commit() + + yield + + # Cleanup + self._cleanup_test_data() + + def _cleanup_test_data(self): + """Clean up test data after each test method.""" + try: + # Clean up workflow pauses + self.session.execute(delete(WorkflowPauseModel)) + # Clean up upload files + self.session.execute( + delete(UploadFile).where( + UploadFile.tenant_id == self.test_tenant_id, + ) + ) + # Clean up workflow runs + self.session.execute( + delete(WorkflowRun).where( + WorkflowRun.tenant_id == self.test_tenant_id, + WorkflowRun.app_id == self.test_app_id, + ) + ) + # Clean up workflows + self.session.execute( + delete(Workflow).where( + Workflow.tenant_id == self.test_tenant_id, + Workflow.app_id == self.test_app_id, + ) + ) + self.session.commit() + except Exception as e: + self.session.rollback() + raise e + + def _create_graph_runtime_state( + self, + outputs: dict[str, object] | None = None, + total_tokens: int = 0, + node_run_steps: int = 0, + variables: dict[tuple[str, str], object] | None = None, + workflow_run_id: str | None = None, + ) -> ReadOnlyGraphRuntimeState: + """Create a real GraphRuntimeState for testing.""" + start_at = time() + + execution_id = workflow_run_id or getattr(self, "test_workflow_run_id", None) or str(uuid.uuid4()) + + # Create variable pool + variable_pool = VariablePool(system_variables=SystemVariable(workflow_execution_id=execution_id)) + if variables: + for (node_id, var_key), value in variables.items(): + variable_pool.add([node_id, var_key], value) + + # Create LLM usage + llm_usage = LLMUsage.empty_usage() + + # Create graph runtime state + graph_runtime_state = GraphRuntimeState( + variable_pool=variable_pool, + start_at=start_at, + total_tokens=total_tokens, + llm_usage=llm_usage, + outputs=outputs or {}, + node_run_steps=node_run_steps, + ) + + return ReadOnlyGraphRuntimeStateWrapper(graph_runtime_state) + + def _create_pause_state_persistence_layer( + self, + workflow_run: WorkflowRun | None = None, + workflow: Workflow | None = None, + state_owner_user_id: str | None = None, + ) -> PauseStatePersistenceLayer: + """Create PauseStatePersistenceLayer with real dependencies.""" + owner_id = state_owner_user_id + if owner_id is None: + if workflow is not None and workflow.created_by: + owner_id = workflow.created_by + elif workflow_run is not None and workflow_run.created_by: + owner_id = workflow_run.created_by + else: + owner_id = getattr(self, "test_user_id", None) + + assert owner_id is not None + owner_id = str(owner_id) + + return PauseStatePersistenceLayer( + session_factory=self.session.get_bind(), + state_owner_user_id=owner_id, + ) + + def test_complete_pause_flow_with_real_dependencies(self, db_session_with_containers): + """Test complete pause flow: event -> state serialization -> database save -> storage save.""" + # Arrange + layer = self._create_pause_state_persistence_layer() + + # Create real graph runtime state with test data + test_outputs = {"result": "test_output", "step": "intermediate"} + test_variables = { + ("node1", "var1"): "string_value", + ("node2", "var2"): {"complex": "object"}, + } + graph_runtime_state = self._create_graph_runtime_state( + outputs=test_outputs, + total_tokens=100, + node_run_steps=5, + variables=test_variables, + ) + + command_channel = _TestCommandChannelImpl() + layer.initialize(graph_runtime_state, command_channel) + + # Create pause event + event = GraphRunPausedEvent( + reason=SchedulingPause(message="test pause"), + outputs={"intermediate": "result"}, + ) + + # Act + layer.on_event(event) + + # Assert - Verify pause state was saved to database + self.session.refresh(self.test_workflow_run) + workflow_run = self.session.get(WorkflowRun, self.test_workflow_run_id) + assert workflow_run is not None + assert workflow_run.status == WorkflowExecutionStatus.PAUSED + + # Verify pause state exists in database + pause_model = self.session.scalars( + select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id) + ).first() + assert pause_model is not None + assert pause_model.workflow_id == self.test_workflow_id + assert pause_model.workflow_run_id == self.test_workflow_run_id + assert pause_model.state_object_key != "" + assert pause_model.resumed_at is None + + storage_content = storage.load(pause_model.state_object_key).decode() + expected_state = json.loads(graph_runtime_state.dumps()) + actual_state = json.loads(storage_content) + + assert actual_state == expected_state + + def test_state_persistence_and_retrieval(self, db_session_with_containers): + """Test that pause state can be persisted and retrieved correctly.""" + # Arrange + layer = self._create_pause_state_persistence_layer() + + # Create complex test data + complex_outputs = { + "nested": {"key": "value", "number": 42}, + "list": [1, 2, 3, {"nested": "item"}], + "boolean": True, + "null_value": None, + } + complex_variables = { + ("node1", "var1"): "string_value", + ("node2", "var2"): {"complex": "object"}, + ("node3", "var3"): [1, 2, 3], + } + + graph_runtime_state = self._create_graph_runtime_state( + outputs=complex_outputs, + total_tokens=250, + node_run_steps=10, + variables=complex_variables, + ) + + command_channel = _TestCommandChannelImpl() + layer.initialize(graph_runtime_state, command_channel) + + event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause")) + + # Act - Save pause state + layer.on_event(event) + + # Assert - Retrieve and verify + pause_entity = self.workflow_run_service._workflow_run_repo.get_workflow_pause(self.test_workflow_run_id) + assert pause_entity is not None + assert pause_entity.workflow_execution_id == self.test_workflow_run_id + + state_bytes = pause_entity.get_state() + retrieved_state = json.loads(state_bytes.decode()) + expected_state = json.loads(graph_runtime_state.dumps()) + + assert retrieved_state == expected_state + assert retrieved_state["outputs"] == complex_outputs + assert retrieved_state["total_tokens"] == 250 + assert retrieved_state["node_run_steps"] == 10 + + def test_database_transaction_handling(self, db_session_with_containers): + """Test that database transactions are handled correctly.""" + # Arrange + layer = self._create_pause_state_persistence_layer() + graph_runtime_state = self._create_graph_runtime_state( + outputs={"test": "transaction"}, + total_tokens=50, + ) + + command_channel = _TestCommandChannelImpl() + layer.initialize(graph_runtime_state, command_channel) + + event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause")) + + # Act + layer.on_event(event) + + # Assert - Verify data is committed and accessible in new session + with Session(bind=self.session.get_bind(), expire_on_commit=False) as new_session: + workflow_run = new_session.get(WorkflowRun, self.test_workflow_run_id) + assert workflow_run is not None + assert workflow_run.status == WorkflowExecutionStatus.PAUSED + + pause_model = new_session.scalars( + select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id) + ).first() + assert pause_model is not None + assert pause_model.workflow_run_id == self.test_workflow_run_id + assert pause_model.resumed_at is None + assert pause_model.state_object_key != "" + + def test_file_storage_integration(self, db_session_with_containers): + """Test integration with file storage system.""" + # Arrange + layer = self._create_pause_state_persistence_layer() + + # Create large state data to test storage + large_outputs = {"data": "x" * 10000} # 10KB of data + graph_runtime_state = self._create_graph_runtime_state( + outputs=large_outputs, + total_tokens=1000, + ) + + command_channel = _TestCommandChannelImpl() + layer.initialize(graph_runtime_state, command_channel) + + event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause")) + + # Act + layer.on_event(event) + + # Assert - Verify file was uploaded to storage + self.session.refresh(self.test_workflow_run) + pause_model = self.session.scalars( + select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == self.test_workflow_run.id) + ).first() + assert pause_model is not None + assert pause_model.state_object_key != "" + + # Verify content in storage + storage_content = storage.load(pause_model.state_object_key).decode() + assert storage_content == graph_runtime_state.dumps() + + def test_workflow_with_different_creators(self, db_session_with_containers): + """Test pause state with workflows created by different users.""" + # Arrange - Create workflow with different creator + different_user_id = str(uuid.uuid4()) + different_workflow = Workflow( + id=str(uuid.uuid4()), + tenant_id=self.test_tenant_id, + app_id=self.test_app_id, + type="workflow", + version="draft", + graph='{"nodes": [], "edges": []}', + features='{"file_upload": {"enabled": false}}', + created_by=different_user_id, + created_at=naive_utc_now(), + ) + + different_workflow_run = WorkflowRun( + id=str(uuid.uuid4()), + tenant_id=self.test_tenant_id, + app_id=self.test_app_id, + workflow_id=different_workflow.id, + type="workflow", + triggered_from="debugging", + version="draft", + status=WorkflowExecutionStatus.RUNNING, + created_by=self.test_user_id, # Run created by different user + created_by_role="account", + created_at=naive_utc_now(), + ) + + self.session.add(different_workflow) + self.session.add(different_workflow_run) + self.session.commit() + + layer = self._create_pause_state_persistence_layer( + workflow_run=different_workflow_run, + workflow=different_workflow, + ) + + graph_runtime_state = self._create_graph_runtime_state( + outputs={"creator_test": "different_creator"}, + workflow_run_id=different_workflow_run.id, + ) + + command_channel = _TestCommandChannelImpl() + layer.initialize(graph_runtime_state, command_channel) + + event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause")) + + # Act + layer.on_event(event) + + # Assert - Should use workflow creator (not run creator) + self.session.refresh(different_workflow_run) + pause_model = self.session.scalars( + select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == different_workflow_run.id) + ).first() + assert pause_model is not None + + # Verify the state owner is the workflow creator + pause_entity = self.workflow_run_service._workflow_run_repo.get_workflow_pause(different_workflow_run.id) + assert pause_entity is not None + + def test_layer_ignores_non_pause_events(self, db_session_with_containers): + """Test that layer ignores non-pause events.""" + # Arrange + layer = self._create_pause_state_persistence_layer() + graph_runtime_state = self._create_graph_runtime_state() + + command_channel = _TestCommandChannelImpl() + layer.initialize(graph_runtime_state, command_channel) + + # Import other event types + from core.workflow.graph_events.graph import ( + GraphRunFailedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + ) + + # Act - Send non-pause events + layer.on_event(GraphRunStartedEvent()) + layer.on_event(GraphRunSucceededEvent(outputs={"result": "success"})) + layer.on_event(GraphRunFailedEvent(error="test error", exceptions_count=1)) + + # Assert - No pause state should be created + self.session.refresh(self.test_workflow_run) + assert self.test_workflow_run.status == WorkflowExecutionStatus.RUNNING + + pause_states = ( + self.session.query(WorkflowPauseModel) + .filter(WorkflowPauseModel.workflow_run_id == self.test_workflow_run_id) + .all() + ) + assert len(pause_states) == 0 + + def test_layer_requires_initialization(self, db_session_with_containers): + """Test that layer requires proper initialization before handling events.""" + # Arrange + layer = self._create_pause_state_persistence_layer() + # Don't initialize - graph_runtime_state should not be set + + event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause")) + + # Act & Assert - Should raise AttributeError + with pytest.raises(AttributeError): + layer.on_event(event) diff --git a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py new file mode 100644 index 0000000000..79da5d4d0e --- /dev/null +++ b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py @@ -0,0 +1,948 @@ +"""Comprehensive integration tests for workflow pause functionality. + +This test suite covers complete workflow pause functionality including: +- Real database interactions using containerized PostgreSQL +- Real storage operations using the test storage backend +- Complete workflow: create -> pause -> resume -> delete +- Testing with actual FileService (not mocked) +- Database transactions and rollback behavior +- Actual file upload and retrieval through storage +- Workflow status transitions in the database +- Error handling with real database constraints +- Concurrent access scenarios +- Multi-tenant isolation +- Prune functionality +- File storage integration + +These tests use TestContainers to spin up real services for integration testing, +providing more reliable and realistic test scenarios than mocks. +""" + +import json +import uuid +from dataclasses import dataclass +from datetime import timedelta + +import pytest +from sqlalchemy import delete, select +from sqlalchemy.orm import Session, selectinload, sessionmaker + +from core.workflow.entities import WorkflowExecution +from core.workflow.enums import WorkflowExecutionStatus +from extensions.ext_storage import storage +from libs.datetime_utils import naive_utc_now +from models import Account +from models import WorkflowPause as WorkflowPauseModel +from models.account import Tenant, TenantAccountJoin, TenantAccountRole +from models.model import UploadFile +from models.workflow import Workflow, WorkflowRun +from repositories.sqlalchemy_api_workflow_run_repository import ( + DifyAPISQLAlchemyWorkflowRunRepository, + _WorkflowRunError, +) + + +@dataclass +class PauseWorkflowSuccessCase: + """Test case for successful pause workflow operations.""" + + name: str + initial_status: WorkflowExecutionStatus + description: str = "" + + +@dataclass +class PauseWorkflowFailureCase: + """Test case for pause workflow failure scenarios.""" + + name: str + initial_status: WorkflowExecutionStatus + description: str = "" + + +@dataclass +class ResumeWorkflowSuccessCase: + """Test case for successful resume workflow operations.""" + + name: str + initial_status: WorkflowExecutionStatus + description: str = "" + + +@dataclass +class ResumeWorkflowFailureCase: + """Test case for resume workflow failure scenarios.""" + + name: str + initial_status: WorkflowExecutionStatus + pause_resumed: bool + set_running_status: bool = False + description: str = "" + + +@dataclass +class PrunePausesTestCase: + """Test case for prune pauses operations.""" + + name: str + pause_age: timedelta + resume_age: timedelta | None + expected_pruned_count: int + description: str = "" + + +def pause_workflow_failure_cases() -> list[PauseWorkflowFailureCase]: + """Create test cases for pause workflow failure scenarios.""" + return [ + PauseWorkflowFailureCase( + name="pause_already_paused_workflow", + initial_status=WorkflowExecutionStatus.PAUSED, + description="Should fail to pause an already paused workflow", + ), + PauseWorkflowFailureCase( + name="pause_completed_workflow", + initial_status=WorkflowExecutionStatus.SUCCEEDED, + description="Should fail to pause a completed workflow", + ), + PauseWorkflowFailureCase( + name="pause_failed_workflow", + initial_status=WorkflowExecutionStatus.FAILED, + description="Should fail to pause a failed workflow", + ), + ] + + +def resume_workflow_success_cases() -> list[ResumeWorkflowSuccessCase]: + """Create test cases for successful resume workflow operations.""" + return [ + ResumeWorkflowSuccessCase( + name="resume_paused_workflow", + initial_status=WorkflowExecutionStatus.PAUSED, + description="Should successfully resume a paused workflow", + ), + ] + + +def resume_workflow_failure_cases() -> list[ResumeWorkflowFailureCase]: + """Create test cases for resume workflow failure scenarios.""" + return [ + ResumeWorkflowFailureCase( + name="resume_already_resumed_workflow", + initial_status=WorkflowExecutionStatus.PAUSED, + pause_resumed=True, + description="Should fail to resume an already resumed workflow", + ), + ResumeWorkflowFailureCase( + name="resume_running_workflow", + initial_status=WorkflowExecutionStatus.RUNNING, + pause_resumed=False, + set_running_status=True, + description="Should fail to resume a running workflow", + ), + ] + + +def prune_pauses_test_cases() -> list[PrunePausesTestCase]: + """Create test cases for prune pauses operations.""" + return [ + PrunePausesTestCase( + name="prune_old_active_pauses", + pause_age=timedelta(days=7), + resume_age=None, + expected_pruned_count=1, + description="Should prune old active pauses", + ), + PrunePausesTestCase( + name="prune_old_resumed_pauses", + pause_age=timedelta(hours=12), # Created 12 hours ago (recent) + resume_age=timedelta(days=7), + expected_pruned_count=1, + description="Should prune old resumed pauses", + ), + PrunePausesTestCase( + name="keep_recent_active_pauses", + pause_age=timedelta(hours=1), + resume_age=None, + expected_pruned_count=0, + description="Should keep recent active pauses", + ), + PrunePausesTestCase( + name="keep_recent_resumed_pauses", + pause_age=timedelta(days=1), + resume_age=timedelta(hours=1), + expected_pruned_count=0, + description="Should keep recent resumed pauses", + ), + ] + + +class TestWorkflowPauseIntegration: + """Comprehensive integration tests for workflow pause functionality.""" + + @pytest.fixture(autouse=True) + def setup_test_data(self, db_session_with_containers): + """Set up test data for each test method using TestContainers.""" + # Create test tenant and account + + tenant = Tenant( + name="Test Tenant", + status="normal", + ) + db_session_with_containers.add(tenant) + db_session_with_containers.commit() + + account = Account( + email="test@example.com", + name="Test User", + interface_language="en-US", + status="active", + ) + db_session_with_containers.add(account) + db_session_with_containers.commit() + + # Create tenant-account join + tenant_join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + db_session_with_containers.add(tenant_join) + db_session_with_containers.commit() + + # Set test data + self.test_tenant_id = tenant.id + self.test_user_id = account.id + self.test_app_id = str(uuid.uuid4()) + self.test_workflow_id = str(uuid.uuid4()) + + # Create test workflow + self.test_workflow = Workflow( + id=self.test_workflow_id, + tenant_id=self.test_tenant_id, + app_id=self.test_app_id, + type="workflow", + version="draft", + graph='{"nodes": [], "edges": []}', + features='{"file_upload": {"enabled": false}}', + created_by=self.test_user_id, + created_at=naive_utc_now(), + ) + + # Store session instance + self.session = db_session_with_containers + + # Save test data to database + self.session.add(self.test_workflow) + self.session.commit() + + yield + + # Cleanup + self._cleanup_test_data() + + def _cleanup_test_data(self): + """Clean up test data after each test method.""" + # Clean up workflow pauses + self.session.execute(delete(WorkflowPauseModel)) + # Clean up upload files + self.session.execute( + delete(UploadFile).where( + UploadFile.tenant_id == self.test_tenant_id, + ) + ) + # Clean up workflow runs + self.session.execute( + delete(WorkflowRun).where( + WorkflowRun.tenant_id == self.test_tenant_id, + WorkflowRun.app_id == self.test_app_id, + ) + ) + # Clean up workflows + self.session.execute( + delete(Workflow).where( + Workflow.tenant_id == self.test_tenant_id, + Workflow.app_id == self.test_app_id, + ) + ) + self.session.commit() + + def _create_test_workflow_run( + self, status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING + ) -> WorkflowRun: + """Create a test workflow run with specified status.""" + workflow_run = WorkflowRun( + id=str(uuid.uuid4()), + tenant_id=self.test_tenant_id, + app_id=self.test_app_id, + workflow_id=self.test_workflow_id, + type="workflow", + triggered_from="debugging", + version="draft", + status=status, + created_by=self.test_user_id, + created_by_role="account", + created_at=naive_utc_now(), + ) + self.session.add(workflow_run) + self.session.commit() + return workflow_run + + def _create_test_state(self) -> str: + """Create a test state string.""" + return json.dumps( + { + "node_id": "test-node", + "node_type": "llm", + "status": "paused", + "data": {"key": "value"}, + "timestamp": naive_utc_now().isoformat(), + } + ) + + def _get_workflow_run_repository(self): + """Get workflow run repository instance for testing.""" + # Create session factory from the test session + engine = self.session.get_bind() + session_factory = sessionmaker(bind=engine, expire_on_commit=False) + + # Create a test-specific repository that implements the missing save method + class TestWorkflowRunRepository(DifyAPISQLAlchemyWorkflowRunRepository): + """Test-specific repository that implements the missing save method.""" + + def save(self, execution: WorkflowExecution): + """Implement the missing save method for testing.""" + # For testing purposes, we don't need to implement this method + # as it's not used in the pause functionality tests + pass + + # Create and return repository instance + repository = TestWorkflowRunRepository(session_maker=session_factory) + return repository + + # ==================== Complete Pause Workflow Tests ==================== + + def test_complete_pause_resume_workflow(self): + """Test complete workflow: create -> pause -> resume -> delete.""" + # Arrange + workflow_run = self._create_test_workflow_run() + test_state = self._create_test_state() + repository = self._get_workflow_run_repository() + + # Act - Create pause state + pause_entity = repository.create_workflow_pause( + workflow_run_id=workflow_run.id, + state_owner_user_id=self.test_user_id, + state=test_state, + ) + + # Assert - Pause state created + assert pause_entity is not None + assert pause_entity.id is not None + assert pause_entity.workflow_execution_id == workflow_run.id + # Convert both to strings for comparison + retrieved_state = pause_entity.get_state() + if isinstance(retrieved_state, bytes): + retrieved_state = retrieved_state.decode() + assert retrieved_state == test_state + + # Verify database state + query = select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id) + pause_model = self.session.scalars(query).first() + assert pause_model is not None + assert pause_model.resumed_at is None + assert pause_model.id == pause_entity.id + + self.session.refresh(workflow_run) + assert workflow_run.status == WorkflowExecutionStatus.PAUSED + + # Act - Get pause state + retrieved_entity = repository.get_workflow_pause(workflow_run.id) + + # Assert - Pause state retrieved + assert retrieved_entity is not None + assert retrieved_entity.id == pause_entity.id + retrieved_state = retrieved_entity.get_state() + if isinstance(retrieved_state, bytes): + retrieved_state = retrieved_state.decode() + assert retrieved_state == test_state + + # Act - Resume workflow + resumed_entity = repository.resume_workflow_pause( + workflow_run_id=workflow_run.id, + pause_entity=pause_entity, + ) + + # Assert - Workflow resumed + assert resumed_entity is not None + assert resumed_entity.id == pause_entity.id + assert resumed_entity.resumed_at is not None + + # Verify database state + self.session.refresh(workflow_run) + assert workflow_run.status == WorkflowExecutionStatus.RUNNING + self.session.refresh(pause_model) + assert pause_model.resumed_at is not None + + # Act - Delete pause state + repository.delete_workflow_pause(pause_entity) + + # Assert - Pause state deleted + with Session(bind=self.session.get_bind()) as session: + deleted_pause = session.get(WorkflowPauseModel, pause_entity.id) + assert deleted_pause is None + + def test_pause_workflow_success(self): + """Test successful pause workflow scenarios.""" + workflow_run = self._create_test_workflow_run(status=WorkflowExecutionStatus.RUNNING) + test_state = self._create_test_state() + repository = self._get_workflow_run_repository() + + pause_entity = repository.create_workflow_pause( + workflow_run_id=workflow_run.id, + state_owner_user_id=self.test_user_id, + state=test_state, + ) + + assert pause_entity is not None + assert pause_entity.workflow_execution_id == workflow_run.id + + retrieved_state = pause_entity.get_state() + if isinstance(retrieved_state, bytes): + retrieved_state = retrieved_state.decode() + assert retrieved_state == test_state + + self.session.refresh(workflow_run) + assert workflow_run.status == WorkflowExecutionStatus.PAUSED + pause_query = select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id) + pause_model = self.session.scalars(pause_query).first() + assert pause_model is not None + assert pause_model.id == pause_entity.id + assert pause_model.resumed_at is None + + @pytest.mark.parametrize("test_case", pause_workflow_failure_cases(), ids=lambda tc: tc.name) + def test_pause_workflow_failure(self, test_case: PauseWorkflowFailureCase): + """Test pause workflow failure scenarios.""" + workflow_run = self._create_test_workflow_run(status=test_case.initial_status) + test_state = self._create_test_state() + repository = self._get_workflow_run_repository() + + with pytest.raises(_WorkflowRunError): + repository.create_workflow_pause( + workflow_run_id=workflow_run.id, + state_owner_user_id=self.test_user_id, + state=test_state, + ) + + @pytest.mark.parametrize("test_case", resume_workflow_success_cases(), ids=lambda tc: tc.name) + def test_resume_workflow_success(self, test_case: ResumeWorkflowSuccessCase): + """Test successful resume workflow scenarios.""" + workflow_run = self._create_test_workflow_run(status=test_case.initial_status) + test_state = self._create_test_state() + repository = self._get_workflow_run_repository() + + if workflow_run.status != WorkflowExecutionStatus.RUNNING: + workflow_run.status = WorkflowExecutionStatus.RUNNING + self.session.commit() + + pause_entity = repository.create_workflow_pause( + workflow_run_id=workflow_run.id, + state_owner_user_id=self.test_user_id, + state=test_state, + ) + + self.session.refresh(workflow_run) + assert workflow_run.status == WorkflowExecutionStatus.PAUSED + + resumed_entity = repository.resume_workflow_pause( + workflow_run_id=workflow_run.id, + pause_entity=pause_entity, + ) + assert resumed_entity is not None + assert resumed_entity.id == pause_entity.id + assert resumed_entity.resumed_at is not None + + self.session.refresh(workflow_run) + assert workflow_run.status == WorkflowExecutionStatus.RUNNING + pause_query = select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id) + pause_model = self.session.scalars(pause_query).first() + assert pause_model is not None + assert pause_model.id == pause_entity.id + assert pause_model.resumed_at is not None + + def test_resume_running_workflow(self): + """Test resume workflow failure scenarios.""" + workflow_run = self._create_test_workflow_run(status=WorkflowExecutionStatus.RUNNING) + test_state = self._create_test_state() + repository = self._get_workflow_run_repository() + + pause_entity = repository.create_workflow_pause( + workflow_run_id=workflow_run.id, + state_owner_user_id=self.test_user_id, + state=test_state, + ) + + self.session.refresh(workflow_run) + workflow_run.status = WorkflowExecutionStatus.RUNNING + self.session.add(workflow_run) + self.session.commit() + + with pytest.raises(_WorkflowRunError): + repository.resume_workflow_pause( + workflow_run_id=workflow_run.id, + pause_entity=pause_entity, + ) + + def test_resume_resumed_pause(self): + """Test resume workflow failure scenarios.""" + workflow_run = self._create_test_workflow_run(status=WorkflowExecutionStatus.RUNNING) + test_state = self._create_test_state() + repository = self._get_workflow_run_repository() + + pause_entity = repository.create_workflow_pause( + workflow_run_id=workflow_run.id, + state_owner_user_id=self.test_user_id, + state=test_state, + ) + pause_model = self.session.get(WorkflowPauseModel, pause_entity.id) + pause_model.resumed_at = naive_utc_now() + self.session.add(pause_model) + self.session.commit() + + with pytest.raises(_WorkflowRunError): + repository.resume_workflow_pause( + workflow_run_id=workflow_run.id, + pause_entity=pause_entity, + ) + + # ==================== Error Scenario Tests ==================== + + def test_pause_nonexistent_workflow_run(self): + """Test pausing a non-existent workflow run.""" + # Arrange + nonexistent_id = str(uuid.uuid4()) + test_state = self._create_test_state() + repository = self._get_workflow_run_repository() + + # Act & Assert + with pytest.raises(ValueError, match="WorkflowRun not found"): + repository.create_workflow_pause( + workflow_run_id=nonexistent_id, + state_owner_user_id=self.test_user_id, + state=test_state, + ) + + def test_resume_nonexistent_workflow_run(self): + """Test resuming a non-existent workflow run.""" + # Arrange + workflow_run = self._create_test_workflow_run() + test_state = self._create_test_state() + repository = self._get_workflow_run_repository() + + pause_entity = repository.create_workflow_pause( + workflow_run_id=workflow_run.id, + state_owner_user_id=self.test_user_id, + state=test_state, + ) + + nonexistent_id = str(uuid.uuid4()) + + # Act & Assert + with pytest.raises(ValueError, match="WorkflowRun not found"): + repository.resume_workflow_pause( + workflow_run_id=nonexistent_id, + pause_entity=pause_entity, + ) + + # ==================== Prune Functionality Tests ==================== + + @pytest.mark.parametrize("test_case", prune_pauses_test_cases(), ids=lambda tc: tc.name) + def test_prune_pauses_scenarios(self, test_case: PrunePausesTestCase): + """Test various prune pauses scenarios.""" + now = naive_utc_now() + + # Create pause state + workflow_run = self._create_test_workflow_run() + test_state = self._create_test_state() + repository = self._get_workflow_run_repository() + + pause_entity = repository.create_workflow_pause( + workflow_run_id=workflow_run.id, + state_owner_user_id=self.test_user_id, + state=test_state, + ) + + # Manually adjust timestamps for testing + pause_model = self.session.get(WorkflowPauseModel, pause_entity.id) + pause_model.created_at = now - test_case.pause_age + + if test_case.resume_age is not None: + # Resume pause and adjust resume time + repository.resume_workflow_pause( + workflow_run_id=workflow_run.id, + pause_entity=pause_entity, + ) + # Need to refresh to get the updated model + self.session.refresh(pause_model) + # Manually set the resumed_at to an older time for testing + pause_model.resumed_at = now - test_case.resume_age + self.session.commit() # Commit the resumed_at change + # Refresh again to ensure the change is persisted + self.session.refresh(pause_model) + + self.session.commit() + + # Act - Prune pauses + expiration_time = now - timedelta(days=1, seconds=1) # Expire pauses older than 1 day (plus 1 second) + resumption_time = now - timedelta( + days=7, seconds=1 + ) # Clean up pauses resumed more than 7 days ago (plus 1 second) + + # Debug: Check pause state before pruning + self.session.refresh(pause_model) + print(f"Pause created_at: {pause_model.created_at}") + print(f"Pause resumed_at: {pause_model.resumed_at}") + print(f"Expiration time: {expiration_time}") + print(f"Resumption time: {resumption_time}") + + # Force commit to ensure timestamps are saved + self.session.commit() + + # Determine if the pause should be pruned based on timestamps + should_be_pruned = False + if test_case.resume_age is not None: + # If resumed, check if resumed_at is older than resumption_time + should_be_pruned = pause_model.resumed_at < resumption_time + else: + # If not resumed, check if created_at is older than expiration_time + should_be_pruned = pause_model.created_at < expiration_time + + # Act - Prune pauses + pruned_ids = repository.prune_pauses( + expiration=expiration_time, + resumption_expiration=resumption_time, + ) + + # Assert - Check pruning results + if should_be_pruned: + assert len(pruned_ids) == test_case.expected_pruned_count + # Verify pause was actually deleted + # The pause should be in the pruned_ids list if it was pruned + assert pause_entity.id in pruned_ids + else: + assert len(pruned_ids) == 0 + + def test_prune_pauses_with_limit(self): + """Test prune pauses with limit parameter.""" + now = naive_utc_now() + + # Create multiple pause states + pause_entities = [] + repository = self._get_workflow_run_repository() + + for i in range(5): + workflow_run = self._create_test_workflow_run() + test_state = self._create_test_state() + + pause_entity = repository.create_workflow_pause( + workflow_run_id=workflow_run.id, + state_owner_user_id=self.test_user_id, + state=test_state, + ) + pause_entities.append(pause_entity) + + # Make all pauses old enough to be pruned + pause_model = self.session.get(WorkflowPauseModel, pause_entity.id) + pause_model.created_at = now - timedelta(days=7) + + self.session.commit() + + # Act - Prune with limit + expiration_time = now - timedelta(days=1) + resumption_time = now - timedelta(days=7) + + pruned_ids = repository.prune_pauses( + expiration=expiration_time, + resumption_expiration=resumption_time, + limit=3, + ) + + # Assert + assert len(pruned_ids) == 3 + + # Verify only 3 were deleted + remaining_count = ( + self.session.query(WorkflowPauseModel) + .filter(WorkflowPauseModel.id.in_([pe.id for pe in pause_entities])) + .count() + ) + assert remaining_count == 2 + + # ==================== Multi-tenant Isolation Tests ==================== + + def test_multi_tenant_pause_isolation(self): + """Test that pause states are properly isolated by tenant.""" + # Arrange - Create second tenant + + tenant2 = Tenant( + name="Test Tenant 2", + status="normal", + ) + self.session.add(tenant2) + self.session.commit() + + account2 = Account( + email="test2@example.com", + name="Test User 2", + interface_language="en-US", + status="active", + ) + self.session.add(account2) + self.session.commit() + + tenant2_join = TenantAccountJoin( + tenant_id=tenant2.id, + account_id=account2.id, + role=TenantAccountRole.OWNER, + current=True, + ) + self.session.add(tenant2_join) + self.session.commit() + + # Create workflow for tenant 2 + workflow2 = Workflow( + id=str(uuid.uuid4()), + tenant_id=tenant2.id, + app_id=str(uuid.uuid4()), + type="workflow", + version="draft", + graph='{"nodes": [], "edges": []}', + features='{"file_upload": {"enabled": false}}', + created_by=account2.id, + created_at=naive_utc_now(), + ) + self.session.add(workflow2) + self.session.commit() + + # Create workflow runs for both tenants + workflow_run1 = self._create_test_workflow_run() + workflow_run2 = WorkflowRun( + id=str(uuid.uuid4()), + tenant_id=tenant2.id, + app_id=workflow2.app_id, + workflow_id=workflow2.id, + type="workflow", + triggered_from="debugging", + version="draft", + status=WorkflowExecutionStatus.RUNNING, + created_by=account2.id, + created_by_role="account", + created_at=naive_utc_now(), + ) + self.session.add(workflow_run2) + self.session.commit() + + test_state = self._create_test_state() + repository = self._get_workflow_run_repository() + + # Act - Create pause for tenant 1 + pause_entity1 = repository.create_workflow_pause( + workflow_run_id=workflow_run1.id, + state_owner_user_id=self.test_user_id, + state=test_state, + ) + + # Try to access pause from tenant 2 using tenant 1's repository + # This should work because we're using the same repository + pause_entity2 = repository.get_workflow_pause(workflow_run2.id) + assert pause_entity2 is None # No pause for tenant 2 yet + + # Create pause for tenant 2 + pause_entity2 = repository.create_workflow_pause( + workflow_run_id=workflow_run2.id, + state_owner_user_id=account2.id, + state=test_state, + ) + + # Assert - Both pauses should exist and be separate + assert pause_entity1 is not None + assert pause_entity2 is not None + assert pause_entity1.id != pause_entity2.id + assert pause_entity1.workflow_execution_id != pause_entity2.workflow_execution_id + + def test_cross_tenant_access_restriction(self): + """Test that cross-tenant access is properly restricted.""" + # This test would require tenant-specific repositories + # For now, we test that pause entities are properly scoped by tenant_id + workflow_run = self._create_test_workflow_run() + test_state = self._create_test_state() + repository = self._get_workflow_run_repository() + + pause_entity = repository.create_workflow_pause( + workflow_run_id=workflow_run.id, + state_owner_user_id=self.test_user_id, + state=test_state, + ) + + # Verify pause is properly scoped + pause_model = self.session.get(WorkflowPauseModel, pause_entity.id) + assert pause_model.workflow_id == self.test_workflow_id + + # ==================== File Storage Integration Tests ==================== + + def test_file_storage_integration(self): + """Test that state files are properly stored and retrieved.""" + # Arrange + workflow_run = self._create_test_workflow_run() + test_state = self._create_test_state() + repository = self._get_workflow_run_repository() + + # Act - Create pause state + pause_entity = repository.create_workflow_pause( + workflow_run_id=workflow_run.id, + state_owner_user_id=self.test_user_id, + state=test_state, + ) + + # Assert - Verify file was uploaded to storage + pause_model = self.session.get(WorkflowPauseModel, pause_entity.id) + assert pause_model.state_object_key != "" + + # Verify file content in storage + + file_key = pause_model.state_object_key + storage_content = storage.load(file_key).decode() + assert storage_content == test_state + + # Verify retrieval through entity + retrieved_state = pause_entity.get_state() + if isinstance(retrieved_state, bytes): + retrieved_state = retrieved_state.decode() + assert retrieved_state == test_state + + def test_file_cleanup_on_pause_deletion(self): + """Test that files are properly handled on pause deletion.""" + # Arrange + workflow_run = self._create_test_workflow_run() + test_state = self._create_test_state() + repository = self._get_workflow_run_repository() + + pause_entity = repository.create_workflow_pause( + workflow_run_id=workflow_run.id, + state_owner_user_id=self.test_user_id, + state=test_state, + ) + + # Get file info before deletion + pause_model = self.session.get(WorkflowPauseModel, pause_entity.id) + file_key = pause_model.state_object_key + + # Act - Delete pause state + repository.delete_workflow_pause(pause_entity) + + # Assert - Pause record should be deleted + self.session.expire_all() # Clear session to ensure fresh query + deleted_pause = self.session.get(WorkflowPauseModel, pause_entity.id) + assert deleted_pause is None + + try: + content = storage.load(file_key).decode() + pytest.fail("File should be deleted from storage after pause deletion") + except FileNotFoundError: + # This is expected - file should be deleted from storage + pass + except Exception as e: + pytest.fail(f"Unexpected error when checking file deletion: {e}") + + def test_large_state_file_handling(self): + """Test handling of large state files.""" + # Arrange - Create a large state (1MB) + large_state = "x" * (1024 * 1024) # 1MB of data + large_state_json = json.dumps({"large_data": large_state}) + + workflow_run = self._create_test_workflow_run() + repository = self._get_workflow_run_repository() + + # Act + pause_entity = repository.create_workflow_pause( + workflow_run_id=workflow_run.id, + state_owner_user_id=self.test_user_id, + state=large_state_json, + ) + + # Assert + assert pause_entity is not None + retrieved_state = pause_entity.get_state() + if isinstance(retrieved_state, bytes): + retrieved_state = retrieved_state.decode() + assert retrieved_state == large_state_json + + # Verify file size in database + pause_model = self.session.get(WorkflowPauseModel, pause_entity.id) + assert pause_model.state_object_key != "" + loaded_state = storage.load(pause_model.state_object_key) + assert loaded_state.decode() == large_state_json + + def test_multiple_pause_resume_cycles(self): + """Test multiple pause/resume cycles on the same workflow run.""" + # Arrange + workflow_run = self._create_test_workflow_run() + repository = self._get_workflow_run_repository() + + # Act & Assert - Multiple cycles + for i in range(3): + state = json.dumps({"cycle": i, "data": f"state_{i}"}) + + # Reset workflow run status to RUNNING before each pause (after first cycle) + if i > 0: + self.session.refresh(workflow_run) # Refresh to get latest state from session + workflow_run.status = WorkflowExecutionStatus.RUNNING + self.session.commit() + self.session.refresh(workflow_run) # Refresh again after commit + + # Pause + pause_entity = repository.create_workflow_pause( + workflow_run_id=workflow_run.id, + state_owner_user_id=self.test_user_id, + state=state, + ) + assert pause_entity is not None + + # Verify pause + self.session.expire_all() # Clear session to ensure fresh query + self.session.refresh(workflow_run) + + # Use the test session directly to verify the pause + stmt = select(WorkflowRun).options(selectinload(WorkflowRun.pause)).where(WorkflowRun.id == workflow_run.id) + workflow_run_with_pause = self.session.scalar(stmt) + pause_model = workflow_run_with_pause.pause + + # Verify pause using test session directly + assert pause_model is not None + assert pause_model.id == pause_entity.id + assert pause_model.state_object_key != "" + + # Load file content using storage directly + file_content = storage.load(pause_model.state_object_key) + if isinstance(file_content, bytes): + file_content = file_content.decode() + assert file_content == state + + # Resume + resumed_entity = repository.resume_workflow_pause( + workflow_run_id=workflow_run.id, + pause_entity=pause_entity, + ) + assert resumed_entity is not None + assert resumed_entity.resumed_at is not None + + # Verify resume - check that pause is marked as resumed + self.session.expire_all() # Clear session to ensure fresh query + stmt = select(WorkflowPauseModel).where(WorkflowPauseModel.id == pause_entity.id) + resumed_pause_model = self.session.scalar(stmt) + assert resumed_pause_model is not None + assert resumed_pause_model.resumed_at is not None + + # Verify workflow run status + self.session.refresh(workflow_run) + assert workflow_run.status == WorkflowExecutionStatus.RUNNING diff --git a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py new file mode 100644 index 0000000000..3bd967cbc0 --- /dev/null +++ b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py @@ -0,0 +1,278 @@ +import json +from time import time +from unittest.mock import Mock + +import pytest + +from core.app.layers.pause_state_persist_layer import PauseStatePersistenceLayer +from core.variables.segments import Segment +from core.workflow.entities.pause_reason import SchedulingPause +from core.workflow.graph_engine.entities.commands import GraphEngineCommand +from core.workflow.graph_events.graph import ( + GraphRunFailedEvent, + GraphRunPausedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, +) +from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyVariablePool +from repositories.factory import DifyAPIRepositoryFactory + + +class TestDataFactory: + """Factory helpers for constructing graph events used in tests.""" + + @staticmethod + def create_graph_run_paused_event(outputs: dict[str, object] | None = None) -> GraphRunPausedEvent: + return GraphRunPausedEvent(reason=SchedulingPause(message="test pause"), outputs=outputs or {}) + + @staticmethod + def create_graph_run_started_event() -> GraphRunStartedEvent: + return GraphRunStartedEvent() + + @staticmethod + def create_graph_run_succeeded_event(outputs: dict[str, object] | None = None) -> GraphRunSucceededEvent: + return GraphRunSucceededEvent(outputs=outputs or {}) + + @staticmethod + def create_graph_run_failed_event( + error: str = "Test error", + exceptions_count: int = 1, + ) -> GraphRunFailedEvent: + return GraphRunFailedEvent(error=error, exceptions_count=exceptions_count) + + +class MockSystemVariableReadOnlyView: + """Minimal read-only system variable view for testing.""" + + def __init__(self, workflow_execution_id: str | None = None) -> None: + self._workflow_execution_id = workflow_execution_id + + @property + def workflow_execution_id(self) -> str | None: + return self._workflow_execution_id + + +class MockReadOnlyVariablePool: + """Mock implementation of ReadOnlyVariablePool for testing.""" + + def __init__(self, variables: dict[tuple[str, str], object] | None = None): + self._variables = variables or {} + + def get(self, node_id: str, variable_key: str) -> Segment | None: + value = self._variables.get((node_id, variable_key)) + if value is None: + return None + mock_segment = Mock(spec=Segment) + mock_segment.value = value + return mock_segment + + def get_all_by_node(self, node_id: str) -> dict[str, object]: + return {key: value for (nid, key), value in self._variables.items() if nid == node_id} + + def get_by_prefix(self, prefix: str) -> dict[str, object]: + return {f"{nid}.{key}": value for (nid, key), value in self._variables.items() if nid.startswith(prefix)} + + +class MockReadOnlyGraphRuntimeState: + """Mock implementation of ReadOnlyGraphRuntimeState for testing.""" + + def __init__( + self, + start_at: float | None = None, + total_tokens: int = 0, + node_run_steps: int = 0, + ready_queue_size: int = 0, + exceptions_count: int = 0, + outputs: dict[str, object] | None = None, + variables: dict[tuple[str, str], object] | None = None, + workflow_execution_id: str | None = None, + ): + self._start_at = start_at or time() + self._total_tokens = total_tokens + self._node_run_steps = node_run_steps + self._ready_queue_size = ready_queue_size + self._exceptions_count = exceptions_count + self._outputs = outputs or {} + self._variable_pool = MockReadOnlyVariablePool(variables) + self._system_variable = MockSystemVariableReadOnlyView(workflow_execution_id) + + @property + def system_variable(self) -> MockSystemVariableReadOnlyView: + return self._system_variable + + @property + def variable_pool(self) -> ReadOnlyVariablePool: + return self._variable_pool + + @property + def start_at(self) -> float: + return self._start_at + + @property + def total_tokens(self) -> int: + return self._total_tokens + + @property + def node_run_steps(self) -> int: + return self._node_run_steps + + @property + def ready_queue_size(self) -> int: + return self._ready_queue_size + + @property + def exceptions_count(self) -> int: + return self._exceptions_count + + @property + def outputs(self) -> dict[str, object]: + return self._outputs.copy() + + @property + def llm_usage(self): + mock_usage = Mock() + mock_usage.prompt_tokens = 10 + mock_usage.completion_tokens = 20 + mock_usage.total_tokens = 30 + return mock_usage + + def get_output(self, key: str, default: object = None) -> object: + return self._outputs.get(key, default) + + def dumps(self) -> str: + return json.dumps( + { + "start_at": self._start_at, + "total_tokens": self._total_tokens, + "node_run_steps": self._node_run_steps, + "ready_queue_size": self._ready_queue_size, + "exceptions_count": self._exceptions_count, + "outputs": self._outputs, + "variables": {f"{k[0]}.{k[1]}": v for k, v in self._variable_pool._variables.items()}, + "workflow_execution_id": self._system_variable.workflow_execution_id, + } + ) + + +class MockCommandChannel: + """Mock implementation of CommandChannel for testing.""" + + def __init__(self): + self._commands: list[GraphEngineCommand] = [] + + def fetch_commands(self) -> list[GraphEngineCommand]: + return self._commands.copy() + + def send_command(self, command: GraphEngineCommand) -> None: + self._commands.append(command) + + +class TestPauseStatePersistenceLayer: + """Unit tests for PauseStatePersistenceLayer.""" + + def test_init_with_dependency_injection(self): + session_factory = Mock(name="session_factory") + state_owner_user_id = "user-123" + + layer = PauseStatePersistenceLayer( + session_factory=session_factory, + state_owner_user_id=state_owner_user_id, + ) + + assert layer._session_maker is session_factory + assert layer._state_owner_user_id == state_owner_user_id + assert not hasattr(layer, "graph_runtime_state") + assert not hasattr(layer, "command_channel") + + def test_initialize_sets_dependencies(self): + session_factory = Mock(name="session_factory") + layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner") + + graph_runtime_state = MockReadOnlyGraphRuntimeState() + command_channel = MockCommandChannel() + + layer.initialize(graph_runtime_state, command_channel) + + assert layer.graph_runtime_state is graph_runtime_state + assert layer.command_channel is command_channel + + def test_on_event_with_graph_run_paused_event(self, monkeypatch: pytest.MonkeyPatch): + session_factory = Mock(name="session_factory") + layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123") + + mock_repo = Mock() + mock_factory = Mock(return_value=mock_repo) + monkeypatch.setattr(DifyAPIRepositoryFactory, "create_api_workflow_run_repository", mock_factory) + + graph_runtime_state = MockReadOnlyGraphRuntimeState( + outputs={"result": "test_output"}, + total_tokens=100, + workflow_execution_id="run-123", + ) + command_channel = MockCommandChannel() + layer.initialize(graph_runtime_state, command_channel) + + event = TestDataFactory.create_graph_run_paused_event(outputs={"intermediate": "result"}) + expected_state = graph_runtime_state.dumps() + + layer.on_event(event) + + mock_factory.assert_called_once_with(session_factory) + mock_repo.create_workflow_pause.assert_called_once_with( + workflow_run_id="run-123", + state_owner_user_id="owner-123", + state=expected_state, + ) + + def test_on_event_ignores_non_paused_events(self, monkeypatch: pytest.MonkeyPatch): + session_factory = Mock(name="session_factory") + layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123") + + mock_repo = Mock() + mock_factory = Mock(return_value=mock_repo) + monkeypatch.setattr(DifyAPIRepositoryFactory, "create_api_workflow_run_repository", mock_factory) + + graph_runtime_state = MockReadOnlyGraphRuntimeState() + command_channel = MockCommandChannel() + layer.initialize(graph_runtime_state, command_channel) + + events = [ + TestDataFactory.create_graph_run_started_event(), + TestDataFactory.create_graph_run_succeeded_event(), + TestDataFactory.create_graph_run_failed_event(), + ] + + for event in events: + layer.on_event(event) + + mock_factory.assert_not_called() + mock_repo.create_workflow_pause.assert_not_called() + + def test_on_event_raises_attribute_error_when_graph_runtime_state_is_none(self): + session_factory = Mock(name="session_factory") + layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123") + + event = TestDataFactory.create_graph_run_paused_event() + + with pytest.raises(AttributeError): + layer.on_event(event) + + def test_on_event_asserts_when_workflow_execution_id_missing(self, monkeypatch: pytest.MonkeyPatch): + session_factory = Mock(name="session_factory") + layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123") + + mock_repo = Mock() + mock_factory = Mock(return_value=mock_repo) + monkeypatch.setattr(DifyAPIRepositoryFactory, "create_api_workflow_run_repository", mock_factory) + + graph_runtime_state = MockReadOnlyGraphRuntimeState(workflow_execution_id=None) + command_channel = MockCommandChannel() + layer.initialize(graph_runtime_state, command_channel) + + event = TestDataFactory.create_graph_run_paused_event() + + with pytest.raises(AssertionError): + layer.on_event(event) + + mock_factory.assert_not_called() + mock_repo.create_workflow_pause.assert_not_called() diff --git a/api/tests/unit_tests/core/workflow/entities/test_private_workflow_pause.py b/api/tests/unit_tests/core/workflow/entities/test_private_workflow_pause.py new file mode 100644 index 0000000000..ccb2dff85a --- /dev/null +++ b/api/tests/unit_tests/core/workflow/entities/test_private_workflow_pause.py @@ -0,0 +1,171 @@ +"""Tests for _PrivateWorkflowPauseEntity implementation.""" + +from datetime import datetime +from unittest.mock import MagicMock, patch + +from models.workflow import WorkflowPause as WorkflowPauseModel +from repositories.sqlalchemy_api_workflow_run_repository import _PrivateWorkflowPauseEntity + + +class TestPrivateWorkflowPauseEntity: + """Test _PrivateWorkflowPauseEntity implementation.""" + + def test_entity_initialization(self): + """Test entity initialization with required parameters.""" + # Create mock models + mock_pause_model = MagicMock(spec=WorkflowPauseModel) + mock_pause_model.id = "pause-123" + mock_pause_model.workflow_run_id = "execution-456" + mock_pause_model.resumed_at = None + + # Create entity + entity = _PrivateWorkflowPauseEntity( + pause_model=mock_pause_model, + ) + + # Verify initialization + assert entity._pause_model is mock_pause_model + assert entity._cached_state is None + + def test_from_models_classmethod(self): + """Test from_models class method.""" + # Create mock models + mock_pause_model = MagicMock(spec=WorkflowPauseModel) + mock_pause_model.id = "pause-123" + mock_pause_model.workflow_run_id = "execution-456" + + # Create entity using from_models + entity = _PrivateWorkflowPauseEntity.from_models( + workflow_pause_model=mock_pause_model, + ) + + # Verify entity creation + assert isinstance(entity, _PrivateWorkflowPauseEntity) + assert entity._pause_model is mock_pause_model + + def test_id_property(self): + """Test id property returns pause model ID.""" + mock_pause_model = MagicMock(spec=WorkflowPauseModel) + mock_pause_model.id = "pause-123" + + entity = _PrivateWorkflowPauseEntity( + pause_model=mock_pause_model, + ) + + assert entity.id == "pause-123" + + def test_workflow_execution_id_property(self): + """Test workflow_execution_id property returns workflow run ID.""" + mock_pause_model = MagicMock(spec=WorkflowPauseModel) + mock_pause_model.workflow_run_id = "execution-456" + + entity = _PrivateWorkflowPauseEntity( + pause_model=mock_pause_model, + ) + + assert entity.workflow_execution_id == "execution-456" + + def test_resumed_at_property(self): + """Test resumed_at property returns pause model resumed_at.""" + resumed_at = datetime(2023, 12, 25, 15, 30, 45) + + mock_pause_model = MagicMock(spec=WorkflowPauseModel) + mock_pause_model.resumed_at = resumed_at + + entity = _PrivateWorkflowPauseEntity( + pause_model=mock_pause_model, + ) + + assert entity.resumed_at == resumed_at + + def test_resumed_at_property_none(self): + """Test resumed_at property returns None when not set.""" + mock_pause_model = MagicMock(spec=WorkflowPauseModel) + mock_pause_model.resumed_at = None + + entity = _PrivateWorkflowPauseEntity( + pause_model=mock_pause_model, + ) + + assert entity.resumed_at is None + + @patch("repositories.sqlalchemy_api_workflow_run_repository.storage") + def test_get_state_first_call(self, mock_storage): + """Test get_state loads from storage on first call.""" + state_data = b'{"test": "data", "step": 5}' + mock_storage.load.return_value = state_data + + mock_pause_model = MagicMock(spec=WorkflowPauseModel) + mock_pause_model.state_object_key = "test-state-key" + + entity = _PrivateWorkflowPauseEntity( + pause_model=mock_pause_model, + ) + + # First call should load from storage + result = entity.get_state() + + assert result == state_data + mock_storage.load.assert_called_once_with("test-state-key") + assert entity._cached_state == state_data + + @patch("repositories.sqlalchemy_api_workflow_run_repository.storage") + def test_get_state_cached_call(self, mock_storage): + """Test get_state returns cached data on subsequent calls.""" + state_data = b'{"test": "data", "step": 5}' + mock_storage.load.return_value = state_data + + mock_pause_model = MagicMock(spec=WorkflowPauseModel) + mock_pause_model.state_object_key = "test-state-key" + + entity = _PrivateWorkflowPauseEntity( + pause_model=mock_pause_model, + ) + + # First call + result1 = entity.get_state() + # Second call should use cache + result2 = entity.get_state() + + assert result1 == state_data + assert result2 == state_data + # Storage should only be called once + mock_storage.load.assert_called_once_with("test-state-key") + + @patch("repositories.sqlalchemy_api_workflow_run_repository.storage") + def test_get_state_with_pre_cached_data(self, mock_storage): + """Test get_state returns pre-cached data.""" + state_data = b'{"test": "data", "step": 5}' + + mock_pause_model = MagicMock(spec=WorkflowPauseModel) + + entity = _PrivateWorkflowPauseEntity( + pause_model=mock_pause_model, + ) + + # Pre-cache data + entity._cached_state = state_data + + # Should return cached data without calling storage + result = entity.get_state() + + assert result == state_data + mock_storage.load.assert_not_called() + + def test_entity_with_binary_state_data(self): + """Test entity with binary state data.""" + # Test with binary data that's not valid JSON + binary_data = b"\x00\x01\x02\x03\x04\x05\xff\xfe" + + with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage: + mock_storage.load.return_value = binary_data + + mock_pause_model = MagicMock(spec=WorkflowPauseModel) + + entity = _PrivateWorkflowPauseEntity( + pause_model=mock_pause_model, + ) + + result = entity.get_state() + + assert result == binary_data diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py index d451e7e608..b29baf5a9f 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py @@ -3,6 +3,7 @@ import time from unittest.mock import MagicMock +from core.workflow.entities.pause_reason import SchedulingPause from core.workflow.graph import Graph from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine.command_channels import InMemoryChannel @@ -149,8 +150,8 @@ def test_pause_command(): assert any(isinstance(e, GraphRunStartedEvent) for e in events) pause_events = [e for e in events if isinstance(e, GraphRunPausedEvent)] assert len(pause_events) == 1 - assert pause_events[0].reason == "User requested pause" + assert pause_events[0].reason == SchedulingPause(message="User requested pause") graph_execution = engine.graph_runtime_state.graph_execution assert graph_execution.is_paused - assert graph_execution.pause_reason == "User requested pause" + assert graph_execution.pause_reason == SchedulingPause(message="User requested pause") diff --git a/api/tests/unit_tests/core/workflow/test_enums.py b/api/tests/unit_tests/core/workflow/test_enums.py new file mode 100644 index 0000000000..7cdb2328f2 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/test_enums.py @@ -0,0 +1,32 @@ +"""Tests for workflow pause related enums and constants.""" + +from core.workflow.enums import ( + WorkflowExecutionStatus, +) + + +class TestWorkflowExecutionStatus: + """Test WorkflowExecutionStatus enum.""" + + def test_is_ended_method(self): + """Test is_ended method for different statuses.""" + # Test ended statuses + ended_statuses = [ + WorkflowExecutionStatus.SUCCEEDED, + WorkflowExecutionStatus.FAILED, + WorkflowExecutionStatus.PARTIAL_SUCCEEDED, + WorkflowExecutionStatus.STOPPED, + ] + + for status in ended_statuses: + assert status.is_ended(), f"{status} should be considered ended" + + # Test non-ended statuses + non_ended_statuses = [ + WorkflowExecutionStatus.SCHEDULED, + WorkflowExecutionStatus.RUNNING, + WorkflowExecutionStatus.PAUSED, + ] + + for status in non_ended_statuses: + assert not status.is_ended(), f"{status} should not be considered ended" diff --git a/api/tests/unit_tests/core/workflow/test_system_variable_read_only_view.py b/api/tests/unit_tests/core/workflow/test_system_variable_read_only_view.py new file mode 100644 index 0000000000..57bc96fe71 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/test_system_variable_read_only_view.py @@ -0,0 +1,202 @@ +from typing import cast + +import pytest + +from core.file.models import File, FileTransferMethod, FileType +from core.workflow.system_variable import SystemVariable, SystemVariableReadOnlyView + + +class TestSystemVariableReadOnlyView: + """Test cases for SystemVariableReadOnlyView class.""" + + def test_read_only_property_access(self): + """Test that all properties return correct values from wrapped instance.""" + # Create test data + test_file = File( + id="file-123", + tenant_id="tenant-123", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="related-123", + ) + + datasource_info = {"key": "value", "nested": {"data": 42}} + + # Create SystemVariable with all fields + system_var = SystemVariable( + user_id="user-123", + app_id="app-123", + workflow_id="workflow-123", + files=[test_file], + workflow_execution_id="exec-123", + query="test query", + conversation_id="conv-123", + dialogue_count=5, + document_id="doc-123", + original_document_id="orig-doc-123", + dataset_id="dataset-123", + batch="batch-123", + datasource_type="type-123", + datasource_info=datasource_info, + invoke_from="invoke-123", + ) + + # Create read-only view + read_only_view = SystemVariableReadOnlyView(system_var) + + # Test all properties + assert read_only_view.user_id == "user-123" + assert read_only_view.app_id == "app-123" + assert read_only_view.workflow_id == "workflow-123" + assert read_only_view.workflow_execution_id == "exec-123" + assert read_only_view.query == "test query" + assert read_only_view.conversation_id == "conv-123" + assert read_only_view.dialogue_count == 5 + assert read_only_view.document_id == "doc-123" + assert read_only_view.original_document_id == "orig-doc-123" + assert read_only_view.dataset_id == "dataset-123" + assert read_only_view.batch == "batch-123" + assert read_only_view.datasource_type == "type-123" + assert read_only_view.invoke_from == "invoke-123" + + def test_defensive_copying_of_mutable_objects(self): + """Test that mutable objects are defensively copied.""" + # Create test data + test_file = File( + id="file-123", + tenant_id="tenant-123", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="related-123", + ) + + datasource_info = {"key": "original_value"} + + # Create SystemVariable + system_var = SystemVariable( + files=[test_file], datasource_info=datasource_info, workflow_execution_id="exec-123" + ) + + # Create read-only view + read_only_view = SystemVariableReadOnlyView(system_var) + + # Test files defensive copying + files_copy = read_only_view.files + assert isinstance(files_copy, tuple) # Should be immutable tuple + assert len(files_copy) == 1 + assert files_copy[0].id == "file-123" + + # Verify it's a copy (can't modify original through view) + assert isinstance(files_copy, tuple) + # tuples don't have append method, so they're immutable + + # Test datasource_info defensive copying + datasource_copy = read_only_view.datasource_info + assert datasource_copy is not None + assert datasource_copy["key"] == "original_value" + + datasource_copy = cast(dict, datasource_copy) + with pytest.raises(TypeError): + datasource_copy["key"] = "modified value" + + # Verify original is unchanged + assert system_var.datasource_info is not None + assert system_var.datasource_info["key"] == "original_value" + assert read_only_view.datasource_info is not None + assert read_only_view.datasource_info["key"] == "original_value" + + def test_always_accesses_latest_data(self): + """Test that properties always return the latest data from wrapped instance.""" + # Create SystemVariable + system_var = SystemVariable(user_id="original-user", workflow_execution_id="exec-123") + + # Create read-only view + read_only_view = SystemVariableReadOnlyView(system_var) + + # Verify initial value + assert read_only_view.user_id == "original-user" + + # Modify the wrapped instance + system_var.user_id = "modified-user" + + # Verify view returns the new value + assert read_only_view.user_id == "modified-user" + + def test_repr_method(self): + """Test the __repr__ method.""" + # Create SystemVariable + system_var = SystemVariable(workflow_execution_id="exec-123") + + # Create read-only view + read_only_view = SystemVariableReadOnlyView(system_var) + + # Test repr + repr_str = repr(read_only_view) + assert "SystemVariableReadOnlyView" in repr_str + assert "system_variable=" in repr_str + + def test_none_value_handling(self): + """Test that None values are properly handled.""" + # Create SystemVariable with all None values except workflow_execution_id + system_var = SystemVariable( + user_id=None, + app_id=None, + workflow_id=None, + workflow_execution_id="exec-123", + query=None, + conversation_id=None, + dialogue_count=None, + document_id=None, + original_document_id=None, + dataset_id=None, + batch=None, + datasource_type=None, + datasource_info=None, + invoke_from=None, + ) + + # Create read-only view + read_only_view = SystemVariableReadOnlyView(system_var) + + # Test all None values + assert read_only_view.user_id is None + assert read_only_view.app_id is None + assert read_only_view.workflow_id is None + assert read_only_view.query is None + assert read_only_view.conversation_id is None + assert read_only_view.dialogue_count is None + assert read_only_view.document_id is None + assert read_only_view.original_document_id is None + assert read_only_view.dataset_id is None + assert read_only_view.batch is None + assert read_only_view.datasource_type is None + assert read_only_view.datasource_info is None + assert read_only_view.invoke_from is None + + # files should be empty tuple even when default list is empty + assert read_only_view.files == () + + def test_empty_files_handling(self): + """Test that empty files list is handled correctly.""" + # Create SystemVariable with empty files + system_var = SystemVariable(files=[], workflow_execution_id="exec-123") + + # Create read-only view + read_only_view = SystemVariableReadOnlyView(system_var) + + # Test files handling + assert read_only_view.files == () + assert isinstance(read_only_view.files, tuple) + + def test_empty_datasource_info_handling(self): + """Test that empty datasource_info is handled correctly.""" + # Create SystemVariable with empty datasource_info + system_var = SystemVariable(datasource_info={}, workflow_execution_id="exec-123") + + # Create read-only view + read_only_view = SystemVariableReadOnlyView(system_var) + + # Test datasource_info handling + assert read_only_view.datasource_info == {} + # Should be a copy, not the same object + assert read_only_view.datasource_info is not system_var.datasource_info diff --git a/api/tests/unit_tests/models/test_base.py b/api/tests/unit_tests/models/test_base.py new file mode 100644 index 0000000000..e0dda3c1dd --- /dev/null +++ b/api/tests/unit_tests/models/test_base.py @@ -0,0 +1,11 @@ +from models.base import DefaultFieldsMixin + + +class FooModel(DefaultFieldsMixin): + def __init__(self, id: str): + self.id = id + + +def test_repr(): + foo_model = FooModel(id="test-id") + assert repr(foo_model) == "" diff --git a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py new file mode 100644 index 0000000000..73b35b8e63 --- /dev/null +++ b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py @@ -0,0 +1,370 @@ +"""Unit tests for DifyAPISQLAlchemyWorkflowRunRepository implementation.""" + +from datetime import UTC, datetime +from unittest.mock import Mock, patch + +import pytest +from sqlalchemy.orm import Session, sessionmaker + +from core.workflow.entities.workflow_pause import WorkflowPauseEntity +from core.workflow.enums import WorkflowExecutionStatus +from models.workflow import WorkflowPause as WorkflowPauseModel +from models.workflow import WorkflowRun +from repositories.sqlalchemy_api_workflow_run_repository import ( + DifyAPISQLAlchemyWorkflowRunRepository, + _PrivateWorkflowPauseEntity, + _WorkflowRunError, +) + + +class TestDifyAPISQLAlchemyWorkflowRunRepository: + """Test DifyAPISQLAlchemyWorkflowRunRepository implementation.""" + + @pytest.fixture + def mock_session(self): + """Create a mock session.""" + return Mock(spec=Session) + + @pytest.fixture + def mock_session_maker(self, mock_session): + """Create a mock sessionmaker.""" + session_maker = Mock(spec=sessionmaker) + + # Create a context manager mock + context_manager = Mock() + context_manager.__enter__ = Mock(return_value=mock_session) + context_manager.__exit__ = Mock(return_value=None) + session_maker.return_value = context_manager + + # Mock session.begin() context manager + begin_context_manager = Mock() + begin_context_manager.__enter__ = Mock(return_value=None) + begin_context_manager.__exit__ = Mock(return_value=None) + mock_session.begin = Mock(return_value=begin_context_manager) + + # Add missing session methods + mock_session.commit = Mock() + mock_session.rollback = Mock() + mock_session.add = Mock() + mock_session.delete = Mock() + mock_session.get = Mock() + mock_session.scalar = Mock() + mock_session.scalars = Mock() + + # Also support expire_on_commit parameter + def make_session(expire_on_commit=None): + cm = Mock() + cm.__enter__ = Mock(return_value=mock_session) + cm.__exit__ = Mock(return_value=None) + return cm + + session_maker.side_effect = make_session + return session_maker + + @pytest.fixture + def repository(self, mock_session_maker): + """Create repository instance with mocked dependencies.""" + + # Create a testable subclass that implements the save method + class TestableDifyAPISQLAlchemyWorkflowRunRepository(DifyAPISQLAlchemyWorkflowRunRepository): + def __init__(self, session_maker): + # Initialize without calling parent __init__ to avoid any instantiation issues + self._session_maker = session_maker + + def save(self, execution): + """Mock implementation of save method.""" + return None + + # Create repository instance + repo = TestableDifyAPISQLAlchemyWorkflowRunRepository(mock_session_maker) + + return repo + + @pytest.fixture + def sample_workflow_run(self): + """Create a sample WorkflowRun model.""" + workflow_run = Mock(spec=WorkflowRun) + workflow_run.id = "workflow-run-123" + workflow_run.tenant_id = "tenant-123" + workflow_run.app_id = "app-123" + workflow_run.workflow_id = "workflow-123" + workflow_run.status = WorkflowExecutionStatus.RUNNING + return workflow_run + + @pytest.fixture + def sample_workflow_pause(self): + """Create a sample WorkflowPauseModel.""" + pause = Mock(spec=WorkflowPauseModel) + pause.id = "pause-123" + pause.workflow_id = "workflow-123" + pause.workflow_run_id = "workflow-run-123" + pause.state_object_key = "workflow-state-123.json" + pause.resumed_at = None + pause.created_at = datetime.now(UTC) + return pause + + +class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository): + """Test create_workflow_pause method.""" + + def test_create_workflow_pause_success( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + mock_session: Mock, + sample_workflow_run: Mock, + ): + """Test successful workflow pause creation.""" + # Arrange + workflow_run_id = "workflow-run-123" + state_owner_user_id = "user-123" + state = '{"test": "state"}' + + mock_session.get.return_value = sample_workflow_run + + with patch("repositories.sqlalchemy_api_workflow_run_repository.uuidv7") as mock_uuidv7: + mock_uuidv7.side_effect = ["pause-123"] + with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage: + # Act + result = repository.create_workflow_pause( + workflow_run_id=workflow_run_id, + state_owner_user_id=state_owner_user_id, + state=state, + ) + + # Assert + assert isinstance(result, _PrivateWorkflowPauseEntity) + assert result.id == "pause-123" + assert result.workflow_execution_id == workflow_run_id + + # Verify database interactions + mock_session.get.assert_called_once_with(WorkflowRun, workflow_run_id) + mock_storage.save.assert_called_once() + mock_session.add.assert_called() + # When using session.begin() context manager, commit is handled automatically + # No explicit commit call is expected + + def test_create_workflow_pause_not_found( + self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock + ): + """Test workflow pause creation when workflow run not found.""" + # Arrange + mock_session.get.return_value = None + + # Act & Assert + with pytest.raises(ValueError, match="WorkflowRun not found: workflow-run-123"): + repository.create_workflow_pause( + workflow_run_id="workflow-run-123", + state_owner_user_id="user-123", + state='{"test": "state"}', + ) + + mock_session.get.assert_called_once_with(WorkflowRun, "workflow-run-123") + + def test_create_workflow_pause_invalid_status( + self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock, sample_workflow_run: Mock + ): + """Test workflow pause creation when workflow not in RUNNING status.""" + # Arrange + sample_workflow_run.status = WorkflowExecutionStatus.PAUSED + mock_session.get.return_value = sample_workflow_run + + # Act & Assert + with pytest.raises(_WorkflowRunError, match="Only WorkflowRun with RUNNING status can be paused"): + repository.create_workflow_pause( + workflow_run_id="workflow-run-123", + state_owner_user_id="user-123", + state='{"test": "state"}', + ) + + +class TestResumeWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository): + """Test resume_workflow_pause method.""" + + def test_resume_workflow_pause_success( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + mock_session: Mock, + sample_workflow_run: Mock, + sample_workflow_pause: Mock, + ): + """Test successful workflow pause resume.""" + # Arrange + workflow_run_id = "workflow-run-123" + pause_entity = Mock(spec=WorkflowPauseEntity) + pause_entity.id = "pause-123" + + # Setup workflow run and pause + sample_workflow_run.status = WorkflowExecutionStatus.PAUSED + sample_workflow_run.pause = sample_workflow_pause + sample_workflow_pause.resumed_at = None + + mock_session.scalar.return_value = sample_workflow_run + + with patch("repositories.sqlalchemy_api_workflow_run_repository.naive_utc_now") as mock_now: + mock_now.return_value = datetime.now(UTC) + + # Act + result = repository.resume_workflow_pause( + workflow_run_id=workflow_run_id, + pause_entity=pause_entity, + ) + + # Assert + assert isinstance(result, _PrivateWorkflowPauseEntity) + assert result.id == "pause-123" + + # Verify state transitions + assert sample_workflow_pause.resumed_at is not None + assert sample_workflow_run.status == WorkflowExecutionStatus.RUNNING + + # Verify database interactions + mock_session.add.assert_called() + # When using session.begin() context manager, commit is handled automatically + # No explicit commit call is expected + + def test_resume_workflow_pause_not_paused( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + mock_session: Mock, + sample_workflow_run: Mock, + ): + """Test resume when workflow is not paused.""" + # Arrange + workflow_run_id = "workflow-run-123" + pause_entity = Mock(spec=WorkflowPauseEntity) + pause_entity.id = "pause-123" + + sample_workflow_run.status = WorkflowExecutionStatus.RUNNING + mock_session.scalar.return_value = sample_workflow_run + + # Act & Assert + with pytest.raises(_WorkflowRunError, match="WorkflowRun is not in PAUSED status"): + repository.resume_workflow_pause( + workflow_run_id=workflow_run_id, + pause_entity=pause_entity, + ) + + def test_resume_workflow_pause_id_mismatch( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + mock_session: Mock, + sample_workflow_run: Mock, + sample_workflow_pause: Mock, + ): + """Test resume when pause ID doesn't match.""" + # Arrange + workflow_run_id = "workflow-run-123" + pause_entity = Mock(spec=WorkflowPauseEntity) + pause_entity.id = "pause-456" # Different ID + + sample_workflow_run.status = WorkflowExecutionStatus.PAUSED + sample_workflow_pause.id = "pause-123" + sample_workflow_run.pause = sample_workflow_pause + mock_session.scalar.return_value = sample_workflow_run + + # Act & Assert + with pytest.raises(_WorkflowRunError, match="different id in WorkflowPause and WorkflowPauseEntity"): + repository.resume_workflow_pause( + workflow_run_id=workflow_run_id, + pause_entity=pause_entity, + ) + + +class TestDeleteWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository): + """Test delete_workflow_pause method.""" + + def test_delete_workflow_pause_success( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + mock_session: Mock, + sample_workflow_pause: Mock, + ): + """Test successful workflow pause deletion.""" + # Arrange + pause_entity = Mock(spec=WorkflowPauseEntity) + pause_entity.id = "pause-123" + + mock_session.get.return_value = sample_workflow_pause + + with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage: + # Act + repository.delete_workflow_pause(pause_entity=pause_entity) + + # Assert + mock_storage.delete.assert_called_once_with(sample_workflow_pause.state_object_key) + mock_session.delete.assert_called_once_with(sample_workflow_pause) + # When using session.begin() context manager, commit is handled automatically + # No explicit commit call is expected + + def test_delete_workflow_pause_not_found( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + mock_session: Mock, + ): + """Test delete when pause not found.""" + # Arrange + pause_entity = Mock(spec=WorkflowPauseEntity) + pause_entity.id = "pause-123" + + mock_session.get.return_value = None + + # Act & Assert + with pytest.raises(_WorkflowRunError, match="WorkflowPause not found: pause-123"): + repository.delete_workflow_pause(pause_entity=pause_entity) + + +class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository): + """Test _PrivateWorkflowPauseEntity class.""" + + def test_from_models(self, sample_workflow_pause: Mock): + """Test creating _PrivateWorkflowPauseEntity from models.""" + # Act + entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause) + + # Assert + assert isinstance(entity, _PrivateWorkflowPauseEntity) + assert entity._pause_model == sample_workflow_pause + + def test_properties(self, sample_workflow_pause: Mock): + """Test entity properties.""" + # Arrange + entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause) + + # Act & Assert + assert entity.id == sample_workflow_pause.id + assert entity.workflow_execution_id == sample_workflow_pause.workflow_run_id + assert entity.resumed_at == sample_workflow_pause.resumed_at + + def test_get_state(self, sample_workflow_pause: Mock): + """Test getting state from storage.""" + # Arrange + entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause) + expected_state = b'{"test": "state"}' + + with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage: + mock_storage.load.return_value = expected_state + + # Act + result = entity.get_state() + + # Assert + assert result == expected_state + mock_storage.load.assert_called_once_with(sample_workflow_pause.state_object_key) + + def test_get_state_caching(self, sample_workflow_pause: Mock): + """Test state caching in get_state method.""" + # Arrange + entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause) + expected_state = b'{"test": "state"}' + + with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage: + mock_storage.load.return_value = expected_state + + # Act + result1 = entity.get_state() + result2 = entity.get_state() # Should use cache + + # Assert + assert result1 == expected_state + assert result2 == expected_state + mock_storage.load.assert_called_once() # Only called once due to caching diff --git a/api/tests/unit_tests/services/test_workflow_run_service_pause.py b/api/tests/unit_tests/services/test_workflow_run_service_pause.py new file mode 100644 index 0000000000..a062d9444e --- /dev/null +++ b/api/tests/unit_tests/services/test_workflow_run_service_pause.py @@ -0,0 +1,200 @@ +"""Comprehensive unit tests for WorkflowRunService class. + +This test suite covers all pause state management operations including: +- Retrieving pause state for workflow runs +- Saving pause state with file uploads +- Marking paused workflows as resumed +- Error handling and edge cases +- Database transaction management +- Repository-based approach testing +""" + +from datetime import datetime +from unittest.mock import MagicMock, create_autospec, patch + +import pytest +from sqlalchemy import Engine +from sqlalchemy.orm import Session, sessionmaker + +from core.workflow.enums import WorkflowExecutionStatus +from repositories.api_workflow_run_repository import APIWorkflowRunRepository +from repositories.sqlalchemy_api_workflow_run_repository import _PrivateWorkflowPauseEntity +from services.workflow_run_service import ( + WorkflowRunService, +) + + +class TestDataFactory: + """Factory class for creating test data objects.""" + + @staticmethod + def create_workflow_run_mock( + id: str = "workflow-run-123", + tenant_id: str = "tenant-456", + app_id: str = "app-789", + workflow_id: str = "workflow-101", + status: str | WorkflowExecutionStatus = "paused", + pause_id: str | None = None, + **kwargs, + ) -> MagicMock: + """Create a mock WorkflowRun object.""" + mock_run = MagicMock() + mock_run.id = id + mock_run.tenant_id = tenant_id + mock_run.app_id = app_id + mock_run.workflow_id = workflow_id + mock_run.status = status + mock_run.pause_id = pause_id + + for key, value in kwargs.items(): + setattr(mock_run, key, value) + + return mock_run + + @staticmethod + def create_workflow_pause_mock( + id: str = "pause-123", + tenant_id: str = "tenant-456", + app_id: str = "app-789", + workflow_id: str = "workflow-101", + workflow_execution_id: str = "workflow-execution-123", + state_file_id: str = "file-456", + resumed_at: datetime | None = None, + **kwargs, + ) -> MagicMock: + """Create a mock WorkflowPauseModel object.""" + mock_pause = MagicMock() + mock_pause.id = id + mock_pause.tenant_id = tenant_id + mock_pause.app_id = app_id + mock_pause.workflow_id = workflow_id + mock_pause.workflow_execution_id = workflow_execution_id + mock_pause.state_file_id = state_file_id + mock_pause.resumed_at = resumed_at + + for key, value in kwargs.items(): + setattr(mock_pause, key, value) + + return mock_pause + + @staticmethod + def create_upload_file_mock( + id: str = "file-456", + key: str = "upload_files/test/state.json", + name: str = "state.json", + tenant_id: str = "tenant-456", + **kwargs, + ) -> MagicMock: + """Create a mock UploadFile object.""" + mock_file = MagicMock() + mock_file.id = id + mock_file.key = key + mock_file.name = name + mock_file.tenant_id = tenant_id + + for key, value in kwargs.items(): + setattr(mock_file, key, value) + + return mock_file + + @staticmethod + def create_pause_entity_mock( + pause_model: MagicMock | None = None, + upload_file: MagicMock | None = None, + ) -> _PrivateWorkflowPauseEntity: + """Create a mock _PrivateWorkflowPauseEntity object.""" + if pause_model is None: + pause_model = TestDataFactory.create_workflow_pause_mock() + if upload_file is None: + upload_file = TestDataFactory.create_upload_file_mock() + + return _PrivateWorkflowPauseEntity.from_models(pause_model, upload_file) + + +class TestWorkflowRunService: + """Comprehensive unit tests for WorkflowRunService class.""" + + @pytest.fixture + def mock_session_factory(self): + """Create a mock session factory with proper session management.""" + mock_session = create_autospec(Session) + + # Create a mock context manager for the session + mock_session_cm = MagicMock() + mock_session_cm.__enter__ = MagicMock(return_value=mock_session) + mock_session_cm.__exit__ = MagicMock(return_value=None) + + # Create a mock context manager for the transaction + mock_transaction_cm = MagicMock() + mock_transaction_cm.__enter__ = MagicMock(return_value=mock_session) + mock_transaction_cm.__exit__ = MagicMock(return_value=None) + + mock_session.begin = MagicMock(return_value=mock_transaction_cm) + + # Create mock factory that returns the context manager + mock_factory = MagicMock(spec=sessionmaker) + mock_factory.return_value = mock_session_cm + + return mock_factory, mock_session + + @pytest.fixture + def mock_workflow_run_repository(self): + """Create a mock APIWorkflowRunRepository.""" + mock_repo = create_autospec(APIWorkflowRunRepository) + return mock_repo + + @pytest.fixture + def workflow_run_service(self, mock_session_factory, mock_workflow_run_repository): + """Create WorkflowRunService instance with mocked dependencies.""" + session_factory, _ = mock_session_factory + + with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory: + mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository + service = WorkflowRunService(session_factory) + return service + + @pytest.fixture + def workflow_run_service_with_engine(self, mock_session_factory, mock_workflow_run_repository): + """Create WorkflowRunService instance with Engine input.""" + mock_engine = create_autospec(Engine) + session_factory, _ = mock_session_factory + + with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory: + mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository + service = WorkflowRunService(mock_engine) + return service + + # ==================== Initialization Tests ==================== + + def test_init_with_session_factory(self, mock_session_factory, mock_workflow_run_repository): + """Test WorkflowRunService initialization with session_factory.""" + session_factory, _ = mock_session_factory + + with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory: + mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository + service = WorkflowRunService(session_factory) + + assert service._session_factory == session_factory + mock_factory.create_api_workflow_run_repository.assert_called_once_with(session_factory) + + def test_init_with_engine(self, mock_session_factory, mock_workflow_run_repository): + """Test WorkflowRunService initialization with Engine (should convert to sessionmaker).""" + mock_engine = create_autospec(Engine) + session_factory, _ = mock_session_factory + + with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory: + mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository + with patch("services.workflow_run_service.sessionmaker", return_value=session_factory) as mock_sessionmaker: + service = WorkflowRunService(mock_engine) + + mock_sessionmaker.assert_called_once_with(bind=mock_engine, expire_on_commit=False) + assert service._session_factory == session_factory + mock_factory.create_api_workflow_run_repository.assert_called_once_with(session_factory) + + def test_init_with_default_dependencies(self, mock_session_factory): + """Test WorkflowRunService initialization with default dependencies.""" + session_factory, _ = mock_session_factory + + service = WorkflowRunService(session_factory) + + assert service._session_factory == session_factory