mirror of https://github.com/langgenius/dify.git
feat(api): Introduce workflow pause state management (#27298)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
fd7c4e8a6d
commit
a1c0bd7a1c
|
|
@ -1,6 +1,6 @@
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping, Sequence
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
@ -25,6 +25,7 @@ from core.moderation.input_moderation import InputModeration
|
||||||
from core.variables.variables import VariableUnion
|
from core.variables.variables import VariableUnion
|
||||||
from core.workflow.enums import WorkflowType
|
from core.workflow.enums import WorkflowType
|
||||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||||
|
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||||
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||||
|
|
@ -61,11 +62,13 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||||
app: App,
|
app: App,
|
||||||
workflow_execution_repository: WorkflowExecutionRepository,
|
workflow_execution_repository: WorkflowExecutionRepository,
|
||||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||||
|
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
queue_manager=queue_manager,
|
queue_manager=queue_manager,
|
||||||
variable_loader=variable_loader,
|
variable_loader=variable_loader,
|
||||||
app_id=application_generate_entity.app_config.app_id,
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
|
graph_engine_layers=graph_engine_layers,
|
||||||
)
|
)
|
||||||
self.application_generate_entity = application_generate_entity
|
self.application_generate_entity = application_generate_entity
|
||||||
self.conversation = conversation
|
self.conversation = conversation
|
||||||
|
|
@ -195,6 +198,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||||
)
|
)
|
||||||
|
|
||||||
workflow_entry.graph_engine.layer(persistence_layer)
|
workflow_entry.graph_engine.layer(persistence_layer)
|
||||||
|
for layer in self._graph_engine_layers:
|
||||||
|
workflow_entry.graph_engine.layer(layer)
|
||||||
|
|
||||||
generator = workflow_entry.run()
|
generator = workflow_entry.run()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -135,6 +135,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||||
)
|
)
|
||||||
|
|
||||||
workflow_entry.graph_engine.layer(persistence_layer)
|
workflow_entry.graph_engine.layer(persistence_layer)
|
||||||
|
for layer in self._graph_engine_layers:
|
||||||
|
workflow_entry.graph_engine.layer(layer)
|
||||||
|
|
||||||
generator = workflow_entry.run()
|
generator = workflow_entry.run()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
import time
|
import time
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping, Sequence
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||||
|
|
@ -27,6 +27,7 @@ from core.app.entities.queue_entities import (
|
||||||
)
|
)
|
||||||
from core.workflow.entities import GraphInitParams
|
from core.workflow.entities import GraphInitParams
|
||||||
from core.workflow.graph import Graph
|
from core.workflow.graph import Graph
|
||||||
|
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||||
from core.workflow.graph_events import (
|
from core.workflow.graph_events import (
|
||||||
GraphEngineEvent,
|
GraphEngineEvent,
|
||||||
GraphRunFailedEvent,
|
GraphRunFailedEvent,
|
||||||
|
|
@ -69,10 +70,12 @@ class WorkflowBasedAppRunner:
|
||||||
queue_manager: AppQueueManager,
|
queue_manager: AppQueueManager,
|
||||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||||
app_id: str,
|
app_id: str,
|
||||||
|
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||||
):
|
):
|
||||||
self._queue_manager = queue_manager
|
self._queue_manager = queue_manager
|
||||||
self._variable_loader = variable_loader
|
self._variable_loader = variable_loader
|
||||||
self._app_id = app_id
|
self._app_id = app_id
|
||||||
|
self._graph_engine_layers = graph_engine_layers
|
||||||
|
|
||||||
def _init_graph(
|
def _init_graph(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,71 @@
|
||||||
|
from sqlalchemy import Engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||||
|
from core.workflow.graph_events.base import GraphEngineEvent
|
||||||
|
from core.workflow.graph_events.graph import GraphRunPausedEvent
|
||||||
|
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||||
|
from repositories.factory import DifyAPIRepositoryFactory
|
||||||
|
|
||||||
|
|
||||||
|
class PauseStatePersistenceLayer(GraphEngineLayer):
|
||||||
|
def __init__(self, session_factory: Engine | sessionmaker, state_owner_user_id: str):
|
||||||
|
"""Create a PauseStatePersistenceLayer.
|
||||||
|
|
||||||
|
The `state_owner_user_id` is used when creating state file for pause.
|
||||||
|
It generally should id of the creator of workflow.
|
||||||
|
"""
|
||||||
|
if isinstance(session_factory, Engine):
|
||||||
|
session_factory = sessionmaker(session_factory)
|
||||||
|
self._session_maker = session_factory
|
||||||
|
self._state_owner_user_id = state_owner_user_id
|
||||||
|
|
||||||
|
def _get_repo(self) -> APIWorkflowRunRepository:
|
||||||
|
return DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_maker)
|
||||||
|
|
||||||
|
def on_graph_start(self) -> None:
|
||||||
|
"""
|
||||||
|
Called when graph execution starts.
|
||||||
|
|
||||||
|
This is called after the engine has been initialized but before any nodes
|
||||||
|
are executed. Layers can use this to set up resources or log start information.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_event(self, event: GraphEngineEvent) -> None:
|
||||||
|
"""
|
||||||
|
Called for every event emitted by the engine.
|
||||||
|
|
||||||
|
This method receives all events generated during graph execution, including:
|
||||||
|
- Graph lifecycle events (start, success, failure)
|
||||||
|
- Node execution events (start, success, failure, retry)
|
||||||
|
- Stream events for response nodes
|
||||||
|
- Container events (iteration, loop)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event: The event emitted by the engine
|
||||||
|
"""
|
||||||
|
if not isinstance(event, GraphRunPausedEvent):
|
||||||
|
return
|
||||||
|
|
||||||
|
assert self.graph_runtime_state is not None
|
||||||
|
workflow_run_id: str | None = self.graph_runtime_state.system_variable.workflow_execution_id
|
||||||
|
assert workflow_run_id is not None
|
||||||
|
repo = self._get_repo()
|
||||||
|
repo.create_workflow_pause(
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
state_owner_user_id=self._state_owner_user_id,
|
||||||
|
state=self.graph_runtime_state.dumps(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_graph_end(self, error: Exception | None) -> None:
|
||||||
|
"""
|
||||||
|
Called when graph execution ends.
|
||||||
|
|
||||||
|
This is called after all nodes have been executed or when execution is
|
||||||
|
aborted. Layers can use this to clean up resources or log final state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error: The exception that caused execution to fail, or None if successful
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
@ -4,6 +4,7 @@ from .agent import AgentNodeStrategyInit
|
||||||
from .graph_init_params import GraphInitParams
|
from .graph_init_params import GraphInitParams
|
||||||
from .workflow_execution import WorkflowExecution
|
from .workflow_execution import WorkflowExecution
|
||||||
from .workflow_node_execution import WorkflowNodeExecution
|
from .workflow_node_execution import WorkflowNodeExecution
|
||||||
|
from .workflow_pause import WorkflowPauseEntity
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AgentNodeStrategyInit",
|
"AgentNodeStrategyInit",
|
||||||
|
|
@ -12,4 +13,5 @@ __all__ = [
|
||||||
"VariablePool",
|
"VariablePool",
|
||||||
"WorkflowExecution",
|
"WorkflowExecution",
|
||||||
"WorkflowNodeExecution",
|
"WorkflowNodeExecution",
|
||||||
|
"WorkflowPauseEntity",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,49 @@
|
||||||
|
from enum import StrEnum, auto
|
||||||
|
from typing import Annotated, Any, ClassVar, TypeAlias
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Discriminator, Tag
|
||||||
|
|
||||||
|
|
||||||
|
class _PauseReasonType(StrEnum):
|
||||||
|
HUMAN_INPUT_REQUIRED = auto()
|
||||||
|
SCHEDULED_PAUSE = auto()
|
||||||
|
|
||||||
|
|
||||||
|
class _PauseReasonBase(BaseModel):
|
||||||
|
TYPE: ClassVar[_PauseReasonType]
|
||||||
|
|
||||||
|
|
||||||
|
class HumanInputRequired(_PauseReasonBase):
|
||||||
|
TYPE = _PauseReasonType.HUMAN_INPUT_REQUIRED
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulingPause(_PauseReasonBase):
|
||||||
|
TYPE = _PauseReasonType.SCHEDULED_PAUSE
|
||||||
|
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
def _get_pause_reason_discriminator(v: Any) -> _PauseReasonType | None:
|
||||||
|
if isinstance(v, _PauseReasonBase):
|
||||||
|
return v.TYPE
|
||||||
|
elif isinstance(v, dict):
|
||||||
|
reason_type_str = v.get("TYPE")
|
||||||
|
if reason_type_str is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
reason_type = _PauseReasonType(reason_type_str)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
return reason_type
|
||||||
|
else:
|
||||||
|
# return None if the discriminator value isn't found
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
PauseReason: TypeAlias = Annotated[
|
||||||
|
(
|
||||||
|
Annotated[HumanInputRequired, Tag(_PauseReasonType.HUMAN_INPUT_REQUIRED)]
|
||||||
|
| Annotated[SchedulingPause, Tag(_PauseReasonType.SCHEDULED_PAUSE)]
|
||||||
|
),
|
||||||
|
Discriminator(_get_pause_reason_discriminator),
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,61 @@
|
||||||
|
"""
|
||||||
|
Domain entities for workflow pause management.
|
||||||
|
|
||||||
|
This module contains the domain model for workflow pause, which is used
|
||||||
|
by the core workflow module. These models are independent of the storage mechanism
|
||||||
|
and don't contain implementation details like tenant_id, app_id, etc.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowPauseEntity(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for workflow pause entities.
|
||||||
|
|
||||||
|
This domain model represents a paused workflow execution state,
|
||||||
|
without implementation details like tenant_id, app_id, etc.
|
||||||
|
It provides the interface for managing workflow pause/resume operations
|
||||||
|
and state persistence through file storage.
|
||||||
|
|
||||||
|
The `WorkflowPauseEntity` is never reused. If a workflow execution pauses multiple times,
|
||||||
|
it will generate multiple `WorkflowPauseEntity` records.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def id(self) -> str:
|
||||||
|
"""The identifier of current WorkflowPauseEntity"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def workflow_execution_id(self) -> str:
|
||||||
|
"""The identifier of the workflow execution record the pause associated with.
|
||||||
|
Correspond to `WorkflowExecution.id`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_state(self) -> bytes:
|
||||||
|
"""
|
||||||
|
Retrieve the serialized workflow state from storage.
|
||||||
|
|
||||||
|
This method should load and return the workflow execution state
|
||||||
|
that was saved when the workflow was paused. The state contains
|
||||||
|
all necessary information to resume the workflow execution.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bytes: The serialized workflow state containing
|
||||||
|
execution context, variable values, node states, etc.
|
||||||
|
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def resumed_at(self) -> datetime | None:
|
||||||
|
"""`resumed_at` return the resumption time of the current pause, or `None` if
|
||||||
|
the pause is not resumed yet.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
@ -92,13 +92,111 @@ class WorkflowType(StrEnum):
|
||||||
|
|
||||||
|
|
||||||
class WorkflowExecutionStatus(StrEnum):
|
class WorkflowExecutionStatus(StrEnum):
|
||||||
|
# State diagram for the workflw status:
|
||||||
|
# (@) means start, (*) means end
|
||||||
|
#
|
||||||
|
# ┌------------------>------------------------->------------------->--------------┐
|
||||||
|
# | |
|
||||||
|
# | ┌-----------------------<--------------------┐ |
|
||||||
|
# ^ | | |
|
||||||
|
# | | ^ |
|
||||||
|
# | V | |
|
||||||
|
# ┌-----------┐ ┌-----------------------┐ ┌-----------┐ V
|
||||||
|
# | Scheduled |------->| Running |---------------------->| paused | |
|
||||||
|
# └-----------┘ └-----------------------┘ └-----------┘ |
|
||||||
|
# | | | | | | |
|
||||||
|
# | | | | | | |
|
||||||
|
# ^ | | | V V |
|
||||||
|
# | | | | | ┌---------┐ |
|
||||||
|
# (@) | | | └------------------------>| Stopped |<----┘
|
||||||
|
# | | | └---------┘
|
||||||
|
# | | | |
|
||||||
|
# | | V V
|
||||||
|
# | | ┌-----------┐ |
|
||||||
|
# | | | Succeeded |------------->--------------┤
|
||||||
|
# | | └-----------┘ |
|
||||||
|
# | V V
|
||||||
|
# | +--------┐ |
|
||||||
|
# | | Failed |---------------------->----------------┤
|
||||||
|
# | └--------┘ |
|
||||||
|
# V V
|
||||||
|
# ┌---------------------┐ |
|
||||||
|
# | Partially Succeeded |---------------------->-----------------┘--------> (*)
|
||||||
|
# └---------------------┘
|
||||||
|
#
|
||||||
|
# Mermaid diagram:
|
||||||
|
#
|
||||||
|
# ---
|
||||||
|
# title: State diagram for Workflow run state
|
||||||
|
# ---
|
||||||
|
# stateDiagram-v2
|
||||||
|
# scheduled: Scheduled
|
||||||
|
# running: Running
|
||||||
|
# succeeded: Succeeded
|
||||||
|
# failed: Failed
|
||||||
|
# partial_succeeded: Partial Succeeded
|
||||||
|
# paused: Paused
|
||||||
|
# stopped: Stopped
|
||||||
|
#
|
||||||
|
# [*] --> scheduled:
|
||||||
|
# scheduled --> running: Start Execution
|
||||||
|
# running --> paused: Human input required
|
||||||
|
# paused --> running: human input added
|
||||||
|
# paused --> stopped: User stops execution
|
||||||
|
# running --> succeeded: Execution finishes without any error
|
||||||
|
# running --> failed: Execution finishes with errors
|
||||||
|
# running --> stopped: User stops execution
|
||||||
|
# running --> partial_succeeded: some execution occurred and handled during execution
|
||||||
|
#
|
||||||
|
# scheduled --> stopped: User stops execution
|
||||||
|
#
|
||||||
|
# succeeded --> [*]
|
||||||
|
# failed --> [*]
|
||||||
|
# partial_succeeded --> [*]
|
||||||
|
# stopped --> [*]
|
||||||
|
|
||||||
|
# `SCHEDULED` means that the workflow is scheduled to run, but has not
|
||||||
|
# started running yet. (maybe due to possible worker saturation.)
|
||||||
|
#
|
||||||
|
# This enum value is currently unused.
|
||||||
|
SCHEDULED = "scheduled"
|
||||||
|
|
||||||
|
# `RUNNING` means the workflow is exeuting.
|
||||||
RUNNING = "running"
|
RUNNING = "running"
|
||||||
|
|
||||||
|
# `SUCCEEDED` means the execution of workflow succeed without any error.
|
||||||
SUCCEEDED = "succeeded"
|
SUCCEEDED = "succeeded"
|
||||||
|
|
||||||
|
# `FAILED` means the execution of workflow failed without some errors.
|
||||||
FAILED = "failed"
|
FAILED = "failed"
|
||||||
|
|
||||||
|
# `STOPPED` means the execution of workflow was stopped, either manually
|
||||||
|
# by the user, or automatically by the Dify application (E.G. the moderation
|
||||||
|
# mechanism.)
|
||||||
STOPPED = "stopped"
|
STOPPED = "stopped"
|
||||||
|
|
||||||
|
# `PARTIAL_SUCCEEDED` indicates that some errors occurred during the workflow
|
||||||
|
# execution, but they were successfully handled (e.g., by using an error
|
||||||
|
# strategy such as "fail branch" or "default value").
|
||||||
PARTIAL_SUCCEEDED = "partial-succeeded"
|
PARTIAL_SUCCEEDED = "partial-succeeded"
|
||||||
|
|
||||||
|
# `PAUSED` indicates that the workflow execution is temporarily paused
|
||||||
|
# (e.g., awaiting human input) and is expected to resume later.
|
||||||
PAUSED = "paused"
|
PAUSED = "paused"
|
||||||
|
|
||||||
|
def is_ended(self) -> bool:
|
||||||
|
return self in _END_STATE
|
||||||
|
|
||||||
|
|
||||||
|
_END_STATE = frozenset(
|
||||||
|
[
|
||||||
|
WorkflowExecutionStatus.SUCCEEDED,
|
||||||
|
WorkflowExecutionStatus.FAILED,
|
||||||
|
WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
|
||||||
|
WorkflowExecutionStatus.STOPPED,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class WorkflowNodeExecutionMetadataKey(StrEnum):
|
class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,8 @@ from typing import final
|
||||||
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from core.workflow.entities.pause_reason import SchedulingPause
|
||||||
|
|
||||||
from ..domain.graph_execution import GraphExecution
|
from ..domain.graph_execution import GraphExecution
|
||||||
from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
|
from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
|
||||||
from .command_processor import CommandHandler
|
from .command_processor import CommandHandler
|
||||||
|
|
@ -25,4 +27,7 @@ class PauseCommandHandler(CommandHandler):
|
||||||
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
|
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
|
||||||
assert isinstance(command, PauseCommand)
|
assert isinstance(command, PauseCommand)
|
||||||
logger.debug("Pausing workflow %s: %s", execution.workflow_id, command.reason)
|
logger.debug("Pausing workflow %s: %s", execution.workflow_id, command.reason)
|
||||||
execution.pause(command.reason)
|
# Convert string reason to PauseReason if needed
|
||||||
|
reason = command.reason
|
||||||
|
pause_reason = SchedulingPause(message=reason)
|
||||||
|
execution.pause(pause_reason)
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ from typing import Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from core.workflow.entities.pause_reason import PauseReason
|
||||||
from core.workflow.enums import NodeState
|
from core.workflow.enums import NodeState
|
||||||
|
|
||||||
from .node_execution import NodeExecution
|
from .node_execution import NodeExecution
|
||||||
|
|
@ -41,7 +42,7 @@ class GraphExecutionState(BaseModel):
|
||||||
completed: bool = Field(default=False)
|
completed: bool = Field(default=False)
|
||||||
aborted: bool = Field(default=False)
|
aborted: bool = Field(default=False)
|
||||||
paused: bool = Field(default=False)
|
paused: bool = Field(default=False)
|
||||||
pause_reason: str | None = Field(default=None)
|
pause_reason: PauseReason | None = Field(default=None)
|
||||||
error: GraphExecutionErrorState | None = Field(default=None)
|
error: GraphExecutionErrorState | None = Field(default=None)
|
||||||
exceptions_count: int = Field(default=0)
|
exceptions_count: int = Field(default=0)
|
||||||
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])
|
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])
|
||||||
|
|
@ -106,7 +107,7 @@ class GraphExecution:
|
||||||
completed: bool = False
|
completed: bool = False
|
||||||
aborted: bool = False
|
aborted: bool = False
|
||||||
paused: bool = False
|
paused: bool = False
|
||||||
pause_reason: str | None = None
|
pause_reason: PauseReason | None = None
|
||||||
error: Exception | None = None
|
error: Exception | None = None
|
||||||
node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
|
node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
|
||||||
exceptions_count: int = 0
|
exceptions_count: int = 0
|
||||||
|
|
@ -130,7 +131,7 @@ class GraphExecution:
|
||||||
self.aborted = True
|
self.aborted = True
|
||||||
self.error = RuntimeError(f"Aborted: {reason}")
|
self.error = RuntimeError(f"Aborted: {reason}")
|
||||||
|
|
||||||
def pause(self, reason: str | None = None) -> None:
|
def pause(self, reason: PauseReason) -> None:
|
||||||
"""Pause the graph execution without marking it complete."""
|
"""Pause the graph execution without marking it complete."""
|
||||||
if self.completed:
|
if self.completed:
|
||||||
raise RuntimeError("Cannot pause execution that has completed")
|
raise RuntimeError("Cannot pause execution that has completed")
|
||||||
|
|
|
||||||
|
|
@ -36,4 +36,4 @@ class PauseCommand(GraphEngineCommand):
|
||||||
"""Command to pause a running workflow execution."""
|
"""Command to pause a running workflow execution."""
|
||||||
|
|
||||||
command_type: CommandType = Field(default=CommandType.PAUSE, description="Type of command")
|
command_type: CommandType = Field(default=CommandType.PAUSE, description="Type of command")
|
||||||
reason: str | None = Field(default=None, description="Optional reason for pause")
|
reason: str = Field(default="unknown reason", description="reason for pause")
|
||||||
|
|
|
||||||
|
|
@ -210,7 +210,7 @@ class EventHandler:
|
||||||
def _(self, event: NodeRunPauseRequestedEvent) -> None:
|
def _(self, event: NodeRunPauseRequestedEvent) -> None:
|
||||||
"""Handle pause requests emitted by nodes."""
|
"""Handle pause requests emitted by nodes."""
|
||||||
|
|
||||||
pause_reason = event.reason or "Awaiting human input"
|
pause_reason = event.reason
|
||||||
self._graph_execution.pause(pause_reason)
|
self._graph_execution.pause(pause_reason)
|
||||||
self._state_manager.finish_execution(event.node_id)
|
self._state_manager.finish_execution(event.node_id)
|
||||||
if event.node_id in self._graph.nodes:
|
if event.node_id in self._graph.nodes:
|
||||||
|
|
|
||||||
|
|
@ -247,8 +247,11 @@ class GraphEngine:
|
||||||
|
|
||||||
# Handle completion
|
# Handle completion
|
||||||
if self._graph_execution.is_paused:
|
if self._graph_execution.is_paused:
|
||||||
|
pause_reason = self._graph_execution.pause_reason
|
||||||
|
assert pause_reason is not None, "pause_reason should not be None when execution is paused."
|
||||||
|
# Ensure we have a valid PauseReason for the event
|
||||||
paused_event = GraphRunPausedEvent(
|
paused_event = GraphRunPausedEvent(
|
||||||
reason=self._graph_execution.pause_reason,
|
reason=pause_reason,
|
||||||
outputs=self._graph_runtime_state.outputs,
|
outputs=self._graph_runtime_state.outputs,
|
||||||
)
|
)
|
||||||
self._event_manager.notify_layers(paused_event)
|
self._event_manager.notify_layers(paused_event)
|
||||||
|
|
|
||||||
|
|
@ -216,7 +216,6 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||||
def _handle_graph_run_paused(self, event: GraphRunPausedEvent) -> None:
|
def _handle_graph_run_paused(self, event: GraphRunPausedEvent) -> None:
|
||||||
execution = self._get_workflow_execution()
|
execution = self._get_workflow_execution()
|
||||||
execution.status = WorkflowExecutionStatus.PAUSED
|
execution.status = WorkflowExecutionStatus.PAUSED
|
||||||
execution.error_message = event.reason or "Workflow execution paused"
|
|
||||||
execution.outputs = event.outputs
|
execution.outputs = event.outputs
|
||||||
self._populate_completion_statistics(execution, update_finished=False)
|
self._populate_completion_statistics(execution, update_finished=False)
|
||||||
|
|
||||||
|
|
@ -296,7 +295,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||||
domain_execution,
|
domain_execution,
|
||||||
event.node_run_result,
|
event.node_run_result,
|
||||||
WorkflowNodeExecutionStatus.PAUSED,
|
WorkflowNodeExecutionStatus.PAUSED,
|
||||||
error=event.reason,
|
error="",
|
||||||
update_outputs=False,
|
update_outputs=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
|
from core.workflow.entities.pause_reason import PauseReason
|
||||||
from core.workflow.graph_events import BaseGraphEvent
|
from core.workflow.graph_events import BaseGraphEvent
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -44,7 +45,8 @@ class GraphRunAbortedEvent(BaseGraphEvent):
|
||||||
class GraphRunPausedEvent(BaseGraphEvent):
|
class GraphRunPausedEvent(BaseGraphEvent):
|
||||||
"""Event emitted when a graph run is paused by user command."""
|
"""Event emitted when a graph run is paused by user command."""
|
||||||
|
|
||||||
reason: str | None = Field(default=None, description="reason for pause")
|
# reason: str | None = Field(default=None, description="reason for pause")
|
||||||
|
reason: PauseReason = Field(..., description="reason for pause")
|
||||||
outputs: dict[str, object] = Field(
|
outputs: dict[str, object] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="Outputs available to the client while the run is paused.",
|
description="Outputs available to the client while the run is paused.",
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from pydantic import Field
|
||||||
|
|
||||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||||
from core.workflow.entities import AgentNodeStrategyInit
|
from core.workflow.entities import AgentNodeStrategyInit
|
||||||
|
from core.workflow.entities.pause_reason import PauseReason
|
||||||
|
|
||||||
from .base import GraphNodeEventBase
|
from .base import GraphNodeEventBase
|
||||||
|
|
||||||
|
|
@ -54,4 +55,4 @@ class NodeRunRetryEvent(NodeRunStartedEvent):
|
||||||
|
|
||||||
|
|
||||||
class NodeRunPauseRequestedEvent(GraphNodeEventBase):
|
class NodeRunPauseRequestedEvent(GraphNodeEventBase):
|
||||||
reason: str | None = Field(default=None, description="Optional pause reason")
|
reason: PauseReason = Field(..., description="pause reason")
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from pydantic import Field
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||||
|
from core.workflow.entities.pause_reason import PauseReason
|
||||||
from core.workflow.node_events import NodeRunResult
|
from core.workflow.node_events import NodeRunResult
|
||||||
|
|
||||||
from .base import NodeEventBase
|
from .base import NodeEventBase
|
||||||
|
|
@ -43,4 +44,4 @@ class StreamCompletedEvent(NodeEventBase):
|
||||||
|
|
||||||
|
|
||||||
class PauseRequestedEvent(NodeEventBase):
|
class PauseRequestedEvent(NodeEventBase):
|
||||||
reason: str | None = Field(default=None, description="Optional pause reason")
|
reason: PauseReason = Field(..., description="pause reason")
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||||
from core.workflow.node_events import NodeRunResult, PauseRequestedEvent
|
from core.workflow.node_events import NodeRunResult, PauseRequestedEvent
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||||
|
|
@ -64,7 +65,7 @@ class HumanInputNode(Node):
|
||||||
return self._pause_generator()
|
return self._pause_generator()
|
||||||
|
|
||||||
def _pause_generator(self):
|
def _pause_generator(self):
|
||||||
yield PauseRequestedEvent(reason=self._node_data.pause_reason)
|
yield PauseRequestedEvent(reason=HumanInputRequired())
|
||||||
|
|
||||||
def _is_completion_ready(self) -> bool:
|
def _is_completion_ready(self) -> bool:
|
||||||
"""Determine whether all required inputs are satisfied."""
|
"""Determine whether all required inputs are satisfied."""
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ from typing import Any, Protocol
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
from core.variables.segments import Segment
|
from core.variables.segments import Segment
|
||||||
|
from core.workflow.system_variable import SystemVariableReadOnlyView
|
||||||
|
|
||||||
|
|
||||||
class ReadOnlyVariablePool(Protocol):
|
class ReadOnlyVariablePool(Protocol):
|
||||||
|
|
@ -30,6 +31,9 @@ class ReadOnlyGraphRuntimeState(Protocol):
|
||||||
All methods return defensive copies to ensure immutability.
|
All methods return defensive copies to ensure immutability.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def system_variable(self) -> SystemVariableReadOnlyView: ...
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def variable_pool(self) -> ReadOnlyVariablePool:
|
def variable_pool(self) -> ReadOnlyVariablePool:
|
||||||
"""Get read-only access to the variable pool."""
|
"""Get read-only access to the variable pool."""
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from typing import Any
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
from core.variables.segments import Segment
|
from core.variables.segments import Segment
|
||||||
|
from core.workflow.system_variable import SystemVariableReadOnlyView
|
||||||
|
|
||||||
from .graph_runtime_state import GraphRuntimeState
|
from .graph_runtime_state import GraphRuntimeState
|
||||||
from .variable_pool import VariablePool
|
from .variable_pool import VariablePool
|
||||||
|
|
@ -42,6 +43,10 @@ class ReadOnlyGraphRuntimeStateWrapper:
|
||||||
self._state = state
|
self._state = state
|
||||||
self._variable_pool_wrapper = ReadOnlyVariablePoolWrapper(state.variable_pool)
|
self._variable_pool_wrapper = ReadOnlyVariablePoolWrapper(state.variable_pool)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def system_variable(self) -> SystemVariableReadOnlyView:
|
||||||
|
return self._state.variable_pool.system_variables.as_view()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def variable_pool(self) -> ReadOnlyVariablePoolWrapper:
|
def variable_pool(self) -> ReadOnlyVariablePoolWrapper:
|
||||||
return self._variable_pool_wrapper
|
return self._variable_pool_wrapper
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
|
from types import MappingProxyType
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator
|
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator
|
||||||
|
|
@ -108,3 +109,102 @@ class SystemVariable(BaseModel):
|
||||||
if self.invoke_from is not None:
|
if self.invoke_from is not None:
|
||||||
d[SystemVariableKey.INVOKE_FROM] = self.invoke_from
|
d[SystemVariableKey.INVOKE_FROM] = self.invoke_from
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
def as_view(self) -> "SystemVariableReadOnlyView":
|
||||||
|
return SystemVariableReadOnlyView(self)
|
||||||
|
|
||||||
|
|
||||||
|
class SystemVariableReadOnlyView:
|
||||||
|
"""
|
||||||
|
A read-only view of a SystemVariable that implements the ReadOnlySystemVariable protocol.
|
||||||
|
|
||||||
|
This class wraps a SystemVariable instance and provides read-only access to all its fields.
|
||||||
|
It always reads the latest data from the wrapped instance and prevents any write operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, system_variable: SystemVariable) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the read-only view with a SystemVariable instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
system_variable: The SystemVariable instance to wrap
|
||||||
|
"""
|
||||||
|
self._system_variable = system_variable
|
||||||
|
|
||||||
|
@property
|
||||||
|
def user_id(self) -> str | None:
|
||||||
|
return self._system_variable.user_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def app_id(self) -> str | None:
|
||||||
|
return self._system_variable.app_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def workflow_id(self) -> str | None:
|
||||||
|
return self._system_variable.workflow_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def workflow_execution_id(self) -> str | None:
|
||||||
|
return self._system_variable.workflow_execution_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def query(self) -> str | None:
|
||||||
|
return self._system_variable.query
|
||||||
|
|
||||||
|
@property
|
||||||
|
def conversation_id(self) -> str | None:
|
||||||
|
return self._system_variable.conversation_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dialogue_count(self) -> int | None:
|
||||||
|
return self._system_variable.dialogue_count
|
||||||
|
|
||||||
|
@property
|
||||||
|
def document_id(self) -> str | None:
|
||||||
|
return self._system_variable.document_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def original_document_id(self) -> str | None:
|
||||||
|
return self._system_variable.original_document_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dataset_id(self) -> str | None:
|
||||||
|
return self._system_variable.dataset_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def batch(self) -> str | None:
|
||||||
|
return self._system_variable.batch
|
||||||
|
|
||||||
|
@property
|
||||||
|
def datasource_type(self) -> str | None:
|
||||||
|
return self._system_variable.datasource_type
|
||||||
|
|
||||||
|
@property
|
||||||
|
def invoke_from(self) -> str | None:
|
||||||
|
return self._system_variable.invoke_from
|
||||||
|
|
||||||
|
@property
|
||||||
|
def files(self) -> Sequence[File]:
|
||||||
|
"""
|
||||||
|
Get a copy of the files from the wrapped SystemVariable.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A defensive copy of the files sequence to prevent modification
|
||||||
|
"""
|
||||||
|
return tuple(self._system_variable.files) # Convert to immutable tuple
|
||||||
|
|
||||||
|
@property
|
||||||
|
def datasource_info(self) -> Mapping[str, Any] | None:
|
||||||
|
"""
|
||||||
|
Get a copy of the datasource info from the wrapped SystemVariable.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A view of the datasource info mapping to prevent modification
|
||||||
|
"""
|
||||||
|
if self._system_variable.datasource_info is None:
|
||||||
|
return None
|
||||||
|
return MappingProxyType(self._system_variable.datasource_info)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
"""Return a string representation of the read-only view."""
|
||||||
|
return f"SystemVariableReadOnlyView(system_variable={self._system_variable!r})"
|
||||||
|
|
|
||||||
|
|
@ -85,7 +85,7 @@ class Storage:
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"unsupported storage type {storage_type}")
|
raise ValueError(f"unsupported storage type {storage_type}")
|
||||||
|
|
||||||
def save(self, filename, data):
|
def save(self, filename: str, data: bytes):
|
||||||
self.storage_runner.save(filename, data)
|
self.storage_runner.save(filename, data)
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ class BaseStorage(ABC):
|
||||||
"""Interface for file storage."""
|
"""Interface for file storage."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def save(self, filename, data):
|
def save(self, filename: str, data: bytes):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,41 @@
|
||||||
|
"""add WorkflowPause model
|
||||||
|
|
||||||
|
Revision ID: 03f8dcbc611e
|
||||||
|
Revises: ae662b25d9bc
|
||||||
|
Create Date: 2025-10-22 16:11:31.805407
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import models as models
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = "03f8dcbc611e"
|
||||||
|
down_revision = "ae662b25d9bc"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table(
|
||||||
|
"workflow_pauses",
|
||||||
|
sa.Column("workflow_id", models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column("workflow_run_id", models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column("resumed_at", sa.DateTime(), nullable=True),
|
||||||
|
sa.Column("state_object_key", sa.String(length=255), nullable=False),
|
||||||
|
sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuidv7()"), nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("workflow_pauses_pkey")),
|
||||||
|
sa.UniqueConstraint("workflow_run_id", name=op.f("workflow_pauses_workflow_run_id_key")),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_table("workflow_pauses")
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
@ -88,6 +88,7 @@ from .workflow import (
|
||||||
WorkflowNodeExecutionModel,
|
WorkflowNodeExecutionModel,
|
||||||
WorkflowNodeExecutionOffload,
|
WorkflowNodeExecutionOffload,
|
||||||
WorkflowNodeExecutionTriggeredFrom,
|
WorkflowNodeExecutionTriggeredFrom,
|
||||||
|
WorkflowPause,
|
||||||
WorkflowRun,
|
WorkflowRun,
|
||||||
WorkflowType,
|
WorkflowType,
|
||||||
)
|
)
|
||||||
|
|
@ -177,6 +178,7 @@ __all__ = [
|
||||||
"WorkflowNodeExecutionModel",
|
"WorkflowNodeExecutionModel",
|
||||||
"WorkflowNodeExecutionOffload",
|
"WorkflowNodeExecutionOffload",
|
||||||
"WorkflowNodeExecutionTriggeredFrom",
|
"WorkflowNodeExecutionTriggeredFrom",
|
||||||
|
"WorkflowPause",
|
||||||
"WorkflowRun",
|
"WorkflowRun",
|
||||||
"WorkflowRunTriggeredFrom",
|
"WorkflowRunTriggeredFrom",
|
||||||
"WorkflowToolProvider",
|
"WorkflowToolProvider",
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,12 @@
|
||||||
from sqlalchemy.orm import DeclarativeBase, MappedAsDataclass
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy import DateTime, func, text
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column
|
||||||
|
|
||||||
|
from libs.datetime_utils import naive_utc_now
|
||||||
|
from libs.uuid_utils import uuidv7
|
||||||
from models.engine import metadata
|
from models.engine import metadata
|
||||||
|
from models.types import StringUUID
|
||||||
|
|
||||||
|
|
||||||
class Base(DeclarativeBase):
|
class Base(DeclarativeBase):
|
||||||
|
|
@ -13,3 +19,34 @@ class TypeBase(MappedAsDataclass, DeclarativeBase):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
metadata = metadata
|
metadata = metadata
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultFieldsMixin:
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
StringUUID,
|
||||||
|
primary_key=True,
|
||||||
|
# NOTE: The default and server_default serve as fallback mechanisms.
|
||||||
|
# The application can generate the `id` before saving to optimize
|
||||||
|
# the insertion process (especially for interdependent models)
|
||||||
|
# and reduce database roundtrips.
|
||||||
|
default=uuidv7,
|
||||||
|
server_default=text("uuidv7()"),
|
||||||
|
)
|
||||||
|
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime,
|
||||||
|
nullable=False,
|
||||||
|
default=naive_utc_now,
|
||||||
|
server_default=func.current_timestamp(),
|
||||||
|
)
|
||||||
|
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
__name_pos=DateTime,
|
||||||
|
nullable=False,
|
||||||
|
default=naive_utc_now,
|
||||||
|
server_default=func.current_timestamp(),
|
||||||
|
onupdate=func.current_timestamp(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<{self.__class__.__name__}(id={self.id})>"
|
||||||
|
|
|
||||||
|
|
@ -13,8 +13,11 @@ from core.file.constants import maybe_file_object
|
||||||
from core.file.models import File
|
from core.file.models import File
|
||||||
from core.variables import utils as variable_utils
|
from core.variables import utils as variable_utils
|
||||||
from core.variables.variables import FloatVariable, IntegerVariable, StringVariable
|
from core.variables.variables import FloatVariable, IntegerVariable, StringVariable
|
||||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
from core.workflow.constants import (
|
||||||
from core.workflow.enums import NodeType
|
CONVERSATION_VARIABLE_NODE_ID,
|
||||||
|
SYSTEM_VARIABLE_NODE_ID,
|
||||||
|
)
|
||||||
|
from core.workflow.enums import NodeType, WorkflowExecutionStatus
|
||||||
from extensions.ext_storage import Storage
|
from extensions.ext_storage import Storage
|
||||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type
|
from factories.variable_factory import TypeMismatchError, build_segment_with_type
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
|
|
@ -35,7 +38,7 @@ from factories import variable_factory
|
||||||
from libs import helper
|
from libs import helper
|
||||||
|
|
||||||
from .account import Account
|
from .account import Account
|
||||||
from .base import Base
|
from .base import Base, DefaultFieldsMixin
|
||||||
from .engine import db
|
from .engine import db
|
||||||
from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType
|
from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType
|
||||||
from .types import EnumText, StringUUID
|
from .types import EnumText, StringUUID
|
||||||
|
|
@ -247,7 +250,9 @@ class Workflow(Base):
|
||||||
return node_type
|
return node_type
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_enclosing_node_type_and_id(node_config: Mapping[str, Any]) -> tuple[NodeType, str] | None:
|
def get_enclosing_node_type_and_id(
|
||||||
|
node_config: Mapping[str, Any],
|
||||||
|
) -> tuple[NodeType, str] | None:
|
||||||
in_loop = node_config.get("isInLoop", False)
|
in_loop = node_config.get("isInLoop", False)
|
||||||
in_iteration = node_config.get("isInIteration", False)
|
in_iteration = node_config.get("isInIteration", False)
|
||||||
if in_loop:
|
if in_loop:
|
||||||
|
|
@ -306,7 +311,10 @@ class Workflow(Base):
|
||||||
if "nodes" not in graph_dict:
|
if "nodes" not in graph_dict:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
start_node = next((node for node in graph_dict["nodes"] if node["data"]["type"] == "start"), None)
|
start_node = next(
|
||||||
|
(node for node in graph_dict["nodes"] if node["data"]["type"] == "start"),
|
||||||
|
None,
|
||||||
|
)
|
||||||
if not start_node:
|
if not start_node:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
@ -359,7 +367,9 @@ class Workflow(Base):
|
||||||
return db.session.execute(stmt).scalar_one()
|
return db.session.execute(stmt).scalar_one()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
|
def environment_variables(
|
||||||
|
self,
|
||||||
|
) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
|
||||||
# TODO: find some way to init `self._environment_variables` when instance created.
|
# TODO: find some way to init `self._environment_variables` when instance created.
|
||||||
if self._environment_variables is None:
|
if self._environment_variables is None:
|
||||||
self._environment_variables = "{}"
|
self._environment_variables = "{}"
|
||||||
|
|
@ -376,7 +386,9 @@ class Workflow(Base):
|
||||||
]
|
]
|
||||||
|
|
||||||
# decrypt secret variables value
|
# decrypt secret variables value
|
||||||
def decrypt_func(var: Variable) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable:
|
def decrypt_func(
|
||||||
|
var: Variable,
|
||||||
|
) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable:
|
||||||
if isinstance(var, SecretVariable):
|
if isinstance(var, SecretVariable):
|
||||||
return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)})
|
return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)})
|
||||||
elif isinstance(var, (StringVariable, IntegerVariable, FloatVariable)):
|
elif isinstance(var, (StringVariable, IntegerVariable, FloatVariable)):
|
||||||
|
|
@ -537,7 +549,10 @@ class WorkflowRun(Base):
|
||||||
version: Mapped[str] = mapped_column(String(255))
|
version: Mapped[str] = mapped_column(String(255))
|
||||||
graph: Mapped[str | None] = mapped_column(sa.Text)
|
graph: Mapped[str | None] = mapped_column(sa.Text)
|
||||||
inputs: Mapped[str | None] = mapped_column(sa.Text)
|
inputs: Mapped[str | None] = mapped_column(sa.Text)
|
||||||
status: Mapped[str] = mapped_column(String(255)) # running, succeeded, failed, stopped, partial-succeeded
|
status: Mapped[str] = mapped_column(
|
||||||
|
EnumText(WorkflowExecutionStatus, length=255),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
outputs: Mapped[str | None] = mapped_column(sa.Text, default="{}")
|
outputs: Mapped[str | None] = mapped_column(sa.Text, default="{}")
|
||||||
error: Mapped[str | None] = mapped_column(sa.Text)
|
error: Mapped[str | None] = mapped_column(sa.Text)
|
||||||
elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
|
elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
|
||||||
|
|
@ -549,6 +564,15 @@ class WorkflowRun(Base):
|
||||||
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
|
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||||
exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
|
exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
|
||||||
|
|
||||||
|
pause: Mapped[Optional["WorkflowPause"]] = orm.relationship(
|
||||||
|
"WorkflowPause",
|
||||||
|
primaryjoin="WorkflowRun.id == foreign(WorkflowPause.workflow_run_id)",
|
||||||
|
uselist=False,
|
||||||
|
# require explicit preloading.
|
||||||
|
lazy="raise",
|
||||||
|
back_populates="workflow_run",
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def created_by_account(self):
|
def created_by_account(self):
|
||||||
created_by_role = CreatorUserRole(self.created_by_role)
|
created_by_role = CreatorUserRole(self.created_by_role)
|
||||||
|
|
@ -1073,7 +1097,10 @@ class ConversationVariable(Base):
|
||||||
DateTime, nullable=False, server_default=func.current_timestamp(), index=True
|
DateTime, nullable=False, server_default=func.current_timestamp(), index=True
|
||||||
)
|
)
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
DateTime,
|
||||||
|
nullable=False,
|
||||||
|
server_default=func.current_timestamp(),
|
||||||
|
onupdate=func.current_timestamp(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, *, id: str, app_id: str, conversation_id: str, data: str):
|
def __init__(self, *, id: str, app_id: str, conversation_id: str, data: str):
|
||||||
|
|
@ -1101,10 +1128,6 @@ class ConversationVariable(Base):
|
||||||
_EDITABLE_SYSTEM_VARIABLE = frozenset(["query", "files"])
|
_EDITABLE_SYSTEM_VARIABLE = frozenset(["query", "files"])
|
||||||
|
|
||||||
|
|
||||||
def _naive_utc_datetime():
|
|
||||||
return naive_utc_now()
|
|
||||||
|
|
||||||
|
|
||||||
class WorkflowDraftVariable(Base):
|
class WorkflowDraftVariable(Base):
|
||||||
"""`WorkflowDraftVariable` record variables and outputs generated during
|
"""`WorkflowDraftVariable` record variables and outputs generated during
|
||||||
debugging workflow or chatflow.
|
debugging workflow or chatflow.
|
||||||
|
|
@ -1138,14 +1161,14 @@ class WorkflowDraftVariable(Base):
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime,
|
DateTime,
|
||||||
nullable=False,
|
nullable=False,
|
||||||
default=_naive_utc_datetime,
|
default=naive_utc_now,
|
||||||
server_default=func.current_timestamp(),
|
server_default=func.current_timestamp(),
|
||||||
)
|
)
|
||||||
|
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime,
|
DateTime,
|
||||||
nullable=False,
|
nullable=False,
|
||||||
default=_naive_utc_datetime,
|
default=naive_utc_now,
|
||||||
server_default=func.current_timestamp(),
|
server_default=func.current_timestamp(),
|
||||||
onupdate=func.current_timestamp(),
|
onupdate=func.current_timestamp(),
|
||||||
)
|
)
|
||||||
|
|
@ -1412,8 +1435,8 @@ class WorkflowDraftVariable(Base):
|
||||||
file_id: str | None = None,
|
file_id: str | None = None,
|
||||||
) -> "WorkflowDraftVariable":
|
) -> "WorkflowDraftVariable":
|
||||||
variable = WorkflowDraftVariable()
|
variable = WorkflowDraftVariable()
|
||||||
variable.created_at = _naive_utc_datetime()
|
variable.created_at = naive_utc_now()
|
||||||
variable.updated_at = _naive_utc_datetime()
|
variable.updated_at = naive_utc_now()
|
||||||
variable.description = description
|
variable.description = description
|
||||||
variable.app_id = app_id
|
variable.app_id = app_id
|
||||||
variable.node_id = node_id
|
variable.node_id = node_id
|
||||||
|
|
@ -1518,7 +1541,7 @@ class WorkflowDraftVariableFile(Base):
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime,
|
DateTime,
|
||||||
nullable=False,
|
nullable=False,
|
||||||
default=_naive_utc_datetime,
|
default=naive_utc_now,
|
||||||
server_default=func.current_timestamp(),
|
server_default=func.current_timestamp(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1583,3 +1606,68 @@ class WorkflowDraftVariableFile(Base):
|
||||||
|
|
||||||
def is_system_variable_editable(name: str) -> bool:
|
def is_system_variable_editable(name: str) -> bool:
|
||||||
return name in _EDITABLE_SYSTEM_VARIABLE
|
return name in _EDITABLE_SYSTEM_VARIABLE
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowPause(DefaultFieldsMixin, Base):
|
||||||
|
"""
|
||||||
|
WorkflowPause records the paused state and related metadata for a specific workflow run.
|
||||||
|
|
||||||
|
Each `WorkflowRun` can have zero or one associated `WorkflowPause`, depending on its execution status.
|
||||||
|
If a `WorkflowRun` is in the `PAUSED` state, there must be a corresponding `WorkflowPause`
|
||||||
|
that has not yet been resumed.
|
||||||
|
Otherwise, there should be no active (non-resumed) `WorkflowPause` linked to that run.
|
||||||
|
|
||||||
|
This model captures the execution context required to resume workflow processing at a later time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "workflow_pauses"
|
||||||
|
__table_args__ = (
|
||||||
|
# Design Note:
|
||||||
|
# Instead of adding a `pause_id` field to the `WorkflowRun` model—which would require a migration
|
||||||
|
# on a potentially large table—we reference `WorkflowRun` from `WorkflowPause` and enforce a unique
|
||||||
|
# constraint on `workflow_run_id` to guarantee a one-to-one relationship.
|
||||||
|
UniqueConstraint("workflow_run_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# `workflow_id` represents the unique identifier of the workflow associated with this pause.
|
||||||
|
# It corresponds to the `id` field in the `Workflow` model.
|
||||||
|
#
|
||||||
|
# Since an application can have multiple versions of a workflow, each with its own unique ID,
|
||||||
|
# the `app_id` alone is insufficient to determine which workflow version should be loaded
|
||||||
|
# when resuming a suspended workflow.
|
||||||
|
workflow_id: Mapped[str] = mapped_column(
|
||||||
|
StringUUID,
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# `workflow_run_id` represents the identifier of the execution of workflow,
|
||||||
|
# correspond to the `id` field of `WorkflowRun`.
|
||||||
|
workflow_run_id: Mapped[str] = mapped_column(
|
||||||
|
StringUUID,
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# `resumed_at` records the timestamp when the suspended workflow was resumed.
|
||||||
|
# It is set to `NULL` if the workflow has not been resumed.
|
||||||
|
#
|
||||||
|
# NOTE: Resuming a suspended WorkflowPause does not delete the record immediately.
|
||||||
|
# It only set `resumed_at` to a non-null value.
|
||||||
|
resumed_at: Mapped[datetime | None] = mapped_column(
|
||||||
|
sa.DateTime,
|
||||||
|
nullable=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# state_object_key stores the object key referencing the serialized runtime state
|
||||||
|
# of the `GraphEngine`. This object captures the complete execution context of the
|
||||||
|
# workflow at the moment it was paused, enabling accurate resumption.
|
||||||
|
state_object_key: Mapped[str] = mapped_column(String(length=255), nullable=False)
|
||||||
|
|
||||||
|
# Relationship to WorkflowRun
|
||||||
|
workflow_run: Mapped["WorkflowRun"] = orm.relationship(
|
||||||
|
foreign_keys=[workflow_run_id],
|
||||||
|
# require explicit preloading.
|
||||||
|
lazy="raise",
|
||||||
|
uselist=False,
|
||||||
|
primaryjoin="WorkflowPause.workflow_run_id == WorkflowRun.id",
|
||||||
|
back_populates="pause",
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -38,6 +38,7 @@ from collections.abc import Sequence
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Protocol
|
from typing import Protocol
|
||||||
|
|
||||||
|
from core.workflow.entities.workflow_pause import WorkflowPauseEntity
|
||||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||||
from models.enums import WorkflowRunTriggeredFrom
|
from models.enums import WorkflowRunTriggeredFrom
|
||||||
|
|
@ -251,6 +252,116 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
def create_workflow_pause(
|
||||||
|
self,
|
||||||
|
workflow_run_id: str,
|
||||||
|
state_owner_user_id: str,
|
||||||
|
state: str,
|
||||||
|
) -> WorkflowPauseEntity:
|
||||||
|
"""
|
||||||
|
Create a new workflow pause state.
|
||||||
|
|
||||||
|
Creates a pause state for a workflow run, storing the current execution
|
||||||
|
state and marking the workflow as paused. This is used when a workflow
|
||||||
|
needs to be suspended and later resumed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workflow_run_id: Identifier of the workflow run to pause
|
||||||
|
state_owner_user_id: User ID who owns the pause state for file storage
|
||||||
|
state: Serialized workflow execution state (JSON string)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
WorkflowPauseEntity representing the created pause state
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If workflow_run_id is invalid or workflow run doesn't exist
|
||||||
|
RuntimeError: If workflow is already paused or in invalid state
|
||||||
|
"""
|
||||||
|
# NOTE: we may get rid of the `state_owner_user_id` in parameter list.
|
||||||
|
# However, removing it would require an extra for `Workflow` model
|
||||||
|
# while creating pause.
|
||||||
|
...
|
||||||
|
|
||||||
|
def resume_workflow_pause(
|
||||||
|
self,
|
||||||
|
workflow_run_id: str,
|
||||||
|
pause_entity: WorkflowPauseEntity,
|
||||||
|
) -> WorkflowPauseEntity:
|
||||||
|
"""
|
||||||
|
Resume a paused workflow.
|
||||||
|
|
||||||
|
Marks a paused workflow as resumed, set the `resumed_at` field of WorkflowPauseEntity
|
||||||
|
and returning the workflow to running status. Returns the pause entity
|
||||||
|
that was resumed.
|
||||||
|
|
||||||
|
The returned `WorkflowPauseEntity` model has `resumed_at` set.
|
||||||
|
|
||||||
|
NOTE: this method does not delete the correspond `WorkflowPauseEntity` record and associated states.
|
||||||
|
It's the callers responsibility to clear the correspond state with `delete_workflow_pause`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workflow_run_id: Identifier of the workflow run to resume
|
||||||
|
pause_entity: The pause entity to resume
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
WorkflowPauseEntity representing the resumed pause state
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If workflow_run_id is invalid
|
||||||
|
RuntimeError: If workflow is not paused or already resumed
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def delete_workflow_pause(
|
||||||
|
self,
|
||||||
|
pause_entity: WorkflowPauseEntity,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Delete a workflow pause state.
|
||||||
|
|
||||||
|
Permanently removes the pause state for a workflow run, including
|
||||||
|
the stored state file. Used for cleanup operations when a paused
|
||||||
|
workflow is no longer needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pause_entity: The pause entity to delete
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If pause_entity is invalid
|
||||||
|
RuntimeError: If workflow is not paused
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This operation is irreversible. The stored workflow state will be
|
||||||
|
permanently deleted along with the pause record.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def prune_pauses(
|
||||||
|
self,
|
||||||
|
expiration: datetime,
|
||||||
|
resumption_expiration: datetime,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> Sequence[str]:
|
||||||
|
"""
|
||||||
|
Clean up expired and old pause states.
|
||||||
|
|
||||||
|
Removes pause states that have expired (created before expiration time)
|
||||||
|
and pause states that were resumed more than resumption_duration ago.
|
||||||
|
This is used for maintenance and cleanup operations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expiration: Remove pause states created before this time
|
||||||
|
resumption_expiration: Remove pause states resumed before this time
|
||||||
|
limit: maximum number of records deleted in one call
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a list of ids for pause records that were pruned
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If parameters are invalid
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
def get_daily_runs_statistics(
|
def get_daily_runs_statistics(
|
||||||
self,
|
self,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
|
|
|
||||||
|
|
@ -20,19 +20,26 @@ Implementation Notes:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import uuid
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from sqlalchemy import delete, func, select
|
from sqlalchemy import and_, delete, func, null, or_, select
|
||||||
from sqlalchemy.engine import CursorResult
|
from sqlalchemy.engine import CursorResult
|
||||||
from sqlalchemy.orm import Session, sessionmaker
|
from sqlalchemy.orm import Session, selectinload, sessionmaker
|
||||||
|
|
||||||
|
from core.workflow.entities.workflow_pause import WorkflowPauseEntity
|
||||||
|
from core.workflow.enums import WorkflowExecutionStatus
|
||||||
|
from extensions.ext_storage import storage
|
||||||
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||||
from libs.time_parser import get_time_threshold
|
from libs.time_parser import get_time_threshold
|
||||||
|
from libs.uuid_utils import uuidv7
|
||||||
from models.enums import WorkflowRunTriggeredFrom
|
from models.enums import WorkflowRunTriggeredFrom
|
||||||
|
from models.workflow import WorkflowPause as WorkflowPauseModel
|
||||||
from models.workflow import WorkflowRun
|
from models.workflow import WorkflowRun
|
||||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||||
from repositories.types import (
|
from repositories.types import (
|
||||||
|
|
@ -45,6 +52,10 @@ from repositories.types import (
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class _WorkflowRunError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||||
"""
|
"""
|
||||||
SQLAlchemy implementation of APIWorkflowRunRepository.
|
SQLAlchemy implementation of APIWorkflowRunRepository.
|
||||||
|
|
@ -301,6 +312,281 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||||
logger.info("Total deleted %s workflow runs for app %s", total_deleted, app_id)
|
logger.info("Total deleted %s workflow runs for app %s", total_deleted, app_id)
|
||||||
return total_deleted
|
return total_deleted
|
||||||
|
|
||||||
|
def create_workflow_pause(
|
||||||
|
self,
|
||||||
|
workflow_run_id: str,
|
||||||
|
state_owner_user_id: str,
|
||||||
|
state: str,
|
||||||
|
) -> WorkflowPauseEntity:
|
||||||
|
"""
|
||||||
|
Create a new workflow pause state.
|
||||||
|
|
||||||
|
Creates a pause state for a workflow run, storing the current execution
|
||||||
|
state and marking the workflow as paused. This is used when a workflow
|
||||||
|
needs to be suspended and later resumed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workflow_run_id: Identifier of the workflow run to pause
|
||||||
|
state_owner_user_id: User ID who owns the pause state for file storage
|
||||||
|
state: Serialized workflow execution state (JSON string)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RepositoryWorkflowPauseEntity representing the created pause state
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If workflow_run_id is invalid or workflow run doesn't exist
|
||||||
|
RuntimeError: If workflow is already paused or in invalid state
|
||||||
|
"""
|
||||||
|
previous_pause_model_query = select(WorkflowPauseModel).where(
|
||||||
|
WorkflowPauseModel.workflow_run_id == workflow_run_id
|
||||||
|
)
|
||||||
|
with self._session_maker() as session, session.begin():
|
||||||
|
# Get the workflow run
|
||||||
|
workflow_run = session.get(WorkflowRun, workflow_run_id)
|
||||||
|
if workflow_run is None:
|
||||||
|
raise ValueError(f"WorkflowRun not found: {workflow_run_id}")
|
||||||
|
|
||||||
|
# Check if workflow is in RUNNING status
|
||||||
|
if workflow_run.status != WorkflowExecutionStatus.RUNNING:
|
||||||
|
raise _WorkflowRunError(
|
||||||
|
f"Only WorkflowRun with RUNNING status can be paused, "
|
||||||
|
f"workflow_run_id={workflow_run_id}, current_status={workflow_run.status}"
|
||||||
|
)
|
||||||
|
#
|
||||||
|
previous_pause = session.scalars(previous_pause_model_query).first()
|
||||||
|
if previous_pause:
|
||||||
|
self._delete_pause_model(session, previous_pause)
|
||||||
|
# we need to flush here to ensure that the old one is actually deleted.
|
||||||
|
session.flush()
|
||||||
|
|
||||||
|
state_obj_key = f"workflow-state-{uuid.uuid4()}.json"
|
||||||
|
storage.save(state_obj_key, state.encode())
|
||||||
|
# Upload the state file
|
||||||
|
|
||||||
|
# Create the pause record
|
||||||
|
pause_model = WorkflowPauseModel()
|
||||||
|
pause_model.id = str(uuidv7())
|
||||||
|
pause_model.workflow_id = workflow_run.workflow_id
|
||||||
|
pause_model.workflow_run_id = workflow_run.id
|
||||||
|
pause_model.state_object_key = state_obj_key
|
||||||
|
pause_model.created_at = naive_utc_now()
|
||||||
|
|
||||||
|
# Update workflow run status
|
||||||
|
workflow_run.status = WorkflowExecutionStatus.PAUSED
|
||||||
|
|
||||||
|
# Save everything in a transaction
|
||||||
|
session.add(pause_model)
|
||||||
|
session.add(workflow_run)
|
||||||
|
|
||||||
|
logger.info("Created workflow pause %s for workflow run %s", pause_model.id, workflow_run_id)
|
||||||
|
|
||||||
|
return _PrivateWorkflowPauseEntity.from_models(pause_model)
|
||||||
|
|
||||||
|
def get_workflow_pause(
|
||||||
|
self,
|
||||||
|
workflow_run_id: str,
|
||||||
|
) -> WorkflowPauseEntity | None:
|
||||||
|
"""
|
||||||
|
Get an existing workflow pause state.
|
||||||
|
|
||||||
|
Retrieves the pause state for a specific workflow run if it exists.
|
||||||
|
Used to check if a workflow is paused and to retrieve its saved state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workflow_run_id: Identifier of the workflow run to get pause state for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RepositoryWorkflowPauseEntity if pause state exists, None otherwise
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If workflow_run_id is invalid
|
||||||
|
"""
|
||||||
|
with self._session_maker() as session:
|
||||||
|
# Query workflow run with pause and state file
|
||||||
|
stmt = select(WorkflowRun).options(selectinload(WorkflowRun.pause)).where(WorkflowRun.id == workflow_run_id)
|
||||||
|
workflow_run = session.scalar(stmt)
|
||||||
|
|
||||||
|
if workflow_run is None:
|
||||||
|
raise ValueError(f"WorkflowRun not found: {workflow_run_id}")
|
||||||
|
|
||||||
|
pause_model = workflow_run.pause
|
||||||
|
if pause_model is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return _PrivateWorkflowPauseEntity.from_models(pause_model)
|
||||||
|
|
||||||
|
def resume_workflow_pause(
|
||||||
|
self,
|
||||||
|
workflow_run_id: str,
|
||||||
|
pause_entity: WorkflowPauseEntity,
|
||||||
|
) -> WorkflowPauseEntity:
|
||||||
|
"""
|
||||||
|
Resume a paused workflow.
|
||||||
|
|
||||||
|
Marks a paused workflow as resumed, clearing the pause state and
|
||||||
|
returning the workflow to running status. Returns the pause entity
|
||||||
|
that was resumed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workflow_run_id: Identifier of the workflow run to resume
|
||||||
|
pause_entity: The pause entity to resume
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RepositoryWorkflowPauseEntity representing the resumed pause state
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If workflow_run_id is invalid
|
||||||
|
RuntimeError: If workflow is not paused or already resumed
|
||||||
|
"""
|
||||||
|
with self._session_maker() as session, session.begin():
|
||||||
|
# Get the workflow run with pause
|
||||||
|
stmt = select(WorkflowRun).options(selectinload(WorkflowRun.pause)).where(WorkflowRun.id == workflow_run_id)
|
||||||
|
workflow_run = session.scalar(stmt)
|
||||||
|
|
||||||
|
if workflow_run is None:
|
||||||
|
raise ValueError(f"WorkflowRun not found: {workflow_run_id}")
|
||||||
|
|
||||||
|
if workflow_run.status != WorkflowExecutionStatus.PAUSED:
|
||||||
|
raise _WorkflowRunError(
|
||||||
|
f"WorkflowRun is not in PAUSED status, workflow_run_id={workflow_run_id}, "
|
||||||
|
f"current_status={workflow_run.status}"
|
||||||
|
)
|
||||||
|
pause_model = workflow_run.pause
|
||||||
|
if pause_model is None:
|
||||||
|
raise _WorkflowRunError(f"No pause state found for workflow run: {workflow_run_id}")
|
||||||
|
|
||||||
|
if pause_model.id != pause_entity.id:
|
||||||
|
raise _WorkflowRunError(
|
||||||
|
"different id in WorkflowPause and WorkflowPauseEntity, "
|
||||||
|
f"WorkflowPause.id={pause_model.id}, "
|
||||||
|
f"WorkflowPauseEntity.id={pause_entity.id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if pause_model.resumed_at is not None:
|
||||||
|
raise _WorkflowRunError(f"Cannot resume an already resumed pause, pause_id={pause_model.id}")
|
||||||
|
|
||||||
|
# Mark as resumed
|
||||||
|
pause_model.resumed_at = naive_utc_now()
|
||||||
|
workflow_run.pause_id = None # type: ignore
|
||||||
|
workflow_run.status = WorkflowExecutionStatus.RUNNING
|
||||||
|
|
||||||
|
session.add(pause_model)
|
||||||
|
session.add(workflow_run)
|
||||||
|
|
||||||
|
logger.info("Resumed workflow pause %s for workflow run %s", pause_model.id, workflow_run_id)
|
||||||
|
|
||||||
|
return _PrivateWorkflowPauseEntity.from_models(pause_model)
|
||||||
|
|
||||||
|
def delete_workflow_pause(
|
||||||
|
self,
|
||||||
|
pause_entity: WorkflowPauseEntity,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Delete a workflow pause state.
|
||||||
|
|
||||||
|
Permanently removes the pause state for a workflow run, including
|
||||||
|
the stored state file. Used for cleanup operations when a paused
|
||||||
|
workflow is no longer needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pause_entity: The pause entity to delete
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If pause_entity is invalid
|
||||||
|
_WorkflowRunError: If workflow is not paused
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This operation is irreversible. The stored workflow state will be
|
||||||
|
permanently deleted along with the pause record.
|
||||||
|
"""
|
||||||
|
with self._session_maker() as session, session.begin():
|
||||||
|
# Get the pause model by ID
|
||||||
|
pause_model = session.get(WorkflowPauseModel, pause_entity.id)
|
||||||
|
if pause_model is None:
|
||||||
|
raise _WorkflowRunError(f"WorkflowPause not found: {pause_entity.id}")
|
||||||
|
self._delete_pause_model(session, pause_model)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _delete_pause_model(session: Session, pause_model: WorkflowPauseModel):
|
||||||
|
storage.delete(pause_model.state_object_key)
|
||||||
|
|
||||||
|
# Delete the pause record
|
||||||
|
session.delete(pause_model)
|
||||||
|
|
||||||
|
logger.info("Deleted workflow pause %s for workflow run %s", pause_model.id, pause_model.workflow_run_id)
|
||||||
|
|
||||||
|
def prune_pauses(
|
||||||
|
self,
|
||||||
|
expiration: datetime,
|
||||||
|
resumption_expiration: datetime,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> Sequence[str]:
|
||||||
|
"""
|
||||||
|
Clean up expired and old pause states.
|
||||||
|
|
||||||
|
Removes pause states that have expired (created before expiration time)
|
||||||
|
and pause states that were resumed more than resumption_duration ago.
|
||||||
|
This is used for maintenance and cleanup operations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expiration: Remove pause states created before this time
|
||||||
|
resumption_expiration: Remove pause states resumed before this time
|
||||||
|
limit: maximum number of records deleted in one call
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a list of ids for pause records that were pruned
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If parameters are invalid
|
||||||
|
"""
|
||||||
|
_limit: int = limit or 1000
|
||||||
|
pruned_record_ids: list[str] = []
|
||||||
|
cond = or_(
|
||||||
|
WorkflowPauseModel.created_at < expiration,
|
||||||
|
and_(
|
||||||
|
WorkflowPauseModel.resumed_at.is_not(null()),
|
||||||
|
WorkflowPauseModel.resumed_at < resumption_expiration,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# First, collect pause records to delete with their state files
|
||||||
|
# Expired pauses (created before expiration time)
|
||||||
|
stmt = select(WorkflowPauseModel).where(cond).limit(_limit)
|
||||||
|
|
||||||
|
with self._session_maker(expire_on_commit=False) as session:
|
||||||
|
# Old resumed pauses (resumed more than resumption_duration ago)
|
||||||
|
|
||||||
|
# Get all records to delete
|
||||||
|
pauses_to_delete = session.scalars(stmt).all()
|
||||||
|
|
||||||
|
# Delete state files from storage
|
||||||
|
for pause in pauses_to_delete:
|
||||||
|
with self._session_maker(expire_on_commit=False) as session, session.begin():
|
||||||
|
# todo: this issues a separate query for each WorkflowPauseModel record.
|
||||||
|
# consider batching this lookup.
|
||||||
|
try:
|
||||||
|
storage.delete(pause.state_object_key)
|
||||||
|
logger.info(
|
||||||
|
"Deleted state object for pause, pause_id=%s, object_key=%s",
|
||||||
|
pause.id,
|
||||||
|
pause.state_object_key,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Failed to delete state file for pause, pause_id=%s, object_key=%s",
|
||||||
|
pause.id,
|
||||||
|
pause.state_object_key,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
session.delete(pause)
|
||||||
|
pruned_record_ids.append(pause.id)
|
||||||
|
logger.info(
|
||||||
|
"workflow pause records deleted, id=%s, resumed_at=%s",
|
||||||
|
pause.id,
|
||||||
|
pause.resumed_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
return pruned_record_ids
|
||||||
|
|
||||||
def get_daily_runs_statistics(
|
def get_daily_runs_statistics(
|
||||||
self,
|
self,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
|
|
@ -510,3 +796,69 @@ GROUP BY
|
||||||
)
|
)
|
||||||
|
|
||||||
return cast(list[AverageInteractionStats], response_data)
|
return cast(list[AverageInteractionStats], response_data)
|
||||||
|
|
||||||
|
|
||||||
|
class _PrivateWorkflowPauseEntity(WorkflowPauseEntity):
|
||||||
|
"""
|
||||||
|
Private implementation of WorkflowPauseEntity for SQLAlchemy repository.
|
||||||
|
|
||||||
|
This implementation is internal to the repository layer and provides
|
||||||
|
the concrete implementation of the WorkflowPauseEntity interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
pause_model: WorkflowPauseModel,
|
||||||
|
) -> None:
|
||||||
|
self._pause_model = pause_model
|
||||||
|
self._cached_state: bytes | None = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_models(cls, workflow_pause_model) -> "_PrivateWorkflowPauseEntity":
|
||||||
|
"""
|
||||||
|
Create a _PrivateWorkflowPauseEntity from database models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workflow_pause_model: The WorkflowPause database model
|
||||||
|
upload_file_model: The UploadFile database model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
_PrivateWorkflowPauseEntity: The constructed entity
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If required model attributes are missing
|
||||||
|
"""
|
||||||
|
return cls(pause_model=workflow_pause_model)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def id(self) -> str:
|
||||||
|
return self._pause_model.id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def workflow_execution_id(self) -> str:
|
||||||
|
return self._pause_model.workflow_run_id
|
||||||
|
|
||||||
|
def get_state(self) -> bytes:
|
||||||
|
"""
|
||||||
|
Retrieve the serialized workflow state from storage.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Mapping[str, Any]: The workflow state as a dictionary
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If the state file cannot be found
|
||||||
|
IOError: If there are issues reading the state file
|
||||||
|
_Workflow: If the state cannot be deserialized properly
|
||||||
|
"""
|
||||||
|
if self._cached_state is not None:
|
||||||
|
return self._cached_state
|
||||||
|
|
||||||
|
# Load the state from storage
|
||||||
|
state_data = storage.load(self._pause_model.state_object_key)
|
||||||
|
self._cached_state = state_data
|
||||||
|
return state_data
|
||||||
|
|
||||||
|
@property
|
||||||
|
def resumed_at(self) -> datetime | None:
|
||||||
|
return self._pause_model.resumed_at
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import threading
|
import threading
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from sqlalchemy import Engine
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
import contexts
|
import contexts
|
||||||
|
|
@ -14,17 +15,26 @@ from models import (
|
||||||
WorkflowRun,
|
WorkflowRun,
|
||||||
WorkflowRunTriggeredFrom,
|
WorkflowRunTriggeredFrom,
|
||||||
)
|
)
|
||||||
|
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||||
from repositories.factory import DifyAPIRepositoryFactory
|
from repositories.factory import DifyAPIRepositoryFactory
|
||||||
|
|
||||||
|
|
||||||
class WorkflowRunService:
|
class WorkflowRunService:
|
||||||
def __init__(self):
|
_session_factory: sessionmaker
|
||||||
|
_workflow_run_repo: APIWorkflowRunRepository
|
||||||
|
|
||||||
|
def __init__(self, session_factory: Engine | sessionmaker | None = None):
|
||||||
"""Initialize WorkflowRunService with repository dependencies."""
|
"""Initialize WorkflowRunService with repository dependencies."""
|
||||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
if session_factory is None:
|
||||||
|
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
|
elif isinstance(session_factory, Engine):
|
||||||
|
session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
|
||||||
|
|
||||||
|
self._session_factory = session_factory
|
||||||
self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
|
self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
|
||||||
session_maker
|
self._session_factory
|
||||||
)
|
)
|
||||||
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_factory)
|
||||||
|
|
||||||
def get_paginate_advanced_chat_workflow_runs(
|
def get_paginate_advanced_chat_workflow_runs(
|
||||||
self, app_model: App, args: dict, triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING
|
self, app_model: App, args: dict, triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
# Core integration tests package
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
# App integration tests package
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
# Layers integration tests package
|
||||||
|
|
@ -0,0 +1,520 @@
|
||||||
|
"""Comprehensive TestContainers-based integration tests for PauseStatePersistenceLayer class.
|
||||||
|
|
||||||
|
This test suite covers complete integration scenarios including:
|
||||||
|
- Real database interactions using containerized PostgreSQL
|
||||||
|
- Real storage operations using test storage backend
|
||||||
|
- Complete workflow: event -> state serialization -> database save -> storage save
|
||||||
|
- Testing with actual WorkflowRunService (not mocked)
|
||||||
|
- Real Workflow and WorkflowRun instances in database
|
||||||
|
- Database transactions and rollback behavior
|
||||||
|
- Actual file upload and retrieval through storage
|
||||||
|
- Workflow status transitions in database
|
||||||
|
- Error handling with real database constraints
|
||||||
|
- Multiple pause events in sequence
|
||||||
|
- Integration with real ReadOnlyGraphRuntimeState implementations
|
||||||
|
|
||||||
|
These tests use TestContainers to spin up real services for integration testing,
|
||||||
|
providing more reliable and realistic test scenarios than mocks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from time import time
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import Engine, delete, select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from core.app.layers.pause_state_persist_layer import PauseStatePersistenceLayer
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
|
from core.workflow.entities.pause_reason import SchedulingPause
|
||||||
|
from core.workflow.enums import WorkflowExecutionStatus
|
||||||
|
from core.workflow.graph_engine.entities.commands import GraphEngineCommand
|
||||||
|
from core.workflow.graph_events.graph import GraphRunPausedEvent
|
||||||
|
from core.workflow.runtime.graph_runtime_state import GraphRuntimeState
|
||||||
|
from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState
|
||||||
|
from core.workflow.runtime.read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper
|
||||||
|
from core.workflow.runtime.variable_pool import SystemVariable, VariablePool
|
||||||
|
from extensions.ext_storage import storage
|
||||||
|
from libs.datetime_utils import naive_utc_now
|
||||||
|
from models import Account
|
||||||
|
from models import WorkflowPause as WorkflowPauseModel
|
||||||
|
from models.model import UploadFile
|
||||||
|
from models.workflow import Workflow, WorkflowRun
|
||||||
|
from services.file_service import FileService
|
||||||
|
from services.workflow_run_service import WorkflowRunService
|
||||||
|
|
||||||
|
|
||||||
|
class _TestCommandChannelImpl:
|
||||||
|
"""Real implementation of CommandChannel for testing."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._commands: list[GraphEngineCommand] = []
|
||||||
|
|
||||||
|
def fetch_commands(self) -> list[GraphEngineCommand]:
|
||||||
|
"""Fetch pending commands for this GraphEngine instance."""
|
||||||
|
return self._commands.copy()
|
||||||
|
|
||||||
|
def send_command(self, command: GraphEngineCommand) -> None:
|
||||||
|
"""Send a command to be processed by this GraphEngine instance."""
|
||||||
|
self._commands.append(command)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPauseStatePersistenceLayerTestContainers:
|
||||||
|
"""Comprehensive TestContainers-based integration tests for PauseStatePersistenceLayer class."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def engine(self, db_session_with_containers: Session):
|
||||||
|
"""Get database engine from TestContainers session."""
|
||||||
|
bind = db_session_with_containers.get_bind()
|
||||||
|
assert isinstance(bind, Engine)
|
||||||
|
return bind
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def file_service(self, engine: Engine):
|
||||||
|
"""Create FileService instance with TestContainers engine."""
|
||||||
|
return FileService(engine)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def workflow_run_service(self, engine: Engine, file_service: FileService):
|
||||||
|
"""Create WorkflowRunService instance with TestContainers engine and FileService."""
|
||||||
|
return WorkflowRunService(engine)
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def setup_test_data(self, db_session_with_containers, file_service, workflow_run_service):
|
||||||
|
"""Set up test data for each test method using TestContainers."""
|
||||||
|
# Create test tenant and account
|
||||||
|
from models.account import Tenant, TenantAccountJoin, TenantAccountRole
|
||||||
|
|
||||||
|
tenant = Tenant(
|
||||||
|
name="Test Tenant",
|
||||||
|
status="normal",
|
||||||
|
)
|
||||||
|
db_session_with_containers.add(tenant)
|
||||||
|
db_session_with_containers.commit()
|
||||||
|
|
||||||
|
account = Account(
|
||||||
|
email="test@example.com",
|
||||||
|
name="Test User",
|
||||||
|
interface_language="en-US",
|
||||||
|
status="active",
|
||||||
|
)
|
||||||
|
db_session_with_containers.add(account)
|
||||||
|
db_session_with_containers.commit()
|
||||||
|
|
||||||
|
# Create tenant-account join
|
||||||
|
tenant_join = TenantAccountJoin(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
account_id=account.id,
|
||||||
|
role=TenantAccountRole.OWNER,
|
||||||
|
current=True,
|
||||||
|
)
|
||||||
|
db_session_with_containers.add(tenant_join)
|
||||||
|
db_session_with_containers.commit()
|
||||||
|
|
||||||
|
# Set test data
|
||||||
|
self.test_tenant_id = tenant.id
|
||||||
|
self.test_user_id = account.id
|
||||||
|
self.test_app_id = str(uuid.uuid4())
|
||||||
|
self.test_workflow_id = str(uuid.uuid4())
|
||||||
|
self.test_workflow_run_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Create test workflow
|
||||||
|
self.test_workflow = Workflow(
|
||||||
|
id=self.test_workflow_id,
|
||||||
|
tenant_id=self.test_tenant_id,
|
||||||
|
app_id=self.test_app_id,
|
||||||
|
type="workflow",
|
||||||
|
version="draft",
|
||||||
|
graph='{"nodes": [], "edges": []}',
|
||||||
|
features='{"file_upload": {"enabled": false}}',
|
||||||
|
created_by=self.test_user_id,
|
||||||
|
created_at=naive_utc_now(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create test workflow run
|
||||||
|
self.test_workflow_run = WorkflowRun(
|
||||||
|
id=self.test_workflow_run_id,
|
||||||
|
tenant_id=self.test_tenant_id,
|
||||||
|
app_id=self.test_app_id,
|
||||||
|
workflow_id=self.test_workflow_id,
|
||||||
|
type="workflow",
|
||||||
|
triggered_from="debugging",
|
||||||
|
version="draft",
|
||||||
|
status=WorkflowExecutionStatus.RUNNING,
|
||||||
|
created_by=self.test_user_id,
|
||||||
|
created_by_role="account",
|
||||||
|
created_at=naive_utc_now(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store session and service instances
|
||||||
|
self.session = db_session_with_containers
|
||||||
|
self.file_service = file_service
|
||||||
|
self.workflow_run_service = workflow_run_service
|
||||||
|
|
||||||
|
# Save test data to database
|
||||||
|
self.session.add(self.test_workflow)
|
||||||
|
self.session.add(self.test_workflow_run)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
self._cleanup_test_data()
|
||||||
|
|
||||||
|
def _cleanup_test_data(self):
|
||||||
|
"""Clean up test data after each test method."""
|
||||||
|
try:
|
||||||
|
# Clean up workflow pauses
|
||||||
|
self.session.execute(delete(WorkflowPauseModel))
|
||||||
|
# Clean up upload files
|
||||||
|
self.session.execute(
|
||||||
|
delete(UploadFile).where(
|
||||||
|
UploadFile.tenant_id == self.test_tenant_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Clean up workflow runs
|
||||||
|
self.session.execute(
|
||||||
|
delete(WorkflowRun).where(
|
||||||
|
WorkflowRun.tenant_id == self.test_tenant_id,
|
||||||
|
WorkflowRun.app_id == self.test_app_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Clean up workflows
|
||||||
|
self.session.execute(
|
||||||
|
delete(Workflow).where(
|
||||||
|
Workflow.tenant_id == self.test_tenant_id,
|
||||||
|
Workflow.app_id == self.test_app_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.session.commit()
|
||||||
|
except Exception as e:
|
||||||
|
self.session.rollback()
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def _create_graph_runtime_state(
|
||||||
|
self,
|
||||||
|
outputs: dict[str, object] | None = None,
|
||||||
|
total_tokens: int = 0,
|
||||||
|
node_run_steps: int = 0,
|
||||||
|
variables: dict[tuple[str, str], object] | None = None,
|
||||||
|
workflow_run_id: str | None = None,
|
||||||
|
) -> ReadOnlyGraphRuntimeState:
|
||||||
|
"""Create a real GraphRuntimeState for testing."""
|
||||||
|
start_at = time()
|
||||||
|
|
||||||
|
execution_id = workflow_run_id or getattr(self, "test_workflow_run_id", None) or str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Create variable pool
|
||||||
|
variable_pool = VariablePool(system_variables=SystemVariable(workflow_execution_id=execution_id))
|
||||||
|
if variables:
|
||||||
|
for (node_id, var_key), value in variables.items():
|
||||||
|
variable_pool.add([node_id, var_key], value)
|
||||||
|
|
||||||
|
# Create LLM usage
|
||||||
|
llm_usage = LLMUsage.empty_usage()
|
||||||
|
|
||||||
|
# Create graph runtime state
|
||||||
|
graph_runtime_state = GraphRuntimeState(
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
start_at=start_at,
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
llm_usage=llm_usage,
|
||||||
|
outputs=outputs or {},
|
||||||
|
node_run_steps=node_run_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ReadOnlyGraphRuntimeStateWrapper(graph_runtime_state)
|
||||||
|
|
||||||
|
def _create_pause_state_persistence_layer(
|
||||||
|
self,
|
||||||
|
workflow_run: WorkflowRun | None = None,
|
||||||
|
workflow: Workflow | None = None,
|
||||||
|
state_owner_user_id: str | None = None,
|
||||||
|
) -> PauseStatePersistenceLayer:
|
||||||
|
"""Create PauseStatePersistenceLayer with real dependencies."""
|
||||||
|
owner_id = state_owner_user_id
|
||||||
|
if owner_id is None:
|
||||||
|
if workflow is not None and workflow.created_by:
|
||||||
|
owner_id = workflow.created_by
|
||||||
|
elif workflow_run is not None and workflow_run.created_by:
|
||||||
|
owner_id = workflow_run.created_by
|
||||||
|
else:
|
||||||
|
owner_id = getattr(self, "test_user_id", None)
|
||||||
|
|
||||||
|
assert owner_id is not None
|
||||||
|
owner_id = str(owner_id)
|
||||||
|
|
||||||
|
return PauseStatePersistenceLayer(
|
||||||
|
session_factory=self.session.get_bind(),
|
||||||
|
state_owner_user_id=owner_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_complete_pause_flow_with_real_dependencies(self, db_session_with_containers):
|
||||||
|
"""Test complete pause flow: event -> state serialization -> database save -> storage save."""
|
||||||
|
# Arrange
|
||||||
|
layer = self._create_pause_state_persistence_layer()
|
||||||
|
|
||||||
|
# Create real graph runtime state with test data
|
||||||
|
test_outputs = {"result": "test_output", "step": "intermediate"}
|
||||||
|
test_variables = {
|
||||||
|
("node1", "var1"): "string_value",
|
||||||
|
("node2", "var2"): {"complex": "object"},
|
||||||
|
}
|
||||||
|
graph_runtime_state = self._create_graph_runtime_state(
|
||||||
|
outputs=test_outputs,
|
||||||
|
total_tokens=100,
|
||||||
|
node_run_steps=5,
|
||||||
|
variables=test_variables,
|
||||||
|
)
|
||||||
|
|
||||||
|
command_channel = _TestCommandChannelImpl()
|
||||||
|
layer.initialize(graph_runtime_state, command_channel)
|
||||||
|
|
||||||
|
# Create pause event
|
||||||
|
event = GraphRunPausedEvent(
|
||||||
|
reason=SchedulingPause(message="test pause"),
|
||||||
|
outputs={"intermediate": "result"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
layer.on_event(event)
|
||||||
|
|
||||||
|
# Assert - Verify pause state was saved to database
|
||||||
|
self.session.refresh(self.test_workflow_run)
|
||||||
|
workflow_run = self.session.get(WorkflowRun, self.test_workflow_run_id)
|
||||||
|
assert workflow_run is not None
|
||||||
|
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
|
||||||
|
|
||||||
|
# Verify pause state exists in database
|
||||||
|
pause_model = self.session.scalars(
|
||||||
|
select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id)
|
||||||
|
).first()
|
||||||
|
assert pause_model is not None
|
||||||
|
assert pause_model.workflow_id == self.test_workflow_id
|
||||||
|
assert pause_model.workflow_run_id == self.test_workflow_run_id
|
||||||
|
assert pause_model.state_object_key != ""
|
||||||
|
assert pause_model.resumed_at is None
|
||||||
|
|
||||||
|
storage_content = storage.load(pause_model.state_object_key).decode()
|
||||||
|
expected_state = json.loads(graph_runtime_state.dumps())
|
||||||
|
actual_state = json.loads(storage_content)
|
||||||
|
|
||||||
|
assert actual_state == expected_state
|
||||||
|
|
||||||
|
def test_state_persistence_and_retrieval(self, db_session_with_containers):
|
||||||
|
"""Test that pause state can be persisted and retrieved correctly."""
|
||||||
|
# Arrange
|
||||||
|
layer = self._create_pause_state_persistence_layer()
|
||||||
|
|
||||||
|
# Create complex test data
|
||||||
|
complex_outputs = {
|
||||||
|
"nested": {"key": "value", "number": 42},
|
||||||
|
"list": [1, 2, 3, {"nested": "item"}],
|
||||||
|
"boolean": True,
|
||||||
|
"null_value": None,
|
||||||
|
}
|
||||||
|
complex_variables = {
|
||||||
|
("node1", "var1"): "string_value",
|
||||||
|
("node2", "var2"): {"complex": "object"},
|
||||||
|
("node3", "var3"): [1, 2, 3],
|
||||||
|
}
|
||||||
|
|
||||||
|
graph_runtime_state = self._create_graph_runtime_state(
|
||||||
|
outputs=complex_outputs,
|
||||||
|
total_tokens=250,
|
||||||
|
node_run_steps=10,
|
||||||
|
variables=complex_variables,
|
||||||
|
)
|
||||||
|
|
||||||
|
command_channel = _TestCommandChannelImpl()
|
||||||
|
layer.initialize(graph_runtime_state, command_channel)
|
||||||
|
|
||||||
|
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
|
||||||
|
|
||||||
|
# Act - Save pause state
|
||||||
|
layer.on_event(event)
|
||||||
|
|
||||||
|
# Assert - Retrieve and verify
|
||||||
|
pause_entity = self.workflow_run_service._workflow_run_repo.get_workflow_pause(self.test_workflow_run_id)
|
||||||
|
assert pause_entity is not None
|
||||||
|
assert pause_entity.workflow_execution_id == self.test_workflow_run_id
|
||||||
|
|
||||||
|
state_bytes = pause_entity.get_state()
|
||||||
|
retrieved_state = json.loads(state_bytes.decode())
|
||||||
|
expected_state = json.loads(graph_runtime_state.dumps())
|
||||||
|
|
||||||
|
assert retrieved_state == expected_state
|
||||||
|
assert retrieved_state["outputs"] == complex_outputs
|
||||||
|
assert retrieved_state["total_tokens"] == 250
|
||||||
|
assert retrieved_state["node_run_steps"] == 10
|
||||||
|
|
||||||
|
def test_database_transaction_handling(self, db_session_with_containers):
|
||||||
|
"""Test that database transactions are handled correctly."""
|
||||||
|
# Arrange
|
||||||
|
layer = self._create_pause_state_persistence_layer()
|
||||||
|
graph_runtime_state = self._create_graph_runtime_state(
|
||||||
|
outputs={"test": "transaction"},
|
||||||
|
total_tokens=50,
|
||||||
|
)
|
||||||
|
|
||||||
|
command_channel = _TestCommandChannelImpl()
|
||||||
|
layer.initialize(graph_runtime_state, command_channel)
|
||||||
|
|
||||||
|
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
|
||||||
|
|
||||||
|
# Act
|
||||||
|
layer.on_event(event)
|
||||||
|
|
||||||
|
# Assert - Verify data is committed and accessible in new session
|
||||||
|
with Session(bind=self.session.get_bind(), expire_on_commit=False) as new_session:
|
||||||
|
workflow_run = new_session.get(WorkflowRun, self.test_workflow_run_id)
|
||||||
|
assert workflow_run is not None
|
||||||
|
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
|
||||||
|
|
||||||
|
pause_model = new_session.scalars(
|
||||||
|
select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id)
|
||||||
|
).first()
|
||||||
|
assert pause_model is not None
|
||||||
|
assert pause_model.workflow_run_id == self.test_workflow_run_id
|
||||||
|
assert pause_model.resumed_at is None
|
||||||
|
assert pause_model.state_object_key != ""
|
||||||
|
|
||||||
|
def test_file_storage_integration(self, db_session_with_containers):
|
||||||
|
"""Test integration with file storage system."""
|
||||||
|
# Arrange
|
||||||
|
layer = self._create_pause_state_persistence_layer()
|
||||||
|
|
||||||
|
# Create large state data to test storage
|
||||||
|
large_outputs = {"data": "x" * 10000} # 10KB of data
|
||||||
|
graph_runtime_state = self._create_graph_runtime_state(
|
||||||
|
outputs=large_outputs,
|
||||||
|
total_tokens=1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
command_channel = _TestCommandChannelImpl()
|
||||||
|
layer.initialize(graph_runtime_state, command_channel)
|
||||||
|
|
||||||
|
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
|
||||||
|
|
||||||
|
# Act
|
||||||
|
layer.on_event(event)
|
||||||
|
|
||||||
|
# Assert - Verify file was uploaded to storage
|
||||||
|
self.session.refresh(self.test_workflow_run)
|
||||||
|
pause_model = self.session.scalars(
|
||||||
|
select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == self.test_workflow_run.id)
|
||||||
|
).first()
|
||||||
|
assert pause_model is not None
|
||||||
|
assert pause_model.state_object_key != ""
|
||||||
|
|
||||||
|
# Verify content in storage
|
||||||
|
storage_content = storage.load(pause_model.state_object_key).decode()
|
||||||
|
assert storage_content == graph_runtime_state.dumps()
|
||||||
|
|
||||||
|
def test_workflow_with_different_creators(self, db_session_with_containers):
|
||||||
|
"""Test pause state with workflows created by different users."""
|
||||||
|
# Arrange - Create workflow with different creator
|
||||||
|
different_user_id = str(uuid.uuid4())
|
||||||
|
different_workflow = Workflow(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
tenant_id=self.test_tenant_id,
|
||||||
|
app_id=self.test_app_id,
|
||||||
|
type="workflow",
|
||||||
|
version="draft",
|
||||||
|
graph='{"nodes": [], "edges": []}',
|
||||||
|
features='{"file_upload": {"enabled": false}}',
|
||||||
|
created_by=different_user_id,
|
||||||
|
created_at=naive_utc_now(),
|
||||||
|
)
|
||||||
|
|
||||||
|
different_workflow_run = WorkflowRun(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
tenant_id=self.test_tenant_id,
|
||||||
|
app_id=self.test_app_id,
|
||||||
|
workflow_id=different_workflow.id,
|
||||||
|
type="workflow",
|
||||||
|
triggered_from="debugging",
|
||||||
|
version="draft",
|
||||||
|
status=WorkflowExecutionStatus.RUNNING,
|
||||||
|
created_by=self.test_user_id, # Run created by different user
|
||||||
|
created_by_role="account",
|
||||||
|
created_at=naive_utc_now(),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.session.add(different_workflow)
|
||||||
|
self.session.add(different_workflow_run)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
layer = self._create_pause_state_persistence_layer(
|
||||||
|
workflow_run=different_workflow_run,
|
||||||
|
workflow=different_workflow,
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_runtime_state = self._create_graph_runtime_state(
|
||||||
|
outputs={"creator_test": "different_creator"},
|
||||||
|
workflow_run_id=different_workflow_run.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
command_channel = _TestCommandChannelImpl()
|
||||||
|
layer.initialize(graph_runtime_state, command_channel)
|
||||||
|
|
||||||
|
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
|
||||||
|
|
||||||
|
# Act
|
||||||
|
layer.on_event(event)
|
||||||
|
|
||||||
|
# Assert - Should use workflow creator (not run creator)
|
||||||
|
self.session.refresh(different_workflow_run)
|
||||||
|
pause_model = self.session.scalars(
|
||||||
|
select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == different_workflow_run.id)
|
||||||
|
).first()
|
||||||
|
assert pause_model is not None
|
||||||
|
|
||||||
|
# Verify the state owner is the workflow creator
|
||||||
|
pause_entity = self.workflow_run_service._workflow_run_repo.get_workflow_pause(different_workflow_run.id)
|
||||||
|
assert pause_entity is not None
|
||||||
|
|
||||||
|
def test_layer_ignores_non_pause_events(self, db_session_with_containers):
|
||||||
|
"""Test that layer ignores non-pause events."""
|
||||||
|
# Arrange
|
||||||
|
layer = self._create_pause_state_persistence_layer()
|
||||||
|
graph_runtime_state = self._create_graph_runtime_state()
|
||||||
|
|
||||||
|
command_channel = _TestCommandChannelImpl()
|
||||||
|
layer.initialize(graph_runtime_state, command_channel)
|
||||||
|
|
||||||
|
# Import other event types
|
||||||
|
from core.workflow.graph_events.graph import (
|
||||||
|
GraphRunFailedEvent,
|
||||||
|
GraphRunStartedEvent,
|
||||||
|
GraphRunSucceededEvent,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act - Send non-pause events
|
||||||
|
layer.on_event(GraphRunStartedEvent())
|
||||||
|
layer.on_event(GraphRunSucceededEvent(outputs={"result": "success"}))
|
||||||
|
layer.on_event(GraphRunFailedEvent(error="test error", exceptions_count=1))
|
||||||
|
|
||||||
|
# Assert - No pause state should be created
|
||||||
|
self.session.refresh(self.test_workflow_run)
|
||||||
|
assert self.test_workflow_run.status == WorkflowExecutionStatus.RUNNING
|
||||||
|
|
||||||
|
pause_states = (
|
||||||
|
self.session.query(WorkflowPauseModel)
|
||||||
|
.filter(WorkflowPauseModel.workflow_run_id == self.test_workflow_run_id)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
assert len(pause_states) == 0
|
||||||
|
|
||||||
|
def test_layer_requires_initialization(self, db_session_with_containers):
|
||||||
|
"""Test that layer requires proper initialization before handling events."""
|
||||||
|
# Arrange
|
||||||
|
layer = self._create_pause_state_persistence_layer()
|
||||||
|
# Don't initialize - graph_runtime_state should not be set
|
||||||
|
|
||||||
|
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
|
||||||
|
|
||||||
|
# Act & Assert - Should raise AttributeError
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
layer.on_event(event)
|
||||||
|
|
@ -0,0 +1,948 @@
|
||||||
|
"""Comprehensive integration tests for workflow pause functionality.
|
||||||
|
|
||||||
|
This test suite covers complete workflow pause functionality including:
|
||||||
|
- Real database interactions using containerized PostgreSQL
|
||||||
|
- Real storage operations using the test storage backend
|
||||||
|
- Complete workflow: create -> pause -> resume -> delete
|
||||||
|
- Testing with actual FileService (not mocked)
|
||||||
|
- Database transactions and rollback behavior
|
||||||
|
- Actual file upload and retrieval through storage
|
||||||
|
- Workflow status transitions in the database
|
||||||
|
- Error handling with real database constraints
|
||||||
|
- Concurrent access scenarios
|
||||||
|
- Multi-tenant isolation
|
||||||
|
- Prune functionality
|
||||||
|
- File storage integration
|
||||||
|
|
||||||
|
These tests use TestContainers to spin up real services for integration testing,
|
||||||
|
providing more reliable and realistic test scenarios than mocks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import delete, select
|
||||||
|
from sqlalchemy.orm import Session, selectinload, sessionmaker
|
||||||
|
|
||||||
|
from core.workflow.entities import WorkflowExecution
|
||||||
|
from core.workflow.enums import WorkflowExecutionStatus
|
||||||
|
from extensions.ext_storage import storage
|
||||||
|
from libs.datetime_utils import naive_utc_now
|
||||||
|
from models import Account
|
||||||
|
from models import WorkflowPause as WorkflowPauseModel
|
||||||
|
from models.account import Tenant, TenantAccountJoin, TenantAccountRole
|
||||||
|
from models.model import UploadFile
|
||||||
|
from models.workflow import Workflow, WorkflowRun
|
||||||
|
from repositories.sqlalchemy_api_workflow_run_repository import (
|
||||||
|
DifyAPISQLAlchemyWorkflowRunRepository,
|
||||||
|
_WorkflowRunError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PauseWorkflowSuccessCase:
|
||||||
|
"""Test case for successful pause workflow operations."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
initial_status: WorkflowExecutionStatus
|
||||||
|
description: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PauseWorkflowFailureCase:
|
||||||
|
"""Test case for pause workflow failure scenarios."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
initial_status: WorkflowExecutionStatus
|
||||||
|
description: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ResumeWorkflowSuccessCase:
|
||||||
|
"""Test case for successful resume workflow operations."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
initial_status: WorkflowExecutionStatus
|
||||||
|
description: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ResumeWorkflowFailureCase:
|
||||||
|
"""Test case for resume workflow failure scenarios."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
initial_status: WorkflowExecutionStatus
|
||||||
|
pause_resumed: bool
|
||||||
|
set_running_status: bool = False
|
||||||
|
description: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PrunePausesTestCase:
|
||||||
|
"""Test case for prune pauses operations."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
pause_age: timedelta
|
||||||
|
resume_age: timedelta | None
|
||||||
|
expected_pruned_count: int
|
||||||
|
description: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
def pause_workflow_failure_cases() -> list[PauseWorkflowFailureCase]:
|
||||||
|
"""Create test cases for pause workflow failure scenarios."""
|
||||||
|
return [
|
||||||
|
PauseWorkflowFailureCase(
|
||||||
|
name="pause_already_paused_workflow",
|
||||||
|
initial_status=WorkflowExecutionStatus.PAUSED,
|
||||||
|
description="Should fail to pause an already paused workflow",
|
||||||
|
),
|
||||||
|
PauseWorkflowFailureCase(
|
||||||
|
name="pause_completed_workflow",
|
||||||
|
initial_status=WorkflowExecutionStatus.SUCCEEDED,
|
||||||
|
description="Should fail to pause a completed workflow",
|
||||||
|
),
|
||||||
|
PauseWorkflowFailureCase(
|
||||||
|
name="pause_failed_workflow",
|
||||||
|
initial_status=WorkflowExecutionStatus.FAILED,
|
||||||
|
description="Should fail to pause a failed workflow",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def resume_workflow_success_cases() -> list[ResumeWorkflowSuccessCase]:
|
||||||
|
"""Create test cases for successful resume workflow operations."""
|
||||||
|
return [
|
||||||
|
ResumeWorkflowSuccessCase(
|
||||||
|
name="resume_paused_workflow",
|
||||||
|
initial_status=WorkflowExecutionStatus.PAUSED,
|
||||||
|
description="Should successfully resume a paused workflow",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def resume_workflow_failure_cases() -> list[ResumeWorkflowFailureCase]:
|
||||||
|
"""Create test cases for resume workflow failure scenarios."""
|
||||||
|
return [
|
||||||
|
ResumeWorkflowFailureCase(
|
||||||
|
name="resume_already_resumed_workflow",
|
||||||
|
initial_status=WorkflowExecutionStatus.PAUSED,
|
||||||
|
pause_resumed=True,
|
||||||
|
description="Should fail to resume an already resumed workflow",
|
||||||
|
),
|
||||||
|
ResumeWorkflowFailureCase(
|
||||||
|
name="resume_running_workflow",
|
||||||
|
initial_status=WorkflowExecutionStatus.RUNNING,
|
||||||
|
pause_resumed=False,
|
||||||
|
set_running_status=True,
|
||||||
|
description="Should fail to resume a running workflow",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def prune_pauses_test_cases() -> list[PrunePausesTestCase]:
|
||||||
|
"""Create test cases for prune pauses operations."""
|
||||||
|
return [
|
||||||
|
PrunePausesTestCase(
|
||||||
|
name="prune_old_active_pauses",
|
||||||
|
pause_age=timedelta(days=7),
|
||||||
|
resume_age=None,
|
||||||
|
expected_pruned_count=1,
|
||||||
|
description="Should prune old active pauses",
|
||||||
|
),
|
||||||
|
PrunePausesTestCase(
|
||||||
|
name="prune_old_resumed_pauses",
|
||||||
|
pause_age=timedelta(hours=12), # Created 12 hours ago (recent)
|
||||||
|
resume_age=timedelta(days=7),
|
||||||
|
expected_pruned_count=1,
|
||||||
|
description="Should prune old resumed pauses",
|
||||||
|
),
|
||||||
|
PrunePausesTestCase(
|
||||||
|
name="keep_recent_active_pauses",
|
||||||
|
pause_age=timedelta(hours=1),
|
||||||
|
resume_age=None,
|
||||||
|
expected_pruned_count=0,
|
||||||
|
description="Should keep recent active pauses",
|
||||||
|
),
|
||||||
|
PrunePausesTestCase(
|
||||||
|
name="keep_recent_resumed_pauses",
|
||||||
|
pause_age=timedelta(days=1),
|
||||||
|
resume_age=timedelta(hours=1),
|
||||||
|
expected_pruned_count=0,
|
||||||
|
description="Should keep recent resumed pauses",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TestWorkflowPauseIntegration:
|
||||||
|
"""Comprehensive integration tests for workflow pause functionality."""
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def setup_test_data(self, db_session_with_containers):
|
||||||
|
"""Set up test data for each test method using TestContainers."""
|
||||||
|
# Create test tenant and account
|
||||||
|
|
||||||
|
tenant = Tenant(
|
||||||
|
name="Test Tenant",
|
||||||
|
status="normal",
|
||||||
|
)
|
||||||
|
db_session_with_containers.add(tenant)
|
||||||
|
db_session_with_containers.commit()
|
||||||
|
|
||||||
|
account = Account(
|
||||||
|
email="test@example.com",
|
||||||
|
name="Test User",
|
||||||
|
interface_language="en-US",
|
||||||
|
status="active",
|
||||||
|
)
|
||||||
|
db_session_with_containers.add(account)
|
||||||
|
db_session_with_containers.commit()
|
||||||
|
|
||||||
|
# Create tenant-account join
|
||||||
|
tenant_join = TenantAccountJoin(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
account_id=account.id,
|
||||||
|
role=TenantAccountRole.OWNER,
|
||||||
|
current=True,
|
||||||
|
)
|
||||||
|
db_session_with_containers.add(tenant_join)
|
||||||
|
db_session_with_containers.commit()
|
||||||
|
|
||||||
|
# Set test data
|
||||||
|
self.test_tenant_id = tenant.id
|
||||||
|
self.test_user_id = account.id
|
||||||
|
self.test_app_id = str(uuid.uuid4())
|
||||||
|
self.test_workflow_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Create test workflow
|
||||||
|
self.test_workflow = Workflow(
|
||||||
|
id=self.test_workflow_id,
|
||||||
|
tenant_id=self.test_tenant_id,
|
||||||
|
app_id=self.test_app_id,
|
||||||
|
type="workflow",
|
||||||
|
version="draft",
|
||||||
|
graph='{"nodes": [], "edges": []}',
|
||||||
|
features='{"file_upload": {"enabled": false}}',
|
||||||
|
created_by=self.test_user_id,
|
||||||
|
created_at=naive_utc_now(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store session instance
|
||||||
|
self.session = db_session_with_containers
|
||||||
|
|
||||||
|
# Save test data to database
|
||||||
|
self.session.add(self.test_workflow)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
self._cleanup_test_data()
|
||||||
|
|
||||||
|
def _cleanup_test_data(self):
|
||||||
|
"""Clean up test data after each test method."""
|
||||||
|
# Clean up workflow pauses
|
||||||
|
self.session.execute(delete(WorkflowPauseModel))
|
||||||
|
# Clean up upload files
|
||||||
|
self.session.execute(
|
||||||
|
delete(UploadFile).where(
|
||||||
|
UploadFile.tenant_id == self.test_tenant_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Clean up workflow runs
|
||||||
|
self.session.execute(
|
||||||
|
delete(WorkflowRun).where(
|
||||||
|
WorkflowRun.tenant_id == self.test_tenant_id,
|
||||||
|
WorkflowRun.app_id == self.test_app_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Clean up workflows
|
||||||
|
self.session.execute(
|
||||||
|
delete(Workflow).where(
|
||||||
|
Workflow.tenant_id == self.test_tenant_id,
|
||||||
|
Workflow.app_id == self.test_app_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
def _create_test_workflow_run(
|
||||||
|
self, status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING
|
||||||
|
) -> WorkflowRun:
|
||||||
|
"""Create a test workflow run with specified status."""
|
||||||
|
workflow_run = WorkflowRun(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
tenant_id=self.test_tenant_id,
|
||||||
|
app_id=self.test_app_id,
|
||||||
|
workflow_id=self.test_workflow_id,
|
||||||
|
type="workflow",
|
||||||
|
triggered_from="debugging",
|
||||||
|
version="draft",
|
||||||
|
status=status,
|
||||||
|
created_by=self.test_user_id,
|
||||||
|
created_by_role="account",
|
||||||
|
created_at=naive_utc_now(),
|
||||||
|
)
|
||||||
|
self.session.add(workflow_run)
|
||||||
|
self.session.commit()
|
||||||
|
return workflow_run
|
||||||
|
|
||||||
|
def _create_test_state(self) -> str:
|
||||||
|
"""Create a test state string."""
|
||||||
|
return json.dumps(
|
||||||
|
{
|
||||||
|
"node_id": "test-node",
|
||||||
|
"node_type": "llm",
|
||||||
|
"status": "paused",
|
||||||
|
"data": {"key": "value"},
|
||||||
|
"timestamp": naive_utc_now().isoformat(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_workflow_run_repository(self):
|
||||||
|
"""Get workflow run repository instance for testing."""
|
||||||
|
# Create session factory from the test session
|
||||||
|
engine = self.session.get_bind()
|
||||||
|
session_factory = sessionmaker(bind=engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
# Create a test-specific repository that implements the missing save method
|
||||||
|
class TestWorkflowRunRepository(DifyAPISQLAlchemyWorkflowRunRepository):
|
||||||
|
"""Test-specific repository that implements the missing save method."""
|
||||||
|
|
||||||
|
def save(self, execution: WorkflowExecution):
|
||||||
|
"""Implement the missing save method for testing."""
|
||||||
|
# For testing purposes, we don't need to implement this method
|
||||||
|
# as it's not used in the pause functionality tests
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Create and return repository instance
|
||||||
|
repository = TestWorkflowRunRepository(session_maker=session_factory)
|
||||||
|
return repository
|
||||||
|
|
||||||
|
# ==================== Complete Pause Workflow Tests ====================
|
||||||
|
|
||||||
|
def test_complete_pause_resume_workflow(self):
|
||||||
|
"""Test complete workflow: create -> pause -> resume -> delete."""
|
||||||
|
# Arrange
|
||||||
|
workflow_run = self._create_test_workflow_run()
|
||||||
|
test_state = self._create_test_state()
|
||||||
|
repository = self._get_workflow_run_repository()
|
||||||
|
|
||||||
|
# Act - Create pause state
|
||||||
|
pause_entity = repository.create_workflow_pause(
|
||||||
|
workflow_run_id=workflow_run.id,
|
||||||
|
state_owner_user_id=self.test_user_id,
|
||||||
|
state=test_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert - Pause state created
|
||||||
|
assert pause_entity is not None
|
||||||
|
assert pause_entity.id is not None
|
||||||
|
assert pause_entity.workflow_execution_id == workflow_run.id
|
||||||
|
# Convert both to strings for comparison
|
||||||
|
retrieved_state = pause_entity.get_state()
|
||||||
|
if isinstance(retrieved_state, bytes):
|
||||||
|
retrieved_state = retrieved_state.decode()
|
||||||
|
assert retrieved_state == test_state
|
||||||
|
|
||||||
|
# Verify database state
|
||||||
|
query = select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id)
|
||||||
|
pause_model = self.session.scalars(query).first()
|
||||||
|
assert pause_model is not None
|
||||||
|
assert pause_model.resumed_at is None
|
||||||
|
assert pause_model.id == pause_entity.id
|
||||||
|
|
||||||
|
self.session.refresh(workflow_run)
|
||||||
|
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
|
||||||
|
|
||||||
|
# Act - Get pause state
|
||||||
|
retrieved_entity = repository.get_workflow_pause(workflow_run.id)
|
||||||
|
|
||||||
|
# Assert - Pause state retrieved
|
||||||
|
assert retrieved_entity is not None
|
||||||
|
assert retrieved_entity.id == pause_entity.id
|
||||||
|
retrieved_state = retrieved_entity.get_state()
|
||||||
|
if isinstance(retrieved_state, bytes):
|
||||||
|
retrieved_state = retrieved_state.decode()
|
||||||
|
assert retrieved_state == test_state
|
||||||
|
|
||||||
|
# Act - Resume workflow
|
||||||
|
resumed_entity = repository.resume_workflow_pause(
|
||||||
|
workflow_run_id=workflow_run.id,
|
||||||
|
pause_entity=pause_entity,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert - Workflow resumed
|
||||||
|
assert resumed_entity is not None
|
||||||
|
assert resumed_entity.id == pause_entity.id
|
||||||
|
assert resumed_entity.resumed_at is not None
|
||||||
|
|
||||||
|
# Verify database state
|
||||||
|
self.session.refresh(workflow_run)
|
||||||
|
assert workflow_run.status == WorkflowExecutionStatus.RUNNING
|
||||||
|
self.session.refresh(pause_model)
|
||||||
|
assert pause_model.resumed_at is not None
|
||||||
|
|
||||||
|
# Act - Delete pause state
|
||||||
|
repository.delete_workflow_pause(pause_entity)
|
||||||
|
|
||||||
|
# Assert - Pause state deleted
|
||||||
|
with Session(bind=self.session.get_bind()) as session:
|
||||||
|
deleted_pause = session.get(WorkflowPauseModel, pause_entity.id)
|
||||||
|
assert deleted_pause is None
|
||||||
|
|
||||||
|
def test_pause_workflow_success(self):
|
||||||
|
"""Test successful pause workflow scenarios."""
|
||||||
|
workflow_run = self._create_test_workflow_run(status=WorkflowExecutionStatus.RUNNING)
|
||||||
|
test_state = self._create_test_state()
|
||||||
|
repository = self._get_workflow_run_repository()
|
||||||
|
|
||||||
|
pause_entity = repository.create_workflow_pause(
|
||||||
|
workflow_run_id=workflow_run.id,
|
||||||
|
state_owner_user_id=self.test_user_id,
|
||||||
|
state=test_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert pause_entity is not None
|
||||||
|
assert pause_entity.workflow_execution_id == workflow_run.id
|
||||||
|
|
||||||
|
retrieved_state = pause_entity.get_state()
|
||||||
|
if isinstance(retrieved_state, bytes):
|
||||||
|
retrieved_state = retrieved_state.decode()
|
||||||
|
assert retrieved_state == test_state
|
||||||
|
|
||||||
|
self.session.refresh(workflow_run)
|
||||||
|
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
|
||||||
|
pause_query = select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id)
|
||||||
|
pause_model = self.session.scalars(pause_query).first()
|
||||||
|
assert pause_model is not None
|
||||||
|
assert pause_model.id == pause_entity.id
|
||||||
|
assert pause_model.resumed_at is None
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("test_case", pause_workflow_failure_cases(), ids=lambda tc: tc.name)
|
||||||
|
def test_pause_workflow_failure(self, test_case: PauseWorkflowFailureCase):
|
||||||
|
"""Test pause workflow failure scenarios."""
|
||||||
|
workflow_run = self._create_test_workflow_run(status=test_case.initial_status)
|
||||||
|
test_state = self._create_test_state()
|
||||||
|
repository = self._get_workflow_run_repository()
|
||||||
|
|
||||||
|
with pytest.raises(_WorkflowRunError):
|
||||||
|
repository.create_workflow_pause(
|
||||||
|
workflow_run_id=workflow_run.id,
|
||||||
|
state_owner_user_id=self.test_user_id,
|
||||||
|
state=test_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("test_case", resume_workflow_success_cases(), ids=lambda tc: tc.name)
|
||||||
|
def test_resume_workflow_success(self, test_case: ResumeWorkflowSuccessCase):
|
||||||
|
"""Test successful resume workflow scenarios."""
|
||||||
|
workflow_run = self._create_test_workflow_run(status=test_case.initial_status)
|
||||||
|
test_state = self._create_test_state()
|
||||||
|
repository = self._get_workflow_run_repository()
|
||||||
|
|
||||||
|
if workflow_run.status != WorkflowExecutionStatus.RUNNING:
|
||||||
|
workflow_run.status = WorkflowExecutionStatus.RUNNING
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
pause_entity = repository.create_workflow_pause(
|
||||||
|
workflow_run_id=workflow_run.id,
|
||||||
|
state_owner_user_id=self.test_user_id,
|
||||||
|
state=test_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.session.refresh(workflow_run)
|
||||||
|
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
|
||||||
|
|
||||||
|
resumed_entity = repository.resume_workflow_pause(
|
||||||
|
workflow_run_id=workflow_run.id,
|
||||||
|
pause_entity=pause_entity,
|
||||||
|
)
|
||||||
|
assert resumed_entity is not None
|
||||||
|
assert resumed_entity.id == pause_entity.id
|
||||||
|
assert resumed_entity.resumed_at is not None
|
||||||
|
|
||||||
|
self.session.refresh(workflow_run)
|
||||||
|
assert workflow_run.status == WorkflowExecutionStatus.RUNNING
|
||||||
|
pause_query = select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id)
|
||||||
|
pause_model = self.session.scalars(pause_query).first()
|
||||||
|
assert pause_model is not None
|
||||||
|
assert pause_model.id == pause_entity.id
|
||||||
|
assert pause_model.resumed_at is not None
|
||||||
|
|
||||||
|
def test_resume_running_workflow(self):
|
||||||
|
"""Test resume workflow failure scenarios."""
|
||||||
|
workflow_run = self._create_test_workflow_run(status=WorkflowExecutionStatus.RUNNING)
|
||||||
|
test_state = self._create_test_state()
|
||||||
|
repository = self._get_workflow_run_repository()
|
||||||
|
|
||||||
|
pause_entity = repository.create_workflow_pause(
|
||||||
|
workflow_run_id=workflow_run.id,
|
||||||
|
state_owner_user_id=self.test_user_id,
|
||||||
|
state=test_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.session.refresh(workflow_run)
|
||||||
|
workflow_run.status = WorkflowExecutionStatus.RUNNING
|
||||||
|
self.session.add(workflow_run)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
with pytest.raises(_WorkflowRunError):
|
||||||
|
repository.resume_workflow_pause(
|
||||||
|
workflow_run_id=workflow_run.id,
|
||||||
|
pause_entity=pause_entity,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_resume_resumed_pause(self):
|
||||||
|
"""Test resume workflow failure scenarios."""
|
||||||
|
workflow_run = self._create_test_workflow_run(status=WorkflowExecutionStatus.RUNNING)
|
||||||
|
test_state = self._create_test_state()
|
||||||
|
repository = self._get_workflow_run_repository()
|
||||||
|
|
||||||
|
pause_entity = repository.create_workflow_pause(
|
||||||
|
workflow_run_id=workflow_run.id,
|
||||||
|
state_owner_user_id=self.test_user_id,
|
||||||
|
state=test_state,
|
||||||
|
)
|
||||||
|
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||||
|
pause_model.resumed_at = naive_utc_now()
|
||||||
|
self.session.add(pause_model)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
with pytest.raises(_WorkflowRunError):
|
||||||
|
repository.resume_workflow_pause(
|
||||||
|
workflow_run_id=workflow_run.id,
|
||||||
|
pause_entity=pause_entity,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ==================== Error Scenario Tests ====================
|
||||||
|
|
||||||
|
def test_pause_nonexistent_workflow_run(self):
|
||||||
|
"""Test pausing a non-existent workflow run."""
|
||||||
|
# Arrange
|
||||||
|
nonexistent_id = str(uuid.uuid4())
|
||||||
|
test_state = self._create_test_state()
|
||||||
|
repository = self._get_workflow_run_repository()
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(ValueError, match="WorkflowRun not found"):
|
||||||
|
repository.create_workflow_pause(
|
||||||
|
workflow_run_id=nonexistent_id,
|
||||||
|
state_owner_user_id=self.test_user_id,
|
||||||
|
state=test_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_resume_nonexistent_workflow_run(self):
|
||||||
|
"""Test resuming a non-existent workflow run."""
|
||||||
|
# Arrange
|
||||||
|
workflow_run = self._create_test_workflow_run()
|
||||||
|
test_state = self._create_test_state()
|
||||||
|
repository = self._get_workflow_run_repository()
|
||||||
|
|
||||||
|
pause_entity = repository.create_workflow_pause(
|
||||||
|
workflow_run_id=workflow_run.id,
|
||||||
|
state_owner_user_id=self.test_user_id,
|
||||||
|
state=test_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
nonexistent_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(ValueError, match="WorkflowRun not found"):
|
||||||
|
repository.resume_workflow_pause(
|
||||||
|
workflow_run_id=nonexistent_id,
|
||||||
|
pause_entity=pause_entity,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ==================== Prune Functionality Tests ====================
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("test_case", prune_pauses_test_cases(), ids=lambda tc: tc.name)
|
||||||
|
def test_prune_pauses_scenarios(self, test_case: PrunePausesTestCase):
|
||||||
|
"""Test various prune pauses scenarios."""
|
||||||
|
now = naive_utc_now()
|
||||||
|
|
||||||
|
# Create pause state
|
||||||
|
workflow_run = self._create_test_workflow_run()
|
||||||
|
test_state = self._create_test_state()
|
||||||
|
repository = self._get_workflow_run_repository()
|
||||||
|
|
||||||
|
pause_entity = repository.create_workflow_pause(
|
||||||
|
workflow_run_id=workflow_run.id,
|
||||||
|
state_owner_user_id=self.test_user_id,
|
||||||
|
state=test_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Manually adjust timestamps for testing
|
||||||
|
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||||
|
pause_model.created_at = now - test_case.pause_age
|
||||||
|
|
||||||
|
if test_case.resume_age is not None:
|
||||||
|
# Resume pause and adjust resume time
|
||||||
|
repository.resume_workflow_pause(
|
||||||
|
workflow_run_id=workflow_run.id,
|
||||||
|
pause_entity=pause_entity,
|
||||||
|
)
|
||||||
|
# Need to refresh to get the updated model
|
||||||
|
self.session.refresh(pause_model)
|
||||||
|
# Manually set the resumed_at to an older time for testing
|
||||||
|
pause_model.resumed_at = now - test_case.resume_age
|
||||||
|
self.session.commit() # Commit the resumed_at change
|
||||||
|
# Refresh again to ensure the change is persisted
|
||||||
|
self.session.refresh(pause_model)
|
||||||
|
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
# Act - Prune pauses
|
||||||
|
expiration_time = now - timedelta(days=1, seconds=1) # Expire pauses older than 1 day (plus 1 second)
|
||||||
|
resumption_time = now - timedelta(
|
||||||
|
days=7, seconds=1
|
||||||
|
) # Clean up pauses resumed more than 7 days ago (plus 1 second)
|
||||||
|
|
||||||
|
# Debug: Check pause state before pruning
|
||||||
|
self.session.refresh(pause_model)
|
||||||
|
print(f"Pause created_at: {pause_model.created_at}")
|
||||||
|
print(f"Pause resumed_at: {pause_model.resumed_at}")
|
||||||
|
print(f"Expiration time: {expiration_time}")
|
||||||
|
print(f"Resumption time: {resumption_time}")
|
||||||
|
|
||||||
|
# Force commit to ensure timestamps are saved
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
# Determine if the pause should be pruned based on timestamps
|
||||||
|
should_be_pruned = False
|
||||||
|
if test_case.resume_age is not None:
|
||||||
|
# If resumed, check if resumed_at is older than resumption_time
|
||||||
|
should_be_pruned = pause_model.resumed_at < resumption_time
|
||||||
|
else:
|
||||||
|
# If not resumed, check if created_at is older than expiration_time
|
||||||
|
should_be_pruned = pause_model.created_at < expiration_time
|
||||||
|
|
||||||
|
# Act - Prune pauses
|
||||||
|
pruned_ids = repository.prune_pauses(
|
||||||
|
expiration=expiration_time,
|
||||||
|
resumption_expiration=resumption_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert - Check pruning results
|
||||||
|
if should_be_pruned:
|
||||||
|
assert len(pruned_ids) == test_case.expected_pruned_count
|
||||||
|
# Verify pause was actually deleted
|
||||||
|
# The pause should be in the pruned_ids list if it was pruned
|
||||||
|
assert pause_entity.id in pruned_ids
|
||||||
|
else:
|
||||||
|
assert len(pruned_ids) == 0
|
||||||
|
|
||||||
|
def test_prune_pauses_with_limit(self):
|
||||||
|
"""Test prune pauses with limit parameter."""
|
||||||
|
now = naive_utc_now()
|
||||||
|
|
||||||
|
# Create multiple pause states
|
||||||
|
pause_entities = []
|
||||||
|
repository = self._get_workflow_run_repository()
|
||||||
|
|
||||||
|
for i in range(5):
|
||||||
|
workflow_run = self._create_test_workflow_run()
|
||||||
|
test_state = self._create_test_state()
|
||||||
|
|
||||||
|
pause_entity = repository.create_workflow_pause(
|
||||||
|
workflow_run_id=workflow_run.id,
|
||||||
|
state_owner_user_id=self.test_user_id,
|
||||||
|
state=test_state,
|
||||||
|
)
|
||||||
|
pause_entities.append(pause_entity)
|
||||||
|
|
||||||
|
# Make all pauses old enough to be pruned
|
||||||
|
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||||
|
pause_model.created_at = now - timedelta(days=7)
|
||||||
|
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
# Act - Prune with limit
|
||||||
|
expiration_time = now - timedelta(days=1)
|
||||||
|
resumption_time = now - timedelta(days=7)
|
||||||
|
|
||||||
|
pruned_ids = repository.prune_pauses(
|
||||||
|
expiration=expiration_time,
|
||||||
|
resumption_expiration=resumption_time,
|
||||||
|
limit=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(pruned_ids) == 3
|
||||||
|
|
||||||
|
# Verify only 3 were deleted
|
||||||
|
remaining_count = (
|
||||||
|
self.session.query(WorkflowPauseModel)
|
||||||
|
.filter(WorkflowPauseModel.id.in_([pe.id for pe in pause_entities]))
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
assert remaining_count == 2
|
||||||
|
|
||||||
|
# ==================== Multi-tenant Isolation Tests ====================
|
||||||
|
|
||||||
|
def test_multi_tenant_pause_isolation(self):
|
||||||
|
"""Test that pause states are properly isolated by tenant."""
|
||||||
|
# Arrange - Create second tenant
|
||||||
|
|
||||||
|
tenant2 = Tenant(
|
||||||
|
name="Test Tenant 2",
|
||||||
|
status="normal",
|
||||||
|
)
|
||||||
|
self.session.add(tenant2)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
account2 = Account(
|
||||||
|
email="test2@example.com",
|
||||||
|
name="Test User 2",
|
||||||
|
interface_language="en-US",
|
||||||
|
status="active",
|
||||||
|
)
|
||||||
|
self.session.add(account2)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
tenant2_join = TenantAccountJoin(
|
||||||
|
tenant_id=tenant2.id,
|
||||||
|
account_id=account2.id,
|
||||||
|
role=TenantAccountRole.OWNER,
|
||||||
|
current=True,
|
||||||
|
)
|
||||||
|
self.session.add(tenant2_join)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
# Create workflow for tenant 2
|
||||||
|
workflow2 = Workflow(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
tenant_id=tenant2.id,
|
||||||
|
app_id=str(uuid.uuid4()),
|
||||||
|
type="workflow",
|
||||||
|
version="draft",
|
||||||
|
graph='{"nodes": [], "edges": []}',
|
||||||
|
features='{"file_upload": {"enabled": false}}',
|
||||||
|
created_by=account2.id,
|
||||||
|
created_at=naive_utc_now(),
|
||||||
|
)
|
||||||
|
self.session.add(workflow2)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
# Create workflow runs for both tenants
|
||||||
|
workflow_run1 = self._create_test_workflow_run()
|
||||||
|
workflow_run2 = WorkflowRun(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
tenant_id=tenant2.id,
|
||||||
|
app_id=workflow2.app_id,
|
||||||
|
workflow_id=workflow2.id,
|
||||||
|
type="workflow",
|
||||||
|
triggered_from="debugging",
|
||||||
|
version="draft",
|
||||||
|
status=WorkflowExecutionStatus.RUNNING,
|
||||||
|
created_by=account2.id,
|
||||||
|
created_by_role="account",
|
||||||
|
created_at=naive_utc_now(),
|
||||||
|
)
|
||||||
|
self.session.add(workflow_run2)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
test_state = self._create_test_state()
|
||||||
|
repository = self._get_workflow_run_repository()
|
||||||
|
|
||||||
|
# Act - Create pause for tenant 1
|
||||||
|
pause_entity1 = repository.create_workflow_pause(
|
||||||
|
workflow_run_id=workflow_run1.id,
|
||||||
|
state_owner_user_id=self.test_user_id,
|
||||||
|
state=test_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to access pause from tenant 2 using tenant 1's repository
|
||||||
|
# This should work because we're using the same repository
|
||||||
|
pause_entity2 = repository.get_workflow_pause(workflow_run2.id)
|
||||||
|
assert pause_entity2 is None # No pause for tenant 2 yet
|
||||||
|
|
||||||
|
# Create pause for tenant 2
|
||||||
|
pause_entity2 = repository.create_workflow_pause(
|
||||||
|
workflow_run_id=workflow_run2.id,
|
||||||
|
state_owner_user_id=account2.id,
|
||||||
|
state=test_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert - Both pauses should exist and be separate
|
||||||
|
assert pause_entity1 is not None
|
||||||
|
assert pause_entity2 is not None
|
||||||
|
assert pause_entity1.id != pause_entity2.id
|
||||||
|
assert pause_entity1.workflow_execution_id != pause_entity2.workflow_execution_id
|
||||||
|
|
||||||
|
def test_cross_tenant_access_restriction(self):
|
||||||
|
"""Test that cross-tenant access is properly restricted."""
|
||||||
|
# This test would require tenant-specific repositories
|
||||||
|
# For now, we test that pause entities are properly scoped by tenant_id
|
||||||
|
workflow_run = self._create_test_workflow_run()
|
||||||
|
test_state = self._create_test_state()
|
||||||
|
repository = self._get_workflow_run_repository()
|
||||||
|
|
||||||
|
pause_entity = repository.create_workflow_pause(
|
||||||
|
workflow_run_id=workflow_run.id,
|
||||||
|
state_owner_user_id=self.test_user_id,
|
||||||
|
state=test_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify pause is properly scoped
|
||||||
|
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||||
|
assert pause_model.workflow_id == self.test_workflow_id
|
||||||
|
|
||||||
|
# ==================== File Storage Integration Tests ====================
|
||||||
|
|
||||||
|
def test_file_storage_integration(self):
|
||||||
|
"""Test that state files are properly stored and retrieved."""
|
||||||
|
# Arrange
|
||||||
|
workflow_run = self._create_test_workflow_run()
|
||||||
|
test_state = self._create_test_state()
|
||||||
|
repository = self._get_workflow_run_repository()
|
||||||
|
|
||||||
|
# Act - Create pause state
|
||||||
|
pause_entity = repository.create_workflow_pause(
|
||||||
|
workflow_run_id=workflow_run.id,
|
||||||
|
state_owner_user_id=self.test_user_id,
|
||||||
|
state=test_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert - Verify file was uploaded to storage
|
||||||
|
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||||
|
assert pause_model.state_object_key != ""
|
||||||
|
|
||||||
|
# Verify file content in storage
|
||||||
|
|
||||||
|
file_key = pause_model.state_object_key
|
||||||
|
storage_content = storage.load(file_key).decode()
|
||||||
|
assert storage_content == test_state
|
||||||
|
|
||||||
|
# Verify retrieval through entity
|
||||||
|
retrieved_state = pause_entity.get_state()
|
||||||
|
if isinstance(retrieved_state, bytes):
|
||||||
|
retrieved_state = retrieved_state.decode()
|
||||||
|
assert retrieved_state == test_state
|
||||||
|
|
||||||
|
def test_file_cleanup_on_pause_deletion(self):
|
||||||
|
"""Test that files are properly handled on pause deletion."""
|
||||||
|
# Arrange
|
||||||
|
workflow_run = self._create_test_workflow_run()
|
||||||
|
test_state = self._create_test_state()
|
||||||
|
repository = self._get_workflow_run_repository()
|
||||||
|
|
||||||
|
pause_entity = repository.create_workflow_pause(
|
||||||
|
workflow_run_id=workflow_run.id,
|
||||||
|
state_owner_user_id=self.test_user_id,
|
||||||
|
state=test_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get file info before deletion
|
||||||
|
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||||
|
file_key = pause_model.state_object_key
|
||||||
|
|
||||||
|
# Act - Delete pause state
|
||||||
|
repository.delete_workflow_pause(pause_entity)
|
||||||
|
|
||||||
|
# Assert - Pause record should be deleted
|
||||||
|
self.session.expire_all() # Clear session to ensure fresh query
|
||||||
|
deleted_pause = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||||
|
assert deleted_pause is None
|
||||||
|
|
||||||
|
try:
|
||||||
|
content = storage.load(file_key).decode()
|
||||||
|
pytest.fail("File should be deleted from storage after pause deletion")
|
||||||
|
except FileNotFoundError:
|
||||||
|
# This is expected - file should be deleted from storage
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Unexpected error when checking file deletion: {e}")
|
||||||
|
|
||||||
|
def test_large_state_file_handling(self):
|
||||||
|
"""Test handling of large state files."""
|
||||||
|
# Arrange - Create a large state (1MB)
|
||||||
|
large_state = "x" * (1024 * 1024) # 1MB of data
|
||||||
|
large_state_json = json.dumps({"large_data": large_state})
|
||||||
|
|
||||||
|
workflow_run = self._create_test_workflow_run()
|
||||||
|
repository = self._get_workflow_run_repository()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
pause_entity = repository.create_workflow_pause(
|
||||||
|
workflow_run_id=workflow_run.id,
|
||||||
|
state_owner_user_id=self.test_user_id,
|
||||||
|
state=large_state_json,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert pause_entity is not None
|
||||||
|
retrieved_state = pause_entity.get_state()
|
||||||
|
if isinstance(retrieved_state, bytes):
|
||||||
|
retrieved_state = retrieved_state.decode()
|
||||||
|
assert retrieved_state == large_state_json
|
||||||
|
|
||||||
|
# Verify file size in database
|
||||||
|
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||||
|
assert pause_model.state_object_key != ""
|
||||||
|
loaded_state = storage.load(pause_model.state_object_key)
|
||||||
|
assert loaded_state.decode() == large_state_json
|
||||||
|
|
||||||
|
def test_multiple_pause_resume_cycles(self):
|
||||||
|
"""Test multiple pause/resume cycles on the same workflow run."""
|
||||||
|
# Arrange
|
||||||
|
workflow_run = self._create_test_workflow_run()
|
||||||
|
repository = self._get_workflow_run_repository()
|
||||||
|
|
||||||
|
# Act & Assert - Multiple cycles
|
||||||
|
for i in range(3):
|
||||||
|
state = json.dumps({"cycle": i, "data": f"state_{i}"})
|
||||||
|
|
||||||
|
# Reset workflow run status to RUNNING before each pause (after first cycle)
|
||||||
|
if i > 0:
|
||||||
|
self.session.refresh(workflow_run) # Refresh to get latest state from session
|
||||||
|
workflow_run.status = WorkflowExecutionStatus.RUNNING
|
||||||
|
self.session.commit()
|
||||||
|
self.session.refresh(workflow_run) # Refresh again after commit
|
||||||
|
|
||||||
|
# Pause
|
||||||
|
pause_entity = repository.create_workflow_pause(
|
||||||
|
workflow_run_id=workflow_run.id,
|
||||||
|
state_owner_user_id=self.test_user_id,
|
||||||
|
state=state,
|
||||||
|
)
|
||||||
|
assert pause_entity is not None
|
||||||
|
|
||||||
|
# Verify pause
|
||||||
|
self.session.expire_all() # Clear session to ensure fresh query
|
||||||
|
self.session.refresh(workflow_run)
|
||||||
|
|
||||||
|
# Use the test session directly to verify the pause
|
||||||
|
stmt = select(WorkflowRun).options(selectinload(WorkflowRun.pause)).where(WorkflowRun.id == workflow_run.id)
|
||||||
|
workflow_run_with_pause = self.session.scalar(stmt)
|
||||||
|
pause_model = workflow_run_with_pause.pause
|
||||||
|
|
||||||
|
# Verify pause using test session directly
|
||||||
|
assert pause_model is not None
|
||||||
|
assert pause_model.id == pause_entity.id
|
||||||
|
assert pause_model.state_object_key != ""
|
||||||
|
|
||||||
|
# Load file content using storage directly
|
||||||
|
file_content = storage.load(pause_model.state_object_key)
|
||||||
|
if isinstance(file_content, bytes):
|
||||||
|
file_content = file_content.decode()
|
||||||
|
assert file_content == state
|
||||||
|
|
||||||
|
# Resume
|
||||||
|
resumed_entity = repository.resume_workflow_pause(
|
||||||
|
workflow_run_id=workflow_run.id,
|
||||||
|
pause_entity=pause_entity,
|
||||||
|
)
|
||||||
|
assert resumed_entity is not None
|
||||||
|
assert resumed_entity.resumed_at is not None
|
||||||
|
|
||||||
|
# Verify resume - check that pause is marked as resumed
|
||||||
|
self.session.expire_all() # Clear session to ensure fresh query
|
||||||
|
stmt = select(WorkflowPauseModel).where(WorkflowPauseModel.id == pause_entity.id)
|
||||||
|
resumed_pause_model = self.session.scalar(stmt)
|
||||||
|
assert resumed_pause_model is not None
|
||||||
|
assert resumed_pause_model.resumed_at is not None
|
||||||
|
|
||||||
|
# Verify workflow run status
|
||||||
|
self.session.refresh(workflow_run)
|
||||||
|
assert workflow_run.status == WorkflowExecutionStatus.RUNNING
|
||||||
|
|
@ -0,0 +1,278 @@
|
||||||
|
import json
|
||||||
|
from time import time
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.app.layers.pause_state_persist_layer import PauseStatePersistenceLayer
|
||||||
|
from core.variables.segments import Segment
|
||||||
|
from core.workflow.entities.pause_reason import SchedulingPause
|
||||||
|
from core.workflow.graph_engine.entities.commands import GraphEngineCommand
|
||||||
|
from core.workflow.graph_events.graph import (
|
||||||
|
GraphRunFailedEvent,
|
||||||
|
GraphRunPausedEvent,
|
||||||
|
GraphRunStartedEvent,
|
||||||
|
GraphRunSucceededEvent,
|
||||||
|
)
|
||||||
|
from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyVariablePool
|
||||||
|
from repositories.factory import DifyAPIRepositoryFactory
|
||||||
|
|
||||||
|
|
||||||
|
class TestDataFactory:
|
||||||
|
"""Factory helpers for constructing graph events used in tests."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_graph_run_paused_event(outputs: dict[str, object] | None = None) -> GraphRunPausedEvent:
|
||||||
|
return GraphRunPausedEvent(reason=SchedulingPause(message="test pause"), outputs=outputs or {})
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_graph_run_started_event() -> GraphRunStartedEvent:
|
||||||
|
return GraphRunStartedEvent()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_graph_run_succeeded_event(outputs: dict[str, object] | None = None) -> GraphRunSucceededEvent:
|
||||||
|
return GraphRunSucceededEvent(outputs=outputs or {})
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_graph_run_failed_event(
|
||||||
|
error: str = "Test error",
|
||||||
|
exceptions_count: int = 1,
|
||||||
|
) -> GraphRunFailedEvent:
|
||||||
|
return GraphRunFailedEvent(error=error, exceptions_count=exceptions_count)
|
||||||
|
|
||||||
|
|
||||||
|
class MockSystemVariableReadOnlyView:
|
||||||
|
"""Minimal read-only system variable view for testing."""
|
||||||
|
|
||||||
|
def __init__(self, workflow_execution_id: str | None = None) -> None:
|
||||||
|
self._workflow_execution_id = workflow_execution_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def workflow_execution_id(self) -> str | None:
|
||||||
|
return self._workflow_execution_id
|
||||||
|
|
||||||
|
|
||||||
|
class MockReadOnlyVariablePool:
|
||||||
|
"""Mock implementation of ReadOnlyVariablePool for testing."""
|
||||||
|
|
||||||
|
def __init__(self, variables: dict[tuple[str, str], object] | None = None):
|
||||||
|
self._variables = variables or {}
|
||||||
|
|
||||||
|
def get(self, node_id: str, variable_key: str) -> Segment | None:
|
||||||
|
value = self._variables.get((node_id, variable_key))
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
mock_segment = Mock(spec=Segment)
|
||||||
|
mock_segment.value = value
|
||||||
|
return mock_segment
|
||||||
|
|
||||||
|
def get_all_by_node(self, node_id: str) -> dict[str, object]:
|
||||||
|
return {key: value for (nid, key), value in self._variables.items() if nid == node_id}
|
||||||
|
|
||||||
|
def get_by_prefix(self, prefix: str) -> dict[str, object]:
|
||||||
|
return {f"{nid}.{key}": value for (nid, key), value in self._variables.items() if nid.startswith(prefix)}
|
||||||
|
|
||||||
|
|
||||||
|
class MockReadOnlyGraphRuntimeState:
|
||||||
|
"""Mock implementation of ReadOnlyGraphRuntimeState for testing."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
start_at: float | None = None,
|
||||||
|
total_tokens: int = 0,
|
||||||
|
node_run_steps: int = 0,
|
||||||
|
ready_queue_size: int = 0,
|
||||||
|
exceptions_count: int = 0,
|
||||||
|
outputs: dict[str, object] | None = None,
|
||||||
|
variables: dict[tuple[str, str], object] | None = None,
|
||||||
|
workflow_execution_id: str | None = None,
|
||||||
|
):
|
||||||
|
self._start_at = start_at or time()
|
||||||
|
self._total_tokens = total_tokens
|
||||||
|
self._node_run_steps = node_run_steps
|
||||||
|
self._ready_queue_size = ready_queue_size
|
||||||
|
self._exceptions_count = exceptions_count
|
||||||
|
self._outputs = outputs or {}
|
||||||
|
self._variable_pool = MockReadOnlyVariablePool(variables)
|
||||||
|
self._system_variable = MockSystemVariableReadOnlyView(workflow_execution_id)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def system_variable(self) -> MockSystemVariableReadOnlyView:
|
||||||
|
return self._system_variable
|
||||||
|
|
||||||
|
@property
|
||||||
|
def variable_pool(self) -> ReadOnlyVariablePool:
|
||||||
|
return self._variable_pool
|
||||||
|
|
||||||
|
@property
|
||||||
|
def start_at(self) -> float:
|
||||||
|
return self._start_at
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_tokens(self) -> int:
|
||||||
|
return self._total_tokens
|
||||||
|
|
||||||
|
@property
|
||||||
|
def node_run_steps(self) -> int:
|
||||||
|
return self._node_run_steps
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ready_queue_size(self) -> int:
|
||||||
|
return self._ready_queue_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def exceptions_count(self) -> int:
|
||||||
|
return self._exceptions_count
|
||||||
|
|
||||||
|
@property
|
||||||
|
def outputs(self) -> dict[str, object]:
|
||||||
|
return self._outputs.copy()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def llm_usage(self):
|
||||||
|
mock_usage = Mock()
|
||||||
|
mock_usage.prompt_tokens = 10
|
||||||
|
mock_usage.completion_tokens = 20
|
||||||
|
mock_usage.total_tokens = 30
|
||||||
|
return mock_usage
|
||||||
|
|
||||||
|
def get_output(self, key: str, default: object = None) -> object:
|
||||||
|
return self._outputs.get(key, default)
|
||||||
|
|
||||||
|
def dumps(self) -> str:
|
||||||
|
return json.dumps(
|
||||||
|
{
|
||||||
|
"start_at": self._start_at,
|
||||||
|
"total_tokens": self._total_tokens,
|
||||||
|
"node_run_steps": self._node_run_steps,
|
||||||
|
"ready_queue_size": self._ready_queue_size,
|
||||||
|
"exceptions_count": self._exceptions_count,
|
||||||
|
"outputs": self._outputs,
|
||||||
|
"variables": {f"{k[0]}.{k[1]}": v for k, v in self._variable_pool._variables.items()},
|
||||||
|
"workflow_execution_id": self._system_variable.workflow_execution_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MockCommandChannel:
|
||||||
|
"""Mock implementation of CommandChannel for testing."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._commands: list[GraphEngineCommand] = []
|
||||||
|
|
||||||
|
def fetch_commands(self) -> list[GraphEngineCommand]:
|
||||||
|
return self._commands.copy()
|
||||||
|
|
||||||
|
def send_command(self, command: GraphEngineCommand) -> None:
|
||||||
|
self._commands.append(command)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPauseStatePersistenceLayer:
|
||||||
|
"""Unit tests for PauseStatePersistenceLayer."""
|
||||||
|
|
||||||
|
def test_init_with_dependency_injection(self):
|
||||||
|
session_factory = Mock(name="session_factory")
|
||||||
|
state_owner_user_id = "user-123"
|
||||||
|
|
||||||
|
layer = PauseStatePersistenceLayer(
|
||||||
|
session_factory=session_factory,
|
||||||
|
state_owner_user_id=state_owner_user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert layer._session_maker is session_factory
|
||||||
|
assert layer._state_owner_user_id == state_owner_user_id
|
||||||
|
assert not hasattr(layer, "graph_runtime_state")
|
||||||
|
assert not hasattr(layer, "command_channel")
|
||||||
|
|
||||||
|
def test_initialize_sets_dependencies(self):
|
||||||
|
session_factory = Mock(name="session_factory")
|
||||||
|
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner")
|
||||||
|
|
||||||
|
graph_runtime_state = MockReadOnlyGraphRuntimeState()
|
||||||
|
command_channel = MockCommandChannel()
|
||||||
|
|
||||||
|
layer.initialize(graph_runtime_state, command_channel)
|
||||||
|
|
||||||
|
assert layer.graph_runtime_state is graph_runtime_state
|
||||||
|
assert layer.command_channel is command_channel
|
||||||
|
|
||||||
|
def test_on_event_with_graph_run_paused_event(self, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
session_factory = Mock(name="session_factory")
|
||||||
|
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
|
||||||
|
|
||||||
|
mock_repo = Mock()
|
||||||
|
mock_factory = Mock(return_value=mock_repo)
|
||||||
|
monkeypatch.setattr(DifyAPIRepositoryFactory, "create_api_workflow_run_repository", mock_factory)
|
||||||
|
|
||||||
|
graph_runtime_state = MockReadOnlyGraphRuntimeState(
|
||||||
|
outputs={"result": "test_output"},
|
||||||
|
total_tokens=100,
|
||||||
|
workflow_execution_id="run-123",
|
||||||
|
)
|
||||||
|
command_channel = MockCommandChannel()
|
||||||
|
layer.initialize(graph_runtime_state, command_channel)
|
||||||
|
|
||||||
|
event = TestDataFactory.create_graph_run_paused_event(outputs={"intermediate": "result"})
|
||||||
|
expected_state = graph_runtime_state.dumps()
|
||||||
|
|
||||||
|
layer.on_event(event)
|
||||||
|
|
||||||
|
mock_factory.assert_called_once_with(session_factory)
|
||||||
|
mock_repo.create_workflow_pause.assert_called_once_with(
|
||||||
|
workflow_run_id="run-123",
|
||||||
|
state_owner_user_id="owner-123",
|
||||||
|
state=expected_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_on_event_ignores_non_paused_events(self, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
session_factory = Mock(name="session_factory")
|
||||||
|
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
|
||||||
|
|
||||||
|
mock_repo = Mock()
|
||||||
|
mock_factory = Mock(return_value=mock_repo)
|
||||||
|
monkeypatch.setattr(DifyAPIRepositoryFactory, "create_api_workflow_run_repository", mock_factory)
|
||||||
|
|
||||||
|
graph_runtime_state = MockReadOnlyGraphRuntimeState()
|
||||||
|
command_channel = MockCommandChannel()
|
||||||
|
layer.initialize(graph_runtime_state, command_channel)
|
||||||
|
|
||||||
|
events = [
|
||||||
|
TestDataFactory.create_graph_run_started_event(),
|
||||||
|
TestDataFactory.create_graph_run_succeeded_event(),
|
||||||
|
TestDataFactory.create_graph_run_failed_event(),
|
||||||
|
]
|
||||||
|
|
||||||
|
for event in events:
|
||||||
|
layer.on_event(event)
|
||||||
|
|
||||||
|
mock_factory.assert_not_called()
|
||||||
|
mock_repo.create_workflow_pause.assert_not_called()
|
||||||
|
|
||||||
|
def test_on_event_raises_attribute_error_when_graph_runtime_state_is_none(self):
|
||||||
|
session_factory = Mock(name="session_factory")
|
||||||
|
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
|
||||||
|
|
||||||
|
event = TestDataFactory.create_graph_run_paused_event()
|
||||||
|
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
layer.on_event(event)
|
||||||
|
|
||||||
|
def test_on_event_asserts_when_workflow_execution_id_missing(self, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
session_factory = Mock(name="session_factory")
|
||||||
|
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
|
||||||
|
|
||||||
|
mock_repo = Mock()
|
||||||
|
mock_factory = Mock(return_value=mock_repo)
|
||||||
|
monkeypatch.setattr(DifyAPIRepositoryFactory, "create_api_workflow_run_repository", mock_factory)
|
||||||
|
|
||||||
|
graph_runtime_state = MockReadOnlyGraphRuntimeState(workflow_execution_id=None)
|
||||||
|
command_channel = MockCommandChannel()
|
||||||
|
layer.initialize(graph_runtime_state, command_channel)
|
||||||
|
|
||||||
|
event = TestDataFactory.create_graph_run_paused_event()
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
layer.on_event(event)
|
||||||
|
|
||||||
|
mock_factory.assert_not_called()
|
||||||
|
mock_repo.create_workflow_pause.assert_not_called()
|
||||||
|
|
@ -0,0 +1,171 @@
|
||||||
|
"""Tests for _PrivateWorkflowPauseEntity implementation."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from models.workflow import WorkflowPause as WorkflowPauseModel
|
||||||
|
from repositories.sqlalchemy_api_workflow_run_repository import _PrivateWorkflowPauseEntity
|
||||||
|
|
||||||
|
|
||||||
|
class TestPrivateWorkflowPauseEntity:
|
||||||
|
"""Test _PrivateWorkflowPauseEntity implementation."""
|
||||||
|
|
||||||
|
def test_entity_initialization(self):
|
||||||
|
"""Test entity initialization with required parameters."""
|
||||||
|
# Create mock models
|
||||||
|
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||||
|
mock_pause_model.id = "pause-123"
|
||||||
|
mock_pause_model.workflow_run_id = "execution-456"
|
||||||
|
mock_pause_model.resumed_at = None
|
||||||
|
|
||||||
|
# Create entity
|
||||||
|
entity = _PrivateWorkflowPauseEntity(
|
||||||
|
pause_model=mock_pause_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify initialization
|
||||||
|
assert entity._pause_model is mock_pause_model
|
||||||
|
assert entity._cached_state is None
|
||||||
|
|
||||||
|
def test_from_models_classmethod(self):
|
||||||
|
"""Test from_models class method."""
|
||||||
|
# Create mock models
|
||||||
|
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||||
|
mock_pause_model.id = "pause-123"
|
||||||
|
mock_pause_model.workflow_run_id = "execution-456"
|
||||||
|
|
||||||
|
# Create entity using from_models
|
||||||
|
entity = _PrivateWorkflowPauseEntity.from_models(
|
||||||
|
workflow_pause_model=mock_pause_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify entity creation
|
||||||
|
assert isinstance(entity, _PrivateWorkflowPauseEntity)
|
||||||
|
assert entity._pause_model is mock_pause_model
|
||||||
|
|
||||||
|
def test_id_property(self):
|
||||||
|
"""Test id property returns pause model ID."""
|
||||||
|
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||||
|
mock_pause_model.id = "pause-123"
|
||||||
|
|
||||||
|
entity = _PrivateWorkflowPauseEntity(
|
||||||
|
pause_model=mock_pause_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert entity.id == "pause-123"
|
||||||
|
|
||||||
|
def test_workflow_execution_id_property(self):
|
||||||
|
"""Test workflow_execution_id property returns workflow run ID."""
|
||||||
|
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||||
|
mock_pause_model.workflow_run_id = "execution-456"
|
||||||
|
|
||||||
|
entity = _PrivateWorkflowPauseEntity(
|
||||||
|
pause_model=mock_pause_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert entity.workflow_execution_id == "execution-456"
|
||||||
|
|
||||||
|
def test_resumed_at_property(self):
|
||||||
|
"""Test resumed_at property returns pause model resumed_at."""
|
||||||
|
resumed_at = datetime(2023, 12, 25, 15, 30, 45)
|
||||||
|
|
||||||
|
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||||
|
mock_pause_model.resumed_at = resumed_at
|
||||||
|
|
||||||
|
entity = _PrivateWorkflowPauseEntity(
|
||||||
|
pause_model=mock_pause_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert entity.resumed_at == resumed_at
|
||||||
|
|
||||||
|
def test_resumed_at_property_none(self):
|
||||||
|
"""Test resumed_at property returns None when not set."""
|
||||||
|
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||||
|
mock_pause_model.resumed_at = None
|
||||||
|
|
||||||
|
entity = _PrivateWorkflowPauseEntity(
|
||||||
|
pause_model=mock_pause_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert entity.resumed_at is None
|
||||||
|
|
||||||
|
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage")
|
||||||
|
def test_get_state_first_call(self, mock_storage):
|
||||||
|
"""Test get_state loads from storage on first call."""
|
||||||
|
state_data = b'{"test": "data", "step": 5}'
|
||||||
|
mock_storage.load.return_value = state_data
|
||||||
|
|
||||||
|
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||||
|
mock_pause_model.state_object_key = "test-state-key"
|
||||||
|
|
||||||
|
entity = _PrivateWorkflowPauseEntity(
|
||||||
|
pause_model=mock_pause_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
# First call should load from storage
|
||||||
|
result = entity.get_state()
|
||||||
|
|
||||||
|
assert result == state_data
|
||||||
|
mock_storage.load.assert_called_once_with("test-state-key")
|
||||||
|
assert entity._cached_state == state_data
|
||||||
|
|
||||||
|
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage")
|
||||||
|
def test_get_state_cached_call(self, mock_storage):
|
||||||
|
"""Test get_state returns cached data on subsequent calls."""
|
||||||
|
state_data = b'{"test": "data", "step": 5}'
|
||||||
|
mock_storage.load.return_value = state_data
|
||||||
|
|
||||||
|
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||||
|
mock_pause_model.state_object_key = "test-state-key"
|
||||||
|
|
||||||
|
entity = _PrivateWorkflowPauseEntity(
|
||||||
|
pause_model=mock_pause_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
# First call
|
||||||
|
result1 = entity.get_state()
|
||||||
|
# Second call should use cache
|
||||||
|
result2 = entity.get_state()
|
||||||
|
|
||||||
|
assert result1 == state_data
|
||||||
|
assert result2 == state_data
|
||||||
|
# Storage should only be called once
|
||||||
|
mock_storage.load.assert_called_once_with("test-state-key")
|
||||||
|
|
||||||
|
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage")
|
||||||
|
def test_get_state_with_pre_cached_data(self, mock_storage):
|
||||||
|
"""Test get_state returns pre-cached data."""
|
||||||
|
state_data = b'{"test": "data", "step": 5}'
|
||||||
|
|
||||||
|
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||||
|
|
||||||
|
entity = _PrivateWorkflowPauseEntity(
|
||||||
|
pause_model=mock_pause_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pre-cache data
|
||||||
|
entity._cached_state = state_data
|
||||||
|
|
||||||
|
# Should return cached data without calling storage
|
||||||
|
result = entity.get_state()
|
||||||
|
|
||||||
|
assert result == state_data
|
||||||
|
mock_storage.load.assert_not_called()
|
||||||
|
|
||||||
|
def test_entity_with_binary_state_data(self):
|
||||||
|
"""Test entity with binary state data."""
|
||||||
|
# Test with binary data that's not valid JSON
|
||||||
|
binary_data = b"\x00\x01\x02\x03\x04\x05\xff\xfe"
|
||||||
|
|
||||||
|
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||||
|
mock_storage.load.return_value = binary_data
|
||||||
|
|
||||||
|
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||||
|
|
||||||
|
entity = _PrivateWorkflowPauseEntity(
|
||||||
|
pause_model=mock_pause_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = entity.get_state()
|
||||||
|
|
||||||
|
assert result == binary_data
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
import time
|
import time
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from core.workflow.entities.pause_reason import SchedulingPause
|
||||||
from core.workflow.graph import Graph
|
from core.workflow.graph import Graph
|
||||||
from core.workflow.graph_engine import GraphEngine
|
from core.workflow.graph_engine import GraphEngine
|
||||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||||
|
|
@ -149,8 +150,8 @@ def test_pause_command():
|
||||||
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
|
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
|
||||||
pause_events = [e for e in events if isinstance(e, GraphRunPausedEvent)]
|
pause_events = [e for e in events if isinstance(e, GraphRunPausedEvent)]
|
||||||
assert len(pause_events) == 1
|
assert len(pause_events) == 1
|
||||||
assert pause_events[0].reason == "User requested pause"
|
assert pause_events[0].reason == SchedulingPause(message="User requested pause")
|
||||||
|
|
||||||
graph_execution = engine.graph_runtime_state.graph_execution
|
graph_execution = engine.graph_runtime_state.graph_execution
|
||||||
assert graph_execution.is_paused
|
assert graph_execution.is_paused
|
||||||
assert graph_execution.pause_reason == "User requested pause"
|
assert graph_execution.pause_reason == SchedulingPause(message="User requested pause")
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,32 @@
|
||||||
|
"""Tests for workflow pause related enums and constants."""
|
||||||
|
|
||||||
|
from core.workflow.enums import (
|
||||||
|
WorkflowExecutionStatus,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestWorkflowExecutionStatus:
|
||||||
|
"""Test WorkflowExecutionStatus enum."""
|
||||||
|
|
||||||
|
def test_is_ended_method(self):
|
||||||
|
"""Test is_ended method for different statuses."""
|
||||||
|
# Test ended statuses
|
||||||
|
ended_statuses = [
|
||||||
|
WorkflowExecutionStatus.SUCCEEDED,
|
||||||
|
WorkflowExecutionStatus.FAILED,
|
||||||
|
WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
|
||||||
|
WorkflowExecutionStatus.STOPPED,
|
||||||
|
]
|
||||||
|
|
||||||
|
for status in ended_statuses:
|
||||||
|
assert status.is_ended(), f"{status} should be considered ended"
|
||||||
|
|
||||||
|
# Test non-ended statuses
|
||||||
|
non_ended_statuses = [
|
||||||
|
WorkflowExecutionStatus.SCHEDULED,
|
||||||
|
WorkflowExecutionStatus.RUNNING,
|
||||||
|
WorkflowExecutionStatus.PAUSED,
|
||||||
|
]
|
||||||
|
|
||||||
|
for status in non_ended_statuses:
|
||||||
|
assert not status.is_ended(), f"{status} should not be considered ended"
|
||||||
|
|
@ -0,0 +1,202 @@
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.file.models import File, FileTransferMethod, FileType
|
||||||
|
from core.workflow.system_variable import SystemVariable, SystemVariableReadOnlyView
|
||||||
|
|
||||||
|
|
||||||
|
class TestSystemVariableReadOnlyView:
|
||||||
|
"""Test cases for SystemVariableReadOnlyView class."""
|
||||||
|
|
||||||
|
def test_read_only_property_access(self):
|
||||||
|
"""Test that all properties return correct values from wrapped instance."""
|
||||||
|
# Create test data
|
||||||
|
test_file = File(
|
||||||
|
id="file-123",
|
||||||
|
tenant_id="tenant-123",
|
||||||
|
type=FileType.IMAGE,
|
||||||
|
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||||
|
related_id="related-123",
|
||||||
|
)
|
||||||
|
|
||||||
|
datasource_info = {"key": "value", "nested": {"data": 42}}
|
||||||
|
|
||||||
|
# Create SystemVariable with all fields
|
||||||
|
system_var = SystemVariable(
|
||||||
|
user_id="user-123",
|
||||||
|
app_id="app-123",
|
||||||
|
workflow_id="workflow-123",
|
||||||
|
files=[test_file],
|
||||||
|
workflow_execution_id="exec-123",
|
||||||
|
query="test query",
|
||||||
|
conversation_id="conv-123",
|
||||||
|
dialogue_count=5,
|
||||||
|
document_id="doc-123",
|
||||||
|
original_document_id="orig-doc-123",
|
||||||
|
dataset_id="dataset-123",
|
||||||
|
batch="batch-123",
|
||||||
|
datasource_type="type-123",
|
||||||
|
datasource_info=datasource_info,
|
||||||
|
invoke_from="invoke-123",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create read-only view
|
||||||
|
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||||
|
|
||||||
|
# Test all properties
|
||||||
|
assert read_only_view.user_id == "user-123"
|
||||||
|
assert read_only_view.app_id == "app-123"
|
||||||
|
assert read_only_view.workflow_id == "workflow-123"
|
||||||
|
assert read_only_view.workflow_execution_id == "exec-123"
|
||||||
|
assert read_only_view.query == "test query"
|
||||||
|
assert read_only_view.conversation_id == "conv-123"
|
||||||
|
assert read_only_view.dialogue_count == 5
|
||||||
|
assert read_only_view.document_id == "doc-123"
|
||||||
|
assert read_only_view.original_document_id == "orig-doc-123"
|
||||||
|
assert read_only_view.dataset_id == "dataset-123"
|
||||||
|
assert read_only_view.batch == "batch-123"
|
||||||
|
assert read_only_view.datasource_type == "type-123"
|
||||||
|
assert read_only_view.invoke_from == "invoke-123"
|
||||||
|
|
||||||
|
def test_defensive_copying_of_mutable_objects(self):
|
||||||
|
"""Test that mutable objects are defensively copied."""
|
||||||
|
# Create test data
|
||||||
|
test_file = File(
|
||||||
|
id="file-123",
|
||||||
|
tenant_id="tenant-123",
|
||||||
|
type=FileType.IMAGE,
|
||||||
|
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||||
|
related_id="related-123",
|
||||||
|
)
|
||||||
|
|
||||||
|
datasource_info = {"key": "original_value"}
|
||||||
|
|
||||||
|
# Create SystemVariable
|
||||||
|
system_var = SystemVariable(
|
||||||
|
files=[test_file], datasource_info=datasource_info, workflow_execution_id="exec-123"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create read-only view
|
||||||
|
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||||
|
|
||||||
|
# Test files defensive copying
|
||||||
|
files_copy = read_only_view.files
|
||||||
|
assert isinstance(files_copy, tuple) # Should be immutable tuple
|
||||||
|
assert len(files_copy) == 1
|
||||||
|
assert files_copy[0].id == "file-123"
|
||||||
|
|
||||||
|
# Verify it's a copy (can't modify original through view)
|
||||||
|
assert isinstance(files_copy, tuple)
|
||||||
|
# tuples don't have append method, so they're immutable
|
||||||
|
|
||||||
|
# Test datasource_info defensive copying
|
||||||
|
datasource_copy = read_only_view.datasource_info
|
||||||
|
assert datasource_copy is not None
|
||||||
|
assert datasource_copy["key"] == "original_value"
|
||||||
|
|
||||||
|
datasource_copy = cast(dict, datasource_copy)
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
datasource_copy["key"] = "modified value"
|
||||||
|
|
||||||
|
# Verify original is unchanged
|
||||||
|
assert system_var.datasource_info is not None
|
||||||
|
assert system_var.datasource_info["key"] == "original_value"
|
||||||
|
assert read_only_view.datasource_info is not None
|
||||||
|
assert read_only_view.datasource_info["key"] == "original_value"
|
||||||
|
|
||||||
|
def test_always_accesses_latest_data(self):
|
||||||
|
"""Test that properties always return the latest data from wrapped instance."""
|
||||||
|
# Create SystemVariable
|
||||||
|
system_var = SystemVariable(user_id="original-user", workflow_execution_id="exec-123")
|
||||||
|
|
||||||
|
# Create read-only view
|
||||||
|
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||||
|
|
||||||
|
# Verify initial value
|
||||||
|
assert read_only_view.user_id == "original-user"
|
||||||
|
|
||||||
|
# Modify the wrapped instance
|
||||||
|
system_var.user_id = "modified-user"
|
||||||
|
|
||||||
|
# Verify view returns the new value
|
||||||
|
assert read_only_view.user_id == "modified-user"
|
||||||
|
|
||||||
|
def test_repr_method(self):
|
||||||
|
"""Test the __repr__ method."""
|
||||||
|
# Create SystemVariable
|
||||||
|
system_var = SystemVariable(workflow_execution_id="exec-123")
|
||||||
|
|
||||||
|
# Create read-only view
|
||||||
|
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||||
|
|
||||||
|
# Test repr
|
||||||
|
repr_str = repr(read_only_view)
|
||||||
|
assert "SystemVariableReadOnlyView" in repr_str
|
||||||
|
assert "system_variable=" in repr_str
|
||||||
|
|
||||||
|
def test_none_value_handling(self):
|
||||||
|
"""Test that None values are properly handled."""
|
||||||
|
# Create SystemVariable with all None values except workflow_execution_id
|
||||||
|
system_var = SystemVariable(
|
||||||
|
user_id=None,
|
||||||
|
app_id=None,
|
||||||
|
workflow_id=None,
|
||||||
|
workflow_execution_id="exec-123",
|
||||||
|
query=None,
|
||||||
|
conversation_id=None,
|
||||||
|
dialogue_count=None,
|
||||||
|
document_id=None,
|
||||||
|
original_document_id=None,
|
||||||
|
dataset_id=None,
|
||||||
|
batch=None,
|
||||||
|
datasource_type=None,
|
||||||
|
datasource_info=None,
|
||||||
|
invoke_from=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create read-only view
|
||||||
|
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||||
|
|
||||||
|
# Test all None values
|
||||||
|
assert read_only_view.user_id is None
|
||||||
|
assert read_only_view.app_id is None
|
||||||
|
assert read_only_view.workflow_id is None
|
||||||
|
assert read_only_view.query is None
|
||||||
|
assert read_only_view.conversation_id is None
|
||||||
|
assert read_only_view.dialogue_count is None
|
||||||
|
assert read_only_view.document_id is None
|
||||||
|
assert read_only_view.original_document_id is None
|
||||||
|
assert read_only_view.dataset_id is None
|
||||||
|
assert read_only_view.batch is None
|
||||||
|
assert read_only_view.datasource_type is None
|
||||||
|
assert read_only_view.datasource_info is None
|
||||||
|
assert read_only_view.invoke_from is None
|
||||||
|
|
||||||
|
# files should be empty tuple even when default list is empty
|
||||||
|
assert read_only_view.files == ()
|
||||||
|
|
||||||
|
def test_empty_files_handling(self):
|
||||||
|
"""Test that empty files list is handled correctly."""
|
||||||
|
# Create SystemVariable with empty files
|
||||||
|
system_var = SystemVariable(files=[], workflow_execution_id="exec-123")
|
||||||
|
|
||||||
|
# Create read-only view
|
||||||
|
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||||
|
|
||||||
|
# Test files handling
|
||||||
|
assert read_only_view.files == ()
|
||||||
|
assert isinstance(read_only_view.files, tuple)
|
||||||
|
|
||||||
|
def test_empty_datasource_info_handling(self):
|
||||||
|
"""Test that empty datasource_info is handled correctly."""
|
||||||
|
# Create SystemVariable with empty datasource_info
|
||||||
|
system_var = SystemVariable(datasource_info={}, workflow_execution_id="exec-123")
|
||||||
|
|
||||||
|
# Create read-only view
|
||||||
|
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||||
|
|
||||||
|
# Test datasource_info handling
|
||||||
|
assert read_only_view.datasource_info == {}
|
||||||
|
# Should be a copy, not the same object
|
||||||
|
assert read_only_view.datasource_info is not system_var.datasource_info
|
||||||
|
|
@ -0,0 +1,11 @@
|
||||||
|
from models.base import DefaultFieldsMixin
|
||||||
|
|
||||||
|
|
||||||
|
class FooModel(DefaultFieldsMixin):
|
||||||
|
def __init__(self, id: str):
|
||||||
|
self.id = id
|
||||||
|
|
||||||
|
|
||||||
|
def test_repr():
|
||||||
|
foo_model = FooModel(id="test-id")
|
||||||
|
assert repr(foo_model) == "<FooModel(id=test-id)>"
|
||||||
|
|
@ -0,0 +1,370 @@
|
||||||
|
"""Unit tests for DifyAPISQLAlchemyWorkflowRunRepository implementation."""
|
||||||
|
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
|
from core.workflow.entities.workflow_pause import WorkflowPauseEntity
|
||||||
|
from core.workflow.enums import WorkflowExecutionStatus
|
||||||
|
from models.workflow import WorkflowPause as WorkflowPauseModel
|
||||||
|
from models.workflow import WorkflowRun
|
||||||
|
from repositories.sqlalchemy_api_workflow_run_repository import (
|
||||||
|
DifyAPISQLAlchemyWorkflowRunRepository,
|
||||||
|
_PrivateWorkflowPauseEntity,
|
||||||
|
_WorkflowRunError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDifyAPISQLAlchemyWorkflowRunRepository:
|
||||||
|
"""Test DifyAPISQLAlchemyWorkflowRunRepository implementation."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_session(self):
|
||||||
|
"""Create a mock session."""
|
||||||
|
return Mock(spec=Session)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_session_maker(self, mock_session):
|
||||||
|
"""Create a mock sessionmaker."""
|
||||||
|
session_maker = Mock(spec=sessionmaker)
|
||||||
|
|
||||||
|
# Create a context manager mock
|
||||||
|
context_manager = Mock()
|
||||||
|
context_manager.__enter__ = Mock(return_value=mock_session)
|
||||||
|
context_manager.__exit__ = Mock(return_value=None)
|
||||||
|
session_maker.return_value = context_manager
|
||||||
|
|
||||||
|
# Mock session.begin() context manager
|
||||||
|
begin_context_manager = Mock()
|
||||||
|
begin_context_manager.__enter__ = Mock(return_value=None)
|
||||||
|
begin_context_manager.__exit__ = Mock(return_value=None)
|
||||||
|
mock_session.begin = Mock(return_value=begin_context_manager)
|
||||||
|
|
||||||
|
# Add missing session methods
|
||||||
|
mock_session.commit = Mock()
|
||||||
|
mock_session.rollback = Mock()
|
||||||
|
mock_session.add = Mock()
|
||||||
|
mock_session.delete = Mock()
|
||||||
|
mock_session.get = Mock()
|
||||||
|
mock_session.scalar = Mock()
|
||||||
|
mock_session.scalars = Mock()
|
||||||
|
|
||||||
|
# Also support expire_on_commit parameter
|
||||||
|
def make_session(expire_on_commit=None):
|
||||||
|
cm = Mock()
|
||||||
|
cm.__enter__ = Mock(return_value=mock_session)
|
||||||
|
cm.__exit__ = Mock(return_value=None)
|
||||||
|
return cm
|
||||||
|
|
||||||
|
session_maker.side_effect = make_session
|
||||||
|
return session_maker
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def repository(self, mock_session_maker):
|
||||||
|
"""Create repository instance with mocked dependencies."""
|
||||||
|
|
||||||
|
# Create a testable subclass that implements the save method
|
||||||
|
class TestableDifyAPISQLAlchemyWorkflowRunRepository(DifyAPISQLAlchemyWorkflowRunRepository):
|
||||||
|
def __init__(self, session_maker):
|
||||||
|
# Initialize without calling parent __init__ to avoid any instantiation issues
|
||||||
|
self._session_maker = session_maker
|
||||||
|
|
||||||
|
def save(self, execution):
|
||||||
|
"""Mock implementation of save method."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Create repository instance
|
||||||
|
repo = TestableDifyAPISQLAlchemyWorkflowRunRepository(mock_session_maker)
|
||||||
|
|
||||||
|
return repo
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_workflow_run(self):
|
||||||
|
"""Create a sample WorkflowRun model."""
|
||||||
|
workflow_run = Mock(spec=WorkflowRun)
|
||||||
|
workflow_run.id = "workflow-run-123"
|
||||||
|
workflow_run.tenant_id = "tenant-123"
|
||||||
|
workflow_run.app_id = "app-123"
|
||||||
|
workflow_run.workflow_id = "workflow-123"
|
||||||
|
workflow_run.status = WorkflowExecutionStatus.RUNNING
|
||||||
|
return workflow_run
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_workflow_pause(self):
|
||||||
|
"""Create a sample WorkflowPauseModel."""
|
||||||
|
pause = Mock(spec=WorkflowPauseModel)
|
||||||
|
pause.id = "pause-123"
|
||||||
|
pause.workflow_id = "workflow-123"
|
||||||
|
pause.workflow_run_id = "workflow-run-123"
|
||||||
|
pause.state_object_key = "workflow-state-123.json"
|
||||||
|
pause.resumed_at = None
|
||||||
|
pause.created_at = datetime.now(UTC)
|
||||||
|
return pause
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||||
|
"""Test create_workflow_pause method."""
|
||||||
|
|
||||||
|
def test_create_workflow_pause_success(
|
||||||
|
self,
|
||||||
|
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||||
|
mock_session: Mock,
|
||||||
|
sample_workflow_run: Mock,
|
||||||
|
):
|
||||||
|
"""Test successful workflow pause creation."""
|
||||||
|
# Arrange
|
||||||
|
workflow_run_id = "workflow-run-123"
|
||||||
|
state_owner_user_id = "user-123"
|
||||||
|
state = '{"test": "state"}'
|
||||||
|
|
||||||
|
mock_session.get.return_value = sample_workflow_run
|
||||||
|
|
||||||
|
with patch("repositories.sqlalchemy_api_workflow_run_repository.uuidv7") as mock_uuidv7:
|
||||||
|
mock_uuidv7.side_effect = ["pause-123"]
|
||||||
|
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||||
|
# Act
|
||||||
|
result = repository.create_workflow_pause(
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
state_owner_user_id=state_owner_user_id,
|
||||||
|
state=state,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(result, _PrivateWorkflowPauseEntity)
|
||||||
|
assert result.id == "pause-123"
|
||||||
|
assert result.workflow_execution_id == workflow_run_id
|
||||||
|
|
||||||
|
# Verify database interactions
|
||||||
|
mock_session.get.assert_called_once_with(WorkflowRun, workflow_run_id)
|
||||||
|
mock_storage.save.assert_called_once()
|
||||||
|
mock_session.add.assert_called()
|
||||||
|
# When using session.begin() context manager, commit is handled automatically
|
||||||
|
# No explicit commit call is expected
|
||||||
|
|
||||||
|
def test_create_workflow_pause_not_found(
|
||||||
|
self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock
|
||||||
|
):
|
||||||
|
"""Test workflow pause creation when workflow run not found."""
|
||||||
|
# Arrange
|
||||||
|
mock_session.get.return_value = None
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(ValueError, match="WorkflowRun not found: workflow-run-123"):
|
||||||
|
repository.create_workflow_pause(
|
||||||
|
workflow_run_id="workflow-run-123",
|
||||||
|
state_owner_user_id="user-123",
|
||||||
|
state='{"test": "state"}',
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_session.get.assert_called_once_with(WorkflowRun, "workflow-run-123")
|
||||||
|
|
||||||
|
def test_create_workflow_pause_invalid_status(
|
||||||
|
self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock, sample_workflow_run: Mock
|
||||||
|
):
|
||||||
|
"""Test workflow pause creation when workflow not in RUNNING status."""
|
||||||
|
# Arrange
|
||||||
|
sample_workflow_run.status = WorkflowExecutionStatus.PAUSED
|
||||||
|
mock_session.get.return_value = sample_workflow_run
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(_WorkflowRunError, match="Only WorkflowRun with RUNNING status can be paused"):
|
||||||
|
repository.create_workflow_pause(
|
||||||
|
workflow_run_id="workflow-run-123",
|
||||||
|
state_owner_user_id="user-123",
|
||||||
|
state='{"test": "state"}',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestResumeWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||||
|
"""Test resume_workflow_pause method."""
|
||||||
|
|
||||||
|
def test_resume_workflow_pause_success(
|
||||||
|
self,
|
||||||
|
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||||
|
mock_session: Mock,
|
||||||
|
sample_workflow_run: Mock,
|
||||||
|
sample_workflow_pause: Mock,
|
||||||
|
):
|
||||||
|
"""Test successful workflow pause resume."""
|
||||||
|
# Arrange
|
||||||
|
workflow_run_id = "workflow-run-123"
|
||||||
|
pause_entity = Mock(spec=WorkflowPauseEntity)
|
||||||
|
pause_entity.id = "pause-123"
|
||||||
|
|
||||||
|
# Setup workflow run and pause
|
||||||
|
sample_workflow_run.status = WorkflowExecutionStatus.PAUSED
|
||||||
|
sample_workflow_run.pause = sample_workflow_pause
|
||||||
|
sample_workflow_pause.resumed_at = None
|
||||||
|
|
||||||
|
mock_session.scalar.return_value = sample_workflow_run
|
||||||
|
|
||||||
|
with patch("repositories.sqlalchemy_api_workflow_run_repository.naive_utc_now") as mock_now:
|
||||||
|
mock_now.return_value = datetime.now(UTC)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = repository.resume_workflow_pause(
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
pause_entity=pause_entity,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(result, _PrivateWorkflowPauseEntity)
|
||||||
|
assert result.id == "pause-123"
|
||||||
|
|
||||||
|
# Verify state transitions
|
||||||
|
assert sample_workflow_pause.resumed_at is not None
|
||||||
|
assert sample_workflow_run.status == WorkflowExecutionStatus.RUNNING
|
||||||
|
|
||||||
|
# Verify database interactions
|
||||||
|
mock_session.add.assert_called()
|
||||||
|
# When using session.begin() context manager, commit is handled automatically
|
||||||
|
# No explicit commit call is expected
|
||||||
|
|
||||||
|
def test_resume_workflow_pause_not_paused(
|
||||||
|
self,
|
||||||
|
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||||
|
mock_session: Mock,
|
||||||
|
sample_workflow_run: Mock,
|
||||||
|
):
|
||||||
|
"""Test resume when workflow is not paused."""
|
||||||
|
# Arrange
|
||||||
|
workflow_run_id = "workflow-run-123"
|
||||||
|
pause_entity = Mock(spec=WorkflowPauseEntity)
|
||||||
|
pause_entity.id = "pause-123"
|
||||||
|
|
||||||
|
sample_workflow_run.status = WorkflowExecutionStatus.RUNNING
|
||||||
|
mock_session.scalar.return_value = sample_workflow_run
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(_WorkflowRunError, match="WorkflowRun is not in PAUSED status"):
|
||||||
|
repository.resume_workflow_pause(
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
pause_entity=pause_entity,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_resume_workflow_pause_id_mismatch(
|
||||||
|
self,
|
||||||
|
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||||
|
mock_session: Mock,
|
||||||
|
sample_workflow_run: Mock,
|
||||||
|
sample_workflow_pause: Mock,
|
||||||
|
):
|
||||||
|
"""Test resume when pause ID doesn't match."""
|
||||||
|
# Arrange
|
||||||
|
workflow_run_id = "workflow-run-123"
|
||||||
|
pause_entity = Mock(spec=WorkflowPauseEntity)
|
||||||
|
pause_entity.id = "pause-456" # Different ID
|
||||||
|
|
||||||
|
sample_workflow_run.status = WorkflowExecutionStatus.PAUSED
|
||||||
|
sample_workflow_pause.id = "pause-123"
|
||||||
|
sample_workflow_run.pause = sample_workflow_pause
|
||||||
|
mock_session.scalar.return_value = sample_workflow_run
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(_WorkflowRunError, match="different id in WorkflowPause and WorkflowPauseEntity"):
|
||||||
|
repository.resume_workflow_pause(
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
pause_entity=pause_entity,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeleteWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||||
|
"""Test delete_workflow_pause method."""
|
||||||
|
|
||||||
|
def test_delete_workflow_pause_success(
|
||||||
|
self,
|
||||||
|
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||||
|
mock_session: Mock,
|
||||||
|
sample_workflow_pause: Mock,
|
||||||
|
):
|
||||||
|
"""Test successful workflow pause deletion."""
|
||||||
|
# Arrange
|
||||||
|
pause_entity = Mock(spec=WorkflowPauseEntity)
|
||||||
|
pause_entity.id = "pause-123"
|
||||||
|
|
||||||
|
mock_session.get.return_value = sample_workflow_pause
|
||||||
|
|
||||||
|
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||||
|
# Act
|
||||||
|
repository.delete_workflow_pause(pause_entity=pause_entity)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_storage.delete.assert_called_once_with(sample_workflow_pause.state_object_key)
|
||||||
|
mock_session.delete.assert_called_once_with(sample_workflow_pause)
|
||||||
|
# When using session.begin() context manager, commit is handled automatically
|
||||||
|
# No explicit commit call is expected
|
||||||
|
|
||||||
|
def test_delete_workflow_pause_not_found(
|
||||||
|
self,
|
||||||
|
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||||
|
mock_session: Mock,
|
||||||
|
):
|
||||||
|
"""Test delete when pause not found."""
|
||||||
|
# Arrange
|
||||||
|
pause_entity = Mock(spec=WorkflowPauseEntity)
|
||||||
|
pause_entity.id = "pause-123"
|
||||||
|
|
||||||
|
mock_session.get.return_value = None
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(_WorkflowRunError, match="WorkflowPause not found: pause-123"):
|
||||||
|
repository.delete_workflow_pause(pause_entity=pause_entity)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||||
|
"""Test _PrivateWorkflowPauseEntity class."""
|
||||||
|
|
||||||
|
def test_from_models(self, sample_workflow_pause: Mock):
|
||||||
|
"""Test creating _PrivateWorkflowPauseEntity from models."""
|
||||||
|
# Act
|
||||||
|
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(entity, _PrivateWorkflowPauseEntity)
|
||||||
|
assert entity._pause_model == sample_workflow_pause
|
||||||
|
|
||||||
|
def test_properties(self, sample_workflow_pause: Mock):
|
||||||
|
"""Test entity properties."""
|
||||||
|
# Arrange
|
||||||
|
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
assert entity.id == sample_workflow_pause.id
|
||||||
|
assert entity.workflow_execution_id == sample_workflow_pause.workflow_run_id
|
||||||
|
assert entity.resumed_at == sample_workflow_pause.resumed_at
|
||||||
|
|
||||||
|
def test_get_state(self, sample_workflow_pause: Mock):
|
||||||
|
"""Test getting state from storage."""
|
||||||
|
# Arrange
|
||||||
|
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
|
||||||
|
expected_state = b'{"test": "state"}'
|
||||||
|
|
||||||
|
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||||
|
mock_storage.load.return_value = expected_state
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = entity.get_state()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result == expected_state
|
||||||
|
mock_storage.load.assert_called_once_with(sample_workflow_pause.state_object_key)
|
||||||
|
|
||||||
|
def test_get_state_caching(self, sample_workflow_pause: Mock):
|
||||||
|
"""Test state caching in get_state method."""
|
||||||
|
# Arrange
|
||||||
|
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
|
||||||
|
expected_state = b'{"test": "state"}'
|
||||||
|
|
||||||
|
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||||
|
mock_storage.load.return_value = expected_state
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result1 = entity.get_state()
|
||||||
|
result2 = entity.get_state() # Should use cache
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result1 == expected_state
|
||||||
|
assert result2 == expected_state
|
||||||
|
mock_storage.load.assert_called_once() # Only called once due to caching
|
||||||
|
|
@ -0,0 +1,200 @@
|
||||||
|
"""Comprehensive unit tests for WorkflowRunService class.
|
||||||
|
|
||||||
|
This test suite covers all pause state management operations including:
|
||||||
|
- Retrieving pause state for workflow runs
|
||||||
|
- Saving pause state with file uploads
|
||||||
|
- Marking paused workflows as resumed
|
||||||
|
- Error handling and edge cases
|
||||||
|
- Database transaction management
|
||||||
|
- Repository-based approach testing
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from unittest.mock import MagicMock, create_autospec, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import Engine
|
||||||
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
|
from core.workflow.enums import WorkflowExecutionStatus
|
||||||
|
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||||
|
from repositories.sqlalchemy_api_workflow_run_repository import _PrivateWorkflowPauseEntity
|
||||||
|
from services.workflow_run_service import (
|
||||||
|
WorkflowRunService,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDataFactory:
|
||||||
|
"""Factory class for creating test data objects."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_workflow_run_mock(
|
||||||
|
id: str = "workflow-run-123",
|
||||||
|
tenant_id: str = "tenant-456",
|
||||||
|
app_id: str = "app-789",
|
||||||
|
workflow_id: str = "workflow-101",
|
||||||
|
status: str | WorkflowExecutionStatus = "paused",
|
||||||
|
pause_id: str | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> MagicMock:
|
||||||
|
"""Create a mock WorkflowRun object."""
|
||||||
|
mock_run = MagicMock()
|
||||||
|
mock_run.id = id
|
||||||
|
mock_run.tenant_id = tenant_id
|
||||||
|
mock_run.app_id = app_id
|
||||||
|
mock_run.workflow_id = workflow_id
|
||||||
|
mock_run.status = status
|
||||||
|
mock_run.pause_id = pause_id
|
||||||
|
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
setattr(mock_run, key, value)
|
||||||
|
|
||||||
|
return mock_run
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_workflow_pause_mock(
|
||||||
|
id: str = "pause-123",
|
||||||
|
tenant_id: str = "tenant-456",
|
||||||
|
app_id: str = "app-789",
|
||||||
|
workflow_id: str = "workflow-101",
|
||||||
|
workflow_execution_id: str = "workflow-execution-123",
|
||||||
|
state_file_id: str = "file-456",
|
||||||
|
resumed_at: datetime | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> MagicMock:
|
||||||
|
"""Create a mock WorkflowPauseModel object."""
|
||||||
|
mock_pause = MagicMock()
|
||||||
|
mock_pause.id = id
|
||||||
|
mock_pause.tenant_id = tenant_id
|
||||||
|
mock_pause.app_id = app_id
|
||||||
|
mock_pause.workflow_id = workflow_id
|
||||||
|
mock_pause.workflow_execution_id = workflow_execution_id
|
||||||
|
mock_pause.state_file_id = state_file_id
|
||||||
|
mock_pause.resumed_at = resumed_at
|
||||||
|
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
setattr(mock_pause, key, value)
|
||||||
|
|
||||||
|
return mock_pause
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_upload_file_mock(
|
||||||
|
id: str = "file-456",
|
||||||
|
key: str = "upload_files/test/state.json",
|
||||||
|
name: str = "state.json",
|
||||||
|
tenant_id: str = "tenant-456",
|
||||||
|
**kwargs,
|
||||||
|
) -> MagicMock:
|
||||||
|
"""Create a mock UploadFile object."""
|
||||||
|
mock_file = MagicMock()
|
||||||
|
mock_file.id = id
|
||||||
|
mock_file.key = key
|
||||||
|
mock_file.name = name
|
||||||
|
mock_file.tenant_id = tenant_id
|
||||||
|
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
setattr(mock_file, key, value)
|
||||||
|
|
||||||
|
return mock_file
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_pause_entity_mock(
|
||||||
|
pause_model: MagicMock | None = None,
|
||||||
|
upload_file: MagicMock | None = None,
|
||||||
|
) -> _PrivateWorkflowPauseEntity:
|
||||||
|
"""Create a mock _PrivateWorkflowPauseEntity object."""
|
||||||
|
if pause_model is None:
|
||||||
|
pause_model = TestDataFactory.create_workflow_pause_mock()
|
||||||
|
if upload_file is None:
|
||||||
|
upload_file = TestDataFactory.create_upload_file_mock()
|
||||||
|
|
||||||
|
return _PrivateWorkflowPauseEntity.from_models(pause_model, upload_file)
|
||||||
|
|
||||||
|
|
||||||
|
class TestWorkflowRunService:
|
||||||
|
"""Comprehensive unit tests for WorkflowRunService class."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_session_factory(self):
|
||||||
|
"""Create a mock session factory with proper session management."""
|
||||||
|
mock_session = create_autospec(Session)
|
||||||
|
|
||||||
|
# Create a mock context manager for the session
|
||||||
|
mock_session_cm = MagicMock()
|
||||||
|
mock_session_cm.__enter__ = MagicMock(return_value=mock_session)
|
||||||
|
mock_session_cm.__exit__ = MagicMock(return_value=None)
|
||||||
|
|
||||||
|
# Create a mock context manager for the transaction
|
||||||
|
mock_transaction_cm = MagicMock()
|
||||||
|
mock_transaction_cm.__enter__ = MagicMock(return_value=mock_session)
|
||||||
|
mock_transaction_cm.__exit__ = MagicMock(return_value=None)
|
||||||
|
|
||||||
|
mock_session.begin = MagicMock(return_value=mock_transaction_cm)
|
||||||
|
|
||||||
|
# Create mock factory that returns the context manager
|
||||||
|
mock_factory = MagicMock(spec=sessionmaker)
|
||||||
|
mock_factory.return_value = mock_session_cm
|
||||||
|
|
||||||
|
return mock_factory, mock_session
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_workflow_run_repository(self):
|
||||||
|
"""Create a mock APIWorkflowRunRepository."""
|
||||||
|
mock_repo = create_autospec(APIWorkflowRunRepository)
|
||||||
|
return mock_repo
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def workflow_run_service(self, mock_session_factory, mock_workflow_run_repository):
|
||||||
|
"""Create WorkflowRunService instance with mocked dependencies."""
|
||||||
|
session_factory, _ = mock_session_factory
|
||||||
|
|
||||||
|
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
|
||||||
|
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
|
||||||
|
service = WorkflowRunService(session_factory)
|
||||||
|
return service
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def workflow_run_service_with_engine(self, mock_session_factory, mock_workflow_run_repository):
|
||||||
|
"""Create WorkflowRunService instance with Engine input."""
|
||||||
|
mock_engine = create_autospec(Engine)
|
||||||
|
session_factory, _ = mock_session_factory
|
||||||
|
|
||||||
|
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
|
||||||
|
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
|
||||||
|
service = WorkflowRunService(mock_engine)
|
||||||
|
return service
|
||||||
|
|
||||||
|
# ==================== Initialization Tests ====================
|
||||||
|
|
||||||
|
def test_init_with_session_factory(self, mock_session_factory, mock_workflow_run_repository):
|
||||||
|
"""Test WorkflowRunService initialization with session_factory."""
|
||||||
|
session_factory, _ = mock_session_factory
|
||||||
|
|
||||||
|
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
|
||||||
|
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
|
||||||
|
service = WorkflowRunService(session_factory)
|
||||||
|
|
||||||
|
assert service._session_factory == session_factory
|
||||||
|
mock_factory.create_api_workflow_run_repository.assert_called_once_with(session_factory)
|
||||||
|
|
||||||
|
def test_init_with_engine(self, mock_session_factory, mock_workflow_run_repository):
|
||||||
|
"""Test WorkflowRunService initialization with Engine (should convert to sessionmaker)."""
|
||||||
|
mock_engine = create_autospec(Engine)
|
||||||
|
session_factory, _ = mock_session_factory
|
||||||
|
|
||||||
|
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
|
||||||
|
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
|
||||||
|
with patch("services.workflow_run_service.sessionmaker", return_value=session_factory) as mock_sessionmaker:
|
||||||
|
service = WorkflowRunService(mock_engine)
|
||||||
|
|
||||||
|
mock_sessionmaker.assert_called_once_with(bind=mock_engine, expire_on_commit=False)
|
||||||
|
assert service._session_factory == session_factory
|
||||||
|
mock_factory.create_api_workflow_run_repository.assert_called_once_with(session_factory)
|
||||||
|
|
||||||
|
def test_init_with_default_dependencies(self, mock_session_factory):
|
||||||
|
"""Test WorkflowRunService initialization with default dependencies."""
|
||||||
|
session_factory, _ = mock_session_factory
|
||||||
|
|
||||||
|
service = WorkflowRunService(session_factory)
|
||||||
|
|
||||||
|
assert service._session_factory == session_factory
|
||||||
Loading…
Reference in New Issue