Enhanced GraphEngine Pause Handling (#28196)

This commit: 

1. Convert `pause_reason` to `pause_reasons` in `GraphExecution` and relevant classes. Change the field from a scalar value to a list that can contain multiple `PauseReason` objects, ensuring all pause events are properly captured.
2. Introduce a new `WorkflowPauseReason` model to record reasons associated with a specific `WorkflowPause`.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
QuantumGhost 2025-11-26 19:59:34 +08:00 committed by GitHub
parent b353a126d8
commit 1c1f124891
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 275 additions and 185 deletions

View File

@ -16,6 +16,7 @@ layers =
graph graph
nodes nodes
node_events node_events
runtime
entities entities
containers = containers =
core.workflow core.workflow

View File

@ -118,6 +118,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
workflow_run_id=workflow_run_id, workflow_run_id=workflow_run_id,
state_owner_user_id=self._state_owner_user_id, state_owner_user_id=self._state_owner_user_id,
state=state.dumps(), state=state.dumps(),
pause_reasons=event.reasons,
) )
def on_graph_end(self, error: Exception | None) -> None: def on_graph_end(self, error: Exception | None) -> None:

View File

@ -1,17 +1,11 @@
from ..runtime.graph_runtime_state import GraphRuntimeState
from ..runtime.variable_pool import VariablePool
from .agent import AgentNodeStrategyInit from .agent import AgentNodeStrategyInit
from .graph_init_params import GraphInitParams from .graph_init_params import GraphInitParams
from .workflow_execution import WorkflowExecution from .workflow_execution import WorkflowExecution
from .workflow_node_execution import WorkflowNodeExecution from .workflow_node_execution import WorkflowNodeExecution
from .workflow_pause import WorkflowPauseEntity
__all__ = [ __all__ = [
"AgentNodeStrategyInit", "AgentNodeStrategyInit",
"GraphInitParams", "GraphInitParams",
"GraphRuntimeState",
"VariablePool",
"WorkflowExecution", "WorkflowExecution",
"WorkflowNodeExecution", "WorkflowNodeExecution",
"WorkflowPauseEntity",
] ]

View File

@ -1,49 +1,26 @@
from enum import StrEnum, auto from enum import StrEnum, auto
from typing import Annotated, Any, ClassVar, TypeAlias from typing import Annotated, Literal, TypeAlias
from pydantic import BaseModel, Discriminator, Tag from pydantic import BaseModel, Field
class _PauseReasonType(StrEnum): class PauseReasonType(StrEnum):
HUMAN_INPUT_REQUIRED = auto() HUMAN_INPUT_REQUIRED = auto()
SCHEDULED_PAUSE = auto() SCHEDULED_PAUSE = auto()
class _PauseReasonBase(BaseModel): class HumanInputRequired(BaseModel):
TYPE: ClassVar[_PauseReasonType] TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED
form_id: str
# The identifier of the human input node causing the pause.
node_id: str
class HumanInputRequired(_PauseReasonBase): class SchedulingPause(BaseModel):
TYPE = _PauseReasonType.HUMAN_INPUT_REQUIRED TYPE: Literal[PauseReasonType.SCHEDULED_PAUSE] = PauseReasonType.SCHEDULED_PAUSE
class SchedulingPause(_PauseReasonBase):
TYPE = _PauseReasonType.SCHEDULED_PAUSE
message: str message: str
def _get_pause_reason_discriminator(v: Any) -> _PauseReasonType | None: PauseReason: TypeAlias = Annotated[HumanInputRequired | SchedulingPause, Field(discriminator="TYPE")]
if isinstance(v, _PauseReasonBase):
return v.TYPE
elif isinstance(v, dict):
reason_type_str = v.get("TYPE")
if reason_type_str is None:
return None
try:
reason_type = _PauseReasonType(reason_type_str)
except ValueError:
return None
return reason_type
else:
# return None if the discriminator value isn't found
return None
PauseReason: TypeAlias = Annotated[
(
Annotated[HumanInputRequired, Tag(_PauseReasonType.HUMAN_INPUT_REQUIRED)]
| Annotated[SchedulingPause, Tag(_PauseReasonType.SCHEDULED_PAUSE)]
),
Discriminator(_get_pause_reason_discriminator),
]

View File

@ -42,7 +42,7 @@ class GraphExecutionState(BaseModel):
completed: bool = Field(default=False) completed: bool = Field(default=False)
aborted: bool = Field(default=False) aborted: bool = Field(default=False)
paused: bool = Field(default=False) paused: bool = Field(default=False)
pause_reason: PauseReason | None = Field(default=None) pause_reasons: list[PauseReason] = Field(default_factory=list)
error: GraphExecutionErrorState | None = Field(default=None) error: GraphExecutionErrorState | None = Field(default=None)
exceptions_count: int = Field(default=0) exceptions_count: int = Field(default=0)
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState]) node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])
@ -107,7 +107,7 @@ class GraphExecution:
completed: bool = False completed: bool = False
aborted: bool = False aborted: bool = False
paused: bool = False paused: bool = False
pause_reason: PauseReason | None = None pause_reasons: list[PauseReason] = field(default_factory=list)
error: Exception | None = None error: Exception | None = None
node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution]) node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
exceptions_count: int = 0 exceptions_count: int = 0
@ -137,10 +137,8 @@ class GraphExecution:
raise RuntimeError("Cannot pause execution that has completed") raise RuntimeError("Cannot pause execution that has completed")
if self.aborted: if self.aborted:
raise RuntimeError("Cannot pause execution that has been aborted") raise RuntimeError("Cannot pause execution that has been aborted")
if self.paused:
return
self.paused = True self.paused = True
self.pause_reason = reason self.pause_reasons.append(reason)
def fail(self, error: Exception) -> None: def fail(self, error: Exception) -> None:
"""Mark the graph execution as failed.""" """Mark the graph execution as failed."""
@ -195,7 +193,7 @@ class GraphExecution:
completed=self.completed, completed=self.completed,
aborted=self.aborted, aborted=self.aborted,
paused=self.paused, paused=self.paused,
pause_reason=self.pause_reason, pause_reasons=self.pause_reasons,
error=_serialize_error(self.error), error=_serialize_error(self.error),
exceptions_count=self.exceptions_count, exceptions_count=self.exceptions_count,
node_executions=node_states, node_executions=node_states,
@ -221,7 +219,7 @@ class GraphExecution:
self.completed = state.completed self.completed = state.completed
self.aborted = state.aborted self.aborted = state.aborted
self.paused = state.paused self.paused = state.paused
self.pause_reason = state.pause_reason self.pause_reasons = state.pause_reasons
self.error = _deserialize_error(state.error) self.error = _deserialize_error(state.error)
self.exceptions_count = state.exceptions_count self.exceptions_count = state.exceptions_count
self.node_executions = { self.node_executions = {

View File

@ -110,7 +110,13 @@ class EventManager:
""" """
with self._lock.write_lock(): with self._lock.write_lock():
self._events.append(event) self._events.append(event)
self._notify_layers(event)
# NOTE: `_notify_layers` is intentionally called outside the critical section
# to minimize lock contention and avoid blocking other readers or writers.
#
# The public `notify_layers` method also does not use a write lock,
# so protecting `_notify_layers` with a lock here is unnecessary.
self._notify_layers(event)
def _get_new_events(self, start_index: int) -> list[GraphEngineEvent]: def _get_new_events(self, start_index: int) -> list[GraphEngineEvent]:
""" """

View File

@ -232,7 +232,7 @@ class GraphEngine:
self._graph_execution.start() self._graph_execution.start()
else: else:
self._graph_execution.paused = False self._graph_execution.paused = False
self._graph_execution.pause_reason = None self._graph_execution.pause_reasons = []
start_event = GraphRunStartedEvent() start_event = GraphRunStartedEvent()
self._event_manager.notify_layers(start_event) self._event_manager.notify_layers(start_event)
@ -246,11 +246,11 @@ class GraphEngine:
# Handle completion # Handle completion
if self._graph_execution.is_paused: if self._graph_execution.is_paused:
pause_reason = self._graph_execution.pause_reason pause_reasons = self._graph_execution.pause_reasons
assert pause_reason is not None, "pause_reason should not be None when execution is paused." assert pause_reasons, "pause_reasons should not be empty when execution is paused."
# Ensure we have a valid PauseReason for the event # Ensure we have a valid PauseReason for the event
paused_event = GraphRunPausedEvent( paused_event = GraphRunPausedEvent(
reason=pause_reason, reasons=pause_reasons,
outputs=self._graph_runtime_state.outputs, outputs=self._graph_runtime_state.outputs,
) )
self._event_manager.notify_layers(paused_event) self._event_manager.notify_layers(paused_event)

View File

@ -45,8 +45,7 @@ class GraphRunAbortedEvent(BaseGraphEvent):
class GraphRunPausedEvent(BaseGraphEvent): class GraphRunPausedEvent(BaseGraphEvent):
"""Event emitted when a graph run is paused by user command.""" """Event emitted when a graph run is paused by user command."""
# reason: str | None = Field(default=None, description="reason for pause") reasons: list[PauseReason] = Field(description="reason for pause", default_factory=list)
reason: PauseReason = Field(..., description="reason for pause")
outputs: dict[str, object] = Field( outputs: dict[str, object] = Field(
default_factory=dict, default_factory=dict,
description="Outputs available to the client while the run is paused.", description="Outputs available to the client while the run is paused.",

View File

@ -65,7 +65,8 @@ class HumanInputNode(Node):
return self._pause_generator() return self._pause_generator()
def _pause_generator(self): def _pause_generator(self):
yield PauseRequestedEvent(reason=HumanInputRequired()) # TODO(QuantumGhost): yield a real form id.
yield PauseRequestedEvent(reason=HumanInputRequired(form_id="test_form_id", node_id=self.id))
def _is_completion_ready(self) -> bool: def _is_completion_ready(self) -> bool:
"""Determine whether all required inputs are satisfied.""" """Determine whether all required inputs are satisfied."""

View File

@ -10,6 +10,7 @@ from typing import Any, Protocol
from pydantic.json import pydantic_encoder from pydantic.json import pydantic_encoder
from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.pause_reason import PauseReason
from core.workflow.runtime.variable_pool import VariablePool from core.workflow.runtime.variable_pool import VariablePool
@ -46,7 +47,11 @@ class ReadyQueueProtocol(Protocol):
class GraphExecutionProtocol(Protocol): class GraphExecutionProtocol(Protocol):
"""Structural interface for graph execution aggregate.""" """Structural interface for graph execution aggregate.
Defines the minimal set of attributes and methods required from a GraphExecution entity
for runtime orchestration and state management.
"""
workflow_id: str workflow_id: str
started: bool started: bool
@ -54,6 +59,7 @@ class GraphExecutionProtocol(Protocol):
aborted: bool aborted: bool
error: Exception | None error: Exception | None
exceptions_count: int exceptions_count: int
pause_reasons: list[PauseReason]
def start(self) -> None: def start(self) -> None:
"""Transition execution into the running state.""" """Transition execution into the running state."""

View File

@ -0,0 +1,41 @@
"""Add workflow_pauses_reasons table
Revision ID: 7bb281b7a422
Revises: 09cfdda155d1
Create Date: 2025-11-18 18:59:26.999572
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "7bb281b7a422"
down_revision = "09cfdda155d1"
branch_labels = None
depends_on = None
def upgrade():
op.create_table(
"workflow_pause_reasons",
sa.Column("id", models.types.StringUUID(), nullable=False),
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("pause_id", models.types.StringUUID(), nullable=False),
sa.Column("type_", sa.String(20), nullable=False),
sa.Column("form_id", sa.String(length=36), nullable=False),
sa.Column("node_id", sa.String(length=255), nullable=False),
sa.Column("message", sa.String(length=255), nullable=False),
sa.PrimaryKeyConstraint("id", name=op.f("workflow_pause_reasons_pkey")),
)
with op.batch_alter_table("workflow_pause_reasons", schema=None) as batch_op:
batch_op.create_index(batch_op.f("workflow_pause_reasons_pause_id_idx"), ["pause_id"], unique=False)
def downgrade():
op.drop_table("workflow_pause_reasons")

View File

@ -29,6 +29,7 @@ from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID, CONVERSATION_VARIABLE_NODE_ID,
SYSTEM_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID,
) )
from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause
from core.workflow.enums import NodeType from core.workflow.enums import NodeType
from extensions.ext_storage import Storage from extensions.ext_storage import Storage
from factories.variable_factory import TypeMismatchError, build_segment_with_type from factories.variable_factory import TypeMismatchError, build_segment_with_type
@ -1728,3 +1729,68 @@ class WorkflowPause(DefaultFieldsMixin, Base):
primaryjoin="WorkflowPause.workflow_run_id == WorkflowRun.id", primaryjoin="WorkflowPause.workflow_run_id == WorkflowRun.id",
back_populates="pause", back_populates="pause",
) )
class WorkflowPauseReason(DefaultFieldsMixin, Base):
__tablename__ = "workflow_pause_reasons"
# `pause_id` represents the identifier of the pause,
# correspond to the `id` field of `WorkflowPause`.
pause_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True)
type_: Mapped[PauseReasonType] = mapped_column(EnumText(PauseReasonType), nullable=False)
# form_id is not empty if and if only type_ == PauseReasonType.HUMAN_INPUT_REQUIRED
#
form_id: Mapped[str] = mapped_column(
String(36),
nullable=False,
default="",
)
# message records the text description of this pause reason. For example,
# "The workflow has been paused due to scheduling."
#
# Empty message means that this pause reason is not speified.
message: Mapped[str] = mapped_column(
String(255),
nullable=False,
default="",
)
# `node_id` is the identifier of node causing the pasue, correspond to
# `Node.id`. Empty `node_id` means that this pause reason is not caused by any specific node
# (E.G. time slicing pauses.)
node_id: Mapped[str] = mapped_column(
String(255),
nullable=False,
default="",
)
# Relationship to WorkflowPause
pause: Mapped[WorkflowPause] = orm.relationship(
foreign_keys=[pause_id],
# require explicit preloading.
lazy="raise",
uselist=False,
primaryjoin="WorkflowPauseReason.pause_id == WorkflowPause.id",
)
@classmethod
def from_entity(cls, pause_reason: PauseReason) -> "WorkflowPauseReason":
if isinstance(pause_reason, HumanInputRequired):
return cls(
type_=PauseReasonType.HUMAN_INPUT_REQUIRED, form_id=pause_reason.form_id, node_id=pause_reason.node_id
)
elif isinstance(pause_reason, SchedulingPause):
return cls(type_=PauseReasonType.SCHEDULED_PAUSE, message=pause_reason.message, node_id="")
else:
raise AssertionError(f"Unknown pause reason type: {pause_reason}")
def to_entity(self) -> PauseReason:
if self.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED:
return HumanInputRequired(form_id=self.form_id, node_id=self.node_id)
elif self.type_ == PauseReasonType.SCHEDULED_PAUSE:
return SchedulingPause(message=self.message)
else:
raise AssertionError(f"Unknown pause reason type: {self.type_}")

View File

@ -38,11 +38,12 @@ from collections.abc import Sequence
from datetime import datetime from datetime import datetime
from typing import Protocol from typing import Protocol
from core.workflow.entities.workflow_pause import WorkflowPauseEntity from core.workflow.entities.pause_reason import PauseReason
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.enums import WorkflowRunTriggeredFrom from models.enums import WorkflowRunTriggeredFrom
from models.workflow import WorkflowRun from models.workflow import WorkflowRun
from repositories.entities.workflow_pause import WorkflowPauseEntity
from repositories.types import ( from repositories.types import (
AverageInteractionStats, AverageInteractionStats,
DailyRunsStats, DailyRunsStats,
@ -257,6 +258,7 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
workflow_run_id: str, workflow_run_id: str,
state_owner_user_id: str, state_owner_user_id: str,
state: str, state: str,
pause_reasons: Sequence[PauseReason],
) -> WorkflowPauseEntity: ) -> WorkflowPauseEntity:
""" """
Create a new workflow pause state. Create a new workflow pause state.

View File

@ -7,8 +7,11 @@ and don't contain implementation details like tenant_id, app_id, etc.
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence
from datetime import datetime from datetime import datetime
from core.workflow.entities.pause_reason import PauseReason
class WorkflowPauseEntity(ABC): class WorkflowPauseEntity(ABC):
""" """
@ -59,3 +62,15 @@ class WorkflowPauseEntity(ABC):
the pause is not resumed yet. the pause is not resumed yet.
""" """
pass pass
@abstractmethod
def get_pause_reasons(self) -> Sequence[PauseReason]:
"""
Retrieve detailed reasons for this pause.
Returns a sequence of `PauseReason` objects describing the specific nodes and
reasons for which the workflow execution was paused.
This information is related to, but distinct from, the `PauseReason` type
defined in `api/core/workflow/entities/pause_reason.py`.
"""
...

View File

@ -31,7 +31,7 @@ from sqlalchemy import and_, delete, func, null, or_, select
from sqlalchemy.engine import CursorResult from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session, selectinload, sessionmaker from sqlalchemy.orm import Session, selectinload, sessionmaker
from core.workflow.entities.workflow_pause import WorkflowPauseEntity from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, SchedulingPause
from core.workflow.enums import WorkflowExecutionStatus from core.workflow.enums import WorkflowExecutionStatus
from extensions.ext_storage import storage from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
@ -41,8 +41,9 @@ from libs.time_parser import get_time_threshold
from libs.uuid_utils import uuidv7 from libs.uuid_utils import uuidv7
from models.enums import WorkflowRunTriggeredFrom from models.enums import WorkflowRunTriggeredFrom
from models.workflow import WorkflowPause as WorkflowPauseModel from models.workflow import WorkflowPause as WorkflowPauseModel
from models.workflow import WorkflowRun from models.workflow import WorkflowPauseReason, WorkflowRun
from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.entities.workflow_pause import WorkflowPauseEntity
from repositories.types import ( from repositories.types import (
AverageInteractionStats, AverageInteractionStats,
DailyRunsStats, DailyRunsStats,
@ -318,6 +319,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
workflow_run_id: str, workflow_run_id: str,
state_owner_user_id: str, state_owner_user_id: str,
state: str, state: str,
pause_reasons: Sequence[PauseReason],
) -> WorkflowPauseEntity: ) -> WorkflowPauseEntity:
""" """
Create a new workflow pause state. Create a new workflow pause state.
@ -371,6 +373,25 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
pause_model.workflow_run_id = workflow_run.id pause_model.workflow_run_id = workflow_run.id
pause_model.state_object_key = state_obj_key pause_model.state_object_key = state_obj_key
pause_model.created_at = naive_utc_now() pause_model.created_at = naive_utc_now()
pause_reason_models = []
for reason in pause_reasons:
if isinstance(reason, HumanInputRequired):
# TODO(QuantumGhost): record node_id for `WorkflowPauseReason`
pause_reason_model = WorkflowPauseReason(
pause_id=pause_model.id,
type_=reason.TYPE,
form_id=reason.form_id,
)
elif isinstance(reason, SchedulingPause):
pause_reason_model = WorkflowPauseReason(
pause_id=pause_model.id,
type_=reason.TYPE,
message=reason.message,
)
else:
raise AssertionError(f"unkown reason type: {type(reason)}")
pause_reason_models.append(pause_reason_model)
# Update workflow run status # Update workflow run status
workflow_run.status = WorkflowExecutionStatus.PAUSED workflow_run.status = WorkflowExecutionStatus.PAUSED
@ -378,10 +399,16 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
# Save everything in a transaction # Save everything in a transaction
session.add(pause_model) session.add(pause_model)
session.add(workflow_run) session.add(workflow_run)
session.add_all(pause_reason_models)
logger.info("Created workflow pause %s for workflow run %s", pause_model.id, workflow_run_id) logger.info("Created workflow pause %s for workflow run %s", pause_model.id, workflow_run_id)
return _PrivateWorkflowPauseEntity.from_models(pause_model) return _PrivateWorkflowPauseEntity(pause_model=pause_model, reason_models=pause_reason_models)
def _get_reasons_by_pause_id(self, session: Session, pause_id: str):
reason_stmt = select(WorkflowPauseReason).where(WorkflowPauseReason.pause_id == pause_id)
pause_reason_models = session.scalars(reason_stmt).all()
return pause_reason_models
def get_workflow_pause( def get_workflow_pause(
self, self,
@ -413,8 +440,16 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
pause_model = workflow_run.pause pause_model = workflow_run.pause
if pause_model is None: if pause_model is None:
return None return None
pause_reason_models = self._get_reasons_by_pause_id(session, pause_model.id)
return _PrivateWorkflowPauseEntity.from_models(pause_model) human_input_form: list[Any] = []
# TODO(QuantumGhost): query human_input_forms model and rebuild PauseReason
return _PrivateWorkflowPauseEntity(
pause_model=pause_model,
reason_models=pause_reason_models,
human_input_form=human_input_form,
)
def resume_workflow_pause( def resume_workflow_pause(
self, self,
@ -466,6 +501,8 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
if pause_model.resumed_at is not None: if pause_model.resumed_at is not None:
raise _WorkflowRunError(f"Cannot resume an already resumed pause, pause_id={pause_model.id}") raise _WorkflowRunError(f"Cannot resume an already resumed pause, pause_id={pause_model.id}")
pause_reasons = self._get_reasons_by_pause_id(session, pause_model.id)
# Mark as resumed # Mark as resumed
pause_model.resumed_at = naive_utc_now() pause_model.resumed_at = naive_utc_now()
workflow_run.pause_id = None # type: ignore workflow_run.pause_id = None # type: ignore
@ -476,7 +513,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
logger.info("Resumed workflow pause %s for workflow run %s", pause_model.id, workflow_run_id) logger.info("Resumed workflow pause %s for workflow run %s", pause_model.id, workflow_run_id)
return _PrivateWorkflowPauseEntity.from_models(pause_model) return _PrivateWorkflowPauseEntity(pause_model=pause_model, reason_models=pause_reasons)
def delete_workflow_pause( def delete_workflow_pause(
self, self,
@ -815,26 +852,13 @@ class _PrivateWorkflowPauseEntity(WorkflowPauseEntity):
self, self,
*, *,
pause_model: WorkflowPauseModel, pause_model: WorkflowPauseModel,
reason_models: Sequence[WorkflowPauseReason],
human_input_form: Sequence = (),
) -> None: ) -> None:
self._pause_model = pause_model self._pause_model = pause_model
self._reason_models = reason_models
self._cached_state: bytes | None = None self._cached_state: bytes | None = None
self._human_input_form = human_input_form
@classmethod
def from_models(cls, workflow_pause_model) -> "_PrivateWorkflowPauseEntity":
"""
Create a _PrivateWorkflowPauseEntity from database models.
Args:
workflow_pause_model: The WorkflowPause database model
upload_file_model: The UploadFile database model
Returns:
_PrivateWorkflowPauseEntity: The constructed entity
Raises:
ValueError: If required model attributes are missing
"""
return cls(pause_model=workflow_pause_model)
@property @property
def id(self) -> str: def id(self) -> str:
@ -867,3 +891,6 @@ class _PrivateWorkflowPauseEntity(WorkflowPauseEntity):
@property @property
def resumed_at(self) -> datetime | None: def resumed_at(self) -> datetime | None:
return self._pause_model.resumed_at return self._pause_model.resumed_at
def get_pause_reasons(self) -> Sequence[PauseReason]:
return [reason.to_entity() for reason in self._reason_models]

View File

@ -15,7 +15,7 @@ from core.file import File
from core.repositories import DifyCoreRepositoryFactory from core.repositories import DifyCoreRepositoryFactory
from core.variables import Variable from core.variables import Variable
from core.variables.variables import VariableUnion from core.variables.variables import VariableUnion
from core.workflow.entities import VariablePool, WorkflowNodeExecution from core.workflow.entities import WorkflowNodeExecution
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent
@ -24,6 +24,7 @@ from core.workflow.nodes import NodeType
from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.node import Node
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.nodes.start.entities import StartNodeData from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.runtime import VariablePool
from core.workflow.system_variable import SystemVariable from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from enums.cloud_plan import CloudPlan from enums.cloud_plan import CloudPlan

View File

@ -319,7 +319,7 @@ class TestPauseStatePersistenceLayerTestContainers:
# Create pause event # Create pause event
event = GraphRunPausedEvent( event = GraphRunPausedEvent(
reason=SchedulingPause(message="test pause"), reasons=[SchedulingPause(message="test pause")],
outputs={"intermediate": "result"}, outputs={"intermediate": "result"},
) )
@ -381,7 +381,7 @@ class TestPauseStatePersistenceLayerTestContainers:
command_channel = _TestCommandChannelImpl() command_channel = _TestCommandChannelImpl()
layer.initialize(graph_runtime_state, command_channel) layer.initialize(graph_runtime_state, command_channel)
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause")) event = GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")])
# Act - Save pause state # Act - Save pause state
layer.on_event(event) layer.on_event(event)
@ -390,6 +390,7 @@ class TestPauseStatePersistenceLayerTestContainers:
pause_entity = self.workflow_run_service._workflow_run_repo.get_workflow_pause(self.test_workflow_run_id) pause_entity = self.workflow_run_service._workflow_run_repo.get_workflow_pause(self.test_workflow_run_id)
assert pause_entity is not None assert pause_entity is not None
assert pause_entity.workflow_execution_id == self.test_workflow_run_id assert pause_entity.workflow_execution_id == self.test_workflow_run_id
assert pause_entity.get_pause_reasons() == event.reasons
state_bytes = pause_entity.get_state() state_bytes = pause_entity.get_state()
resumption_context = WorkflowResumptionContext.loads(state_bytes.decode()) resumption_context = WorkflowResumptionContext.loads(state_bytes.decode())
@ -414,7 +415,7 @@ class TestPauseStatePersistenceLayerTestContainers:
command_channel = _TestCommandChannelImpl() command_channel = _TestCommandChannelImpl()
layer.initialize(graph_runtime_state, command_channel) layer.initialize(graph_runtime_state, command_channel)
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause")) event = GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")])
# Act # Act
layer.on_event(event) layer.on_event(event)
@ -448,7 +449,7 @@ class TestPauseStatePersistenceLayerTestContainers:
command_channel = _TestCommandChannelImpl() command_channel = _TestCommandChannelImpl()
layer.initialize(graph_runtime_state, command_channel) layer.initialize(graph_runtime_state, command_channel)
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause")) event = GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")])
# Act # Act
layer.on_event(event) layer.on_event(event)
@ -514,7 +515,7 @@ class TestPauseStatePersistenceLayerTestContainers:
command_channel = _TestCommandChannelImpl() command_channel = _TestCommandChannelImpl()
layer.initialize(graph_runtime_state, command_channel) layer.initialize(graph_runtime_state, command_channel)
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause")) event = GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")])
# Act # Act
layer.on_event(event) layer.on_event(event)
@ -570,7 +571,7 @@ class TestPauseStatePersistenceLayerTestContainers:
layer = self._create_pause_state_persistence_layer() layer = self._create_pause_state_persistence_layer()
# Don't initialize - graph_runtime_state should not be set # Don't initialize - graph_runtime_state should not be set
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause")) event = GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")])
# Act & Assert - Should raise AttributeError # Act & Assert - Should raise AttributeError
with pytest.raises(AttributeError): with pytest.raises(AttributeError):

View File

@ -334,12 +334,14 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run.id, workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id, state_owner_user_id=self.test_user_id,
state=test_state, state=test_state,
pause_reasons=[],
) )
# Assert - Pause state created # Assert - Pause state created
assert pause_entity is not None assert pause_entity is not None
assert pause_entity.id is not None assert pause_entity.id is not None
assert pause_entity.workflow_execution_id == workflow_run.id assert pause_entity.workflow_execution_id == workflow_run.id
assert list(pause_entity.get_pause_reasons()) == []
# Convert both to strings for comparison # Convert both to strings for comparison
retrieved_state = pause_entity.get_state() retrieved_state = pause_entity.get_state()
if isinstance(retrieved_state, bytes): if isinstance(retrieved_state, bytes):
@ -366,6 +368,7 @@ class TestWorkflowPauseIntegration:
if isinstance(retrieved_state, bytes): if isinstance(retrieved_state, bytes):
retrieved_state = retrieved_state.decode() retrieved_state = retrieved_state.decode()
assert retrieved_state == test_state assert retrieved_state == test_state
assert list(retrieved_entity.get_pause_reasons()) == []
# Act - Resume workflow # Act - Resume workflow
resumed_entity = repository.resume_workflow_pause( resumed_entity = repository.resume_workflow_pause(
@ -402,6 +405,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run.id, workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id, state_owner_user_id=self.test_user_id,
state=test_state, state=test_state,
pause_reasons=[],
) )
assert pause_entity is not None assert pause_entity is not None
@ -432,6 +436,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run.id, workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id, state_owner_user_id=self.test_user_id,
state=test_state, state=test_state,
pause_reasons=[],
) )
@pytest.mark.parametrize("test_case", resume_workflow_success_cases(), ids=lambda tc: tc.name) @pytest.mark.parametrize("test_case", resume_workflow_success_cases(), ids=lambda tc: tc.name)
@ -449,6 +454,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run.id, workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id, state_owner_user_id=self.test_user_id,
state=test_state, state=test_state,
pause_reasons=[],
) )
self.session.refresh(workflow_run) self.session.refresh(workflow_run)
@ -480,6 +486,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run.id, workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id, state_owner_user_id=self.test_user_id,
state=test_state, state=test_state,
pause_reasons=[],
) )
self.session.refresh(workflow_run) self.session.refresh(workflow_run)
@ -503,6 +510,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run.id, workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id, state_owner_user_id=self.test_user_id,
state=test_state, state=test_state,
pause_reasons=[],
) )
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id) pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
pause_model.resumed_at = naive_utc_now() pause_model.resumed_at = naive_utc_now()
@ -530,6 +538,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=nonexistent_id, workflow_run_id=nonexistent_id,
state_owner_user_id=self.test_user_id, state_owner_user_id=self.test_user_id,
state=test_state, state=test_state,
pause_reasons=[],
) )
def test_resume_nonexistent_workflow_run(self): def test_resume_nonexistent_workflow_run(self):
@ -543,6 +552,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run.id, workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id, state_owner_user_id=self.test_user_id,
state=test_state, state=test_state,
pause_reasons=[],
) )
nonexistent_id = str(uuid.uuid4()) nonexistent_id = str(uuid.uuid4())
@ -570,6 +580,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run.id, workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id, state_owner_user_id=self.test_user_id,
state=test_state, state=test_state,
pause_reasons=[],
) )
# Manually adjust timestamps for testing # Manually adjust timestamps for testing
@ -648,6 +659,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run.id, workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id, state_owner_user_id=self.test_user_id,
state=test_state, state=test_state,
pause_reasons=[],
) )
pause_entities.append(pause_entity) pause_entities.append(pause_entity)
@ -750,6 +762,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run1.id, workflow_run_id=workflow_run1.id,
state_owner_user_id=self.test_user_id, state_owner_user_id=self.test_user_id,
state=test_state, state=test_state,
pause_reasons=[],
) )
# Try to access pause from tenant 2 using tenant 1's repository # Try to access pause from tenant 2 using tenant 1's repository
@ -762,6 +775,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run2.id, workflow_run_id=workflow_run2.id,
state_owner_user_id=account2.id, state_owner_user_id=account2.id,
state=test_state, state=test_state,
pause_reasons=[],
) )
# Assert - Both pauses should exist and be separate # Assert - Both pauses should exist and be separate
@ -782,6 +796,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run.id, workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id, state_owner_user_id=self.test_user_id,
state=test_state, state=test_state,
pause_reasons=[],
) )
# Verify pause is properly scoped # Verify pause is properly scoped
@ -802,6 +817,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run.id, workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id, state_owner_user_id=self.test_user_id,
state=test_state, state=test_state,
pause_reasons=[],
) )
# Assert - Verify file was uploaded to storage # Assert - Verify file was uploaded to storage
@ -828,9 +844,7 @@ class TestWorkflowPauseIntegration:
repository = self._get_workflow_run_repository() repository = self._get_workflow_run_repository()
pause_entity = repository.create_workflow_pause( pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id, workflow_run_id=workflow_run.id, state_owner_user_id=self.test_user_id, state=test_state, pause_reasons=[]
state_owner_user_id=self.test_user_id,
state=test_state,
) )
# Get file info before deletion # Get file info before deletion
@ -868,6 +882,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run.id, workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id, state_owner_user_id=self.test_user_id,
state=large_state_json, state=large_state_json,
pause_reasons=[],
) )
# Assert # Assert
@ -902,9 +917,7 @@ class TestWorkflowPauseIntegration:
# Pause # Pause
pause_entity = repository.create_workflow_pause( pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id, workflow_run_id=workflow_run.id, state_owner_user_id=self.test_user_id, state=state, pause_reasons=[]
state_owner_user_id=self.test_user_id,
state=state,
) )
assert pause_entity is not None assert pause_entity is not None

View File

@ -31,7 +31,7 @@ class TestDataFactory:
@staticmethod @staticmethod
def create_graph_run_paused_event(outputs: dict[str, object] | None = None) -> GraphRunPausedEvent: def create_graph_run_paused_event(outputs: dict[str, object] | None = None) -> GraphRunPausedEvent:
return GraphRunPausedEvent(reason=SchedulingPause(message="test pause"), outputs=outputs or {}) return GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")], outputs=outputs or {})
@staticmethod @staticmethod
def create_graph_run_started_event() -> GraphRunStartedEvent: def create_graph_run_started_event() -> GraphRunStartedEvent:
@ -255,15 +255,17 @@ class TestPauseStatePersistenceLayer:
layer.on_event(event) layer.on_event(event)
mock_factory.assert_called_once_with(session_factory) mock_factory.assert_called_once_with(session_factory)
mock_repo.create_workflow_pause.assert_called_once_with( assert mock_repo.create_workflow_pause.call_count == 1
workflow_run_id="run-123", call_kwargs = mock_repo.create_workflow_pause.call_args.kwargs
state_owner_user_id="owner-123", assert call_kwargs["workflow_run_id"] == "run-123"
state=mock_repo.create_workflow_pause.call_args.kwargs["state"], assert call_kwargs["state_owner_user_id"] == "owner-123"
) serialized_state = call_kwargs["state"]
serialized_state = mock_repo.create_workflow_pause.call_args.kwargs["state"]
resumption_context = WorkflowResumptionContext.loads(serialized_state) resumption_context = WorkflowResumptionContext.loads(serialized_state)
assert resumption_context.serialized_graph_runtime_state == expected_state assert resumption_context.serialized_graph_runtime_state == expected_state
assert resumption_context.get_generate_entity().model_dump() == generate_entity.model_dump() assert resumption_context.get_generate_entity().model_dump() == generate_entity.model_dump()
pause_reasons = call_kwargs["pause_reasons"]
assert isinstance(pause_reasons, list)
def test_on_event_ignores_non_paused_events(self, monkeypatch: pytest.MonkeyPatch): def test_on_event_ignores_non_paused_events(self, monkeypatch: pytest.MonkeyPatch):
session_factory = Mock(name="session_factory") session_factory = Mock(name="session_factory")

View File

@ -19,38 +19,18 @@ class TestPrivateWorkflowPauseEntity:
mock_pause_model.resumed_at = None mock_pause_model.resumed_at = None
# Create entity # Create entity
entity = _PrivateWorkflowPauseEntity( entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
pause_model=mock_pause_model,
)
# Verify initialization # Verify initialization
assert entity._pause_model is mock_pause_model assert entity._pause_model is mock_pause_model
assert entity._cached_state is None assert entity._cached_state is None
def test_from_models_classmethod(self):
"""Test from_models class method."""
# Create mock models
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.id = "pause-123"
mock_pause_model.workflow_run_id = "execution-456"
# Create entity using from_models
entity = _PrivateWorkflowPauseEntity.from_models(
workflow_pause_model=mock_pause_model,
)
# Verify entity creation
assert isinstance(entity, _PrivateWorkflowPauseEntity)
assert entity._pause_model is mock_pause_model
def test_id_property(self): def test_id_property(self):
"""Test id property returns pause model ID.""" """Test id property returns pause model ID."""
mock_pause_model = MagicMock(spec=WorkflowPauseModel) mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.id = "pause-123" mock_pause_model.id = "pause-123"
entity = _PrivateWorkflowPauseEntity( entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
pause_model=mock_pause_model,
)
assert entity.id == "pause-123" assert entity.id == "pause-123"
@ -59,9 +39,7 @@ class TestPrivateWorkflowPauseEntity:
mock_pause_model = MagicMock(spec=WorkflowPauseModel) mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.workflow_run_id = "execution-456" mock_pause_model.workflow_run_id = "execution-456"
entity = _PrivateWorkflowPauseEntity( entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
pause_model=mock_pause_model,
)
assert entity.workflow_execution_id == "execution-456" assert entity.workflow_execution_id == "execution-456"
@ -72,9 +50,7 @@ class TestPrivateWorkflowPauseEntity:
mock_pause_model = MagicMock(spec=WorkflowPauseModel) mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.resumed_at = resumed_at mock_pause_model.resumed_at = resumed_at
entity = _PrivateWorkflowPauseEntity( entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
pause_model=mock_pause_model,
)
assert entity.resumed_at == resumed_at assert entity.resumed_at == resumed_at
@ -83,9 +59,7 @@ class TestPrivateWorkflowPauseEntity:
mock_pause_model = MagicMock(spec=WorkflowPauseModel) mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.resumed_at = None mock_pause_model.resumed_at = None
entity = _PrivateWorkflowPauseEntity( entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
pause_model=mock_pause_model,
)
assert entity.resumed_at is None assert entity.resumed_at is None
@ -98,9 +72,7 @@ class TestPrivateWorkflowPauseEntity:
mock_pause_model = MagicMock(spec=WorkflowPauseModel) mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.state_object_key = "test-state-key" mock_pause_model.state_object_key = "test-state-key"
entity = _PrivateWorkflowPauseEntity( entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
pause_model=mock_pause_model,
)
# First call should load from storage # First call should load from storage
result = entity.get_state() result = entity.get_state()
@ -118,9 +90,7 @@ class TestPrivateWorkflowPauseEntity:
mock_pause_model = MagicMock(spec=WorkflowPauseModel) mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.state_object_key = "test-state-key" mock_pause_model.state_object_key = "test-state-key"
entity = _PrivateWorkflowPauseEntity( entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
pause_model=mock_pause_model,
)
# First call # First call
result1 = entity.get_state() result1 = entity.get_state()
@ -139,9 +109,7 @@ class TestPrivateWorkflowPauseEntity:
mock_pause_model = MagicMock(spec=WorkflowPauseModel) mock_pause_model = MagicMock(spec=WorkflowPauseModel)
entity = _PrivateWorkflowPauseEntity( entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
pause_model=mock_pause_model,
)
# Pre-cache data # Pre-cache data
entity._cached_state = state_data entity._cached_state = state_data
@ -162,9 +130,7 @@ class TestPrivateWorkflowPauseEntity:
mock_pause_model = MagicMock(spec=WorkflowPauseModel) mock_pause_model = MagicMock(spec=WorkflowPauseModel)
entity = _PrivateWorkflowPauseEntity( entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
pause_model=mock_pause_model,
)
result = entity.get_state() result = entity.get_state()

View File

@ -8,12 +8,13 @@ from typing import Any
import pytest import pytest
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool from core.workflow.entities import GraphInitParams
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
from core.workflow.graph import Graph from core.workflow.graph import Graph
from core.workflow.graph.validation import GraphValidationError from core.workflow.graph.validation import GraphValidationError
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.node import Node
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom from models.enums import UserFrom

View File

@ -178,8 +178,7 @@ def test_pause_command():
assert any(isinstance(e, GraphRunStartedEvent) for e in events) assert any(isinstance(e, GraphRunStartedEvent) for e in events)
pause_events = [e for e in events if isinstance(e, GraphRunPausedEvent)] pause_events = [e for e in events if isinstance(e, GraphRunPausedEvent)]
assert len(pause_events) == 1 assert len(pause_events) == 1
assert pause_events[0].reason == SchedulingPause(message="User requested pause") assert pause_events[0].reasons == [SchedulingPause(message="User requested pause")]
graph_execution = engine.graph_runtime_state.graph_execution graph_execution = engine.graph_runtime_state.graph_execution
assert graph_execution.paused assert graph_execution.pause_reasons == [SchedulingPause(message="User requested pause")]
assert graph_execution.pause_reason == SchedulingPause(message="User requested pause")

View File

@ -6,10 +6,10 @@ from unittest.mock import Mock, patch
import pytest import pytest
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import Session, sessionmaker
from core.workflow.entities.workflow_pause import WorkflowPauseEntity
from core.workflow.enums import WorkflowExecutionStatus from core.workflow.enums import WorkflowExecutionStatus
from models.workflow import WorkflowPause as WorkflowPauseModel from models.workflow import WorkflowPause as WorkflowPauseModel
from models.workflow import WorkflowRun from models.workflow import WorkflowRun
from repositories.entities.workflow_pause import WorkflowPauseEntity
from repositories.sqlalchemy_api_workflow_run_repository import ( from repositories.sqlalchemy_api_workflow_run_repository import (
DifyAPISQLAlchemyWorkflowRunRepository, DifyAPISQLAlchemyWorkflowRunRepository,
_PrivateWorkflowPauseEntity, _PrivateWorkflowPauseEntity,
@ -129,12 +129,14 @@ class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
workflow_run_id=workflow_run_id, workflow_run_id=workflow_run_id,
state_owner_user_id=state_owner_user_id, state_owner_user_id=state_owner_user_id,
state=state, state=state,
pause_reasons=[],
) )
# Assert # Assert
assert isinstance(result, _PrivateWorkflowPauseEntity) assert isinstance(result, _PrivateWorkflowPauseEntity)
assert result.id == "pause-123" assert result.id == "pause-123"
assert result.workflow_execution_id == workflow_run_id assert result.workflow_execution_id == workflow_run_id
assert result.get_pause_reasons() == []
# Verify database interactions # Verify database interactions
mock_session.get.assert_called_once_with(WorkflowRun, workflow_run_id) mock_session.get.assert_called_once_with(WorkflowRun, workflow_run_id)
@ -156,6 +158,7 @@ class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
workflow_run_id="workflow-run-123", workflow_run_id="workflow-run-123",
state_owner_user_id="user-123", state_owner_user_id="user-123",
state='{"test": "state"}', state='{"test": "state"}',
pause_reasons=[],
) )
mock_session.get.assert_called_once_with(WorkflowRun, "workflow-run-123") mock_session.get.assert_called_once_with(WorkflowRun, "workflow-run-123")
@ -174,6 +177,7 @@ class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
workflow_run_id="workflow-run-123", workflow_run_id="workflow-run-123",
state_owner_user_id="user-123", state_owner_user_id="user-123",
state='{"test": "state"}', state='{"test": "state"}',
pause_reasons=[],
) )
@ -316,19 +320,10 @@ class TestDeleteWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository): class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository):
"""Test _PrivateWorkflowPauseEntity class.""" """Test _PrivateWorkflowPauseEntity class."""
def test_from_models(self, sample_workflow_pause: Mock):
"""Test creating _PrivateWorkflowPauseEntity from models."""
# Act
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
# Assert
assert isinstance(entity, _PrivateWorkflowPauseEntity)
assert entity._pause_model == sample_workflow_pause
def test_properties(self, sample_workflow_pause: Mock): def test_properties(self, sample_workflow_pause: Mock):
"""Test entity properties.""" """Test entity properties."""
# Arrange # Arrange
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause) entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
# Act & Assert # Act & Assert
assert entity.id == sample_workflow_pause.id assert entity.id == sample_workflow_pause.id
@ -338,7 +333,7 @@ class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository)
def test_get_state(self, sample_workflow_pause: Mock): def test_get_state(self, sample_workflow_pause: Mock):
"""Test getting state from storage.""" """Test getting state from storage."""
# Arrange # Arrange
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause) entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
expected_state = b'{"test": "state"}' expected_state = b'{"test": "state"}'
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage: with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
@ -354,7 +349,7 @@ class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository)
def test_get_state_caching(self, sample_workflow_pause: Mock): def test_get_state_caching(self, sample_workflow_pause: Mock):
"""Test state caching in get_state method.""" """Test state caching in get_state method."""
# Arrange # Arrange
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause) entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
expected_state = b'{"test": "state"}' expected_state = b'{"test": "state"}'
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage: with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:

View File

@ -17,6 +17,7 @@ from sqlalchemy import Engine
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import Session, sessionmaker
from core.workflow.enums import WorkflowExecutionStatus from core.workflow.enums import WorkflowExecutionStatus
from models.workflow import WorkflowPause
from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.sqlalchemy_api_workflow_run_repository import _PrivateWorkflowPauseEntity from repositories.sqlalchemy_api_workflow_run_repository import _PrivateWorkflowPauseEntity
from services.workflow_run_service import ( from services.workflow_run_service import (
@ -63,7 +64,7 @@ class TestDataFactory:
**kwargs, **kwargs,
) -> MagicMock: ) -> MagicMock:
"""Create a mock WorkflowPauseModel object.""" """Create a mock WorkflowPauseModel object."""
mock_pause = MagicMock() mock_pause = MagicMock(spec=WorkflowPause)
mock_pause.id = id mock_pause.id = id
mock_pause.tenant_id = tenant_id mock_pause.tenant_id = tenant_id
mock_pause.app_id = app_id mock_pause.app_id = app_id
@ -77,38 +78,15 @@ class TestDataFactory:
return mock_pause return mock_pause
@staticmethod
def create_upload_file_mock(
id: str = "file-456",
key: str = "upload_files/test/state.json",
name: str = "state.json",
tenant_id: str = "tenant-456",
**kwargs,
) -> MagicMock:
"""Create a mock UploadFile object."""
mock_file = MagicMock()
mock_file.id = id
mock_file.key = key
mock_file.name = name
mock_file.tenant_id = tenant_id
for key, value in kwargs.items():
setattr(mock_file, key, value)
return mock_file
@staticmethod @staticmethod
def create_pause_entity_mock( def create_pause_entity_mock(
pause_model: MagicMock | None = None, pause_model: MagicMock | None = None,
upload_file: MagicMock | None = None,
) -> _PrivateWorkflowPauseEntity: ) -> _PrivateWorkflowPauseEntity:
"""Create a mock _PrivateWorkflowPauseEntity object.""" """Create a mock _PrivateWorkflowPauseEntity object."""
if pause_model is None: if pause_model is None:
pause_model = TestDataFactory.create_workflow_pause_mock() pause_model = TestDataFactory.create_workflow_pause_mock()
if upload_file is None:
upload_file = TestDataFactory.create_upload_file_mock()
return _PrivateWorkflowPauseEntity.from_models(pause_model, upload_file) return _PrivateWorkflowPauseEntity(pause_model=pause_model, reason_models=[], human_input_form=[])
class TestWorkflowRunService: class TestWorkflowRunService: