mirror of https://github.com/langgenius/dify.git
feat(api): Introduce workflow pause state management (#27298)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
fd7c4e8a6d
commit
a1c0bd7a1c
|
|
@ -1,6 +1,6 @@
|
|||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
|
@ -25,6 +25,7 @@ from core.moderation.input_moderation import InputModeration
|
|||
from core.variables.variables import VariableUnion
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
|
|
@ -61,11 +62,13 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
app: App,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
):
|
||||
super().__init__(
|
||||
queue_manager=queue_manager,
|
||||
variable_loader=variable_loader,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
graph_engine_layers=graph_engine_layers,
|
||||
)
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self.conversation = conversation
|
||||
|
|
@ -195,6 +198,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
)
|
||||
|
||||
workflow_entry.graph_engine.layer(persistence_layer)
|
||||
for layer in self._graph_engine_layers:
|
||||
workflow_entry.graph_engine.layer(layer)
|
||||
|
||||
generator = workflow_entry.run()
|
||||
|
||||
|
|
|
|||
|
|
@ -135,6 +135,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||
)
|
||||
|
||||
workflow_entry.graph_engine.layer(persistence_layer)
|
||||
for layer in self._graph_engine_layers:
|
||||
workflow_entry.graph_engine.layer(layer)
|
||||
|
||||
generator = workflow_entry.run()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import time
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
|
|
@ -27,6 +27,7 @@ from core.app.entities.queue_entities import (
|
|||
)
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
|
|
@ -69,10 +70,12 @@ class WorkflowBasedAppRunner:
|
|||
queue_manager: AppQueueManager,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
app_id: str,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
):
|
||||
self._queue_manager = queue_manager
|
||||
self._variable_loader = variable_loader
|
||||
self._app_id = app_id
|
||||
self._graph_engine_layers = graph_engine_layers
|
||||
|
||||
def _init_graph(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,71 @@
|
|||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events.base import GraphEngineEvent
|
||||
from core.workflow.graph_events.graph import GraphRunPausedEvent
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
|
||||
class PauseStatePersistenceLayer(GraphEngineLayer):
|
||||
def __init__(self, session_factory: Engine | sessionmaker, state_owner_user_id: str):
|
||||
"""Create a PauseStatePersistenceLayer.
|
||||
|
||||
The `state_owner_user_id` is used when creating state file for pause.
|
||||
It generally should id of the creator of workflow.
|
||||
"""
|
||||
if isinstance(session_factory, Engine):
|
||||
session_factory = sessionmaker(session_factory)
|
||||
self._session_maker = session_factory
|
||||
self._state_owner_user_id = state_owner_user_id
|
||||
|
||||
def _get_repo(self) -> APIWorkflowRunRepository:
|
||||
return DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_maker)
|
||||
|
||||
def on_graph_start(self) -> None:
|
||||
"""
|
||||
Called when graph execution starts.
|
||||
|
||||
This is called after the engine has been initialized but before any nodes
|
||||
are executed. Layers can use this to set up resources or log start information.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Called for every event emitted by the engine.
|
||||
|
||||
This method receives all events generated during graph execution, including:
|
||||
- Graph lifecycle events (start, success, failure)
|
||||
- Node execution events (start, success, failure, retry)
|
||||
- Stream events for response nodes
|
||||
- Container events (iteration, loop)
|
||||
|
||||
Args:
|
||||
event: The event emitted by the engine
|
||||
"""
|
||||
if not isinstance(event, GraphRunPausedEvent):
|
||||
return
|
||||
|
||||
assert self.graph_runtime_state is not None
|
||||
workflow_run_id: str | None = self.graph_runtime_state.system_variable.workflow_execution_id
|
||||
assert workflow_run_id is not None
|
||||
repo = self._get_repo()
|
||||
repo.create_workflow_pause(
|
||||
workflow_run_id=workflow_run_id,
|
||||
state_owner_user_id=self._state_owner_user_id,
|
||||
state=self.graph_runtime_state.dumps(),
|
||||
)
|
||||
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
"""
|
||||
Called when graph execution ends.
|
||||
|
||||
This is called after all nodes have been executed or when execution is
|
||||
aborted. Layers can use this to clean up resources or log final state.
|
||||
|
||||
Args:
|
||||
error: The exception that caused execution to fail, or None if successful
|
||||
"""
|
||||
pass
|
||||
|
|
@ -4,6 +4,7 @@ from .agent import AgentNodeStrategyInit
|
|||
from .graph_init_params import GraphInitParams
|
||||
from .workflow_execution import WorkflowExecution
|
||||
from .workflow_node_execution import WorkflowNodeExecution
|
||||
from .workflow_pause import WorkflowPauseEntity
|
||||
|
||||
__all__ = [
|
||||
"AgentNodeStrategyInit",
|
||||
|
|
@ -12,4 +13,5 @@ __all__ = [
|
|||
"VariablePool",
|
||||
"WorkflowExecution",
|
||||
"WorkflowNodeExecution",
|
||||
"WorkflowPauseEntity",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,49 @@
|
|||
from enum import StrEnum, auto
|
||||
from typing import Annotated, Any, ClassVar, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, Discriminator, Tag
|
||||
|
||||
|
||||
class _PauseReasonType(StrEnum):
|
||||
HUMAN_INPUT_REQUIRED = auto()
|
||||
SCHEDULED_PAUSE = auto()
|
||||
|
||||
|
||||
class _PauseReasonBase(BaseModel):
|
||||
TYPE: ClassVar[_PauseReasonType]
|
||||
|
||||
|
||||
class HumanInputRequired(_PauseReasonBase):
|
||||
TYPE = _PauseReasonType.HUMAN_INPUT_REQUIRED
|
||||
|
||||
|
||||
class SchedulingPause(_PauseReasonBase):
|
||||
TYPE = _PauseReasonType.SCHEDULED_PAUSE
|
||||
|
||||
message: str
|
||||
|
||||
|
||||
def _get_pause_reason_discriminator(v: Any) -> _PauseReasonType | None:
|
||||
if isinstance(v, _PauseReasonBase):
|
||||
return v.TYPE
|
||||
elif isinstance(v, dict):
|
||||
reason_type_str = v.get("TYPE")
|
||||
if reason_type_str is None:
|
||||
return None
|
||||
try:
|
||||
reason_type = _PauseReasonType(reason_type_str)
|
||||
except ValueError:
|
||||
return None
|
||||
return reason_type
|
||||
else:
|
||||
# return None if the discriminator value isn't found
|
||||
return None
|
||||
|
||||
|
||||
PauseReason: TypeAlias = Annotated[
|
||||
(
|
||||
Annotated[HumanInputRequired, Tag(_PauseReasonType.HUMAN_INPUT_REQUIRED)]
|
||||
| Annotated[SchedulingPause, Tag(_PauseReasonType.SCHEDULED_PAUSE)]
|
||||
),
|
||||
Discriminator(_get_pause_reason_discriminator),
|
||||
]
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
"""
|
||||
Domain entities for workflow pause management.
|
||||
|
||||
This module contains the domain model for workflow pause, which is used
|
||||
by the core workflow module. These models are independent of the storage mechanism
|
||||
and don't contain implementation details like tenant_id, app_id, etc.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class WorkflowPauseEntity(ABC):
|
||||
"""
|
||||
Abstract base class for workflow pause entities.
|
||||
|
||||
This domain model represents a paused workflow execution state,
|
||||
without implementation details like tenant_id, app_id, etc.
|
||||
It provides the interface for managing workflow pause/resume operations
|
||||
and state persistence through file storage.
|
||||
|
||||
The `WorkflowPauseEntity` is never reused. If a workflow execution pauses multiple times,
|
||||
it will generate multiple `WorkflowPauseEntity` records.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def id(self) -> str:
|
||||
"""The identifier of current WorkflowPauseEntity"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def workflow_execution_id(self) -> str:
|
||||
"""The identifier of the workflow execution record the pause associated with.
|
||||
Correspond to `WorkflowExecution.id`.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_state(self) -> bytes:
|
||||
"""
|
||||
Retrieve the serialized workflow state from storage.
|
||||
|
||||
This method should load and return the workflow execution state
|
||||
that was saved when the workflow was paused. The state contains
|
||||
all necessary information to resume the workflow execution.
|
||||
|
||||
Returns:
|
||||
bytes: The serialized workflow state containing
|
||||
execution context, variable values, node states, etc.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def resumed_at(self) -> datetime | None:
|
||||
"""`resumed_at` return the resumption time of the current pause, or `None` if
|
||||
the pause is not resumed yet.
|
||||
"""
|
||||
pass
|
||||
|
|
@ -92,13 +92,111 @@ class WorkflowType(StrEnum):
|
|||
|
||||
|
||||
class WorkflowExecutionStatus(StrEnum):
|
||||
# State diagram for the workflw status:
|
||||
# (@) means start, (*) means end
|
||||
#
|
||||
# ┌------------------>------------------------->------------------->--------------┐
|
||||
# | |
|
||||
# | ┌-----------------------<--------------------┐ |
|
||||
# ^ | | |
|
||||
# | | ^ |
|
||||
# | V | |
|
||||
# ┌-----------┐ ┌-----------------------┐ ┌-----------┐ V
|
||||
# | Scheduled |------->| Running |---------------------->| paused | |
|
||||
# └-----------┘ └-----------------------┘ └-----------┘ |
|
||||
# | | | | | | |
|
||||
# | | | | | | |
|
||||
# ^ | | | V V |
|
||||
# | | | | | ┌---------┐ |
|
||||
# (@) | | | └------------------------>| Stopped |<----┘
|
||||
# | | | └---------┘
|
||||
# | | | |
|
||||
# | | V V
|
||||
# | | ┌-----------┐ |
|
||||
# | | | Succeeded |------------->--------------┤
|
||||
# | | └-----------┘ |
|
||||
# | V V
|
||||
# | +--------┐ |
|
||||
# | | Failed |---------------------->----------------┤
|
||||
# | └--------┘ |
|
||||
# V V
|
||||
# ┌---------------------┐ |
|
||||
# | Partially Succeeded |---------------------->-----------------┘--------> (*)
|
||||
# └---------------------┘
|
||||
#
|
||||
# Mermaid diagram:
|
||||
#
|
||||
# ---
|
||||
# title: State diagram for Workflow run state
|
||||
# ---
|
||||
# stateDiagram-v2
|
||||
# scheduled: Scheduled
|
||||
# running: Running
|
||||
# succeeded: Succeeded
|
||||
# failed: Failed
|
||||
# partial_succeeded: Partial Succeeded
|
||||
# paused: Paused
|
||||
# stopped: Stopped
|
||||
#
|
||||
# [*] --> scheduled:
|
||||
# scheduled --> running: Start Execution
|
||||
# running --> paused: Human input required
|
||||
# paused --> running: human input added
|
||||
# paused --> stopped: User stops execution
|
||||
# running --> succeeded: Execution finishes without any error
|
||||
# running --> failed: Execution finishes with errors
|
||||
# running --> stopped: User stops execution
|
||||
# running --> partial_succeeded: some execution occurred and handled during execution
|
||||
#
|
||||
# scheduled --> stopped: User stops execution
|
||||
#
|
||||
# succeeded --> [*]
|
||||
# failed --> [*]
|
||||
# partial_succeeded --> [*]
|
||||
# stopped --> [*]
|
||||
|
||||
# `SCHEDULED` means that the workflow is scheduled to run, but has not
|
||||
# started running yet. (maybe due to possible worker saturation.)
|
||||
#
|
||||
# This enum value is currently unused.
|
||||
SCHEDULED = "scheduled"
|
||||
|
||||
# `RUNNING` means the workflow is exeuting.
|
||||
RUNNING = "running"
|
||||
|
||||
# `SUCCEEDED` means the execution of workflow succeed without any error.
|
||||
SUCCEEDED = "succeeded"
|
||||
|
||||
# `FAILED` means the execution of workflow failed without some errors.
|
||||
FAILED = "failed"
|
||||
|
||||
# `STOPPED` means the execution of workflow was stopped, either manually
|
||||
# by the user, or automatically by the Dify application (E.G. the moderation
|
||||
# mechanism.)
|
||||
STOPPED = "stopped"
|
||||
|
||||
# `PARTIAL_SUCCEEDED` indicates that some errors occurred during the workflow
|
||||
# execution, but they were successfully handled (e.g., by using an error
|
||||
# strategy such as "fail branch" or "default value").
|
||||
PARTIAL_SUCCEEDED = "partial-succeeded"
|
||||
|
||||
# `PAUSED` indicates that the workflow execution is temporarily paused
|
||||
# (e.g., awaiting human input) and is expected to resume later.
|
||||
PAUSED = "paused"
|
||||
|
||||
def is_ended(self) -> bool:
|
||||
return self in _END_STATE
|
||||
|
||||
|
||||
_END_STATE = frozenset(
|
||||
[
|
||||
WorkflowExecutionStatus.SUCCEEDED,
|
||||
WorkflowExecutionStatus.FAILED,
|
||||
WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
|
||||
WorkflowExecutionStatus.STOPPED,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ from typing import final
|
|||
|
||||
from typing_extensions import override
|
||||
|
||||
from core.workflow.entities.pause_reason import SchedulingPause
|
||||
|
||||
from ..domain.graph_execution import GraphExecution
|
||||
from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
|
||||
from .command_processor import CommandHandler
|
||||
|
|
@ -25,4 +27,7 @@ class PauseCommandHandler(CommandHandler):
|
|||
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
|
||||
assert isinstance(command, PauseCommand)
|
||||
logger.debug("Pausing workflow %s: %s", execution.workflow_id, command.reason)
|
||||
execution.pause(command.reason)
|
||||
# Convert string reason to PauseReason if needed
|
||||
reason = command.reason
|
||||
pause_reason = SchedulingPause(message=reason)
|
||||
execution.pause(pause_reason)
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from typing import Literal
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.enums import NodeState
|
||||
|
||||
from .node_execution import NodeExecution
|
||||
|
|
@ -41,7 +42,7 @@ class GraphExecutionState(BaseModel):
|
|||
completed: bool = Field(default=False)
|
||||
aborted: bool = Field(default=False)
|
||||
paused: bool = Field(default=False)
|
||||
pause_reason: str | None = Field(default=None)
|
||||
pause_reason: PauseReason | None = Field(default=None)
|
||||
error: GraphExecutionErrorState | None = Field(default=None)
|
||||
exceptions_count: int = Field(default=0)
|
||||
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])
|
||||
|
|
@ -106,7 +107,7 @@ class GraphExecution:
|
|||
completed: bool = False
|
||||
aborted: bool = False
|
||||
paused: bool = False
|
||||
pause_reason: str | None = None
|
||||
pause_reason: PauseReason | None = None
|
||||
error: Exception | None = None
|
||||
node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
|
||||
exceptions_count: int = 0
|
||||
|
|
@ -130,7 +131,7 @@ class GraphExecution:
|
|||
self.aborted = True
|
||||
self.error = RuntimeError(f"Aborted: {reason}")
|
||||
|
||||
def pause(self, reason: str | None = None) -> None:
|
||||
def pause(self, reason: PauseReason) -> None:
|
||||
"""Pause the graph execution without marking it complete."""
|
||||
if self.completed:
|
||||
raise RuntimeError("Cannot pause execution that has completed")
|
||||
|
|
|
|||
|
|
@ -36,4 +36,4 @@ class PauseCommand(GraphEngineCommand):
|
|||
"""Command to pause a running workflow execution."""
|
||||
|
||||
command_type: CommandType = Field(default=CommandType.PAUSE, description="Type of command")
|
||||
reason: str | None = Field(default=None, description="Optional reason for pause")
|
||||
reason: str = Field(default="unknown reason", description="reason for pause")
|
||||
|
|
|
|||
|
|
@ -210,7 +210,7 @@ class EventHandler:
|
|||
def _(self, event: NodeRunPauseRequestedEvent) -> None:
|
||||
"""Handle pause requests emitted by nodes."""
|
||||
|
||||
pause_reason = event.reason or "Awaiting human input"
|
||||
pause_reason = event.reason
|
||||
self._graph_execution.pause(pause_reason)
|
||||
self._state_manager.finish_execution(event.node_id)
|
||||
if event.node_id in self._graph.nodes:
|
||||
|
|
|
|||
|
|
@ -247,8 +247,11 @@ class GraphEngine:
|
|||
|
||||
# Handle completion
|
||||
if self._graph_execution.is_paused:
|
||||
pause_reason = self._graph_execution.pause_reason
|
||||
assert pause_reason is not None, "pause_reason should not be None when execution is paused."
|
||||
# Ensure we have a valid PauseReason for the event
|
||||
paused_event = GraphRunPausedEvent(
|
||||
reason=self._graph_execution.pause_reason,
|
||||
reason=pause_reason,
|
||||
outputs=self._graph_runtime_state.outputs,
|
||||
)
|
||||
self._event_manager.notify_layers(paused_event)
|
||||
|
|
|
|||
|
|
@ -216,7 +216,6 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
|||
def _handle_graph_run_paused(self, event: GraphRunPausedEvent) -> None:
|
||||
execution = self._get_workflow_execution()
|
||||
execution.status = WorkflowExecutionStatus.PAUSED
|
||||
execution.error_message = event.reason or "Workflow execution paused"
|
||||
execution.outputs = event.outputs
|
||||
self._populate_completion_statistics(execution, update_finished=False)
|
||||
|
||||
|
|
@ -296,7 +295,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
|||
domain_execution,
|
||||
event.node_run_result,
|
||||
WorkflowNodeExecutionStatus.PAUSED,
|
||||
error=event.reason,
|
||||
error="",
|
||||
update_outputs=False,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from pydantic import Field
|
||||
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.graph_events import BaseGraphEvent
|
||||
|
||||
|
||||
|
|
@ -44,7 +45,8 @@ class GraphRunAbortedEvent(BaseGraphEvent):
|
|||
class GraphRunPausedEvent(BaseGraphEvent):
|
||||
"""Event emitted when a graph run is paused by user command."""
|
||||
|
||||
reason: str | None = Field(default=None, description="reason for pause")
|
||||
# reason: str | None = Field(default=None, description="reason for pause")
|
||||
reason: PauseReason = Field(..., description="reason for pause")
|
||||
outputs: dict[str, object] = Field(
|
||||
default_factory=dict,
|
||||
description="Outputs available to the client while the run is paused.",
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from pydantic import Field
|
|||
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
|
||||
from .base import GraphNodeEventBase
|
||||
|
||||
|
|
@ -54,4 +55,4 @@ class NodeRunRetryEvent(NodeRunStartedEvent):
|
|||
|
||||
|
||||
class NodeRunPauseRequestedEvent(GraphNodeEventBase):
|
||||
reason: str | None = Field(default=None, description="Optional pause reason")
|
||||
reason: PauseReason = Field(..., description="pause reason")
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from pydantic import Field
|
|||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
|
||||
from .base import NodeEventBase
|
||||
|
|
@ -43,4 +44,4 @@ class StreamCompletedEvent(NodeEventBase):
|
|||
|
||||
|
||||
class PauseRequestedEvent(NodeEventBase):
|
||||
reason: str | None = Field(default=None, description="Optional pause reason")
|
||||
reason: PauseReason = Field(..., description="pause reason")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult, PauseRequestedEvent
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
|
|
@ -64,7 +65,7 @@ class HumanInputNode(Node):
|
|||
return self._pause_generator()
|
||||
|
||||
def _pause_generator(self):
|
||||
yield PauseRequestedEvent(reason=self._node_data.pause_reason)
|
||||
yield PauseRequestedEvent(reason=HumanInputRequired())
|
||||
|
||||
def _is_completion_ready(self) -> bool:
|
||||
"""Determine whether all required inputs are satisfied."""
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from typing import Any, Protocol
|
|||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.variables.segments import Segment
|
||||
from core.workflow.system_variable import SystemVariableReadOnlyView
|
||||
|
||||
|
||||
class ReadOnlyVariablePool(Protocol):
|
||||
|
|
@ -30,6 +31,9 @@ class ReadOnlyGraphRuntimeState(Protocol):
|
|||
All methods return defensive copies to ensure immutability.
|
||||
"""
|
||||
|
||||
@property
|
||||
def system_variable(self) -> SystemVariableReadOnlyView: ...
|
||||
|
||||
@property
|
||||
def variable_pool(self) -> ReadOnlyVariablePool:
|
||||
"""Get read-only access to the variable pool."""
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from typing import Any
|
|||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.variables.segments import Segment
|
||||
from core.workflow.system_variable import SystemVariableReadOnlyView
|
||||
|
||||
from .graph_runtime_state import GraphRuntimeState
|
||||
from .variable_pool import VariablePool
|
||||
|
|
@ -42,6 +43,10 @@ class ReadOnlyGraphRuntimeStateWrapper:
|
|||
self._state = state
|
||||
self._variable_pool_wrapper = ReadOnlyVariablePoolWrapper(state.variable_pool)
|
||||
|
||||
@property
|
||||
def system_variable(self) -> SystemVariableReadOnlyView:
|
||||
return self._state.variable_pool.system_variables.as_view()
|
||||
|
||||
@property
|
||||
def variable_pool(self) -> ReadOnlyVariablePoolWrapper:
|
||||
return self._variable_pool_wrapper
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from collections.abc import Mapping, Sequence
|
||||
from types import MappingProxyType
|
||||
from typing import Any
|
||||
|
||||
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator
|
||||
|
|
@ -108,3 +109,102 @@ class SystemVariable(BaseModel):
|
|||
if self.invoke_from is not None:
|
||||
d[SystemVariableKey.INVOKE_FROM] = self.invoke_from
|
||||
return d
|
||||
|
||||
def as_view(self) -> "SystemVariableReadOnlyView":
|
||||
return SystemVariableReadOnlyView(self)
|
||||
|
||||
|
||||
class SystemVariableReadOnlyView:
|
||||
"""
|
||||
A read-only view of a SystemVariable that implements the ReadOnlySystemVariable protocol.
|
||||
|
||||
This class wraps a SystemVariable instance and provides read-only access to all its fields.
|
||||
It always reads the latest data from the wrapped instance and prevents any write operations.
|
||||
"""
|
||||
|
||||
def __init__(self, system_variable: SystemVariable) -> None:
|
||||
"""
|
||||
Initialize the read-only view with a SystemVariable instance.
|
||||
|
||||
Args:
|
||||
system_variable: The SystemVariable instance to wrap
|
||||
"""
|
||||
self._system_variable = system_variable
|
||||
|
||||
@property
|
||||
def user_id(self) -> str | None:
|
||||
return self._system_variable.user_id
|
||||
|
||||
@property
|
||||
def app_id(self) -> str | None:
|
||||
return self._system_variable.app_id
|
||||
|
||||
@property
|
||||
def workflow_id(self) -> str | None:
|
||||
return self._system_variable.workflow_id
|
||||
|
||||
@property
|
||||
def workflow_execution_id(self) -> str | None:
|
||||
return self._system_variable.workflow_execution_id
|
||||
|
||||
@property
|
||||
def query(self) -> str | None:
|
||||
return self._system_variable.query
|
||||
|
||||
@property
|
||||
def conversation_id(self) -> str | None:
|
||||
return self._system_variable.conversation_id
|
||||
|
||||
@property
|
||||
def dialogue_count(self) -> int | None:
|
||||
return self._system_variable.dialogue_count
|
||||
|
||||
@property
|
||||
def document_id(self) -> str | None:
|
||||
return self._system_variable.document_id
|
||||
|
||||
@property
|
||||
def original_document_id(self) -> str | None:
|
||||
return self._system_variable.original_document_id
|
||||
|
||||
@property
|
||||
def dataset_id(self) -> str | None:
|
||||
return self._system_variable.dataset_id
|
||||
|
||||
@property
|
||||
def batch(self) -> str | None:
|
||||
return self._system_variable.batch
|
||||
|
||||
@property
|
||||
def datasource_type(self) -> str | None:
|
||||
return self._system_variable.datasource_type
|
||||
|
||||
@property
|
||||
def invoke_from(self) -> str | None:
|
||||
return self._system_variable.invoke_from
|
||||
|
||||
@property
|
||||
def files(self) -> Sequence[File]:
|
||||
"""
|
||||
Get a copy of the files from the wrapped SystemVariable.
|
||||
|
||||
Returns:
|
||||
A defensive copy of the files sequence to prevent modification
|
||||
"""
|
||||
return tuple(self._system_variable.files) # Convert to immutable tuple
|
||||
|
||||
@property
|
||||
def datasource_info(self) -> Mapping[str, Any] | None:
|
||||
"""
|
||||
Get a copy of the datasource info from the wrapped SystemVariable.
|
||||
|
||||
Returns:
|
||||
A view of the datasource info mapping to prevent modification
|
||||
"""
|
||||
if self._system_variable.datasource_info is None:
|
||||
return None
|
||||
return MappingProxyType(self._system_variable.datasource_info)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return a string representation of the read-only view."""
|
||||
return f"SystemVariableReadOnlyView(system_variable={self._system_variable!r})"
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ class Storage:
|
|||
case _:
|
||||
raise ValueError(f"unsupported storage type {storage_type}")
|
||||
|
||||
def save(self, filename, data):
|
||||
def save(self, filename: str, data: bytes):
|
||||
self.storage_runner.save(filename, data)
|
||||
|
||||
@overload
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ class BaseStorage(ABC):
|
|||
"""Interface for file storage."""
|
||||
|
||||
@abstractmethod
|
||||
def save(self, filename, data):
|
||||
def save(self, filename: str, data: bytes):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
|||
|
|
@ -0,0 +1,41 @@
|
|||
"""add WorkflowPause model
|
||||
|
||||
Revision ID: 03f8dcbc611e
|
||||
Revises: ae662b25d9bc
|
||||
Create Date: 2025-10-22 16:11:31.805407
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "03f8dcbc611e"
|
||||
down_revision = "ae662b25d9bc"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"workflow_pauses",
|
||||
sa.Column("workflow_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("workflow_run_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("resumed_at", sa.DateTime(), nullable=True),
|
||||
sa.Column("state_object_key", sa.String(length=255), nullable=False),
|
||||
sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuidv7()"), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("workflow_pauses_pkey")),
|
||||
sa.UniqueConstraint("workflow_run_id", name=op.f("workflow_pauses_workflow_run_id_key")),
|
||||
)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("workflow_pauses")
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -88,6 +88,7 @@ from .workflow import (
|
|||
WorkflowNodeExecutionModel,
|
||||
WorkflowNodeExecutionOffload,
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
WorkflowPause,
|
||||
WorkflowRun,
|
||||
WorkflowType,
|
||||
)
|
||||
|
|
@ -177,6 +178,7 @@ __all__ = [
|
|||
"WorkflowNodeExecutionModel",
|
||||
"WorkflowNodeExecutionOffload",
|
||||
"WorkflowNodeExecutionTriggeredFrom",
|
||||
"WorkflowPause",
|
||||
"WorkflowRun",
|
||||
"WorkflowRunTriggeredFrom",
|
||||
"WorkflowToolProvider",
|
||||
|
|
|
|||
|
|
@ -1,6 +1,12 @@
|
|||
from sqlalchemy.orm import DeclarativeBase, MappedAsDataclass
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, func, text
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column
|
||||
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.uuid_utils import uuidv7
|
||||
from models.engine import metadata
|
||||
from models.types import StringUUID
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
|
|
@ -13,3 +19,34 @@ class TypeBase(MappedAsDataclass, DeclarativeBase):
|
|||
"""
|
||||
|
||||
metadata = metadata
|
||||
|
||||
|
||||
class DefaultFieldsMixin:
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
primary_key=True,
|
||||
# NOTE: The default and server_default serve as fallback mechanisms.
|
||||
# The application can generate the `id` before saving to optimize
|
||||
# the insertion process (especially for interdependent models)
|
||||
# and reduce database roundtrips.
|
||||
default=uuidv7,
|
||||
server_default=text("uuidv7()"),
|
||||
)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
default=naive_utc_now,
|
||||
server_default=func.current_timestamp(),
|
||||
)
|
||||
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
__name_pos=DateTime,
|
||||
nullable=False,
|
||||
default=naive_utc_now,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}(id={self.id})>"
|
||||
|
|
|
|||
|
|
@ -13,8 +13,11 @@ from core.file.constants import maybe_file_object
|
|||
from core.file.models import File
|
||||
from core.variables import utils as variable_utils
|
||||
from core.variables.variables import FloatVariable, IntegerVariable, StringVariable
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.constants import (
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
)
|
||||
from core.workflow.enums import NodeType, WorkflowExecutionStatus
|
||||
from extensions.ext_storage import Storage
|
||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
|
@ -35,7 +38,7 @@ from factories import variable_factory
|
|||
from libs import helper
|
||||
|
||||
from .account import Account
|
||||
from .base import Base
|
||||
from .base import Base, DefaultFieldsMixin
|
||||
from .engine import db
|
||||
from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType
|
||||
from .types import EnumText, StringUUID
|
||||
|
|
@ -247,7 +250,9 @@ class Workflow(Base):
|
|||
return node_type
|
||||
|
||||
@staticmethod
|
||||
def get_enclosing_node_type_and_id(node_config: Mapping[str, Any]) -> tuple[NodeType, str] | None:
|
||||
def get_enclosing_node_type_and_id(
|
||||
node_config: Mapping[str, Any],
|
||||
) -> tuple[NodeType, str] | None:
|
||||
in_loop = node_config.get("isInLoop", False)
|
||||
in_iteration = node_config.get("isInIteration", False)
|
||||
if in_loop:
|
||||
|
|
@ -306,7 +311,10 @@ class Workflow(Base):
|
|||
if "nodes" not in graph_dict:
|
||||
return []
|
||||
|
||||
start_node = next((node for node in graph_dict["nodes"] if node["data"]["type"] == "start"), None)
|
||||
start_node = next(
|
||||
(node for node in graph_dict["nodes"] if node["data"]["type"] == "start"),
|
||||
None,
|
||||
)
|
||||
if not start_node:
|
||||
return []
|
||||
|
||||
|
|
@ -359,7 +367,9 @@ class Workflow(Base):
|
|||
return db.session.execute(stmt).scalar_one()
|
||||
|
||||
@property
|
||||
def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
|
||||
def environment_variables(
|
||||
self,
|
||||
) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
|
||||
# TODO: find some way to init `self._environment_variables` when instance created.
|
||||
if self._environment_variables is None:
|
||||
self._environment_variables = "{}"
|
||||
|
|
@ -376,7 +386,9 @@ class Workflow(Base):
|
|||
]
|
||||
|
||||
# decrypt secret variables value
|
||||
def decrypt_func(var: Variable) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable:
|
||||
def decrypt_func(
|
||||
var: Variable,
|
||||
) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable:
|
||||
if isinstance(var, SecretVariable):
|
||||
return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)})
|
||||
elif isinstance(var, (StringVariable, IntegerVariable, FloatVariable)):
|
||||
|
|
@ -537,7 +549,10 @@ class WorkflowRun(Base):
|
|||
version: Mapped[str] = mapped_column(String(255))
|
||||
graph: Mapped[str | None] = mapped_column(sa.Text)
|
||||
inputs: Mapped[str | None] = mapped_column(sa.Text)
|
||||
status: Mapped[str] = mapped_column(String(255)) # running, succeeded, failed, stopped, partial-succeeded
|
||||
status: Mapped[str] = mapped_column(
|
||||
EnumText(WorkflowExecutionStatus, length=255),
|
||||
nullable=False,
|
||||
)
|
||||
outputs: Mapped[str | None] = mapped_column(sa.Text, default="{}")
|
||||
error: Mapped[str | None] = mapped_column(sa.Text)
|
||||
elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
|
||||
|
|
@ -549,6 +564,15 @@ class WorkflowRun(Base):
|
|||
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||
exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
|
||||
|
||||
pause: Mapped[Optional["WorkflowPause"]] = orm.relationship(
|
||||
"WorkflowPause",
|
||||
primaryjoin="WorkflowRun.id == foreign(WorkflowPause.workflow_run_id)",
|
||||
uselist=False,
|
||||
# require explicit preloading.
|
||||
lazy="raise",
|
||||
back_populates="workflow_run",
|
||||
)
|
||||
|
||||
@property
|
||||
def created_by_account(self):
|
||||
created_by_role = CreatorUserRole(self.created_by_role)
|
||||
|
|
@ -1073,7 +1097,10 @@ class ConversationVariable(Base):
|
|||
DateTime, nullable=False, server_default=func.current_timestamp(), index=True
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
DateTime,
|
||||
nullable=False,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
)
|
||||
|
||||
def __init__(self, *, id: str, app_id: str, conversation_id: str, data: str):
|
||||
|
|
@ -1101,10 +1128,6 @@ class ConversationVariable(Base):
|
|||
_EDITABLE_SYSTEM_VARIABLE = frozenset(["query", "files"])
|
||||
|
||||
|
||||
def _naive_utc_datetime():
|
||||
return naive_utc_now()
|
||||
|
||||
|
||||
class WorkflowDraftVariable(Base):
|
||||
"""`WorkflowDraftVariable` record variables and outputs generated during
|
||||
debugging workflow or chatflow.
|
||||
|
|
@ -1138,14 +1161,14 @@ class WorkflowDraftVariable(Base):
|
|||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
default=_naive_utc_datetime,
|
||||
default=naive_utc_now,
|
||||
server_default=func.current_timestamp(),
|
||||
)
|
||||
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
default=_naive_utc_datetime,
|
||||
default=naive_utc_now,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
)
|
||||
|
|
@ -1412,8 +1435,8 @@ class WorkflowDraftVariable(Base):
|
|||
file_id: str | None = None,
|
||||
) -> "WorkflowDraftVariable":
|
||||
variable = WorkflowDraftVariable()
|
||||
variable.created_at = _naive_utc_datetime()
|
||||
variable.updated_at = _naive_utc_datetime()
|
||||
variable.created_at = naive_utc_now()
|
||||
variable.updated_at = naive_utc_now()
|
||||
variable.description = description
|
||||
variable.app_id = app_id
|
||||
variable.node_id = node_id
|
||||
|
|
@ -1518,7 +1541,7 @@ class WorkflowDraftVariableFile(Base):
|
|||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
default=_naive_utc_datetime,
|
||||
default=naive_utc_now,
|
||||
server_default=func.current_timestamp(),
|
||||
)
|
||||
|
||||
|
|
@ -1583,3 +1606,68 @@ class WorkflowDraftVariableFile(Base):
|
|||
|
||||
def is_system_variable_editable(name: str) -> bool:
|
||||
return name in _EDITABLE_SYSTEM_VARIABLE
|
||||
|
||||
|
||||
class WorkflowPause(DefaultFieldsMixin, Base):
|
||||
"""
|
||||
WorkflowPause records the paused state and related metadata for a specific workflow run.
|
||||
|
||||
Each `WorkflowRun` can have zero or one associated `WorkflowPause`, depending on its execution status.
|
||||
If a `WorkflowRun` is in the `PAUSED` state, there must be a corresponding `WorkflowPause`
|
||||
that has not yet been resumed.
|
||||
Otherwise, there should be no active (non-resumed) `WorkflowPause` linked to that run.
|
||||
|
||||
This model captures the execution context required to resume workflow processing at a later time.
|
||||
"""
|
||||
|
||||
__tablename__ = "workflow_pauses"
|
||||
__table_args__ = (
|
||||
# Design Note:
|
||||
# Instead of adding a `pause_id` field to the `WorkflowRun` model—which would require a migration
|
||||
# on a potentially large table—we reference `WorkflowRun` from `WorkflowPause` and enforce a unique
|
||||
# constraint on `workflow_run_id` to guarantee a one-to-one relationship.
|
||||
UniqueConstraint("workflow_run_id"),
|
||||
)
|
||||
|
||||
# `workflow_id` represents the unique identifier of the workflow associated with this pause.
|
||||
# It corresponds to the `id` field in the `Workflow` model.
|
||||
#
|
||||
# Since an application can have multiple versions of a workflow, each with its own unique ID,
|
||||
# the `app_id` alone is insufficient to determine which workflow version should be loaded
|
||||
# when resuming a suspended workflow.
|
||||
workflow_id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# `workflow_run_id` represents the identifier of the execution of workflow,
|
||||
# correspond to the `id` field of `WorkflowRun`.
|
||||
workflow_run_id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# `resumed_at` records the timestamp when the suspended workflow was resumed.
|
||||
# It is set to `NULL` if the workflow has not been resumed.
|
||||
#
|
||||
# NOTE: Resuming a suspended WorkflowPause does not delete the record immediately.
|
||||
# It only set `resumed_at` to a non-null value.
|
||||
resumed_at: Mapped[datetime | None] = mapped_column(
|
||||
sa.DateTime,
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# state_object_key stores the object key referencing the serialized runtime state
|
||||
# of the `GraphEngine`. This object captures the complete execution context of the
|
||||
# workflow at the moment it was paused, enabling accurate resumption.
|
||||
state_object_key: Mapped[str] = mapped_column(String(length=255), nullable=False)
|
||||
|
||||
# Relationship to WorkflowRun
|
||||
workflow_run: Mapped["WorkflowRun"] = orm.relationship(
|
||||
foreign_keys=[workflow_run_id],
|
||||
# require explicit preloading.
|
||||
lazy="raise",
|
||||
uselist=False,
|
||||
primaryjoin="WorkflowPause.workflow_run_id == WorkflowRun.id",
|
||||
back_populates="pause",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ from collections.abc import Sequence
|
|||
from datetime import datetime
|
||||
from typing import Protocol
|
||||
|
||||
from core.workflow.entities.workflow_pause import WorkflowPauseEntity
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
|
|
@ -251,6 +252,116 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
def create_workflow_pause(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
state_owner_user_id: str,
|
||||
state: str,
|
||||
) -> WorkflowPauseEntity:
|
||||
"""
|
||||
Create a new workflow pause state.
|
||||
|
||||
Creates a pause state for a workflow run, storing the current execution
|
||||
state and marking the workflow as paused. This is used when a workflow
|
||||
needs to be suspended and later resumed.
|
||||
|
||||
Args:
|
||||
workflow_run_id: Identifier of the workflow run to pause
|
||||
state_owner_user_id: User ID who owns the pause state for file storage
|
||||
state: Serialized workflow execution state (JSON string)
|
||||
|
||||
Returns:
|
||||
WorkflowPauseEntity representing the created pause state
|
||||
|
||||
Raises:
|
||||
ValueError: If workflow_run_id is invalid or workflow run doesn't exist
|
||||
RuntimeError: If workflow is already paused or in invalid state
|
||||
"""
|
||||
# NOTE: we may get rid of the `state_owner_user_id` in parameter list.
|
||||
# However, removing it would require an extra for `Workflow` model
|
||||
# while creating pause.
|
||||
...
|
||||
|
||||
def resume_workflow_pause(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
pause_entity: WorkflowPauseEntity,
|
||||
) -> WorkflowPauseEntity:
|
||||
"""
|
||||
Resume a paused workflow.
|
||||
|
||||
Marks a paused workflow as resumed, set the `resumed_at` field of WorkflowPauseEntity
|
||||
and returning the workflow to running status. Returns the pause entity
|
||||
that was resumed.
|
||||
|
||||
The returned `WorkflowPauseEntity` model has `resumed_at` set.
|
||||
|
||||
NOTE: this method does not delete the correspond `WorkflowPauseEntity` record and associated states.
|
||||
It's the callers responsibility to clear the correspond state with `delete_workflow_pause`.
|
||||
|
||||
Args:
|
||||
workflow_run_id: Identifier of the workflow run to resume
|
||||
pause_entity: The pause entity to resume
|
||||
|
||||
Returns:
|
||||
WorkflowPauseEntity representing the resumed pause state
|
||||
|
||||
Raises:
|
||||
ValueError: If workflow_run_id is invalid
|
||||
RuntimeError: If workflow is not paused or already resumed
|
||||
"""
|
||||
...
|
||||
|
||||
def delete_workflow_pause(
|
||||
self,
|
||||
pause_entity: WorkflowPauseEntity,
|
||||
) -> None:
|
||||
"""
|
||||
Delete a workflow pause state.
|
||||
|
||||
Permanently removes the pause state for a workflow run, including
|
||||
the stored state file. Used for cleanup operations when a paused
|
||||
workflow is no longer needed.
|
||||
|
||||
Args:
|
||||
pause_entity: The pause entity to delete
|
||||
|
||||
Raises:
|
||||
ValueError: If pause_entity is invalid
|
||||
RuntimeError: If workflow is not paused
|
||||
|
||||
Note:
|
||||
This operation is irreversible. The stored workflow state will be
|
||||
permanently deleted along with the pause record.
|
||||
"""
|
||||
...
|
||||
|
||||
def prune_pauses(
|
||||
self,
|
||||
expiration: datetime,
|
||||
resumption_expiration: datetime,
|
||||
limit: int | None = None,
|
||||
) -> Sequence[str]:
|
||||
"""
|
||||
Clean up expired and old pause states.
|
||||
|
||||
Removes pause states that have expired (created before expiration time)
|
||||
and pause states that were resumed more than resumption_duration ago.
|
||||
This is used for maintenance and cleanup operations.
|
||||
|
||||
Args:
|
||||
expiration: Remove pause states created before this time
|
||||
resumption_expiration: Remove pause states resumed before this time
|
||||
limit: maximum number of records deleted in one call
|
||||
|
||||
Returns:
|
||||
a list of ids for pause records that were pruned
|
||||
|
||||
Raises:
|
||||
ValueError: If parameters are invalid
|
||||
"""
|
||||
...
|
||||
|
||||
def get_daily_runs_statistics(
|
||||
self,
|
||||
tenant_id: str,
|
||||
|
|
|
|||
|
|
@ -20,19 +20,26 @@ Implementation Notes:
|
|||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from typing import Any, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy import and_, delete, func, null, or_, select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import Session, selectinload, sessionmaker
|
||||
|
||||
from core.workflow.entities.workflow_pause import WorkflowPauseEntity
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from extensions.ext_storage import storage
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.time_parser import get_time_threshold
|
||||
from libs.uuid_utils import uuidv7
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow import WorkflowPause as WorkflowPauseModel
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.types import (
|
||||
|
|
@ -45,6 +52,10 @@ from repositories.types import (
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _WorkflowRunError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
"""
|
||||
SQLAlchemy implementation of APIWorkflowRunRepository.
|
||||
|
|
@ -301,6 +312,281 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
logger.info("Total deleted %s workflow runs for app %s", total_deleted, app_id)
|
||||
return total_deleted
|
||||
|
||||
def create_workflow_pause(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
state_owner_user_id: str,
|
||||
state: str,
|
||||
) -> WorkflowPauseEntity:
|
||||
"""
|
||||
Create a new workflow pause state.
|
||||
|
||||
Creates a pause state for a workflow run, storing the current execution
|
||||
state and marking the workflow as paused. This is used when a workflow
|
||||
needs to be suspended and later resumed.
|
||||
|
||||
Args:
|
||||
workflow_run_id: Identifier of the workflow run to pause
|
||||
state_owner_user_id: User ID who owns the pause state for file storage
|
||||
state: Serialized workflow execution state (JSON string)
|
||||
|
||||
Returns:
|
||||
RepositoryWorkflowPauseEntity representing the created pause state
|
||||
|
||||
Raises:
|
||||
ValueError: If workflow_run_id is invalid or workflow run doesn't exist
|
||||
RuntimeError: If workflow is already paused or in invalid state
|
||||
"""
|
||||
previous_pause_model_query = select(WorkflowPauseModel).where(
|
||||
WorkflowPauseModel.workflow_run_id == workflow_run_id
|
||||
)
|
||||
with self._session_maker() as session, session.begin():
|
||||
# Get the workflow run
|
||||
workflow_run = session.get(WorkflowRun, workflow_run_id)
|
||||
if workflow_run is None:
|
||||
raise ValueError(f"WorkflowRun not found: {workflow_run_id}")
|
||||
|
||||
# Check if workflow is in RUNNING status
|
||||
if workflow_run.status != WorkflowExecutionStatus.RUNNING:
|
||||
raise _WorkflowRunError(
|
||||
f"Only WorkflowRun with RUNNING status can be paused, "
|
||||
f"workflow_run_id={workflow_run_id}, current_status={workflow_run.status}"
|
||||
)
|
||||
#
|
||||
previous_pause = session.scalars(previous_pause_model_query).first()
|
||||
if previous_pause:
|
||||
self._delete_pause_model(session, previous_pause)
|
||||
# we need to flush here to ensure that the old one is actually deleted.
|
||||
session.flush()
|
||||
|
||||
state_obj_key = f"workflow-state-{uuid.uuid4()}.json"
|
||||
storage.save(state_obj_key, state.encode())
|
||||
# Upload the state file
|
||||
|
||||
# Create the pause record
|
||||
pause_model = WorkflowPauseModel()
|
||||
pause_model.id = str(uuidv7())
|
||||
pause_model.workflow_id = workflow_run.workflow_id
|
||||
pause_model.workflow_run_id = workflow_run.id
|
||||
pause_model.state_object_key = state_obj_key
|
||||
pause_model.created_at = naive_utc_now()
|
||||
|
||||
# Update workflow run status
|
||||
workflow_run.status = WorkflowExecutionStatus.PAUSED
|
||||
|
||||
# Save everything in a transaction
|
||||
session.add(pause_model)
|
||||
session.add(workflow_run)
|
||||
|
||||
logger.info("Created workflow pause %s for workflow run %s", pause_model.id, workflow_run_id)
|
||||
|
||||
return _PrivateWorkflowPauseEntity.from_models(pause_model)
|
||||
|
||||
def get_workflow_pause(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
) -> WorkflowPauseEntity | None:
|
||||
"""
|
||||
Get an existing workflow pause state.
|
||||
|
||||
Retrieves the pause state for a specific workflow run if it exists.
|
||||
Used to check if a workflow is paused and to retrieve its saved state.
|
||||
|
||||
Args:
|
||||
workflow_run_id: Identifier of the workflow run to get pause state for
|
||||
|
||||
Returns:
|
||||
RepositoryWorkflowPauseEntity if pause state exists, None otherwise
|
||||
|
||||
Raises:
|
||||
ValueError: If workflow_run_id is invalid
|
||||
"""
|
||||
with self._session_maker() as session:
|
||||
# Query workflow run with pause and state file
|
||||
stmt = select(WorkflowRun).options(selectinload(WorkflowRun.pause)).where(WorkflowRun.id == workflow_run_id)
|
||||
workflow_run = session.scalar(stmt)
|
||||
|
||||
if workflow_run is None:
|
||||
raise ValueError(f"WorkflowRun not found: {workflow_run_id}")
|
||||
|
||||
pause_model = workflow_run.pause
|
||||
if pause_model is None:
|
||||
return None
|
||||
|
||||
return _PrivateWorkflowPauseEntity.from_models(pause_model)
|
||||
|
||||
def resume_workflow_pause(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
pause_entity: WorkflowPauseEntity,
|
||||
) -> WorkflowPauseEntity:
|
||||
"""
|
||||
Resume a paused workflow.
|
||||
|
||||
Marks a paused workflow as resumed, clearing the pause state and
|
||||
returning the workflow to running status. Returns the pause entity
|
||||
that was resumed.
|
||||
|
||||
Args:
|
||||
workflow_run_id: Identifier of the workflow run to resume
|
||||
pause_entity: The pause entity to resume
|
||||
|
||||
Returns:
|
||||
RepositoryWorkflowPauseEntity representing the resumed pause state
|
||||
|
||||
Raises:
|
||||
ValueError: If workflow_run_id is invalid
|
||||
RuntimeError: If workflow is not paused or already resumed
|
||||
"""
|
||||
with self._session_maker() as session, session.begin():
|
||||
# Get the workflow run with pause
|
||||
stmt = select(WorkflowRun).options(selectinload(WorkflowRun.pause)).where(WorkflowRun.id == workflow_run_id)
|
||||
workflow_run = session.scalar(stmt)
|
||||
|
||||
if workflow_run is None:
|
||||
raise ValueError(f"WorkflowRun not found: {workflow_run_id}")
|
||||
|
||||
if workflow_run.status != WorkflowExecutionStatus.PAUSED:
|
||||
raise _WorkflowRunError(
|
||||
f"WorkflowRun is not in PAUSED status, workflow_run_id={workflow_run_id}, "
|
||||
f"current_status={workflow_run.status}"
|
||||
)
|
||||
pause_model = workflow_run.pause
|
||||
if pause_model is None:
|
||||
raise _WorkflowRunError(f"No pause state found for workflow run: {workflow_run_id}")
|
||||
|
||||
if pause_model.id != pause_entity.id:
|
||||
raise _WorkflowRunError(
|
||||
"different id in WorkflowPause and WorkflowPauseEntity, "
|
||||
f"WorkflowPause.id={pause_model.id}, "
|
||||
f"WorkflowPauseEntity.id={pause_entity.id}"
|
||||
)
|
||||
|
||||
if pause_model.resumed_at is not None:
|
||||
raise _WorkflowRunError(f"Cannot resume an already resumed pause, pause_id={pause_model.id}")
|
||||
|
||||
# Mark as resumed
|
||||
pause_model.resumed_at = naive_utc_now()
|
||||
workflow_run.pause_id = None # type: ignore
|
||||
workflow_run.status = WorkflowExecutionStatus.RUNNING
|
||||
|
||||
session.add(pause_model)
|
||||
session.add(workflow_run)
|
||||
|
||||
logger.info("Resumed workflow pause %s for workflow run %s", pause_model.id, workflow_run_id)
|
||||
|
||||
return _PrivateWorkflowPauseEntity.from_models(pause_model)
|
||||
|
||||
def delete_workflow_pause(
|
||||
self,
|
||||
pause_entity: WorkflowPauseEntity,
|
||||
) -> None:
|
||||
"""
|
||||
Delete a workflow pause state.
|
||||
|
||||
Permanently removes the pause state for a workflow run, including
|
||||
the stored state file. Used for cleanup operations when a paused
|
||||
workflow is no longer needed.
|
||||
|
||||
Args:
|
||||
pause_entity: The pause entity to delete
|
||||
|
||||
Raises:
|
||||
ValueError: If pause_entity is invalid
|
||||
_WorkflowRunError: If workflow is not paused
|
||||
|
||||
Note:
|
||||
This operation is irreversible. The stored workflow state will be
|
||||
permanently deleted along with the pause record.
|
||||
"""
|
||||
with self._session_maker() as session, session.begin():
|
||||
# Get the pause model by ID
|
||||
pause_model = session.get(WorkflowPauseModel, pause_entity.id)
|
||||
if pause_model is None:
|
||||
raise _WorkflowRunError(f"WorkflowPause not found: {pause_entity.id}")
|
||||
self._delete_pause_model(session, pause_model)
|
||||
|
||||
@staticmethod
|
||||
def _delete_pause_model(session: Session, pause_model: WorkflowPauseModel):
|
||||
storage.delete(pause_model.state_object_key)
|
||||
|
||||
# Delete the pause record
|
||||
session.delete(pause_model)
|
||||
|
||||
logger.info("Deleted workflow pause %s for workflow run %s", pause_model.id, pause_model.workflow_run_id)
|
||||
|
||||
def prune_pauses(
|
||||
self,
|
||||
expiration: datetime,
|
||||
resumption_expiration: datetime,
|
||||
limit: int | None = None,
|
||||
) -> Sequence[str]:
|
||||
"""
|
||||
Clean up expired and old pause states.
|
||||
|
||||
Removes pause states that have expired (created before expiration time)
|
||||
and pause states that were resumed more than resumption_duration ago.
|
||||
This is used for maintenance and cleanup operations.
|
||||
|
||||
Args:
|
||||
expiration: Remove pause states created before this time
|
||||
resumption_expiration: Remove pause states resumed before this time
|
||||
limit: maximum number of records deleted in one call
|
||||
|
||||
Returns:
|
||||
a list of ids for pause records that were pruned
|
||||
|
||||
Raises:
|
||||
ValueError: If parameters are invalid
|
||||
"""
|
||||
_limit: int = limit or 1000
|
||||
pruned_record_ids: list[str] = []
|
||||
cond = or_(
|
||||
WorkflowPauseModel.created_at < expiration,
|
||||
and_(
|
||||
WorkflowPauseModel.resumed_at.is_not(null()),
|
||||
WorkflowPauseModel.resumed_at < resumption_expiration,
|
||||
),
|
||||
)
|
||||
# First, collect pause records to delete with their state files
|
||||
# Expired pauses (created before expiration time)
|
||||
stmt = select(WorkflowPauseModel).where(cond).limit(_limit)
|
||||
|
||||
with self._session_maker(expire_on_commit=False) as session:
|
||||
# Old resumed pauses (resumed more than resumption_duration ago)
|
||||
|
||||
# Get all records to delete
|
||||
pauses_to_delete = session.scalars(stmt).all()
|
||||
|
||||
# Delete state files from storage
|
||||
for pause in pauses_to_delete:
|
||||
with self._session_maker(expire_on_commit=False) as session, session.begin():
|
||||
# todo: this issues a separate query for each WorkflowPauseModel record.
|
||||
# consider batching this lookup.
|
||||
try:
|
||||
storage.delete(pause.state_object_key)
|
||||
logger.info(
|
||||
"Deleted state object for pause, pause_id=%s, object_key=%s",
|
||||
pause.id,
|
||||
pause.state_object_key,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to delete state file for pause, pause_id=%s, object_key=%s",
|
||||
pause.id,
|
||||
pause.state_object_key,
|
||||
)
|
||||
continue
|
||||
session.delete(pause)
|
||||
pruned_record_ids.append(pause.id)
|
||||
logger.info(
|
||||
"workflow pause records deleted, id=%s, resumed_at=%s",
|
||||
pause.id,
|
||||
pause.resumed_at,
|
||||
)
|
||||
|
||||
return pruned_record_ids
|
||||
|
||||
def get_daily_runs_statistics(
|
||||
self,
|
||||
tenant_id: str,
|
||||
|
|
@ -510,3 +796,69 @@ GROUP BY
|
|||
)
|
||||
|
||||
return cast(list[AverageInteractionStats], response_data)
|
||||
|
||||
|
||||
class _PrivateWorkflowPauseEntity(WorkflowPauseEntity):
|
||||
"""
|
||||
Private implementation of WorkflowPauseEntity for SQLAlchemy repository.
|
||||
|
||||
This implementation is internal to the repository layer and provides
|
||||
the concrete implementation of the WorkflowPauseEntity interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
pause_model: WorkflowPauseModel,
|
||||
) -> None:
|
||||
self._pause_model = pause_model
|
||||
self._cached_state: bytes | None = None
|
||||
|
||||
@classmethod
|
||||
def from_models(cls, workflow_pause_model) -> "_PrivateWorkflowPauseEntity":
|
||||
"""
|
||||
Create a _PrivateWorkflowPauseEntity from database models.
|
||||
|
||||
Args:
|
||||
workflow_pause_model: The WorkflowPause database model
|
||||
upload_file_model: The UploadFile database model
|
||||
|
||||
Returns:
|
||||
_PrivateWorkflowPauseEntity: The constructed entity
|
||||
|
||||
Raises:
|
||||
ValueError: If required model attributes are missing
|
||||
"""
|
||||
return cls(pause_model=workflow_pause_model)
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self._pause_model.id
|
||||
|
||||
@property
|
||||
def workflow_execution_id(self) -> str:
|
||||
return self._pause_model.workflow_run_id
|
||||
|
||||
def get_state(self) -> bytes:
|
||||
"""
|
||||
Retrieve the serialized workflow state from storage.
|
||||
|
||||
Returns:
|
||||
Mapping[str, Any]: The workflow state as a dictionary
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the state file cannot be found
|
||||
IOError: If there are issues reading the state file
|
||||
_Workflow: If the state cannot be deserialized properly
|
||||
"""
|
||||
if self._cached_state is not None:
|
||||
return self._cached_state
|
||||
|
||||
# Load the state from storage
|
||||
state_data = storage.load(self._pause_model.state_object_key)
|
||||
self._cached_state = state_data
|
||||
return state_data
|
||||
|
||||
@property
|
||||
def resumed_at(self) -> datetime | None:
|
||||
return self._pause_model.resumed_at
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import threading
|
||||
from collections.abc import Sequence
|
||||
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
import contexts
|
||||
|
|
@ -14,17 +15,26 @@ from models import (
|
|||
WorkflowRun,
|
||||
WorkflowRunTriggeredFrom,
|
||||
)
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
|
||||
class WorkflowRunService:
|
||||
def __init__(self):
|
||||
_session_factory: sessionmaker
|
||||
_workflow_run_repo: APIWorkflowRunRepository
|
||||
|
||||
def __init__(self, session_factory: Engine | sessionmaker | None = None):
|
||||
"""Initialize WorkflowRunService with repository dependencies."""
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
if session_factory is None:
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
elif isinstance(session_factory, Engine):
|
||||
session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
|
||||
|
||||
self._session_factory = session_factory
|
||||
self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
|
||||
session_maker
|
||||
self._session_factory
|
||||
)
|
||||
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_factory)
|
||||
|
||||
def get_paginate_advanced_chat_workflow_runs(
|
||||
self, app_model: App, args: dict, triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
# Core integration tests package
|
||||
|
|
@ -0,0 +1 @@
|
|||
# App integration tests package
|
||||
|
|
@ -0,0 +1 @@
|
|||
# Layers integration tests package
|
||||
|
|
@ -0,0 +1,520 @@
|
|||
"""Comprehensive TestContainers-based integration tests for PauseStatePersistenceLayer class.
|
||||
|
||||
This test suite covers complete integration scenarios including:
|
||||
- Real database interactions using containerized PostgreSQL
|
||||
- Real storage operations using test storage backend
|
||||
- Complete workflow: event -> state serialization -> database save -> storage save
|
||||
- Testing with actual WorkflowRunService (not mocked)
|
||||
- Real Workflow and WorkflowRun instances in database
|
||||
- Database transactions and rollback behavior
|
||||
- Actual file upload and retrieval through storage
|
||||
- Workflow status transitions in database
|
||||
- Error handling with real database constraints
|
||||
- Multiple pause events in sequence
|
||||
- Integration with real ReadOnlyGraphRuntimeState implementations
|
||||
|
||||
These tests use TestContainers to spin up real services for integration testing,
|
||||
providing more reliable and realistic test scenarios than mocks.
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from time import time
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Engine, delete, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.layers.pause_state_persist_layer import PauseStatePersistenceLayer
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities.pause_reason import SchedulingPause
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.graph_engine.entities.commands import GraphEngineCommand
|
||||
from core.workflow.graph_events.graph import GraphRunPausedEvent
|
||||
from core.workflow.runtime.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState
|
||||
from core.workflow.runtime.read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper
|
||||
from core.workflow.runtime.variable_pool import SystemVariable, VariablePool
|
||||
from extensions.ext_storage import storage
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account
|
||||
from models import WorkflowPause as WorkflowPauseModel
|
||||
from models.model import UploadFile
|
||||
from models.workflow import Workflow, WorkflowRun
|
||||
from services.file_service import FileService
|
||||
from services.workflow_run_service import WorkflowRunService
|
||||
|
||||
|
||||
class _TestCommandChannelImpl:
|
||||
"""Real implementation of CommandChannel for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self._commands: list[GraphEngineCommand] = []
|
||||
|
||||
def fetch_commands(self) -> list[GraphEngineCommand]:
|
||||
"""Fetch pending commands for this GraphEngine instance."""
|
||||
return self._commands.copy()
|
||||
|
||||
def send_command(self, command: GraphEngineCommand) -> None:
|
||||
"""Send a command to be processed by this GraphEngine instance."""
|
||||
self._commands.append(command)
|
||||
|
||||
|
||||
class TestPauseStatePersistenceLayerTestContainers:
|
||||
"""Comprehensive TestContainers-based integration tests for PauseStatePersistenceLayer class."""
|
||||
|
||||
@pytest.fixture
|
||||
def engine(self, db_session_with_containers: Session):
|
||||
"""Get database engine from TestContainers session."""
|
||||
bind = db_session_with_containers.get_bind()
|
||||
assert isinstance(bind, Engine)
|
||||
return bind
|
||||
|
||||
@pytest.fixture
|
||||
def file_service(self, engine: Engine):
|
||||
"""Create FileService instance with TestContainers engine."""
|
||||
return FileService(engine)
|
||||
|
||||
@pytest.fixture
|
||||
def workflow_run_service(self, engine: Engine, file_service: FileService):
|
||||
"""Create WorkflowRunService instance with TestContainers engine and FileService."""
|
||||
return WorkflowRunService(engine)
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_test_data(self, db_session_with_containers, file_service, workflow_run_service):
|
||||
"""Set up test data for each test method using TestContainers."""
|
||||
# Create test tenant and account
|
||||
from models.account import Tenant, TenantAccountJoin, TenantAccountRole
|
||||
|
||||
tenant = Tenant(
|
||||
name="Test Tenant",
|
||||
status="normal",
|
||||
)
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
account = Account(
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
tenant_join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db_session_with_containers.add(tenant_join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Set test data
|
||||
self.test_tenant_id = tenant.id
|
||||
self.test_user_id = account.id
|
||||
self.test_app_id = str(uuid.uuid4())
|
||||
self.test_workflow_id = str(uuid.uuid4())
|
||||
self.test_workflow_run_id = str(uuid.uuid4())
|
||||
|
||||
# Create test workflow
|
||||
self.test_workflow = Workflow(
|
||||
id=self.test_workflow_id,
|
||||
tenant_id=self.test_tenant_id,
|
||||
app_id=self.test_app_id,
|
||||
type="workflow",
|
||||
version="draft",
|
||||
graph='{"nodes": [], "edges": []}',
|
||||
features='{"file_upload": {"enabled": false}}',
|
||||
created_by=self.test_user_id,
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Create test workflow run
|
||||
self.test_workflow_run = WorkflowRun(
|
||||
id=self.test_workflow_run_id,
|
||||
tenant_id=self.test_tenant_id,
|
||||
app_id=self.test_app_id,
|
||||
workflow_id=self.test_workflow_id,
|
||||
type="workflow",
|
||||
triggered_from="debugging",
|
||||
version="draft",
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
created_by=self.test_user_id,
|
||||
created_by_role="account",
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Store session and service instances
|
||||
self.session = db_session_with_containers
|
||||
self.file_service = file_service
|
||||
self.workflow_run_service = workflow_run_service
|
||||
|
||||
# Save test data to database
|
||||
self.session.add(self.test_workflow)
|
||||
self.session.add(self.test_workflow_run)
|
||||
self.session.commit()
|
||||
|
||||
yield
|
||||
|
||||
# Cleanup
|
||||
self._cleanup_test_data()
|
||||
|
||||
def _cleanup_test_data(self):
|
||||
"""Clean up test data after each test method."""
|
||||
try:
|
||||
# Clean up workflow pauses
|
||||
self.session.execute(delete(WorkflowPauseModel))
|
||||
# Clean up upload files
|
||||
self.session.execute(
|
||||
delete(UploadFile).where(
|
||||
UploadFile.tenant_id == self.test_tenant_id,
|
||||
)
|
||||
)
|
||||
# Clean up workflow runs
|
||||
self.session.execute(
|
||||
delete(WorkflowRun).where(
|
||||
WorkflowRun.tenant_id == self.test_tenant_id,
|
||||
WorkflowRun.app_id == self.test_app_id,
|
||||
)
|
||||
)
|
||||
# Clean up workflows
|
||||
self.session.execute(
|
||||
delete(Workflow).where(
|
||||
Workflow.tenant_id == self.test_tenant_id,
|
||||
Workflow.app_id == self.test_app_id,
|
||||
)
|
||||
)
|
||||
self.session.commit()
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
raise e
|
||||
|
||||
def _create_graph_runtime_state(
|
||||
self,
|
||||
outputs: dict[str, object] | None = None,
|
||||
total_tokens: int = 0,
|
||||
node_run_steps: int = 0,
|
||||
variables: dict[tuple[str, str], object] | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
) -> ReadOnlyGraphRuntimeState:
|
||||
"""Create a real GraphRuntimeState for testing."""
|
||||
start_at = time()
|
||||
|
||||
execution_id = workflow_run_id or getattr(self, "test_workflow_run_id", None) or str(uuid.uuid4())
|
||||
|
||||
# Create variable pool
|
||||
variable_pool = VariablePool(system_variables=SystemVariable(workflow_execution_id=execution_id))
|
||||
if variables:
|
||||
for (node_id, var_key), value in variables.items():
|
||||
variable_pool.add([node_id, var_key], value)
|
||||
|
||||
# Create LLM usage
|
||||
llm_usage = LLMUsage.empty_usage()
|
||||
|
||||
# Create graph runtime state
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=start_at,
|
||||
total_tokens=total_tokens,
|
||||
llm_usage=llm_usage,
|
||||
outputs=outputs or {},
|
||||
node_run_steps=node_run_steps,
|
||||
)
|
||||
|
||||
return ReadOnlyGraphRuntimeStateWrapper(graph_runtime_state)
|
||||
|
||||
def _create_pause_state_persistence_layer(
|
||||
self,
|
||||
workflow_run: WorkflowRun | None = None,
|
||||
workflow: Workflow | None = None,
|
||||
state_owner_user_id: str | None = None,
|
||||
) -> PauseStatePersistenceLayer:
|
||||
"""Create PauseStatePersistenceLayer with real dependencies."""
|
||||
owner_id = state_owner_user_id
|
||||
if owner_id is None:
|
||||
if workflow is not None and workflow.created_by:
|
||||
owner_id = workflow.created_by
|
||||
elif workflow_run is not None and workflow_run.created_by:
|
||||
owner_id = workflow_run.created_by
|
||||
else:
|
||||
owner_id = getattr(self, "test_user_id", None)
|
||||
|
||||
assert owner_id is not None
|
||||
owner_id = str(owner_id)
|
||||
|
||||
return PauseStatePersistenceLayer(
|
||||
session_factory=self.session.get_bind(),
|
||||
state_owner_user_id=owner_id,
|
||||
)
|
||||
|
||||
def test_complete_pause_flow_with_real_dependencies(self, db_session_with_containers):
|
||||
"""Test complete pause flow: event -> state serialization -> database save -> storage save."""
|
||||
# Arrange
|
||||
layer = self._create_pause_state_persistence_layer()
|
||||
|
||||
# Create real graph runtime state with test data
|
||||
test_outputs = {"result": "test_output", "step": "intermediate"}
|
||||
test_variables = {
|
||||
("node1", "var1"): "string_value",
|
||||
("node2", "var2"): {"complex": "object"},
|
||||
}
|
||||
graph_runtime_state = self._create_graph_runtime_state(
|
||||
outputs=test_outputs,
|
||||
total_tokens=100,
|
||||
node_run_steps=5,
|
||||
variables=test_variables,
|
||||
)
|
||||
|
||||
command_channel = _TestCommandChannelImpl()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
# Create pause event
|
||||
event = GraphRunPausedEvent(
|
||||
reason=SchedulingPause(message="test pause"),
|
||||
outputs={"intermediate": "result"},
|
||||
)
|
||||
|
||||
# Act
|
||||
layer.on_event(event)
|
||||
|
||||
# Assert - Verify pause state was saved to database
|
||||
self.session.refresh(self.test_workflow_run)
|
||||
workflow_run = self.session.get(WorkflowRun, self.test_workflow_run_id)
|
||||
assert workflow_run is not None
|
||||
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
|
||||
|
||||
# Verify pause state exists in database
|
||||
pause_model = self.session.scalars(
|
||||
select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id)
|
||||
).first()
|
||||
assert pause_model is not None
|
||||
assert pause_model.workflow_id == self.test_workflow_id
|
||||
assert pause_model.workflow_run_id == self.test_workflow_run_id
|
||||
assert pause_model.state_object_key != ""
|
||||
assert pause_model.resumed_at is None
|
||||
|
||||
storage_content = storage.load(pause_model.state_object_key).decode()
|
||||
expected_state = json.loads(graph_runtime_state.dumps())
|
||||
actual_state = json.loads(storage_content)
|
||||
|
||||
assert actual_state == expected_state
|
||||
|
||||
def test_state_persistence_and_retrieval(self, db_session_with_containers):
|
||||
"""Test that pause state can be persisted and retrieved correctly."""
|
||||
# Arrange
|
||||
layer = self._create_pause_state_persistence_layer()
|
||||
|
||||
# Create complex test data
|
||||
complex_outputs = {
|
||||
"nested": {"key": "value", "number": 42},
|
||||
"list": [1, 2, 3, {"nested": "item"}],
|
||||
"boolean": True,
|
||||
"null_value": None,
|
||||
}
|
||||
complex_variables = {
|
||||
("node1", "var1"): "string_value",
|
||||
("node2", "var2"): {"complex": "object"},
|
||||
("node3", "var3"): [1, 2, 3],
|
||||
}
|
||||
|
||||
graph_runtime_state = self._create_graph_runtime_state(
|
||||
outputs=complex_outputs,
|
||||
total_tokens=250,
|
||||
node_run_steps=10,
|
||||
variables=complex_variables,
|
||||
)
|
||||
|
||||
command_channel = _TestCommandChannelImpl()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
|
||||
|
||||
# Act - Save pause state
|
||||
layer.on_event(event)
|
||||
|
||||
# Assert - Retrieve and verify
|
||||
pause_entity = self.workflow_run_service._workflow_run_repo.get_workflow_pause(self.test_workflow_run_id)
|
||||
assert pause_entity is not None
|
||||
assert pause_entity.workflow_execution_id == self.test_workflow_run_id
|
||||
|
||||
state_bytes = pause_entity.get_state()
|
||||
retrieved_state = json.loads(state_bytes.decode())
|
||||
expected_state = json.loads(graph_runtime_state.dumps())
|
||||
|
||||
assert retrieved_state == expected_state
|
||||
assert retrieved_state["outputs"] == complex_outputs
|
||||
assert retrieved_state["total_tokens"] == 250
|
||||
assert retrieved_state["node_run_steps"] == 10
|
||||
|
||||
def test_database_transaction_handling(self, db_session_with_containers):
|
||||
"""Test that database transactions are handled correctly."""
|
||||
# Arrange
|
||||
layer = self._create_pause_state_persistence_layer()
|
||||
graph_runtime_state = self._create_graph_runtime_state(
|
||||
outputs={"test": "transaction"},
|
||||
total_tokens=50,
|
||||
)
|
||||
|
||||
command_channel = _TestCommandChannelImpl()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
|
||||
|
||||
# Act
|
||||
layer.on_event(event)
|
||||
|
||||
# Assert - Verify data is committed and accessible in new session
|
||||
with Session(bind=self.session.get_bind(), expire_on_commit=False) as new_session:
|
||||
workflow_run = new_session.get(WorkflowRun, self.test_workflow_run_id)
|
||||
assert workflow_run is not None
|
||||
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
|
||||
|
||||
pause_model = new_session.scalars(
|
||||
select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id)
|
||||
).first()
|
||||
assert pause_model is not None
|
||||
assert pause_model.workflow_run_id == self.test_workflow_run_id
|
||||
assert pause_model.resumed_at is None
|
||||
assert pause_model.state_object_key != ""
|
||||
|
||||
def test_file_storage_integration(self, db_session_with_containers):
|
||||
"""Test integration with file storage system."""
|
||||
# Arrange
|
||||
layer = self._create_pause_state_persistence_layer()
|
||||
|
||||
# Create large state data to test storage
|
||||
large_outputs = {"data": "x" * 10000} # 10KB of data
|
||||
graph_runtime_state = self._create_graph_runtime_state(
|
||||
outputs=large_outputs,
|
||||
total_tokens=1000,
|
||||
)
|
||||
|
||||
command_channel = _TestCommandChannelImpl()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
|
||||
|
||||
# Act
|
||||
layer.on_event(event)
|
||||
|
||||
# Assert - Verify file was uploaded to storage
|
||||
self.session.refresh(self.test_workflow_run)
|
||||
pause_model = self.session.scalars(
|
||||
select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == self.test_workflow_run.id)
|
||||
).first()
|
||||
assert pause_model is not None
|
||||
assert pause_model.state_object_key != ""
|
||||
|
||||
# Verify content in storage
|
||||
storage_content = storage.load(pause_model.state_object_key).decode()
|
||||
assert storage_content == graph_runtime_state.dumps()
|
||||
|
||||
def test_workflow_with_different_creators(self, db_session_with_containers):
|
||||
"""Test pause state with workflows created by different users."""
|
||||
# Arrange - Create workflow with different creator
|
||||
different_user_id = str(uuid.uuid4())
|
||||
different_workflow = Workflow(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=self.test_tenant_id,
|
||||
app_id=self.test_app_id,
|
||||
type="workflow",
|
||||
version="draft",
|
||||
graph='{"nodes": [], "edges": []}',
|
||||
features='{"file_upload": {"enabled": false}}',
|
||||
created_by=different_user_id,
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
different_workflow_run = WorkflowRun(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=self.test_tenant_id,
|
||||
app_id=self.test_app_id,
|
||||
workflow_id=different_workflow.id,
|
||||
type="workflow",
|
||||
triggered_from="debugging",
|
||||
version="draft",
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
created_by=self.test_user_id, # Run created by different user
|
||||
created_by_role="account",
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
self.session.add(different_workflow)
|
||||
self.session.add(different_workflow_run)
|
||||
self.session.commit()
|
||||
|
||||
layer = self._create_pause_state_persistence_layer(
|
||||
workflow_run=different_workflow_run,
|
||||
workflow=different_workflow,
|
||||
)
|
||||
|
||||
graph_runtime_state = self._create_graph_runtime_state(
|
||||
outputs={"creator_test": "different_creator"},
|
||||
workflow_run_id=different_workflow_run.id,
|
||||
)
|
||||
|
||||
command_channel = _TestCommandChannelImpl()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
|
||||
|
||||
# Act
|
||||
layer.on_event(event)
|
||||
|
||||
# Assert - Should use workflow creator (not run creator)
|
||||
self.session.refresh(different_workflow_run)
|
||||
pause_model = self.session.scalars(
|
||||
select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == different_workflow_run.id)
|
||||
).first()
|
||||
assert pause_model is not None
|
||||
|
||||
# Verify the state owner is the workflow creator
|
||||
pause_entity = self.workflow_run_service._workflow_run_repo.get_workflow_pause(different_workflow_run.id)
|
||||
assert pause_entity is not None
|
||||
|
||||
def test_layer_ignores_non_pause_events(self, db_session_with_containers):
|
||||
"""Test that layer ignores non-pause events."""
|
||||
# Arrange
|
||||
layer = self._create_pause_state_persistence_layer()
|
||||
graph_runtime_state = self._create_graph_runtime_state()
|
||||
|
||||
command_channel = _TestCommandChannelImpl()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
# Import other event types
|
||||
from core.workflow.graph_events.graph import (
|
||||
GraphRunFailedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
|
||||
# Act - Send non-pause events
|
||||
layer.on_event(GraphRunStartedEvent())
|
||||
layer.on_event(GraphRunSucceededEvent(outputs={"result": "success"}))
|
||||
layer.on_event(GraphRunFailedEvent(error="test error", exceptions_count=1))
|
||||
|
||||
# Assert - No pause state should be created
|
||||
self.session.refresh(self.test_workflow_run)
|
||||
assert self.test_workflow_run.status == WorkflowExecutionStatus.RUNNING
|
||||
|
||||
pause_states = (
|
||||
self.session.query(WorkflowPauseModel)
|
||||
.filter(WorkflowPauseModel.workflow_run_id == self.test_workflow_run_id)
|
||||
.all()
|
||||
)
|
||||
assert len(pause_states) == 0
|
||||
|
||||
def test_layer_requires_initialization(self, db_session_with_containers):
|
||||
"""Test that layer requires proper initialization before handling events."""
|
||||
# Arrange
|
||||
layer = self._create_pause_state_persistence_layer()
|
||||
# Don't initialize - graph_runtime_state should not be set
|
||||
|
||||
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
|
||||
|
||||
# Act & Assert - Should raise AttributeError
|
||||
with pytest.raises(AttributeError):
|
||||
layer.on_event(event)
|
||||
|
|
@ -0,0 +1,948 @@
|
|||
"""Comprehensive integration tests for workflow pause functionality.
|
||||
|
||||
This test suite covers complete workflow pause functionality including:
|
||||
- Real database interactions using containerized PostgreSQL
|
||||
- Real storage operations using the test storage backend
|
||||
- Complete workflow: create -> pause -> resume -> delete
|
||||
- Testing with actual FileService (not mocked)
|
||||
- Database transactions and rollback behavior
|
||||
- Actual file upload and retrieval through storage
|
||||
- Workflow status transitions in the database
|
||||
- Error handling with real database constraints
|
||||
- Concurrent access scenarios
|
||||
- Multi-tenant isolation
|
||||
- Prune functionality
|
||||
- File storage integration
|
||||
|
||||
These tests use TestContainers to spin up real services for integration testing,
|
||||
providing more reliable and realistic test scenarios than mocks.
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.orm import Session, selectinload, sessionmaker
|
||||
|
||||
from core.workflow.entities import WorkflowExecution
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from extensions.ext_storage import storage
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account
|
||||
from models import WorkflowPause as WorkflowPauseModel
|
||||
from models.account import Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.model import UploadFile
|
||||
from models.workflow import Workflow, WorkflowRun
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import (
|
||||
DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
_WorkflowRunError,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PauseWorkflowSuccessCase:
|
||||
"""Test case for successful pause workflow operations."""
|
||||
|
||||
name: str
|
||||
initial_status: WorkflowExecutionStatus
|
||||
description: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class PauseWorkflowFailureCase:
|
||||
"""Test case for pause workflow failure scenarios."""
|
||||
|
||||
name: str
|
||||
initial_status: WorkflowExecutionStatus
|
||||
description: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResumeWorkflowSuccessCase:
|
||||
"""Test case for successful resume workflow operations."""
|
||||
|
||||
name: str
|
||||
initial_status: WorkflowExecutionStatus
|
||||
description: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResumeWorkflowFailureCase:
|
||||
"""Test case for resume workflow failure scenarios."""
|
||||
|
||||
name: str
|
||||
initial_status: WorkflowExecutionStatus
|
||||
pause_resumed: bool
|
||||
set_running_status: bool = False
|
||||
description: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class PrunePausesTestCase:
|
||||
"""Test case for prune pauses operations."""
|
||||
|
||||
name: str
|
||||
pause_age: timedelta
|
||||
resume_age: timedelta | None
|
||||
expected_pruned_count: int
|
||||
description: str = ""
|
||||
|
||||
|
||||
def pause_workflow_failure_cases() -> list[PauseWorkflowFailureCase]:
|
||||
"""Create test cases for pause workflow failure scenarios."""
|
||||
return [
|
||||
PauseWorkflowFailureCase(
|
||||
name="pause_already_paused_workflow",
|
||||
initial_status=WorkflowExecutionStatus.PAUSED,
|
||||
description="Should fail to pause an already paused workflow",
|
||||
),
|
||||
PauseWorkflowFailureCase(
|
||||
name="pause_completed_workflow",
|
||||
initial_status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
description="Should fail to pause a completed workflow",
|
||||
),
|
||||
PauseWorkflowFailureCase(
|
||||
name="pause_failed_workflow",
|
||||
initial_status=WorkflowExecutionStatus.FAILED,
|
||||
description="Should fail to pause a failed workflow",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def resume_workflow_success_cases() -> list[ResumeWorkflowSuccessCase]:
|
||||
"""Create test cases for successful resume workflow operations."""
|
||||
return [
|
||||
ResumeWorkflowSuccessCase(
|
||||
name="resume_paused_workflow",
|
||||
initial_status=WorkflowExecutionStatus.PAUSED,
|
||||
description="Should successfully resume a paused workflow",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def resume_workflow_failure_cases() -> list[ResumeWorkflowFailureCase]:
|
||||
"""Create test cases for resume workflow failure scenarios."""
|
||||
return [
|
||||
ResumeWorkflowFailureCase(
|
||||
name="resume_already_resumed_workflow",
|
||||
initial_status=WorkflowExecutionStatus.PAUSED,
|
||||
pause_resumed=True,
|
||||
description="Should fail to resume an already resumed workflow",
|
||||
),
|
||||
ResumeWorkflowFailureCase(
|
||||
name="resume_running_workflow",
|
||||
initial_status=WorkflowExecutionStatus.RUNNING,
|
||||
pause_resumed=False,
|
||||
set_running_status=True,
|
||||
description="Should fail to resume a running workflow",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def prune_pauses_test_cases() -> list[PrunePausesTestCase]:
|
||||
"""Create test cases for prune pauses operations."""
|
||||
return [
|
||||
PrunePausesTestCase(
|
||||
name="prune_old_active_pauses",
|
||||
pause_age=timedelta(days=7),
|
||||
resume_age=None,
|
||||
expected_pruned_count=1,
|
||||
description="Should prune old active pauses",
|
||||
),
|
||||
PrunePausesTestCase(
|
||||
name="prune_old_resumed_pauses",
|
||||
pause_age=timedelta(hours=12), # Created 12 hours ago (recent)
|
||||
resume_age=timedelta(days=7),
|
||||
expected_pruned_count=1,
|
||||
description="Should prune old resumed pauses",
|
||||
),
|
||||
PrunePausesTestCase(
|
||||
name="keep_recent_active_pauses",
|
||||
pause_age=timedelta(hours=1),
|
||||
resume_age=None,
|
||||
expected_pruned_count=0,
|
||||
description="Should keep recent active pauses",
|
||||
),
|
||||
PrunePausesTestCase(
|
||||
name="keep_recent_resumed_pauses",
|
||||
pause_age=timedelta(days=1),
|
||||
resume_age=timedelta(hours=1),
|
||||
expected_pruned_count=0,
|
||||
description="Should keep recent resumed pauses",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class TestWorkflowPauseIntegration:
|
||||
"""Comprehensive integration tests for workflow pause functionality."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_test_data(self, db_session_with_containers):
|
||||
"""Set up test data for each test method using TestContainers."""
|
||||
# Create test tenant and account
|
||||
|
||||
tenant = Tenant(
|
||||
name="Test Tenant",
|
||||
status="normal",
|
||||
)
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
account = Account(
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
tenant_join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db_session_with_containers.add(tenant_join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Set test data
|
||||
self.test_tenant_id = tenant.id
|
||||
self.test_user_id = account.id
|
||||
self.test_app_id = str(uuid.uuid4())
|
||||
self.test_workflow_id = str(uuid.uuid4())
|
||||
|
||||
# Create test workflow
|
||||
self.test_workflow = Workflow(
|
||||
id=self.test_workflow_id,
|
||||
tenant_id=self.test_tenant_id,
|
||||
app_id=self.test_app_id,
|
||||
type="workflow",
|
||||
version="draft",
|
||||
graph='{"nodes": [], "edges": []}',
|
||||
features='{"file_upload": {"enabled": false}}',
|
||||
created_by=self.test_user_id,
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Store session instance
|
||||
self.session = db_session_with_containers
|
||||
|
||||
# Save test data to database
|
||||
self.session.add(self.test_workflow)
|
||||
self.session.commit()
|
||||
|
||||
yield
|
||||
|
||||
# Cleanup
|
||||
self._cleanup_test_data()
|
||||
|
||||
def _cleanup_test_data(self):
|
||||
"""Clean up test data after each test method."""
|
||||
# Clean up workflow pauses
|
||||
self.session.execute(delete(WorkflowPauseModel))
|
||||
# Clean up upload files
|
||||
self.session.execute(
|
||||
delete(UploadFile).where(
|
||||
UploadFile.tenant_id == self.test_tenant_id,
|
||||
)
|
||||
)
|
||||
# Clean up workflow runs
|
||||
self.session.execute(
|
||||
delete(WorkflowRun).where(
|
||||
WorkflowRun.tenant_id == self.test_tenant_id,
|
||||
WorkflowRun.app_id == self.test_app_id,
|
||||
)
|
||||
)
|
||||
# Clean up workflows
|
||||
self.session.execute(
|
||||
delete(Workflow).where(
|
||||
Workflow.tenant_id == self.test_tenant_id,
|
||||
Workflow.app_id == self.test_app_id,
|
||||
)
|
||||
)
|
||||
self.session.commit()
|
||||
|
||||
def _create_test_workflow_run(
|
||||
self, status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING
|
||||
) -> WorkflowRun:
|
||||
"""Create a test workflow run with specified status."""
|
||||
workflow_run = WorkflowRun(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=self.test_tenant_id,
|
||||
app_id=self.test_app_id,
|
||||
workflow_id=self.test_workflow_id,
|
||||
type="workflow",
|
||||
triggered_from="debugging",
|
||||
version="draft",
|
||||
status=status,
|
||||
created_by=self.test_user_id,
|
||||
created_by_role="account",
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
self.session.add(workflow_run)
|
||||
self.session.commit()
|
||||
return workflow_run
|
||||
|
||||
def _create_test_state(self) -> str:
|
||||
"""Create a test state string."""
|
||||
return json.dumps(
|
||||
{
|
||||
"node_id": "test-node",
|
||||
"node_type": "llm",
|
||||
"status": "paused",
|
||||
"data": {"key": "value"},
|
||||
"timestamp": naive_utc_now().isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
def _get_workflow_run_repository(self):
|
||||
"""Get workflow run repository instance for testing."""
|
||||
# Create session factory from the test session
|
||||
engine = self.session.get_bind()
|
||||
session_factory = sessionmaker(bind=engine, expire_on_commit=False)
|
||||
|
||||
# Create a test-specific repository that implements the missing save method
|
||||
class TestWorkflowRunRepository(DifyAPISQLAlchemyWorkflowRunRepository):
|
||||
"""Test-specific repository that implements the missing save method."""
|
||||
|
||||
def save(self, execution: WorkflowExecution):
|
||||
"""Implement the missing save method for testing."""
|
||||
# For testing purposes, we don't need to implement this method
|
||||
# as it's not used in the pause functionality tests
|
||||
pass
|
||||
|
||||
# Create and return repository instance
|
||||
repository = TestWorkflowRunRepository(session_maker=session_factory)
|
||||
return repository
|
||||
|
||||
# ==================== Complete Pause Workflow Tests ====================
|
||||
|
||||
def test_complete_pause_resume_workflow(self):
|
||||
"""Test complete workflow: create -> pause -> resume -> delete."""
|
||||
# Arrange
|
||||
workflow_run = self._create_test_workflow_run()
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
# Act - Create pause state
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
# Assert - Pause state created
|
||||
assert pause_entity is not None
|
||||
assert pause_entity.id is not None
|
||||
assert pause_entity.workflow_execution_id == workflow_run.id
|
||||
# Convert both to strings for comparison
|
||||
retrieved_state = pause_entity.get_state()
|
||||
if isinstance(retrieved_state, bytes):
|
||||
retrieved_state = retrieved_state.decode()
|
||||
assert retrieved_state == test_state
|
||||
|
||||
# Verify database state
|
||||
query = select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id)
|
||||
pause_model = self.session.scalars(query).first()
|
||||
assert pause_model is not None
|
||||
assert pause_model.resumed_at is None
|
||||
assert pause_model.id == pause_entity.id
|
||||
|
||||
self.session.refresh(workflow_run)
|
||||
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
|
||||
|
||||
# Act - Get pause state
|
||||
retrieved_entity = repository.get_workflow_pause(workflow_run.id)
|
||||
|
||||
# Assert - Pause state retrieved
|
||||
assert retrieved_entity is not None
|
||||
assert retrieved_entity.id == pause_entity.id
|
||||
retrieved_state = retrieved_entity.get_state()
|
||||
if isinstance(retrieved_state, bytes):
|
||||
retrieved_state = retrieved_state.decode()
|
||||
assert retrieved_state == test_state
|
||||
|
||||
# Act - Resume workflow
|
||||
resumed_entity = repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
|
||||
# Assert - Workflow resumed
|
||||
assert resumed_entity is not None
|
||||
assert resumed_entity.id == pause_entity.id
|
||||
assert resumed_entity.resumed_at is not None
|
||||
|
||||
# Verify database state
|
||||
self.session.refresh(workflow_run)
|
||||
assert workflow_run.status == WorkflowExecutionStatus.RUNNING
|
||||
self.session.refresh(pause_model)
|
||||
assert pause_model.resumed_at is not None
|
||||
|
||||
# Act - Delete pause state
|
||||
repository.delete_workflow_pause(pause_entity)
|
||||
|
||||
# Assert - Pause state deleted
|
||||
with Session(bind=self.session.get_bind()) as session:
|
||||
deleted_pause = session.get(WorkflowPauseModel, pause_entity.id)
|
||||
assert deleted_pause is None
|
||||
|
||||
def test_pause_workflow_success(self):
|
||||
"""Test successful pause workflow scenarios."""
|
||||
workflow_run = self._create_test_workflow_run(status=WorkflowExecutionStatus.RUNNING)
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
assert pause_entity is not None
|
||||
assert pause_entity.workflow_execution_id == workflow_run.id
|
||||
|
||||
retrieved_state = pause_entity.get_state()
|
||||
if isinstance(retrieved_state, bytes):
|
||||
retrieved_state = retrieved_state.decode()
|
||||
assert retrieved_state == test_state
|
||||
|
||||
self.session.refresh(workflow_run)
|
||||
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
|
||||
pause_query = select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id)
|
||||
pause_model = self.session.scalars(pause_query).first()
|
||||
assert pause_model is not None
|
||||
assert pause_model.id == pause_entity.id
|
||||
assert pause_model.resumed_at is None
|
||||
|
||||
@pytest.mark.parametrize("test_case", pause_workflow_failure_cases(), ids=lambda tc: tc.name)
|
||||
def test_pause_workflow_failure(self, test_case: PauseWorkflowFailureCase):
|
||||
"""Test pause workflow failure scenarios."""
|
||||
workflow_run = self._create_test_workflow_run(status=test_case.initial_status)
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
with pytest.raises(_WorkflowRunError):
|
||||
repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("test_case", resume_workflow_success_cases(), ids=lambda tc: tc.name)
|
||||
def test_resume_workflow_success(self, test_case: ResumeWorkflowSuccessCase):
|
||||
"""Test successful resume workflow scenarios."""
|
||||
workflow_run = self._create_test_workflow_run(status=test_case.initial_status)
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
if workflow_run.status != WorkflowExecutionStatus.RUNNING:
|
||||
workflow_run.status = WorkflowExecutionStatus.RUNNING
|
||||
self.session.commit()
|
||||
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
self.session.refresh(workflow_run)
|
||||
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
|
||||
|
||||
resumed_entity = repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
assert resumed_entity is not None
|
||||
assert resumed_entity.id == pause_entity.id
|
||||
assert resumed_entity.resumed_at is not None
|
||||
|
||||
self.session.refresh(workflow_run)
|
||||
assert workflow_run.status == WorkflowExecutionStatus.RUNNING
|
||||
pause_query = select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id)
|
||||
pause_model = self.session.scalars(pause_query).first()
|
||||
assert pause_model is not None
|
||||
assert pause_model.id == pause_entity.id
|
||||
assert pause_model.resumed_at is not None
|
||||
|
||||
def test_resume_running_workflow(self):
|
||||
"""Test resume workflow failure scenarios."""
|
||||
workflow_run = self._create_test_workflow_run(status=WorkflowExecutionStatus.RUNNING)
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
self.session.refresh(workflow_run)
|
||||
workflow_run.status = WorkflowExecutionStatus.RUNNING
|
||||
self.session.add(workflow_run)
|
||||
self.session.commit()
|
||||
|
||||
with pytest.raises(_WorkflowRunError):
|
||||
repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
|
||||
def test_resume_resumed_pause(self):
|
||||
"""Test resume workflow failure scenarios."""
|
||||
workflow_run = self._create_test_workflow_run(status=WorkflowExecutionStatus.RUNNING)
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||
pause_model.resumed_at = naive_utc_now()
|
||||
self.session.add(pause_model)
|
||||
self.session.commit()
|
||||
|
||||
with pytest.raises(_WorkflowRunError):
|
||||
repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
|
||||
# ==================== Error Scenario Tests ====================
|
||||
|
||||
def test_pause_nonexistent_workflow_run(self):
|
||||
"""Test pausing a non-existent workflow run."""
|
||||
# Arrange
|
||||
nonexistent_id = str(uuid.uuid4())
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="WorkflowRun not found"):
|
||||
repository.create_workflow_pause(
|
||||
workflow_run_id=nonexistent_id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
def test_resume_nonexistent_workflow_run(self):
|
||||
"""Test resuming a non-existent workflow run."""
|
||||
# Arrange
|
||||
workflow_run = self._create_test_workflow_run()
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
nonexistent_id = str(uuid.uuid4())
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="WorkflowRun not found"):
|
||||
repository.resume_workflow_pause(
|
||||
workflow_run_id=nonexistent_id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
|
||||
# ==================== Prune Functionality Tests ====================
|
||||
|
||||
@pytest.mark.parametrize("test_case", prune_pauses_test_cases(), ids=lambda tc: tc.name)
|
||||
def test_prune_pauses_scenarios(self, test_case: PrunePausesTestCase):
|
||||
"""Test various prune pauses scenarios."""
|
||||
now = naive_utc_now()
|
||||
|
||||
# Create pause state
|
||||
workflow_run = self._create_test_workflow_run()
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
# Manually adjust timestamps for testing
|
||||
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||
pause_model.created_at = now - test_case.pause_age
|
||||
|
||||
if test_case.resume_age is not None:
|
||||
# Resume pause and adjust resume time
|
||||
repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
# Need to refresh to get the updated model
|
||||
self.session.refresh(pause_model)
|
||||
# Manually set the resumed_at to an older time for testing
|
||||
pause_model.resumed_at = now - test_case.resume_age
|
||||
self.session.commit() # Commit the resumed_at change
|
||||
# Refresh again to ensure the change is persisted
|
||||
self.session.refresh(pause_model)
|
||||
|
||||
self.session.commit()
|
||||
|
||||
# Act - Prune pauses
|
||||
expiration_time = now - timedelta(days=1, seconds=1) # Expire pauses older than 1 day (plus 1 second)
|
||||
resumption_time = now - timedelta(
|
||||
days=7, seconds=1
|
||||
) # Clean up pauses resumed more than 7 days ago (plus 1 second)
|
||||
|
||||
# Debug: Check pause state before pruning
|
||||
self.session.refresh(pause_model)
|
||||
print(f"Pause created_at: {pause_model.created_at}")
|
||||
print(f"Pause resumed_at: {pause_model.resumed_at}")
|
||||
print(f"Expiration time: {expiration_time}")
|
||||
print(f"Resumption time: {resumption_time}")
|
||||
|
||||
# Force commit to ensure timestamps are saved
|
||||
self.session.commit()
|
||||
|
||||
# Determine if the pause should be pruned based on timestamps
|
||||
should_be_pruned = False
|
||||
if test_case.resume_age is not None:
|
||||
# If resumed, check if resumed_at is older than resumption_time
|
||||
should_be_pruned = pause_model.resumed_at < resumption_time
|
||||
else:
|
||||
# If not resumed, check if created_at is older than expiration_time
|
||||
should_be_pruned = pause_model.created_at < expiration_time
|
||||
|
||||
# Act - Prune pauses
|
||||
pruned_ids = repository.prune_pauses(
|
||||
expiration=expiration_time,
|
||||
resumption_expiration=resumption_time,
|
||||
)
|
||||
|
||||
# Assert - Check pruning results
|
||||
if should_be_pruned:
|
||||
assert len(pruned_ids) == test_case.expected_pruned_count
|
||||
# Verify pause was actually deleted
|
||||
# The pause should be in the pruned_ids list if it was pruned
|
||||
assert pause_entity.id in pruned_ids
|
||||
else:
|
||||
assert len(pruned_ids) == 0
|
||||
|
||||
def test_prune_pauses_with_limit(self):
|
||||
"""Test prune pauses with limit parameter."""
|
||||
now = naive_utc_now()
|
||||
|
||||
# Create multiple pause states
|
||||
pause_entities = []
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
for i in range(5):
|
||||
workflow_run = self._create_test_workflow_run()
|
||||
test_state = self._create_test_state()
|
||||
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
pause_entities.append(pause_entity)
|
||||
|
||||
# Make all pauses old enough to be pruned
|
||||
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||
pause_model.created_at = now - timedelta(days=7)
|
||||
|
||||
self.session.commit()
|
||||
|
||||
# Act - Prune with limit
|
||||
expiration_time = now - timedelta(days=1)
|
||||
resumption_time = now - timedelta(days=7)
|
||||
|
||||
pruned_ids = repository.prune_pauses(
|
||||
expiration=expiration_time,
|
||||
resumption_expiration=resumption_time,
|
||||
limit=3,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(pruned_ids) == 3
|
||||
|
||||
# Verify only 3 were deleted
|
||||
remaining_count = (
|
||||
self.session.query(WorkflowPauseModel)
|
||||
.filter(WorkflowPauseModel.id.in_([pe.id for pe in pause_entities]))
|
||||
.count()
|
||||
)
|
||||
assert remaining_count == 2
|
||||
|
||||
# ==================== Multi-tenant Isolation Tests ====================
|
||||
|
||||
def test_multi_tenant_pause_isolation(self):
|
||||
"""Test that pause states are properly isolated by tenant."""
|
||||
# Arrange - Create second tenant
|
||||
|
||||
tenant2 = Tenant(
|
||||
name="Test Tenant 2",
|
||||
status="normal",
|
||||
)
|
||||
self.session.add(tenant2)
|
||||
self.session.commit()
|
||||
|
||||
account2 = Account(
|
||||
email="test2@example.com",
|
||||
name="Test User 2",
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
self.session.add(account2)
|
||||
self.session.commit()
|
||||
|
||||
tenant2_join = TenantAccountJoin(
|
||||
tenant_id=tenant2.id,
|
||||
account_id=account2.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
self.session.add(tenant2_join)
|
||||
self.session.commit()
|
||||
|
||||
# Create workflow for tenant 2
|
||||
workflow2 = Workflow(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant2.id,
|
||||
app_id=str(uuid.uuid4()),
|
||||
type="workflow",
|
||||
version="draft",
|
||||
graph='{"nodes": [], "edges": []}',
|
||||
features='{"file_upload": {"enabled": false}}',
|
||||
created_by=account2.id,
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
self.session.add(workflow2)
|
||||
self.session.commit()
|
||||
|
||||
# Create workflow runs for both tenants
|
||||
workflow_run1 = self._create_test_workflow_run()
|
||||
workflow_run2 = WorkflowRun(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant2.id,
|
||||
app_id=workflow2.app_id,
|
||||
workflow_id=workflow2.id,
|
||||
type="workflow",
|
||||
triggered_from="debugging",
|
||||
version="draft",
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
created_by=account2.id,
|
||||
created_by_role="account",
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
self.session.add(workflow_run2)
|
||||
self.session.commit()
|
||||
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
# Act - Create pause for tenant 1
|
||||
pause_entity1 = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run1.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
# Try to access pause from tenant 2 using tenant 1's repository
|
||||
# This should work because we're using the same repository
|
||||
pause_entity2 = repository.get_workflow_pause(workflow_run2.id)
|
||||
assert pause_entity2 is None # No pause for tenant 2 yet
|
||||
|
||||
# Create pause for tenant 2
|
||||
pause_entity2 = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run2.id,
|
||||
state_owner_user_id=account2.id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
# Assert - Both pauses should exist and be separate
|
||||
assert pause_entity1 is not None
|
||||
assert pause_entity2 is not None
|
||||
assert pause_entity1.id != pause_entity2.id
|
||||
assert pause_entity1.workflow_execution_id != pause_entity2.workflow_execution_id
|
||||
|
||||
def test_cross_tenant_access_restriction(self):
|
||||
"""Test that cross-tenant access is properly restricted."""
|
||||
# This test would require tenant-specific repositories
|
||||
# For now, we test that pause entities are properly scoped by tenant_id
|
||||
workflow_run = self._create_test_workflow_run()
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
# Verify pause is properly scoped
|
||||
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||
assert pause_model.workflow_id == self.test_workflow_id
|
||||
|
||||
# ==================== File Storage Integration Tests ====================
|
||||
|
||||
def test_file_storage_integration(self):
|
||||
"""Test that state files are properly stored and retrieved."""
|
||||
# Arrange
|
||||
workflow_run = self._create_test_workflow_run()
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
# Act - Create pause state
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
# Assert - Verify file was uploaded to storage
|
||||
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||
assert pause_model.state_object_key != ""
|
||||
|
||||
# Verify file content in storage
|
||||
|
||||
file_key = pause_model.state_object_key
|
||||
storage_content = storage.load(file_key).decode()
|
||||
assert storage_content == test_state
|
||||
|
||||
# Verify retrieval through entity
|
||||
retrieved_state = pause_entity.get_state()
|
||||
if isinstance(retrieved_state, bytes):
|
||||
retrieved_state = retrieved_state.decode()
|
||||
assert retrieved_state == test_state
|
||||
|
||||
def test_file_cleanup_on_pause_deletion(self):
|
||||
"""Test that files are properly handled on pause deletion."""
|
||||
# Arrange
|
||||
workflow_run = self._create_test_workflow_run()
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
# Get file info before deletion
|
||||
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||
file_key = pause_model.state_object_key
|
||||
|
||||
# Act - Delete pause state
|
||||
repository.delete_workflow_pause(pause_entity)
|
||||
|
||||
# Assert - Pause record should be deleted
|
||||
self.session.expire_all() # Clear session to ensure fresh query
|
||||
deleted_pause = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||
assert deleted_pause is None
|
||||
|
||||
try:
|
||||
content = storage.load(file_key).decode()
|
||||
pytest.fail("File should be deleted from storage after pause deletion")
|
||||
except FileNotFoundError:
|
||||
# This is expected - file should be deleted from storage
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Unexpected error when checking file deletion: {e}")
|
||||
|
||||
def test_large_state_file_handling(self):
|
||||
"""Test handling of large state files."""
|
||||
# Arrange - Create a large state (1MB)
|
||||
large_state = "x" * (1024 * 1024) # 1MB of data
|
||||
large_state_json = json.dumps({"large_data": large_state})
|
||||
|
||||
workflow_run = self._create_test_workflow_run()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
# Act
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=large_state_json,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert pause_entity is not None
|
||||
retrieved_state = pause_entity.get_state()
|
||||
if isinstance(retrieved_state, bytes):
|
||||
retrieved_state = retrieved_state.decode()
|
||||
assert retrieved_state == large_state_json
|
||||
|
||||
# Verify file size in database
|
||||
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||
assert pause_model.state_object_key != ""
|
||||
loaded_state = storage.load(pause_model.state_object_key)
|
||||
assert loaded_state.decode() == large_state_json
|
||||
|
||||
def test_multiple_pause_resume_cycles(self):
|
||||
"""Test multiple pause/resume cycles on the same workflow run."""
|
||||
# Arrange
|
||||
workflow_run = self._create_test_workflow_run()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
# Act & Assert - Multiple cycles
|
||||
for i in range(3):
|
||||
state = json.dumps({"cycle": i, "data": f"state_{i}"})
|
||||
|
||||
# Reset workflow run status to RUNNING before each pause (after first cycle)
|
||||
if i > 0:
|
||||
self.session.refresh(workflow_run) # Refresh to get latest state from session
|
||||
workflow_run.status = WorkflowExecutionStatus.RUNNING
|
||||
self.session.commit()
|
||||
self.session.refresh(workflow_run) # Refresh again after commit
|
||||
|
||||
# Pause
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=state,
|
||||
)
|
||||
assert pause_entity is not None
|
||||
|
||||
# Verify pause
|
||||
self.session.expire_all() # Clear session to ensure fresh query
|
||||
self.session.refresh(workflow_run)
|
||||
|
||||
# Use the test session directly to verify the pause
|
||||
stmt = select(WorkflowRun).options(selectinload(WorkflowRun.pause)).where(WorkflowRun.id == workflow_run.id)
|
||||
workflow_run_with_pause = self.session.scalar(stmt)
|
||||
pause_model = workflow_run_with_pause.pause
|
||||
|
||||
# Verify pause using test session directly
|
||||
assert pause_model is not None
|
||||
assert pause_model.id == pause_entity.id
|
||||
assert pause_model.state_object_key != ""
|
||||
|
||||
# Load file content using storage directly
|
||||
file_content = storage.load(pause_model.state_object_key)
|
||||
if isinstance(file_content, bytes):
|
||||
file_content = file_content.decode()
|
||||
assert file_content == state
|
||||
|
||||
# Resume
|
||||
resumed_entity = repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
assert resumed_entity is not None
|
||||
assert resumed_entity.resumed_at is not None
|
||||
|
||||
# Verify resume - check that pause is marked as resumed
|
||||
self.session.expire_all() # Clear session to ensure fresh query
|
||||
stmt = select(WorkflowPauseModel).where(WorkflowPauseModel.id == pause_entity.id)
|
||||
resumed_pause_model = self.session.scalar(stmt)
|
||||
assert resumed_pause_model is not None
|
||||
assert resumed_pause_model.resumed_at is not None
|
||||
|
||||
# Verify workflow run status
|
||||
self.session.refresh(workflow_run)
|
||||
assert workflow_run.status == WorkflowExecutionStatus.RUNNING
|
||||
|
|
@ -0,0 +1,278 @@
|
|||
import json
|
||||
from time import time
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.layers.pause_state_persist_layer import PauseStatePersistenceLayer
|
||||
from core.variables.segments import Segment
|
||||
from core.workflow.entities.pause_reason import SchedulingPause
|
||||
from core.workflow.graph_engine.entities.commands import GraphEngineCommand
|
||||
from core.workflow.graph_events.graph import (
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyVariablePool
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
|
||||
class TestDataFactory:
|
||||
"""Factory helpers for constructing graph events used in tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_graph_run_paused_event(outputs: dict[str, object] | None = None) -> GraphRunPausedEvent:
|
||||
return GraphRunPausedEvent(reason=SchedulingPause(message="test pause"), outputs=outputs or {})
|
||||
|
||||
@staticmethod
|
||||
def create_graph_run_started_event() -> GraphRunStartedEvent:
|
||||
return GraphRunStartedEvent()
|
||||
|
||||
@staticmethod
|
||||
def create_graph_run_succeeded_event(outputs: dict[str, object] | None = None) -> GraphRunSucceededEvent:
|
||||
return GraphRunSucceededEvent(outputs=outputs or {})
|
||||
|
||||
@staticmethod
|
||||
def create_graph_run_failed_event(
|
||||
error: str = "Test error",
|
||||
exceptions_count: int = 1,
|
||||
) -> GraphRunFailedEvent:
|
||||
return GraphRunFailedEvent(error=error, exceptions_count=exceptions_count)
|
||||
|
||||
|
||||
class MockSystemVariableReadOnlyView:
|
||||
"""Minimal read-only system variable view for testing."""
|
||||
|
||||
def __init__(self, workflow_execution_id: str | None = None) -> None:
|
||||
self._workflow_execution_id = workflow_execution_id
|
||||
|
||||
@property
|
||||
def workflow_execution_id(self) -> str | None:
|
||||
return self._workflow_execution_id
|
||||
|
||||
|
||||
class MockReadOnlyVariablePool:
|
||||
"""Mock implementation of ReadOnlyVariablePool for testing."""
|
||||
|
||||
def __init__(self, variables: dict[tuple[str, str], object] | None = None):
|
||||
self._variables = variables or {}
|
||||
|
||||
def get(self, node_id: str, variable_key: str) -> Segment | None:
|
||||
value = self._variables.get((node_id, variable_key))
|
||||
if value is None:
|
||||
return None
|
||||
mock_segment = Mock(spec=Segment)
|
||||
mock_segment.value = value
|
||||
return mock_segment
|
||||
|
||||
def get_all_by_node(self, node_id: str) -> dict[str, object]:
|
||||
return {key: value for (nid, key), value in self._variables.items() if nid == node_id}
|
||||
|
||||
def get_by_prefix(self, prefix: str) -> dict[str, object]:
|
||||
return {f"{nid}.{key}": value for (nid, key), value in self._variables.items() if nid.startswith(prefix)}
|
||||
|
||||
|
||||
class MockReadOnlyGraphRuntimeState:
|
||||
"""Mock implementation of ReadOnlyGraphRuntimeState for testing."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
start_at: float | None = None,
|
||||
total_tokens: int = 0,
|
||||
node_run_steps: int = 0,
|
||||
ready_queue_size: int = 0,
|
||||
exceptions_count: int = 0,
|
||||
outputs: dict[str, object] | None = None,
|
||||
variables: dict[tuple[str, str], object] | None = None,
|
||||
workflow_execution_id: str | None = None,
|
||||
):
|
||||
self._start_at = start_at or time()
|
||||
self._total_tokens = total_tokens
|
||||
self._node_run_steps = node_run_steps
|
||||
self._ready_queue_size = ready_queue_size
|
||||
self._exceptions_count = exceptions_count
|
||||
self._outputs = outputs or {}
|
||||
self._variable_pool = MockReadOnlyVariablePool(variables)
|
||||
self._system_variable = MockSystemVariableReadOnlyView(workflow_execution_id)
|
||||
|
||||
@property
|
||||
def system_variable(self) -> MockSystemVariableReadOnlyView:
|
||||
return self._system_variable
|
||||
|
||||
@property
|
||||
def variable_pool(self) -> ReadOnlyVariablePool:
|
||||
return self._variable_pool
|
||||
|
||||
@property
|
||||
def start_at(self) -> float:
|
||||
return self._start_at
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
return self._total_tokens
|
||||
|
||||
@property
|
||||
def node_run_steps(self) -> int:
|
||||
return self._node_run_steps
|
||||
|
||||
@property
|
||||
def ready_queue_size(self) -> int:
|
||||
return self._ready_queue_size
|
||||
|
||||
@property
|
||||
def exceptions_count(self) -> int:
|
||||
return self._exceptions_count
|
||||
|
||||
@property
|
||||
def outputs(self) -> dict[str, object]:
|
||||
return self._outputs.copy()
|
||||
|
||||
@property
|
||||
def llm_usage(self):
|
||||
mock_usage = Mock()
|
||||
mock_usage.prompt_tokens = 10
|
||||
mock_usage.completion_tokens = 20
|
||||
mock_usage.total_tokens = 30
|
||||
return mock_usage
|
||||
|
||||
def get_output(self, key: str, default: object = None) -> object:
|
||||
return self._outputs.get(key, default)
|
||||
|
||||
def dumps(self) -> str:
|
||||
return json.dumps(
|
||||
{
|
||||
"start_at": self._start_at,
|
||||
"total_tokens": self._total_tokens,
|
||||
"node_run_steps": self._node_run_steps,
|
||||
"ready_queue_size": self._ready_queue_size,
|
||||
"exceptions_count": self._exceptions_count,
|
||||
"outputs": self._outputs,
|
||||
"variables": {f"{k[0]}.{k[1]}": v for k, v in self._variable_pool._variables.items()},
|
||||
"workflow_execution_id": self._system_variable.workflow_execution_id,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class MockCommandChannel:
|
||||
"""Mock implementation of CommandChannel for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self._commands: list[GraphEngineCommand] = []
|
||||
|
||||
def fetch_commands(self) -> list[GraphEngineCommand]:
|
||||
return self._commands.copy()
|
||||
|
||||
def send_command(self, command: GraphEngineCommand) -> None:
|
||||
self._commands.append(command)
|
||||
|
||||
|
||||
class TestPauseStatePersistenceLayer:
|
||||
"""Unit tests for PauseStatePersistenceLayer."""
|
||||
|
||||
def test_init_with_dependency_injection(self):
|
||||
session_factory = Mock(name="session_factory")
|
||||
state_owner_user_id = "user-123"
|
||||
|
||||
layer = PauseStatePersistenceLayer(
|
||||
session_factory=session_factory,
|
||||
state_owner_user_id=state_owner_user_id,
|
||||
)
|
||||
|
||||
assert layer._session_maker is session_factory
|
||||
assert layer._state_owner_user_id == state_owner_user_id
|
||||
assert not hasattr(layer, "graph_runtime_state")
|
||||
assert not hasattr(layer, "command_channel")
|
||||
|
||||
def test_initialize_sets_dependencies(self):
|
||||
session_factory = Mock(name="session_factory")
|
||||
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner")
|
||||
|
||||
graph_runtime_state = MockReadOnlyGraphRuntimeState()
|
||||
command_channel = MockCommandChannel()
|
||||
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
assert layer.graph_runtime_state is graph_runtime_state
|
||||
assert layer.command_channel is command_channel
|
||||
|
||||
def test_on_event_with_graph_run_paused_event(self, monkeypatch: pytest.MonkeyPatch):
|
||||
session_factory = Mock(name="session_factory")
|
||||
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
|
||||
|
||||
mock_repo = Mock()
|
||||
mock_factory = Mock(return_value=mock_repo)
|
||||
monkeypatch.setattr(DifyAPIRepositoryFactory, "create_api_workflow_run_repository", mock_factory)
|
||||
|
||||
graph_runtime_state = MockReadOnlyGraphRuntimeState(
|
||||
outputs={"result": "test_output"},
|
||||
total_tokens=100,
|
||||
workflow_execution_id="run-123",
|
||||
)
|
||||
command_channel = MockCommandChannel()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
event = TestDataFactory.create_graph_run_paused_event(outputs={"intermediate": "result"})
|
||||
expected_state = graph_runtime_state.dumps()
|
||||
|
||||
layer.on_event(event)
|
||||
|
||||
mock_factory.assert_called_once_with(session_factory)
|
||||
mock_repo.create_workflow_pause.assert_called_once_with(
|
||||
workflow_run_id="run-123",
|
||||
state_owner_user_id="owner-123",
|
||||
state=expected_state,
|
||||
)
|
||||
|
||||
def test_on_event_ignores_non_paused_events(self, monkeypatch: pytest.MonkeyPatch):
|
||||
session_factory = Mock(name="session_factory")
|
||||
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
|
||||
|
||||
mock_repo = Mock()
|
||||
mock_factory = Mock(return_value=mock_repo)
|
||||
monkeypatch.setattr(DifyAPIRepositoryFactory, "create_api_workflow_run_repository", mock_factory)
|
||||
|
||||
graph_runtime_state = MockReadOnlyGraphRuntimeState()
|
||||
command_channel = MockCommandChannel()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
events = [
|
||||
TestDataFactory.create_graph_run_started_event(),
|
||||
TestDataFactory.create_graph_run_succeeded_event(),
|
||||
TestDataFactory.create_graph_run_failed_event(),
|
||||
]
|
||||
|
||||
for event in events:
|
||||
layer.on_event(event)
|
||||
|
||||
mock_factory.assert_not_called()
|
||||
mock_repo.create_workflow_pause.assert_not_called()
|
||||
|
||||
def test_on_event_raises_attribute_error_when_graph_runtime_state_is_none(self):
|
||||
session_factory = Mock(name="session_factory")
|
||||
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
|
||||
|
||||
event = TestDataFactory.create_graph_run_paused_event()
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
layer.on_event(event)
|
||||
|
||||
def test_on_event_asserts_when_workflow_execution_id_missing(self, monkeypatch: pytest.MonkeyPatch):
|
||||
session_factory = Mock(name="session_factory")
|
||||
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
|
||||
|
||||
mock_repo = Mock()
|
||||
mock_factory = Mock(return_value=mock_repo)
|
||||
monkeypatch.setattr(DifyAPIRepositoryFactory, "create_api_workflow_run_repository", mock_factory)
|
||||
|
||||
graph_runtime_state = MockReadOnlyGraphRuntimeState(workflow_execution_id=None)
|
||||
command_channel = MockCommandChannel()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
event = TestDataFactory.create_graph_run_paused_event()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
layer.on_event(event)
|
||||
|
||||
mock_factory.assert_not_called()
|
||||
mock_repo.create_workflow_pause.assert_not_called()
|
||||
|
|
@ -0,0 +1,171 @@
|
|||
"""Tests for _PrivateWorkflowPauseEntity implementation."""
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from models.workflow import WorkflowPause as WorkflowPauseModel
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import _PrivateWorkflowPauseEntity
|
||||
|
||||
|
||||
class TestPrivateWorkflowPauseEntity:
|
||||
"""Test _PrivateWorkflowPauseEntity implementation."""
|
||||
|
||||
def test_entity_initialization(self):
|
||||
"""Test entity initialization with required parameters."""
|
||||
# Create mock models
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.id = "pause-123"
|
||||
mock_pause_model.workflow_run_id = "execution-456"
|
||||
mock_pause_model.resumed_at = None
|
||||
|
||||
# Create entity
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
assert entity._pause_model is mock_pause_model
|
||||
assert entity._cached_state is None
|
||||
|
||||
def test_from_models_classmethod(self):
|
||||
"""Test from_models class method."""
|
||||
# Create mock models
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.id = "pause-123"
|
||||
mock_pause_model.workflow_run_id = "execution-456"
|
||||
|
||||
# Create entity using from_models
|
||||
entity = _PrivateWorkflowPauseEntity.from_models(
|
||||
workflow_pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
# Verify entity creation
|
||||
assert isinstance(entity, _PrivateWorkflowPauseEntity)
|
||||
assert entity._pause_model is mock_pause_model
|
||||
|
||||
def test_id_property(self):
|
||||
"""Test id property returns pause model ID."""
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.id = "pause-123"
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
assert entity.id == "pause-123"
|
||||
|
||||
def test_workflow_execution_id_property(self):
|
||||
"""Test workflow_execution_id property returns workflow run ID."""
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.workflow_run_id = "execution-456"
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
assert entity.workflow_execution_id == "execution-456"
|
||||
|
||||
def test_resumed_at_property(self):
|
||||
"""Test resumed_at property returns pause model resumed_at."""
|
||||
resumed_at = datetime(2023, 12, 25, 15, 30, 45)
|
||||
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.resumed_at = resumed_at
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
assert entity.resumed_at == resumed_at
|
||||
|
||||
def test_resumed_at_property_none(self):
|
||||
"""Test resumed_at property returns None when not set."""
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.resumed_at = None
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
assert entity.resumed_at is None
|
||||
|
||||
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage")
|
||||
def test_get_state_first_call(self, mock_storage):
|
||||
"""Test get_state loads from storage on first call."""
|
||||
state_data = b'{"test": "data", "step": 5}'
|
||||
mock_storage.load.return_value = state_data
|
||||
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.state_object_key = "test-state-key"
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
# First call should load from storage
|
||||
result = entity.get_state()
|
||||
|
||||
assert result == state_data
|
||||
mock_storage.load.assert_called_once_with("test-state-key")
|
||||
assert entity._cached_state == state_data
|
||||
|
||||
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage")
|
||||
def test_get_state_cached_call(self, mock_storage):
|
||||
"""Test get_state returns cached data on subsequent calls."""
|
||||
state_data = b'{"test": "data", "step": 5}'
|
||||
mock_storage.load.return_value = state_data
|
||||
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.state_object_key = "test-state-key"
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
# First call
|
||||
result1 = entity.get_state()
|
||||
# Second call should use cache
|
||||
result2 = entity.get_state()
|
||||
|
||||
assert result1 == state_data
|
||||
assert result2 == state_data
|
||||
# Storage should only be called once
|
||||
mock_storage.load.assert_called_once_with("test-state-key")
|
||||
|
||||
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage")
|
||||
def test_get_state_with_pre_cached_data(self, mock_storage):
|
||||
"""Test get_state returns pre-cached data."""
|
||||
state_data = b'{"test": "data", "step": 5}'
|
||||
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
# Pre-cache data
|
||||
entity._cached_state = state_data
|
||||
|
||||
# Should return cached data without calling storage
|
||||
result = entity.get_state()
|
||||
|
||||
assert result == state_data
|
||||
mock_storage.load.assert_not_called()
|
||||
|
||||
def test_entity_with_binary_state_data(self):
|
||||
"""Test entity with binary state data."""
|
||||
# Test with binary data that's not valid JSON
|
||||
binary_data = b"\x00\x01\x02\x03\x04\x05\xff\xfe"
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||
mock_storage.load.return_value = binary_data
|
||||
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
result = entity.get_state()
|
||||
|
||||
assert result == binary_data
|
||||
|
|
@ -3,6 +3,7 @@
|
|||
import time
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.workflow.entities.pause_reason import SchedulingPause
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
|
|
@ -149,8 +150,8 @@ def test_pause_command():
|
|||
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
|
||||
pause_events = [e for e in events if isinstance(e, GraphRunPausedEvent)]
|
||||
assert len(pause_events) == 1
|
||||
assert pause_events[0].reason == "User requested pause"
|
||||
assert pause_events[0].reason == SchedulingPause(message="User requested pause")
|
||||
|
||||
graph_execution = engine.graph_runtime_state.graph_execution
|
||||
assert graph_execution.is_paused
|
||||
assert graph_execution.pause_reason == "User requested pause"
|
||||
assert graph_execution.pause_reason == SchedulingPause(message="User requested pause")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,32 @@
|
|||
"""Tests for workflow pause related enums and constants."""
|
||||
|
||||
from core.workflow.enums import (
|
||||
WorkflowExecutionStatus,
|
||||
)
|
||||
|
||||
|
||||
class TestWorkflowExecutionStatus:
|
||||
"""Test WorkflowExecutionStatus enum."""
|
||||
|
||||
def test_is_ended_method(self):
|
||||
"""Test is_ended method for different statuses."""
|
||||
# Test ended statuses
|
||||
ended_statuses = [
|
||||
WorkflowExecutionStatus.SUCCEEDED,
|
||||
WorkflowExecutionStatus.FAILED,
|
||||
WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
|
||||
WorkflowExecutionStatus.STOPPED,
|
||||
]
|
||||
|
||||
for status in ended_statuses:
|
||||
assert status.is_ended(), f"{status} should be considered ended"
|
||||
|
||||
# Test non-ended statuses
|
||||
non_ended_statuses = [
|
||||
WorkflowExecutionStatus.SCHEDULED,
|
||||
WorkflowExecutionStatus.RUNNING,
|
||||
WorkflowExecutionStatus.PAUSED,
|
||||
]
|
||||
|
||||
for status in non_ended_statuses:
|
||||
assert not status.is_ended(), f"{status} should not be considered ended"
|
||||
|
|
@ -0,0 +1,202 @@
|
|||
from typing import cast
|
||||
|
||||
import pytest
|
||||
|
||||
from core.file.models import File, FileTransferMethod, FileType
|
||||
from core.workflow.system_variable import SystemVariable, SystemVariableReadOnlyView
|
||||
|
||||
|
||||
class TestSystemVariableReadOnlyView:
|
||||
"""Test cases for SystemVariableReadOnlyView class."""
|
||||
|
||||
def test_read_only_property_access(self):
|
||||
"""Test that all properties return correct values from wrapped instance."""
|
||||
# Create test data
|
||||
test_file = File(
|
||||
id="file-123",
|
||||
tenant_id="tenant-123",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="related-123",
|
||||
)
|
||||
|
||||
datasource_info = {"key": "value", "nested": {"data": 42}}
|
||||
|
||||
# Create SystemVariable with all fields
|
||||
system_var = SystemVariable(
|
||||
user_id="user-123",
|
||||
app_id="app-123",
|
||||
workflow_id="workflow-123",
|
||||
files=[test_file],
|
||||
workflow_execution_id="exec-123",
|
||||
query="test query",
|
||||
conversation_id="conv-123",
|
||||
dialogue_count=5,
|
||||
document_id="doc-123",
|
||||
original_document_id="orig-doc-123",
|
||||
dataset_id="dataset-123",
|
||||
batch="batch-123",
|
||||
datasource_type="type-123",
|
||||
datasource_info=datasource_info,
|
||||
invoke_from="invoke-123",
|
||||
)
|
||||
|
||||
# Create read-only view
|
||||
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||
|
||||
# Test all properties
|
||||
assert read_only_view.user_id == "user-123"
|
||||
assert read_only_view.app_id == "app-123"
|
||||
assert read_only_view.workflow_id == "workflow-123"
|
||||
assert read_only_view.workflow_execution_id == "exec-123"
|
||||
assert read_only_view.query == "test query"
|
||||
assert read_only_view.conversation_id == "conv-123"
|
||||
assert read_only_view.dialogue_count == 5
|
||||
assert read_only_view.document_id == "doc-123"
|
||||
assert read_only_view.original_document_id == "orig-doc-123"
|
||||
assert read_only_view.dataset_id == "dataset-123"
|
||||
assert read_only_view.batch == "batch-123"
|
||||
assert read_only_view.datasource_type == "type-123"
|
||||
assert read_only_view.invoke_from == "invoke-123"
|
||||
|
||||
def test_defensive_copying_of_mutable_objects(self):
|
||||
"""Test that mutable objects are defensively copied."""
|
||||
# Create test data
|
||||
test_file = File(
|
||||
id="file-123",
|
||||
tenant_id="tenant-123",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="related-123",
|
||||
)
|
||||
|
||||
datasource_info = {"key": "original_value"}
|
||||
|
||||
# Create SystemVariable
|
||||
system_var = SystemVariable(
|
||||
files=[test_file], datasource_info=datasource_info, workflow_execution_id="exec-123"
|
||||
)
|
||||
|
||||
# Create read-only view
|
||||
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||
|
||||
# Test files defensive copying
|
||||
files_copy = read_only_view.files
|
||||
assert isinstance(files_copy, tuple) # Should be immutable tuple
|
||||
assert len(files_copy) == 1
|
||||
assert files_copy[0].id == "file-123"
|
||||
|
||||
# Verify it's a copy (can't modify original through view)
|
||||
assert isinstance(files_copy, tuple)
|
||||
# tuples don't have append method, so they're immutable
|
||||
|
||||
# Test datasource_info defensive copying
|
||||
datasource_copy = read_only_view.datasource_info
|
||||
assert datasource_copy is not None
|
||||
assert datasource_copy["key"] == "original_value"
|
||||
|
||||
datasource_copy = cast(dict, datasource_copy)
|
||||
with pytest.raises(TypeError):
|
||||
datasource_copy["key"] = "modified value"
|
||||
|
||||
# Verify original is unchanged
|
||||
assert system_var.datasource_info is not None
|
||||
assert system_var.datasource_info["key"] == "original_value"
|
||||
assert read_only_view.datasource_info is not None
|
||||
assert read_only_view.datasource_info["key"] == "original_value"
|
||||
|
||||
def test_always_accesses_latest_data(self):
|
||||
"""Test that properties always return the latest data from wrapped instance."""
|
||||
# Create SystemVariable
|
||||
system_var = SystemVariable(user_id="original-user", workflow_execution_id="exec-123")
|
||||
|
||||
# Create read-only view
|
||||
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||
|
||||
# Verify initial value
|
||||
assert read_only_view.user_id == "original-user"
|
||||
|
||||
# Modify the wrapped instance
|
||||
system_var.user_id = "modified-user"
|
||||
|
||||
# Verify view returns the new value
|
||||
assert read_only_view.user_id == "modified-user"
|
||||
|
||||
def test_repr_method(self):
|
||||
"""Test the __repr__ method."""
|
||||
# Create SystemVariable
|
||||
system_var = SystemVariable(workflow_execution_id="exec-123")
|
||||
|
||||
# Create read-only view
|
||||
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||
|
||||
# Test repr
|
||||
repr_str = repr(read_only_view)
|
||||
assert "SystemVariableReadOnlyView" in repr_str
|
||||
assert "system_variable=" in repr_str
|
||||
|
||||
def test_none_value_handling(self):
|
||||
"""Test that None values are properly handled."""
|
||||
# Create SystemVariable with all None values except workflow_execution_id
|
||||
system_var = SystemVariable(
|
||||
user_id=None,
|
||||
app_id=None,
|
||||
workflow_id=None,
|
||||
workflow_execution_id="exec-123",
|
||||
query=None,
|
||||
conversation_id=None,
|
||||
dialogue_count=None,
|
||||
document_id=None,
|
||||
original_document_id=None,
|
||||
dataset_id=None,
|
||||
batch=None,
|
||||
datasource_type=None,
|
||||
datasource_info=None,
|
||||
invoke_from=None,
|
||||
)
|
||||
|
||||
# Create read-only view
|
||||
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||
|
||||
# Test all None values
|
||||
assert read_only_view.user_id is None
|
||||
assert read_only_view.app_id is None
|
||||
assert read_only_view.workflow_id is None
|
||||
assert read_only_view.query is None
|
||||
assert read_only_view.conversation_id is None
|
||||
assert read_only_view.dialogue_count is None
|
||||
assert read_only_view.document_id is None
|
||||
assert read_only_view.original_document_id is None
|
||||
assert read_only_view.dataset_id is None
|
||||
assert read_only_view.batch is None
|
||||
assert read_only_view.datasource_type is None
|
||||
assert read_only_view.datasource_info is None
|
||||
assert read_only_view.invoke_from is None
|
||||
|
||||
# files should be empty tuple even when default list is empty
|
||||
assert read_only_view.files == ()
|
||||
|
||||
def test_empty_files_handling(self):
|
||||
"""Test that empty files list is handled correctly."""
|
||||
# Create SystemVariable with empty files
|
||||
system_var = SystemVariable(files=[], workflow_execution_id="exec-123")
|
||||
|
||||
# Create read-only view
|
||||
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||
|
||||
# Test files handling
|
||||
assert read_only_view.files == ()
|
||||
assert isinstance(read_only_view.files, tuple)
|
||||
|
||||
def test_empty_datasource_info_handling(self):
|
||||
"""Test that empty datasource_info is handled correctly."""
|
||||
# Create SystemVariable with empty datasource_info
|
||||
system_var = SystemVariable(datasource_info={}, workflow_execution_id="exec-123")
|
||||
|
||||
# Create read-only view
|
||||
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||
|
||||
# Test datasource_info handling
|
||||
assert read_only_view.datasource_info == {}
|
||||
# Should be a copy, not the same object
|
||||
assert read_only_view.datasource_info is not system_var.datasource_info
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
from models.base import DefaultFieldsMixin
|
||||
|
||||
|
||||
class FooModel(DefaultFieldsMixin):
|
||||
def __init__(self, id: str):
|
||||
self.id = id
|
||||
|
||||
|
||||
def test_repr():
|
||||
foo_model = FooModel(id="test-id")
|
||||
assert repr(foo_model) == "<FooModel(id=test-id)>"
|
||||
|
|
@ -0,0 +1,370 @@
|
|||
"""Unit tests for DifyAPISQLAlchemyWorkflowRunRepository implementation."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.workflow.entities.workflow_pause import WorkflowPauseEntity
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from models.workflow import WorkflowPause as WorkflowPauseModel
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import (
|
||||
DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
_PrivateWorkflowPauseEntity,
|
||||
_WorkflowRunError,
|
||||
)
|
||||
|
||||
|
||||
class TestDifyAPISQLAlchemyWorkflowRunRepository:
|
||||
"""Test DifyAPISQLAlchemyWorkflowRunRepository implementation."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
"""Create a mock session."""
|
||||
return Mock(spec=Session)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_maker(self, mock_session):
|
||||
"""Create a mock sessionmaker."""
|
||||
session_maker = Mock(spec=sessionmaker)
|
||||
|
||||
# Create a context manager mock
|
||||
context_manager = Mock()
|
||||
context_manager.__enter__ = Mock(return_value=mock_session)
|
||||
context_manager.__exit__ = Mock(return_value=None)
|
||||
session_maker.return_value = context_manager
|
||||
|
||||
# Mock session.begin() context manager
|
||||
begin_context_manager = Mock()
|
||||
begin_context_manager.__enter__ = Mock(return_value=None)
|
||||
begin_context_manager.__exit__ = Mock(return_value=None)
|
||||
mock_session.begin = Mock(return_value=begin_context_manager)
|
||||
|
||||
# Add missing session methods
|
||||
mock_session.commit = Mock()
|
||||
mock_session.rollback = Mock()
|
||||
mock_session.add = Mock()
|
||||
mock_session.delete = Mock()
|
||||
mock_session.get = Mock()
|
||||
mock_session.scalar = Mock()
|
||||
mock_session.scalars = Mock()
|
||||
|
||||
# Also support expire_on_commit parameter
|
||||
def make_session(expire_on_commit=None):
|
||||
cm = Mock()
|
||||
cm.__enter__ = Mock(return_value=mock_session)
|
||||
cm.__exit__ = Mock(return_value=None)
|
||||
return cm
|
||||
|
||||
session_maker.side_effect = make_session
|
||||
return session_maker
|
||||
|
||||
@pytest.fixture
|
||||
def repository(self, mock_session_maker):
|
||||
"""Create repository instance with mocked dependencies."""
|
||||
|
||||
# Create a testable subclass that implements the save method
|
||||
class TestableDifyAPISQLAlchemyWorkflowRunRepository(DifyAPISQLAlchemyWorkflowRunRepository):
|
||||
def __init__(self, session_maker):
|
||||
# Initialize without calling parent __init__ to avoid any instantiation issues
|
||||
self._session_maker = session_maker
|
||||
|
||||
def save(self, execution):
|
||||
"""Mock implementation of save method."""
|
||||
return None
|
||||
|
||||
# Create repository instance
|
||||
repo = TestableDifyAPISQLAlchemyWorkflowRunRepository(mock_session_maker)
|
||||
|
||||
return repo
|
||||
|
||||
@pytest.fixture
|
||||
def sample_workflow_run(self):
|
||||
"""Create a sample WorkflowRun model."""
|
||||
workflow_run = Mock(spec=WorkflowRun)
|
||||
workflow_run.id = "workflow-run-123"
|
||||
workflow_run.tenant_id = "tenant-123"
|
||||
workflow_run.app_id = "app-123"
|
||||
workflow_run.workflow_id = "workflow-123"
|
||||
workflow_run.status = WorkflowExecutionStatus.RUNNING
|
||||
return workflow_run
|
||||
|
||||
@pytest.fixture
|
||||
def sample_workflow_pause(self):
|
||||
"""Create a sample WorkflowPauseModel."""
|
||||
pause = Mock(spec=WorkflowPauseModel)
|
||||
pause.id = "pause-123"
|
||||
pause.workflow_id = "workflow-123"
|
||||
pause.workflow_run_id = "workflow-run-123"
|
||||
pause.state_object_key = "workflow-state-123.json"
|
||||
pause.resumed_at = None
|
||||
pause.created_at = datetime.now(UTC)
|
||||
return pause
|
||||
|
||||
|
||||
class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
"""Test create_workflow_pause method."""
|
||||
|
||||
def test_create_workflow_pause_success(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
mock_session: Mock,
|
||||
sample_workflow_run: Mock,
|
||||
):
|
||||
"""Test successful workflow pause creation."""
|
||||
# Arrange
|
||||
workflow_run_id = "workflow-run-123"
|
||||
state_owner_user_id = "user-123"
|
||||
state = '{"test": "state"}'
|
||||
|
||||
mock_session.get.return_value = sample_workflow_run
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.uuidv7") as mock_uuidv7:
|
||||
mock_uuidv7.side_effect = ["pause-123"]
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||
# Act
|
||||
result = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run_id,
|
||||
state_owner_user_id=state_owner_user_id,
|
||||
state=state,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, _PrivateWorkflowPauseEntity)
|
||||
assert result.id == "pause-123"
|
||||
assert result.workflow_execution_id == workflow_run_id
|
||||
|
||||
# Verify database interactions
|
||||
mock_session.get.assert_called_once_with(WorkflowRun, workflow_run_id)
|
||||
mock_storage.save.assert_called_once()
|
||||
mock_session.add.assert_called()
|
||||
# When using session.begin() context manager, commit is handled automatically
|
||||
# No explicit commit call is expected
|
||||
|
||||
def test_create_workflow_pause_not_found(
|
||||
self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock
|
||||
):
|
||||
"""Test workflow pause creation when workflow run not found."""
|
||||
# Arrange
|
||||
mock_session.get.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="WorkflowRun not found: workflow-run-123"):
|
||||
repository.create_workflow_pause(
|
||||
workflow_run_id="workflow-run-123",
|
||||
state_owner_user_id="user-123",
|
||||
state='{"test": "state"}',
|
||||
)
|
||||
|
||||
mock_session.get.assert_called_once_with(WorkflowRun, "workflow-run-123")
|
||||
|
||||
def test_create_workflow_pause_invalid_status(
|
||||
self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock, sample_workflow_run: Mock
|
||||
):
|
||||
"""Test workflow pause creation when workflow not in RUNNING status."""
|
||||
# Arrange
|
||||
sample_workflow_run.status = WorkflowExecutionStatus.PAUSED
|
||||
mock_session.get.return_value = sample_workflow_run
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(_WorkflowRunError, match="Only WorkflowRun with RUNNING status can be paused"):
|
||||
repository.create_workflow_pause(
|
||||
workflow_run_id="workflow-run-123",
|
||||
state_owner_user_id="user-123",
|
||||
state='{"test": "state"}',
|
||||
)
|
||||
|
||||
|
||||
class TestResumeWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
"""Test resume_workflow_pause method."""
|
||||
|
||||
def test_resume_workflow_pause_success(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
mock_session: Mock,
|
||||
sample_workflow_run: Mock,
|
||||
sample_workflow_pause: Mock,
|
||||
):
|
||||
"""Test successful workflow pause resume."""
|
||||
# Arrange
|
||||
workflow_run_id = "workflow-run-123"
|
||||
pause_entity = Mock(spec=WorkflowPauseEntity)
|
||||
pause_entity.id = "pause-123"
|
||||
|
||||
# Setup workflow run and pause
|
||||
sample_workflow_run.status = WorkflowExecutionStatus.PAUSED
|
||||
sample_workflow_run.pause = sample_workflow_pause
|
||||
sample_workflow_pause.resumed_at = None
|
||||
|
||||
mock_session.scalar.return_value = sample_workflow_run
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.naive_utc_now") as mock_now:
|
||||
mock_now.return_value = datetime.now(UTC)
|
||||
|
||||
# Act
|
||||
result = repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run_id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, _PrivateWorkflowPauseEntity)
|
||||
assert result.id == "pause-123"
|
||||
|
||||
# Verify state transitions
|
||||
assert sample_workflow_pause.resumed_at is not None
|
||||
assert sample_workflow_run.status == WorkflowExecutionStatus.RUNNING
|
||||
|
||||
# Verify database interactions
|
||||
mock_session.add.assert_called()
|
||||
# When using session.begin() context manager, commit is handled automatically
|
||||
# No explicit commit call is expected
|
||||
|
||||
def test_resume_workflow_pause_not_paused(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
mock_session: Mock,
|
||||
sample_workflow_run: Mock,
|
||||
):
|
||||
"""Test resume when workflow is not paused."""
|
||||
# Arrange
|
||||
workflow_run_id = "workflow-run-123"
|
||||
pause_entity = Mock(spec=WorkflowPauseEntity)
|
||||
pause_entity.id = "pause-123"
|
||||
|
||||
sample_workflow_run.status = WorkflowExecutionStatus.RUNNING
|
||||
mock_session.scalar.return_value = sample_workflow_run
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(_WorkflowRunError, match="WorkflowRun is not in PAUSED status"):
|
||||
repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run_id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
|
||||
def test_resume_workflow_pause_id_mismatch(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
mock_session: Mock,
|
||||
sample_workflow_run: Mock,
|
||||
sample_workflow_pause: Mock,
|
||||
):
|
||||
"""Test resume when pause ID doesn't match."""
|
||||
# Arrange
|
||||
workflow_run_id = "workflow-run-123"
|
||||
pause_entity = Mock(spec=WorkflowPauseEntity)
|
||||
pause_entity.id = "pause-456" # Different ID
|
||||
|
||||
sample_workflow_run.status = WorkflowExecutionStatus.PAUSED
|
||||
sample_workflow_pause.id = "pause-123"
|
||||
sample_workflow_run.pause = sample_workflow_pause
|
||||
mock_session.scalar.return_value = sample_workflow_run
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(_WorkflowRunError, match="different id in WorkflowPause and WorkflowPauseEntity"):
|
||||
repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run_id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
|
||||
|
||||
class TestDeleteWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
"""Test delete_workflow_pause method."""
|
||||
|
||||
def test_delete_workflow_pause_success(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
mock_session: Mock,
|
||||
sample_workflow_pause: Mock,
|
||||
):
|
||||
"""Test successful workflow pause deletion."""
|
||||
# Arrange
|
||||
pause_entity = Mock(spec=WorkflowPauseEntity)
|
||||
pause_entity.id = "pause-123"
|
||||
|
||||
mock_session.get.return_value = sample_workflow_pause
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||
# Act
|
||||
repository.delete_workflow_pause(pause_entity=pause_entity)
|
||||
|
||||
# Assert
|
||||
mock_storage.delete.assert_called_once_with(sample_workflow_pause.state_object_key)
|
||||
mock_session.delete.assert_called_once_with(sample_workflow_pause)
|
||||
# When using session.begin() context manager, commit is handled automatically
|
||||
# No explicit commit call is expected
|
||||
|
||||
def test_delete_workflow_pause_not_found(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
mock_session: Mock,
|
||||
):
|
||||
"""Test delete when pause not found."""
|
||||
# Arrange
|
||||
pause_entity = Mock(spec=WorkflowPauseEntity)
|
||||
pause_entity.id = "pause-123"
|
||||
|
||||
mock_session.get.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(_WorkflowRunError, match="WorkflowPause not found: pause-123"):
|
||||
repository.delete_workflow_pause(pause_entity=pause_entity)
|
||||
|
||||
|
||||
class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
"""Test _PrivateWorkflowPauseEntity class."""
|
||||
|
||||
def test_from_models(self, sample_workflow_pause: Mock):
|
||||
"""Test creating _PrivateWorkflowPauseEntity from models."""
|
||||
# Act
|
||||
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
|
||||
|
||||
# Assert
|
||||
assert isinstance(entity, _PrivateWorkflowPauseEntity)
|
||||
assert entity._pause_model == sample_workflow_pause
|
||||
|
||||
def test_properties(self, sample_workflow_pause: Mock):
|
||||
"""Test entity properties."""
|
||||
# Arrange
|
||||
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
|
||||
|
||||
# Act & Assert
|
||||
assert entity.id == sample_workflow_pause.id
|
||||
assert entity.workflow_execution_id == sample_workflow_pause.workflow_run_id
|
||||
assert entity.resumed_at == sample_workflow_pause.resumed_at
|
||||
|
||||
def test_get_state(self, sample_workflow_pause: Mock):
|
||||
"""Test getting state from storage."""
|
||||
# Arrange
|
||||
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
|
||||
expected_state = b'{"test": "state"}'
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||
mock_storage.load.return_value = expected_state
|
||||
|
||||
# Act
|
||||
result = entity.get_state()
|
||||
|
||||
# Assert
|
||||
assert result == expected_state
|
||||
mock_storage.load.assert_called_once_with(sample_workflow_pause.state_object_key)
|
||||
|
||||
def test_get_state_caching(self, sample_workflow_pause: Mock):
|
||||
"""Test state caching in get_state method."""
|
||||
# Arrange
|
||||
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
|
||||
expected_state = b'{"test": "state"}'
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||
mock_storage.load.return_value = expected_state
|
||||
|
||||
# Act
|
||||
result1 = entity.get_state()
|
||||
result2 = entity.get_state() # Should use cache
|
||||
|
||||
# Assert
|
||||
assert result1 == expected_state
|
||||
assert result2 == expected_state
|
||||
mock_storage.load.assert_called_once() # Only called once due to caching
|
||||
|
|
@ -0,0 +1,200 @@
|
|||
"""Comprehensive unit tests for WorkflowRunService class.
|
||||
|
||||
This test suite covers all pause state management operations including:
|
||||
- Retrieving pause state for workflow runs
|
||||
- Saving pause state with file uploads
|
||||
- Marking paused workflows as resumed
|
||||
- Error handling and edge cases
|
||||
- Database transaction management
|
||||
- Repository-based approach testing
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import _PrivateWorkflowPauseEntity
|
||||
from services.workflow_run_service import (
|
||||
WorkflowRunService,
|
||||
)
|
||||
|
||||
|
||||
class TestDataFactory:
|
||||
"""Factory class for creating test data objects."""
|
||||
|
||||
@staticmethod
|
||||
def create_workflow_run_mock(
|
||||
id: str = "workflow-run-123",
|
||||
tenant_id: str = "tenant-456",
|
||||
app_id: str = "app-789",
|
||||
workflow_id: str = "workflow-101",
|
||||
status: str | WorkflowExecutionStatus = "paused",
|
||||
pause_id: str | None = None,
|
||||
**kwargs,
|
||||
) -> MagicMock:
|
||||
"""Create a mock WorkflowRun object."""
|
||||
mock_run = MagicMock()
|
||||
mock_run.id = id
|
||||
mock_run.tenant_id = tenant_id
|
||||
mock_run.app_id = app_id
|
||||
mock_run.workflow_id = workflow_id
|
||||
mock_run.status = status
|
||||
mock_run.pause_id = pause_id
|
||||
|
||||
for key, value in kwargs.items():
|
||||
setattr(mock_run, key, value)
|
||||
|
||||
return mock_run
|
||||
|
||||
@staticmethod
|
||||
def create_workflow_pause_mock(
|
||||
id: str = "pause-123",
|
||||
tenant_id: str = "tenant-456",
|
||||
app_id: str = "app-789",
|
||||
workflow_id: str = "workflow-101",
|
||||
workflow_execution_id: str = "workflow-execution-123",
|
||||
state_file_id: str = "file-456",
|
||||
resumed_at: datetime | None = None,
|
||||
**kwargs,
|
||||
) -> MagicMock:
|
||||
"""Create a mock WorkflowPauseModel object."""
|
||||
mock_pause = MagicMock()
|
||||
mock_pause.id = id
|
||||
mock_pause.tenant_id = tenant_id
|
||||
mock_pause.app_id = app_id
|
||||
mock_pause.workflow_id = workflow_id
|
||||
mock_pause.workflow_execution_id = workflow_execution_id
|
||||
mock_pause.state_file_id = state_file_id
|
||||
mock_pause.resumed_at = resumed_at
|
||||
|
||||
for key, value in kwargs.items():
|
||||
setattr(mock_pause, key, value)
|
||||
|
||||
return mock_pause
|
||||
|
||||
@staticmethod
|
||||
def create_upload_file_mock(
|
||||
id: str = "file-456",
|
||||
key: str = "upload_files/test/state.json",
|
||||
name: str = "state.json",
|
||||
tenant_id: str = "tenant-456",
|
||||
**kwargs,
|
||||
) -> MagicMock:
|
||||
"""Create a mock UploadFile object."""
|
||||
mock_file = MagicMock()
|
||||
mock_file.id = id
|
||||
mock_file.key = key
|
||||
mock_file.name = name
|
||||
mock_file.tenant_id = tenant_id
|
||||
|
||||
for key, value in kwargs.items():
|
||||
setattr(mock_file, key, value)
|
||||
|
||||
return mock_file
|
||||
|
||||
@staticmethod
|
||||
def create_pause_entity_mock(
|
||||
pause_model: MagicMock | None = None,
|
||||
upload_file: MagicMock | None = None,
|
||||
) -> _PrivateWorkflowPauseEntity:
|
||||
"""Create a mock _PrivateWorkflowPauseEntity object."""
|
||||
if pause_model is None:
|
||||
pause_model = TestDataFactory.create_workflow_pause_mock()
|
||||
if upload_file is None:
|
||||
upload_file = TestDataFactory.create_upload_file_mock()
|
||||
|
||||
return _PrivateWorkflowPauseEntity.from_models(pause_model, upload_file)
|
||||
|
||||
|
||||
class TestWorkflowRunService:
|
||||
"""Comprehensive unit tests for WorkflowRunService class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_factory(self):
|
||||
"""Create a mock session factory with proper session management."""
|
||||
mock_session = create_autospec(Session)
|
||||
|
||||
# Create a mock context manager for the session
|
||||
mock_session_cm = MagicMock()
|
||||
mock_session_cm.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_session_cm.__exit__ = MagicMock(return_value=None)
|
||||
|
||||
# Create a mock context manager for the transaction
|
||||
mock_transaction_cm = MagicMock()
|
||||
mock_transaction_cm.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_transaction_cm.__exit__ = MagicMock(return_value=None)
|
||||
|
||||
mock_session.begin = MagicMock(return_value=mock_transaction_cm)
|
||||
|
||||
# Create mock factory that returns the context manager
|
||||
mock_factory = MagicMock(spec=sessionmaker)
|
||||
mock_factory.return_value = mock_session_cm
|
||||
|
||||
return mock_factory, mock_session
|
||||
|
||||
@pytest.fixture
|
||||
def mock_workflow_run_repository(self):
|
||||
"""Create a mock APIWorkflowRunRepository."""
|
||||
mock_repo = create_autospec(APIWorkflowRunRepository)
|
||||
return mock_repo
|
||||
|
||||
@pytest.fixture
|
||||
def workflow_run_service(self, mock_session_factory, mock_workflow_run_repository):
|
||||
"""Create WorkflowRunService instance with mocked dependencies."""
|
||||
session_factory, _ = mock_session_factory
|
||||
|
||||
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
|
||||
service = WorkflowRunService(session_factory)
|
||||
return service
|
||||
|
||||
@pytest.fixture
|
||||
def workflow_run_service_with_engine(self, mock_session_factory, mock_workflow_run_repository):
|
||||
"""Create WorkflowRunService instance with Engine input."""
|
||||
mock_engine = create_autospec(Engine)
|
||||
session_factory, _ = mock_session_factory
|
||||
|
||||
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
|
||||
service = WorkflowRunService(mock_engine)
|
||||
return service
|
||||
|
||||
# ==================== Initialization Tests ====================
|
||||
|
||||
def test_init_with_session_factory(self, mock_session_factory, mock_workflow_run_repository):
|
||||
"""Test WorkflowRunService initialization with session_factory."""
|
||||
session_factory, _ = mock_session_factory
|
||||
|
||||
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
|
||||
service = WorkflowRunService(session_factory)
|
||||
|
||||
assert service._session_factory == session_factory
|
||||
mock_factory.create_api_workflow_run_repository.assert_called_once_with(session_factory)
|
||||
|
||||
def test_init_with_engine(self, mock_session_factory, mock_workflow_run_repository):
|
||||
"""Test WorkflowRunService initialization with Engine (should convert to sessionmaker)."""
|
||||
mock_engine = create_autospec(Engine)
|
||||
session_factory, _ = mock_session_factory
|
||||
|
||||
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
|
||||
with patch("services.workflow_run_service.sessionmaker", return_value=session_factory) as mock_sessionmaker:
|
||||
service = WorkflowRunService(mock_engine)
|
||||
|
||||
mock_sessionmaker.assert_called_once_with(bind=mock_engine, expire_on_commit=False)
|
||||
assert service._session_factory == session_factory
|
||||
mock_factory.create_api_workflow_run_repository.assert_called_once_with(session_factory)
|
||||
|
||||
def test_init_with_default_dependencies(self, mock_session_factory):
|
||||
"""Test WorkflowRunService initialization with default dependencies."""
|
||||
session_factory, _ = mock_session_factory
|
||||
|
||||
service = WorkflowRunService(session_factory)
|
||||
|
||||
assert service._session_factory == session_factory
|
||||
Loading…
Reference in New Issue