From 1c1f124891c8e07f4c18fd1f9239226547bf32ae Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Wed, 26 Nov 2025 19:59:34 +0800 Subject: [PATCH] Enhanced GraphEngine Pause Handling (#28196) This commit: 1. Convert `pause_reason` to `pause_reasons` in `GraphExecution` and relevant classes. Change the field from a scalar value to a list that can contain multiple `PauseReason` objects, ensuring all pause events are properly captured. 2. Introduce a new `WorkflowPauseReason` model to record reasons associated with a specific `WorkflowPause`. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: -LAN- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- api/.importlinter | 1 + .../app/layers/pause_state_persist_layer.py | 1 + api/core/workflow/entities/__init__.py | 6 -- api/core/workflow/entities/pause_reason.py | 47 ++++-------- .../graph_engine/domain/graph_execution.py | 12 ++-- .../event_management/event_manager.py | 8 ++- .../workflow/graph_engine/graph_engine.py | 8 +-- api/core/workflow/graph_events/graph.py | 3 +- .../nodes/human_input/human_input_node.py | 3 +- .../workflow/runtime/graph_runtime_state.py | 8 ++- ...b7a422_add_workflow_pause_reasons_table.py | 41 +++++++++++ api/models/workflow.py | 66 +++++++++++++++++ .../api_workflow_run_repository.py | 4 +- .../entities/workflow_pause.py | 15 ++++ .../sqlalchemy_api_workflow_run_repository.py | 71 +++++++++++++------ api/services/workflow_service.py | 3 +- .../layers/test_pause_state_persist_layer.py | 13 ++-- .../test_workflow_pause_integration.py | 25 +++++-- .../layers/test_pause_state_persist_layer.py | 16 +++-- .../entities/test_private_workflow_pause.py | 52 +++----------- .../workflow/graph/test_graph_validation.py | 3 +- .../graph_engine/test_command_system.py | 5 +- ..._sqlalchemy_api_workflow_run_repository.py | 21 +++--- .../test_workflow_run_service_pause.py | 28 +------- 24 files changed, 275 insertions(+), 185 deletions(-) create mode 100644 api/migrations/versions/2025_11_18_1859-7bb281b7a422_add_workflow_pause_reasons_table.py rename api/{core/workflow => repositories}/entities/workflow_pause.py (77%) diff --git a/api/.importlinter b/api/.importlinter index 98fe5f50bb..24ece72b30 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -16,6 +16,7 @@ layers = graph nodes node_events + runtime entities containers = core.workflow diff --git a/api/core/app/layers/pause_state_persist_layer.py b/api/core/app/layers/pause_state_persist_layer.py index 412eb98dd4..61a3e1baca 100644 --- a/api/core/app/layers/pause_state_persist_layer.py +++ b/api/core/app/layers/pause_state_persist_layer.py @@ -118,6 +118,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer): workflow_run_id=workflow_run_id, state_owner_user_id=self._state_owner_user_id, state=state.dumps(), + pause_reasons=event.reasons, ) def on_graph_end(self, error: Exception | None) -> None: diff --git a/api/core/workflow/entities/__init__.py b/api/core/workflow/entities/__init__.py index f4ce9052e0..be70e467a0 100644 --- a/api/core/workflow/entities/__init__.py +++ b/api/core/workflow/entities/__init__.py @@ -1,17 +1,11 @@ -from ..runtime.graph_runtime_state import GraphRuntimeState -from ..runtime.variable_pool import VariablePool from .agent import AgentNodeStrategyInit from .graph_init_params import GraphInitParams from .workflow_execution import WorkflowExecution from .workflow_node_execution import WorkflowNodeExecution -from .workflow_pause import WorkflowPauseEntity __all__ = [ "AgentNodeStrategyInit", "GraphInitParams", - "GraphRuntimeState", - "VariablePool", "WorkflowExecution", "WorkflowNodeExecution", - "WorkflowPauseEntity", ] diff --git a/api/core/workflow/entities/pause_reason.py b/api/core/workflow/entities/pause_reason.py index 16ad3d639d..c6655b7eab 100644 --- a/api/core/workflow/entities/pause_reason.py +++ b/api/core/workflow/entities/pause_reason.py @@ -1,49 +1,26 @@ from enum import StrEnum, auto -from typing import Annotated, Any, ClassVar, TypeAlias +from typing import Annotated, Literal, TypeAlias -from pydantic import BaseModel, Discriminator, Tag +from pydantic import BaseModel, Field -class _PauseReasonType(StrEnum): +class PauseReasonType(StrEnum): HUMAN_INPUT_REQUIRED = auto() SCHEDULED_PAUSE = auto() -class _PauseReasonBase(BaseModel): - TYPE: ClassVar[_PauseReasonType] +class HumanInputRequired(BaseModel): + TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED + + form_id: str + # The identifier of the human input node causing the pause. + node_id: str -class HumanInputRequired(_PauseReasonBase): - TYPE = _PauseReasonType.HUMAN_INPUT_REQUIRED - - -class SchedulingPause(_PauseReasonBase): - TYPE = _PauseReasonType.SCHEDULED_PAUSE +class SchedulingPause(BaseModel): + TYPE: Literal[PauseReasonType.SCHEDULED_PAUSE] = PauseReasonType.SCHEDULED_PAUSE message: str -def _get_pause_reason_discriminator(v: Any) -> _PauseReasonType | None: - if isinstance(v, _PauseReasonBase): - return v.TYPE - elif isinstance(v, dict): - reason_type_str = v.get("TYPE") - if reason_type_str is None: - return None - try: - reason_type = _PauseReasonType(reason_type_str) - except ValueError: - return None - return reason_type - else: - # return None if the discriminator value isn't found - return None - - -PauseReason: TypeAlias = Annotated[ - ( - Annotated[HumanInputRequired, Tag(_PauseReasonType.HUMAN_INPUT_REQUIRED)] - | Annotated[SchedulingPause, Tag(_PauseReasonType.SCHEDULED_PAUSE)] - ), - Discriminator(_get_pause_reason_discriminator), -] +PauseReason: TypeAlias = Annotated[HumanInputRequired | SchedulingPause, Field(discriminator="TYPE")] diff --git a/api/core/workflow/graph_engine/domain/graph_execution.py b/api/core/workflow/graph_engine/domain/graph_execution.py index 3d587d6691..9ca607458f 100644 --- a/api/core/workflow/graph_engine/domain/graph_execution.py +++ b/api/core/workflow/graph_engine/domain/graph_execution.py @@ -42,7 +42,7 @@ class GraphExecutionState(BaseModel): completed: bool = Field(default=False) aborted: bool = Field(default=False) paused: bool = Field(default=False) - pause_reason: PauseReason | None = Field(default=None) + pause_reasons: list[PauseReason] = Field(default_factory=list) error: GraphExecutionErrorState | None = Field(default=None) exceptions_count: int = Field(default=0) node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState]) @@ -107,7 +107,7 @@ class GraphExecution: completed: bool = False aborted: bool = False paused: bool = False - pause_reason: PauseReason | None = None + pause_reasons: list[PauseReason] = field(default_factory=list) error: Exception | None = None node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution]) exceptions_count: int = 0 @@ -137,10 +137,8 @@ class GraphExecution: raise RuntimeError("Cannot pause execution that has completed") if self.aborted: raise RuntimeError("Cannot pause execution that has been aborted") - if self.paused: - return self.paused = True - self.pause_reason = reason + self.pause_reasons.append(reason) def fail(self, error: Exception) -> None: """Mark the graph execution as failed.""" @@ -195,7 +193,7 @@ class GraphExecution: completed=self.completed, aborted=self.aborted, paused=self.paused, - pause_reason=self.pause_reason, + pause_reasons=self.pause_reasons, error=_serialize_error(self.error), exceptions_count=self.exceptions_count, node_executions=node_states, @@ -221,7 +219,7 @@ class GraphExecution: self.completed = state.completed self.aborted = state.aborted self.paused = state.paused - self.pause_reason = state.pause_reason + self.pause_reasons = state.pause_reasons self.error = _deserialize_error(state.error) self.exceptions_count = state.exceptions_count self.node_executions = { diff --git a/api/core/workflow/graph_engine/event_management/event_manager.py b/api/core/workflow/graph_engine/event_management/event_manager.py index 689cf53cf0..71043b9a43 100644 --- a/api/core/workflow/graph_engine/event_management/event_manager.py +++ b/api/core/workflow/graph_engine/event_management/event_manager.py @@ -110,7 +110,13 @@ class EventManager: """ with self._lock.write_lock(): self._events.append(event) - self._notify_layers(event) + + # NOTE: `_notify_layers` is intentionally called outside the critical section + # to minimize lock contention and avoid blocking other readers or writers. + # + # The public `notify_layers` method also does not use a write lock, + # so protecting `_notify_layers` with a lock here is unnecessary. + self._notify_layers(event) def _get_new_events(self, start_index: int) -> list[GraphEngineEvent]: """ diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 98e1a20044..a4b2df2a8c 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -232,7 +232,7 @@ class GraphEngine: self._graph_execution.start() else: self._graph_execution.paused = False - self._graph_execution.pause_reason = None + self._graph_execution.pause_reasons = [] start_event = GraphRunStartedEvent() self._event_manager.notify_layers(start_event) @@ -246,11 +246,11 @@ class GraphEngine: # Handle completion if self._graph_execution.is_paused: - pause_reason = self._graph_execution.pause_reason - assert pause_reason is not None, "pause_reason should not be None when execution is paused." + pause_reasons = self._graph_execution.pause_reasons + assert pause_reasons, "pause_reasons should not be empty when execution is paused." # Ensure we have a valid PauseReason for the event paused_event = GraphRunPausedEvent( - reason=pause_reason, + reasons=pause_reasons, outputs=self._graph_runtime_state.outputs, ) self._event_manager.notify_layers(paused_event) diff --git a/api/core/workflow/graph_events/graph.py b/api/core/workflow/graph_events/graph.py index 9faafc3173..5d10a76c15 100644 --- a/api/core/workflow/graph_events/graph.py +++ b/api/core/workflow/graph_events/graph.py @@ -45,8 +45,7 @@ class GraphRunAbortedEvent(BaseGraphEvent): class GraphRunPausedEvent(BaseGraphEvent): """Event emitted when a graph run is paused by user command.""" - # reason: str | None = Field(default=None, description="reason for pause") - reason: PauseReason = Field(..., description="reason for pause") + reasons: list[PauseReason] = Field(description="reason for pause", default_factory=list) outputs: dict[str, object] = Field( default_factory=dict, description="Outputs available to the client while the run is paused.", diff --git a/api/core/workflow/nodes/human_input/human_input_node.py b/api/core/workflow/nodes/human_input/human_input_node.py index 2d6d9760af..c0d64a060a 100644 --- a/api/core/workflow/nodes/human_input/human_input_node.py +++ b/api/core/workflow/nodes/human_input/human_input_node.py @@ -65,7 +65,8 @@ class HumanInputNode(Node): return self._pause_generator() def _pause_generator(self): - yield PauseRequestedEvent(reason=HumanInputRequired()) + # TODO(QuantumGhost): yield a real form id. + yield PauseRequestedEvent(reason=HumanInputRequired(form_id="test_form_id", node_id=self.id)) def _is_completion_ready(self) -> bool: """Determine whether all required inputs are satisfied.""" diff --git a/api/core/workflow/runtime/graph_runtime_state.py b/api/core/workflow/runtime/graph_runtime_state.py index 0fbc8ab23e..1561b789df 100644 --- a/api/core/workflow/runtime/graph_runtime_state.py +++ b/api/core/workflow/runtime/graph_runtime_state.py @@ -10,6 +10,7 @@ from typing import Any, Protocol from pydantic.json import pydantic_encoder from core.model_runtime.entities.llm_entities import LLMUsage +from core.workflow.entities.pause_reason import PauseReason from core.workflow.runtime.variable_pool import VariablePool @@ -46,7 +47,11 @@ class ReadyQueueProtocol(Protocol): class GraphExecutionProtocol(Protocol): - """Structural interface for graph execution aggregate.""" + """Structural interface for graph execution aggregate. + + Defines the minimal set of attributes and methods required from a GraphExecution entity + for runtime orchestration and state management. + """ workflow_id: str started: bool @@ -54,6 +59,7 @@ class GraphExecutionProtocol(Protocol): aborted: bool error: Exception | None exceptions_count: int + pause_reasons: list[PauseReason] def start(self) -> None: """Transition execution into the running state.""" diff --git a/api/migrations/versions/2025_11_18_1859-7bb281b7a422_add_workflow_pause_reasons_table.py b/api/migrations/versions/2025_11_18_1859-7bb281b7a422_add_workflow_pause_reasons_table.py new file mode 100644 index 0000000000..8478820999 --- /dev/null +++ b/api/migrations/versions/2025_11_18_1859-7bb281b7a422_add_workflow_pause_reasons_table.py @@ -0,0 +1,41 @@ +"""Add workflow_pauses_reasons table + +Revision ID: 7bb281b7a422 +Revises: 09cfdda155d1 +Create Date: 2025-11-18 18:59:26.999572 + +""" + +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "7bb281b7a422" +down_revision = "09cfdda155d1" +branch_labels = None +depends_on = None + + +def upgrade(): + op.create_table( + "workflow_pause_reasons", + sa.Column("id", models.types.StringUUID(), nullable=False), + sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + + sa.Column("pause_id", models.types.StringUUID(), nullable=False), + sa.Column("type_", sa.String(20), nullable=False), + sa.Column("form_id", sa.String(length=36), nullable=False), + sa.Column("node_id", sa.String(length=255), nullable=False), + sa.Column("message", sa.String(length=255), nullable=False), + + sa.PrimaryKeyConstraint("id", name=op.f("workflow_pause_reasons_pkey")), + ) + with op.batch_alter_table("workflow_pause_reasons", schema=None) as batch_op: + batch_op.create_index(batch_op.f("workflow_pause_reasons_pause_id_idx"), ["pause_id"], unique=False) + + +def downgrade(): + op.drop_table("workflow_pause_reasons") diff --git a/api/models/workflow.py b/api/models/workflow.py index f206a6a870..4efa829692 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -29,6 +29,7 @@ from core.workflow.constants import ( CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID, ) +from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause from core.workflow.enums import NodeType from extensions.ext_storage import Storage from factories.variable_factory import TypeMismatchError, build_segment_with_type @@ -1728,3 +1729,68 @@ class WorkflowPause(DefaultFieldsMixin, Base): primaryjoin="WorkflowPause.workflow_run_id == WorkflowRun.id", back_populates="pause", ) + + +class WorkflowPauseReason(DefaultFieldsMixin, Base): + __tablename__ = "workflow_pause_reasons" + + # `pause_id` represents the identifier of the pause, + # correspond to the `id` field of `WorkflowPause`. + pause_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True) + + type_: Mapped[PauseReasonType] = mapped_column(EnumText(PauseReasonType), nullable=False) + + # form_id is not empty if and if only type_ == PauseReasonType.HUMAN_INPUT_REQUIRED + # + form_id: Mapped[str] = mapped_column( + String(36), + nullable=False, + default="", + ) + + # message records the text description of this pause reason. For example, + # "The workflow has been paused due to scheduling." + # + # Empty message means that this pause reason is not speified. + message: Mapped[str] = mapped_column( + String(255), + nullable=False, + default="", + ) + + # `node_id` is the identifier of node causing the pasue, correspond to + # `Node.id`. Empty `node_id` means that this pause reason is not caused by any specific node + # (E.G. time slicing pauses.) + node_id: Mapped[str] = mapped_column( + String(255), + nullable=False, + default="", + ) + + # Relationship to WorkflowPause + pause: Mapped[WorkflowPause] = orm.relationship( + foreign_keys=[pause_id], + # require explicit preloading. + lazy="raise", + uselist=False, + primaryjoin="WorkflowPauseReason.pause_id == WorkflowPause.id", + ) + + @classmethod + def from_entity(cls, pause_reason: PauseReason) -> "WorkflowPauseReason": + if isinstance(pause_reason, HumanInputRequired): + return cls( + type_=PauseReasonType.HUMAN_INPUT_REQUIRED, form_id=pause_reason.form_id, node_id=pause_reason.node_id + ) + elif isinstance(pause_reason, SchedulingPause): + return cls(type_=PauseReasonType.SCHEDULED_PAUSE, message=pause_reason.message, node_id="") + else: + raise AssertionError(f"Unknown pause reason type: {pause_reason}") + + def to_entity(self) -> PauseReason: + if self.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED: + return HumanInputRequired(form_id=self.form_id, node_id=self.node_id) + elif self.type_ == PauseReasonType.SCHEDULED_PAUSE: + return SchedulingPause(message=self.message) + else: + raise AssertionError(f"Unknown pause reason type: {self.type_}") diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py index 21fd57cd22..fd547c78ba 100644 --- a/api/repositories/api_workflow_run_repository.py +++ b/api/repositories/api_workflow_run_repository.py @@ -38,11 +38,12 @@ from collections.abc import Sequence from datetime import datetime from typing import Protocol -from core.workflow.entities.workflow_pause import WorkflowPauseEntity +from core.workflow.entities.pause_reason import PauseReason from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import WorkflowRunTriggeredFrom from models.workflow import WorkflowRun +from repositories.entities.workflow_pause import WorkflowPauseEntity from repositories.types import ( AverageInteractionStats, DailyRunsStats, @@ -257,6 +258,7 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): workflow_run_id: str, state_owner_user_id: str, state: str, + pause_reasons: Sequence[PauseReason], ) -> WorkflowPauseEntity: """ Create a new workflow pause state. diff --git a/api/core/workflow/entities/workflow_pause.py b/api/repositories/entities/workflow_pause.py similarity index 77% rename from api/core/workflow/entities/workflow_pause.py rename to api/repositories/entities/workflow_pause.py index 2f31c1ff53..b970f39816 100644 --- a/api/core/workflow/entities/workflow_pause.py +++ b/api/repositories/entities/workflow_pause.py @@ -7,8 +7,11 @@ and don't contain implementation details like tenant_id, app_id, etc. """ from abc import ABC, abstractmethod +from collections.abc import Sequence from datetime import datetime +from core.workflow.entities.pause_reason import PauseReason + class WorkflowPauseEntity(ABC): """ @@ -59,3 +62,15 @@ class WorkflowPauseEntity(ABC): the pause is not resumed yet. """ pass + + @abstractmethod + def get_pause_reasons(self) -> Sequence[PauseReason]: + """ + Retrieve detailed reasons for this pause. + + Returns a sequence of `PauseReason` objects describing the specific nodes and + reasons for which the workflow execution was paused. + This information is related to, but distinct from, the `PauseReason` type + defined in `api/core/workflow/entities/pause_reason.py`. + """ + ... diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index eb2a32d764..b172c6a3ac 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -31,7 +31,7 @@ from sqlalchemy import and_, delete, func, null, or_, select from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, selectinload, sessionmaker -from core.workflow.entities.workflow_pause import WorkflowPauseEntity +from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, SchedulingPause from core.workflow.enums import WorkflowExecutionStatus from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now @@ -41,8 +41,9 @@ from libs.time_parser import get_time_threshold from libs.uuid_utils import uuidv7 from models.enums import WorkflowRunTriggeredFrom from models.workflow import WorkflowPause as WorkflowPauseModel -from models.workflow import WorkflowRun +from models.workflow import WorkflowPauseReason, WorkflowRun from repositories.api_workflow_run_repository import APIWorkflowRunRepository +from repositories.entities.workflow_pause import WorkflowPauseEntity from repositories.types import ( AverageInteractionStats, DailyRunsStats, @@ -318,6 +319,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): workflow_run_id: str, state_owner_user_id: str, state: str, + pause_reasons: Sequence[PauseReason], ) -> WorkflowPauseEntity: """ Create a new workflow pause state. @@ -371,6 +373,25 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): pause_model.workflow_run_id = workflow_run.id pause_model.state_object_key = state_obj_key pause_model.created_at = naive_utc_now() + pause_reason_models = [] + for reason in pause_reasons: + if isinstance(reason, HumanInputRequired): + # TODO(QuantumGhost): record node_id for `WorkflowPauseReason` + pause_reason_model = WorkflowPauseReason( + pause_id=pause_model.id, + type_=reason.TYPE, + form_id=reason.form_id, + ) + elif isinstance(reason, SchedulingPause): + pause_reason_model = WorkflowPauseReason( + pause_id=pause_model.id, + type_=reason.TYPE, + message=reason.message, + ) + else: + raise AssertionError(f"unkown reason type: {type(reason)}") + + pause_reason_models.append(pause_reason_model) # Update workflow run status workflow_run.status = WorkflowExecutionStatus.PAUSED @@ -378,10 +399,16 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): # Save everything in a transaction session.add(pause_model) session.add(workflow_run) + session.add_all(pause_reason_models) logger.info("Created workflow pause %s for workflow run %s", pause_model.id, workflow_run_id) - return _PrivateWorkflowPauseEntity.from_models(pause_model) + return _PrivateWorkflowPauseEntity(pause_model=pause_model, reason_models=pause_reason_models) + + def _get_reasons_by_pause_id(self, session: Session, pause_id: str): + reason_stmt = select(WorkflowPauseReason).where(WorkflowPauseReason.pause_id == pause_id) + pause_reason_models = session.scalars(reason_stmt).all() + return pause_reason_models def get_workflow_pause( self, @@ -413,8 +440,16 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): pause_model = workflow_run.pause if pause_model is None: return None + pause_reason_models = self._get_reasons_by_pause_id(session, pause_model.id) - return _PrivateWorkflowPauseEntity.from_models(pause_model) + human_input_form: list[Any] = [] + # TODO(QuantumGhost): query human_input_forms model and rebuild PauseReason + + return _PrivateWorkflowPauseEntity( + pause_model=pause_model, + reason_models=pause_reason_models, + human_input_form=human_input_form, + ) def resume_workflow_pause( self, @@ -466,6 +501,8 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): if pause_model.resumed_at is not None: raise _WorkflowRunError(f"Cannot resume an already resumed pause, pause_id={pause_model.id}") + pause_reasons = self._get_reasons_by_pause_id(session, pause_model.id) + # Mark as resumed pause_model.resumed_at = naive_utc_now() workflow_run.pause_id = None # type: ignore @@ -476,7 +513,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): logger.info("Resumed workflow pause %s for workflow run %s", pause_model.id, workflow_run_id) - return _PrivateWorkflowPauseEntity.from_models(pause_model) + return _PrivateWorkflowPauseEntity(pause_model=pause_model, reason_models=pause_reasons) def delete_workflow_pause( self, @@ -815,26 +852,13 @@ class _PrivateWorkflowPauseEntity(WorkflowPauseEntity): self, *, pause_model: WorkflowPauseModel, + reason_models: Sequence[WorkflowPauseReason], + human_input_form: Sequence = (), ) -> None: self._pause_model = pause_model + self._reason_models = reason_models self._cached_state: bytes | None = None - - @classmethod - def from_models(cls, workflow_pause_model) -> "_PrivateWorkflowPauseEntity": - """ - Create a _PrivateWorkflowPauseEntity from database models. - - Args: - workflow_pause_model: The WorkflowPause database model - upload_file_model: The UploadFile database model - - Returns: - _PrivateWorkflowPauseEntity: The constructed entity - - Raises: - ValueError: If required model attributes are missing - """ - return cls(pause_model=workflow_pause_model) + self._human_input_form = human_input_form @property def id(self) -> str: @@ -867,3 +891,6 @@ class _PrivateWorkflowPauseEntity(WorkflowPauseEntity): @property def resumed_at(self) -> datetime | None: return self._pause_model.resumed_at + + def get_pause_reasons(self) -> Sequence[PauseReason]: + return [reason.to_entity() for reason in self._reason_models] diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index b6764f1fa7..b45a167b73 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -15,7 +15,7 @@ from core.file import File from core.repositories import DifyCoreRepositoryFactory from core.variables import Variable from core.variables.variables import VariableUnion -from core.workflow.entities import VariablePool, WorkflowNodeExecution +from core.workflow.entities import WorkflowNodeExecution from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent @@ -24,6 +24,7 @@ from core.workflow.nodes import NodeType from core.workflow.nodes.base.node import Node from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from core.workflow.nodes.start.entities import StartNodeData +from core.workflow.runtime import VariablePool from core.workflow.system_variable import SystemVariable from core.workflow.workflow_entry import WorkflowEntry from enums.cloud_plan import CloudPlan diff --git a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py index bec3517d66..72469ad646 100644 --- a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py @@ -319,7 +319,7 @@ class TestPauseStatePersistenceLayerTestContainers: # Create pause event event = GraphRunPausedEvent( - reason=SchedulingPause(message="test pause"), + reasons=[SchedulingPause(message="test pause")], outputs={"intermediate": "result"}, ) @@ -381,7 +381,7 @@ class TestPauseStatePersistenceLayerTestContainers: command_channel = _TestCommandChannelImpl() layer.initialize(graph_runtime_state, command_channel) - event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause")) + event = GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")]) # Act - Save pause state layer.on_event(event) @@ -390,6 +390,7 @@ class TestPauseStatePersistenceLayerTestContainers: pause_entity = self.workflow_run_service._workflow_run_repo.get_workflow_pause(self.test_workflow_run_id) assert pause_entity is not None assert pause_entity.workflow_execution_id == self.test_workflow_run_id + assert pause_entity.get_pause_reasons() == event.reasons state_bytes = pause_entity.get_state() resumption_context = WorkflowResumptionContext.loads(state_bytes.decode()) @@ -414,7 +415,7 @@ class TestPauseStatePersistenceLayerTestContainers: command_channel = _TestCommandChannelImpl() layer.initialize(graph_runtime_state, command_channel) - event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause")) + event = GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")]) # Act layer.on_event(event) @@ -448,7 +449,7 @@ class TestPauseStatePersistenceLayerTestContainers: command_channel = _TestCommandChannelImpl() layer.initialize(graph_runtime_state, command_channel) - event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause")) + event = GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")]) # Act layer.on_event(event) @@ -514,7 +515,7 @@ class TestPauseStatePersistenceLayerTestContainers: command_channel = _TestCommandChannelImpl() layer.initialize(graph_runtime_state, command_channel) - event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause")) + event = GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")]) # Act layer.on_event(event) @@ -570,7 +571,7 @@ class TestPauseStatePersistenceLayerTestContainers: layer = self._create_pause_state_persistence_layer() # Don't initialize - graph_runtime_state should not be set - event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause")) + event = GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")]) # Act & Assert - Should raise AttributeError with pytest.raises(AttributeError): diff --git a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py index 79da5d4d0e..889e3d1d83 100644 --- a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py +++ b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py @@ -334,12 +334,14 @@ class TestWorkflowPauseIntegration: workflow_run_id=workflow_run.id, state_owner_user_id=self.test_user_id, state=test_state, + pause_reasons=[], ) # Assert - Pause state created assert pause_entity is not None assert pause_entity.id is not None assert pause_entity.workflow_execution_id == workflow_run.id + assert list(pause_entity.get_pause_reasons()) == [] # Convert both to strings for comparison retrieved_state = pause_entity.get_state() if isinstance(retrieved_state, bytes): @@ -366,6 +368,7 @@ class TestWorkflowPauseIntegration: if isinstance(retrieved_state, bytes): retrieved_state = retrieved_state.decode() assert retrieved_state == test_state + assert list(retrieved_entity.get_pause_reasons()) == [] # Act - Resume workflow resumed_entity = repository.resume_workflow_pause( @@ -402,6 +405,7 @@ class TestWorkflowPauseIntegration: workflow_run_id=workflow_run.id, state_owner_user_id=self.test_user_id, state=test_state, + pause_reasons=[], ) assert pause_entity is not None @@ -432,6 +436,7 @@ class TestWorkflowPauseIntegration: workflow_run_id=workflow_run.id, state_owner_user_id=self.test_user_id, state=test_state, + pause_reasons=[], ) @pytest.mark.parametrize("test_case", resume_workflow_success_cases(), ids=lambda tc: tc.name) @@ -449,6 +454,7 @@ class TestWorkflowPauseIntegration: workflow_run_id=workflow_run.id, state_owner_user_id=self.test_user_id, state=test_state, + pause_reasons=[], ) self.session.refresh(workflow_run) @@ -480,6 +486,7 @@ class TestWorkflowPauseIntegration: workflow_run_id=workflow_run.id, state_owner_user_id=self.test_user_id, state=test_state, + pause_reasons=[], ) self.session.refresh(workflow_run) @@ -503,6 +510,7 @@ class TestWorkflowPauseIntegration: workflow_run_id=workflow_run.id, state_owner_user_id=self.test_user_id, state=test_state, + pause_reasons=[], ) pause_model = self.session.get(WorkflowPauseModel, pause_entity.id) pause_model.resumed_at = naive_utc_now() @@ -530,6 +538,7 @@ class TestWorkflowPauseIntegration: workflow_run_id=nonexistent_id, state_owner_user_id=self.test_user_id, state=test_state, + pause_reasons=[], ) def test_resume_nonexistent_workflow_run(self): @@ -543,6 +552,7 @@ class TestWorkflowPauseIntegration: workflow_run_id=workflow_run.id, state_owner_user_id=self.test_user_id, state=test_state, + pause_reasons=[], ) nonexistent_id = str(uuid.uuid4()) @@ -570,6 +580,7 @@ class TestWorkflowPauseIntegration: workflow_run_id=workflow_run.id, state_owner_user_id=self.test_user_id, state=test_state, + pause_reasons=[], ) # Manually adjust timestamps for testing @@ -648,6 +659,7 @@ class TestWorkflowPauseIntegration: workflow_run_id=workflow_run.id, state_owner_user_id=self.test_user_id, state=test_state, + pause_reasons=[], ) pause_entities.append(pause_entity) @@ -750,6 +762,7 @@ class TestWorkflowPauseIntegration: workflow_run_id=workflow_run1.id, state_owner_user_id=self.test_user_id, state=test_state, + pause_reasons=[], ) # Try to access pause from tenant 2 using tenant 1's repository @@ -762,6 +775,7 @@ class TestWorkflowPauseIntegration: workflow_run_id=workflow_run2.id, state_owner_user_id=account2.id, state=test_state, + pause_reasons=[], ) # Assert - Both pauses should exist and be separate @@ -782,6 +796,7 @@ class TestWorkflowPauseIntegration: workflow_run_id=workflow_run.id, state_owner_user_id=self.test_user_id, state=test_state, + pause_reasons=[], ) # Verify pause is properly scoped @@ -802,6 +817,7 @@ class TestWorkflowPauseIntegration: workflow_run_id=workflow_run.id, state_owner_user_id=self.test_user_id, state=test_state, + pause_reasons=[], ) # Assert - Verify file was uploaded to storage @@ -828,9 +844,7 @@ class TestWorkflowPauseIntegration: repository = self._get_workflow_run_repository() pause_entity = repository.create_workflow_pause( - workflow_run_id=workflow_run.id, - state_owner_user_id=self.test_user_id, - state=test_state, + workflow_run_id=workflow_run.id, state_owner_user_id=self.test_user_id, state=test_state, pause_reasons=[] ) # Get file info before deletion @@ -868,6 +882,7 @@ class TestWorkflowPauseIntegration: workflow_run_id=workflow_run.id, state_owner_user_id=self.test_user_id, state=large_state_json, + pause_reasons=[], ) # Assert @@ -902,9 +917,7 @@ class TestWorkflowPauseIntegration: # Pause pause_entity = repository.create_workflow_pause( - workflow_run_id=workflow_run.id, - state_owner_user_id=self.test_user_id, - state=state, + workflow_run_id=workflow_run.id, state_owner_user_id=self.test_user_id, state=state, pause_reasons=[] ) assert pause_entity is not None diff --git a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py index 807f5e0fa5..534420f21e 100644 --- a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py @@ -31,7 +31,7 @@ class TestDataFactory: @staticmethod def create_graph_run_paused_event(outputs: dict[str, object] | None = None) -> GraphRunPausedEvent: - return GraphRunPausedEvent(reason=SchedulingPause(message="test pause"), outputs=outputs or {}) + return GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")], outputs=outputs or {}) @staticmethod def create_graph_run_started_event() -> GraphRunStartedEvent: @@ -255,15 +255,17 @@ class TestPauseStatePersistenceLayer: layer.on_event(event) mock_factory.assert_called_once_with(session_factory) - mock_repo.create_workflow_pause.assert_called_once_with( - workflow_run_id="run-123", - state_owner_user_id="owner-123", - state=mock_repo.create_workflow_pause.call_args.kwargs["state"], - ) - serialized_state = mock_repo.create_workflow_pause.call_args.kwargs["state"] + assert mock_repo.create_workflow_pause.call_count == 1 + call_kwargs = mock_repo.create_workflow_pause.call_args.kwargs + assert call_kwargs["workflow_run_id"] == "run-123" + assert call_kwargs["state_owner_user_id"] == "owner-123" + serialized_state = call_kwargs["state"] resumption_context = WorkflowResumptionContext.loads(serialized_state) assert resumption_context.serialized_graph_runtime_state == expected_state assert resumption_context.get_generate_entity().model_dump() == generate_entity.model_dump() + pause_reasons = call_kwargs["pause_reasons"] + + assert isinstance(pause_reasons, list) def test_on_event_ignores_non_paused_events(self, monkeypatch: pytest.MonkeyPatch): session_factory = Mock(name="session_factory") diff --git a/api/tests/unit_tests/core/workflow/entities/test_private_workflow_pause.py b/api/tests/unit_tests/core/workflow/entities/test_private_workflow_pause.py index ccb2dff85a..be165bf1c1 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_private_workflow_pause.py +++ b/api/tests/unit_tests/core/workflow/entities/test_private_workflow_pause.py @@ -19,38 +19,18 @@ class TestPrivateWorkflowPauseEntity: mock_pause_model.resumed_at = None # Create entity - entity = _PrivateWorkflowPauseEntity( - pause_model=mock_pause_model, - ) + entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[]) # Verify initialization assert entity._pause_model is mock_pause_model assert entity._cached_state is None - def test_from_models_classmethod(self): - """Test from_models class method.""" - # Create mock models - mock_pause_model = MagicMock(spec=WorkflowPauseModel) - mock_pause_model.id = "pause-123" - mock_pause_model.workflow_run_id = "execution-456" - - # Create entity using from_models - entity = _PrivateWorkflowPauseEntity.from_models( - workflow_pause_model=mock_pause_model, - ) - - # Verify entity creation - assert isinstance(entity, _PrivateWorkflowPauseEntity) - assert entity._pause_model is mock_pause_model - def test_id_property(self): """Test id property returns pause model ID.""" mock_pause_model = MagicMock(spec=WorkflowPauseModel) mock_pause_model.id = "pause-123" - entity = _PrivateWorkflowPauseEntity( - pause_model=mock_pause_model, - ) + entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[]) assert entity.id == "pause-123" @@ -59,9 +39,7 @@ class TestPrivateWorkflowPauseEntity: mock_pause_model = MagicMock(spec=WorkflowPauseModel) mock_pause_model.workflow_run_id = "execution-456" - entity = _PrivateWorkflowPauseEntity( - pause_model=mock_pause_model, - ) + entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[]) assert entity.workflow_execution_id == "execution-456" @@ -72,9 +50,7 @@ class TestPrivateWorkflowPauseEntity: mock_pause_model = MagicMock(spec=WorkflowPauseModel) mock_pause_model.resumed_at = resumed_at - entity = _PrivateWorkflowPauseEntity( - pause_model=mock_pause_model, - ) + entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[]) assert entity.resumed_at == resumed_at @@ -83,9 +59,7 @@ class TestPrivateWorkflowPauseEntity: mock_pause_model = MagicMock(spec=WorkflowPauseModel) mock_pause_model.resumed_at = None - entity = _PrivateWorkflowPauseEntity( - pause_model=mock_pause_model, - ) + entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[]) assert entity.resumed_at is None @@ -98,9 +72,7 @@ class TestPrivateWorkflowPauseEntity: mock_pause_model = MagicMock(spec=WorkflowPauseModel) mock_pause_model.state_object_key = "test-state-key" - entity = _PrivateWorkflowPauseEntity( - pause_model=mock_pause_model, - ) + entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[]) # First call should load from storage result = entity.get_state() @@ -118,9 +90,7 @@ class TestPrivateWorkflowPauseEntity: mock_pause_model = MagicMock(spec=WorkflowPauseModel) mock_pause_model.state_object_key = "test-state-key" - entity = _PrivateWorkflowPauseEntity( - pause_model=mock_pause_model, - ) + entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[]) # First call result1 = entity.get_state() @@ -139,9 +109,7 @@ class TestPrivateWorkflowPauseEntity: mock_pause_model = MagicMock(spec=WorkflowPauseModel) - entity = _PrivateWorkflowPauseEntity( - pause_model=mock_pause_model, - ) + entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[]) # Pre-cache data entity._cached_state = state_data @@ -162,9 +130,7 @@ class TestPrivateWorkflowPauseEntity: mock_pause_model = MagicMock(spec=WorkflowPauseModel) - entity = _PrivateWorkflowPauseEntity( - pause_model=mock_pause_model, - ) + entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[]) result = entity.get_state() diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py index c55c40c5b4..0f62a11684 100644 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py +++ b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py @@ -8,12 +8,13 @@ from typing import Any import pytest from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool +from core.workflow.entities import GraphInitParams from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType from core.workflow.graph import Graph from core.workflow.graph.validation import GraphValidationError from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node +from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from models.enums import UserFrom diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py index 868edf9832..5d958803bc 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py @@ -178,8 +178,7 @@ def test_pause_command(): assert any(isinstance(e, GraphRunStartedEvent) for e in events) pause_events = [e for e in events if isinstance(e, GraphRunPausedEvent)] assert len(pause_events) == 1 - assert pause_events[0].reason == SchedulingPause(message="User requested pause") + assert pause_events[0].reasons == [SchedulingPause(message="User requested pause")] graph_execution = engine.graph_runtime_state.graph_execution - assert graph_execution.paused - assert graph_execution.pause_reason == SchedulingPause(message="User requested pause") + assert graph_execution.pause_reasons == [SchedulingPause(message="User requested pause")] diff --git a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py index 73b35b8e63..0c34676252 100644 --- a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py +++ b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py @@ -6,10 +6,10 @@ from unittest.mock import Mock, patch import pytest from sqlalchemy.orm import Session, sessionmaker -from core.workflow.entities.workflow_pause import WorkflowPauseEntity from core.workflow.enums import WorkflowExecutionStatus from models.workflow import WorkflowPause as WorkflowPauseModel from models.workflow import WorkflowRun +from repositories.entities.workflow_pause import WorkflowPauseEntity from repositories.sqlalchemy_api_workflow_run_repository import ( DifyAPISQLAlchemyWorkflowRunRepository, _PrivateWorkflowPauseEntity, @@ -129,12 +129,14 @@ class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository): workflow_run_id=workflow_run_id, state_owner_user_id=state_owner_user_id, state=state, + pause_reasons=[], ) # Assert assert isinstance(result, _PrivateWorkflowPauseEntity) assert result.id == "pause-123" assert result.workflow_execution_id == workflow_run_id + assert result.get_pause_reasons() == [] # Verify database interactions mock_session.get.assert_called_once_with(WorkflowRun, workflow_run_id) @@ -156,6 +158,7 @@ class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository): workflow_run_id="workflow-run-123", state_owner_user_id="user-123", state='{"test": "state"}', + pause_reasons=[], ) mock_session.get.assert_called_once_with(WorkflowRun, "workflow-run-123") @@ -174,6 +177,7 @@ class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository): workflow_run_id="workflow-run-123", state_owner_user_id="user-123", state='{"test": "state"}', + pause_reasons=[], ) @@ -316,19 +320,10 @@ class TestDeleteWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository): class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository): """Test _PrivateWorkflowPauseEntity class.""" - def test_from_models(self, sample_workflow_pause: Mock): - """Test creating _PrivateWorkflowPauseEntity from models.""" - # Act - entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause) - - # Assert - assert isinstance(entity, _PrivateWorkflowPauseEntity) - assert entity._pause_model == sample_workflow_pause - def test_properties(self, sample_workflow_pause: Mock): """Test entity properties.""" # Arrange - entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause) + entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[]) # Act & Assert assert entity.id == sample_workflow_pause.id @@ -338,7 +333,7 @@ class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository) def test_get_state(self, sample_workflow_pause: Mock): """Test getting state from storage.""" # Arrange - entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause) + entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[]) expected_state = b'{"test": "state"}' with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage: @@ -354,7 +349,7 @@ class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository) def test_get_state_caching(self, sample_workflow_pause: Mock): """Test state caching in get_state method.""" # Arrange - entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause) + entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[]) expected_state = b'{"test": "state"}' with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage: diff --git a/api/tests/unit_tests/services/test_workflow_run_service_pause.py b/api/tests/unit_tests/services/test_workflow_run_service_pause.py index a062d9444e..f45a72927e 100644 --- a/api/tests/unit_tests/services/test_workflow_run_service_pause.py +++ b/api/tests/unit_tests/services/test_workflow_run_service_pause.py @@ -17,6 +17,7 @@ from sqlalchemy import Engine from sqlalchemy.orm import Session, sessionmaker from core.workflow.enums import WorkflowExecutionStatus +from models.workflow import WorkflowPause from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.sqlalchemy_api_workflow_run_repository import _PrivateWorkflowPauseEntity from services.workflow_run_service import ( @@ -63,7 +64,7 @@ class TestDataFactory: **kwargs, ) -> MagicMock: """Create a mock WorkflowPauseModel object.""" - mock_pause = MagicMock() + mock_pause = MagicMock(spec=WorkflowPause) mock_pause.id = id mock_pause.tenant_id = tenant_id mock_pause.app_id = app_id @@ -77,38 +78,15 @@ class TestDataFactory: return mock_pause - @staticmethod - def create_upload_file_mock( - id: str = "file-456", - key: str = "upload_files/test/state.json", - name: str = "state.json", - tenant_id: str = "tenant-456", - **kwargs, - ) -> MagicMock: - """Create a mock UploadFile object.""" - mock_file = MagicMock() - mock_file.id = id - mock_file.key = key - mock_file.name = name - mock_file.tenant_id = tenant_id - - for key, value in kwargs.items(): - setattr(mock_file, key, value) - - return mock_file - @staticmethod def create_pause_entity_mock( pause_model: MagicMock | None = None, - upload_file: MagicMock | None = None, ) -> _PrivateWorkflowPauseEntity: """Create a mock _PrivateWorkflowPauseEntity object.""" if pause_model is None: pause_model = TestDataFactory.create_workflow_pause_mock() - if upload_file is None: - upload_file = TestDataFactory.create_upload_file_mock() - return _PrivateWorkflowPauseEntity.from_models(pause_model, upload_file) + return _PrivateWorkflowPauseEntity(pause_model=pause_model, reason_models=[], human_input_form=[]) class TestWorkflowRunService: