From fd255e81e13299ac2caf86eab432d899282c73dc Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Wed, 12 Nov 2025 17:00:02 +0800 Subject: [PATCH] feat(api): Introduce `WorkflowResumptionContext` for pause state management (#28122) Certain metadata (including but not limited to `InvokeFrom`, `call_depth`, and `streaming`) is required when resuming a paused workflow. However, these fields are not part of `GraphRuntimeState` and were not saved in the previous implementation of `PauseStatePersistenceLayer`. This commit addresses this limitation by introducing a `WorkflowResumptionContext` model that wraps both the `*GenerateEntity` and `GraphRuntimeState`. This approach provides: - A structured container for all necessary resumption data - Better separation of concerns between execution state and persistence - Enhanced extensibility for future metadata additions - Clearer naming that distinguishes from `GraphRuntimeState` The `WorkflowResumptionContext` model makes extending the pause state easier while maintaining backward compatibility and proper version management for the entire execution state ecosystem. Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/core/app/entities/app_invoke_entities.py | 5 + .../app/layers/pause_state_persist_layer.py | 68 +++++++- .../layers/test_pause_state_persist_layer.py | 69 ++++++++- .../layers/test_pause_state_persist_layer.py | 146 +++++++++++++++++- 4 files changed, 273 insertions(+), 15 deletions(-) diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 894f80a670..b49d4d6511 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -104,6 +104,11 @@ class AppGenerateEntity(BaseModel): inputs: Mapping[str, Any] files: Sequence[File] + + # Unique identifier of the user initiating the execution. + # This corresponds to `Account.id` for platform users or `EndUser.id` for end users. + # + # Note: The `user_id` field does not indicate whether the user is a platform user or an end user. user_id: str # extras diff --git a/api/core/app/layers/pause_state_persist_layer.py b/api/core/app/layers/pause_state_persist_layer.py index 3dee75c082..7e79c22c4d 100644 --- a/api/core/app/layers/pause_state_persist_layer.py +++ b/api/core/app/layers/pause_state_persist_layer.py @@ -1,15 +1,64 @@ +from typing import Annotated, Literal, Self, TypeAlias + +from pydantic import BaseModel, Field from sqlalchemy import Engine from sqlalchemy.orm import sessionmaker +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity from core.workflow.graph_engine.layers.base import GraphEngineLayer from core.workflow.graph_events.base import GraphEngineEvent from core.workflow.graph_events.graph import GraphRunPausedEvent +from models.model import AppMode from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.factory import DifyAPIRepositoryFactory +# Wrapper types for `WorkflowAppGenerateEntity` and +# `AdvancedChatAppGenerateEntity`. These wrappers enable type discrimination +# and correct reconstruction of the entity field during (de)serialization. +class _WorkflowGenerateEntityWrapper(BaseModel): + type: Literal[AppMode.WORKFLOW] = AppMode.WORKFLOW + entity: WorkflowAppGenerateEntity + + +class _AdvancedChatAppGenerateEntityWrapper(BaseModel): + type: Literal[AppMode.ADVANCED_CHAT] = AppMode.ADVANCED_CHAT + entity: AdvancedChatAppGenerateEntity + + +_GenerateEntityUnion: TypeAlias = Annotated[ + _WorkflowGenerateEntityWrapper | _AdvancedChatAppGenerateEntityWrapper, + Field(discriminator="type"), +] + + +class WorkflowResumptionContext(BaseModel): + """WorkflowResumptionContext captures all state necessary for resumption.""" + + version: Literal["1"] = "1" + + # Only workflow / chatflow could be paused. + generate_entity: _GenerateEntityUnion + serialized_graph_runtime_state: str + + def dumps(self) -> str: + return self.model_dump_json() + + @classmethod + def loads(cls, value: str) -> Self: + return cls.model_validate_json(value) + + def get_generate_entity(self) -> WorkflowAppGenerateEntity | AdvancedChatAppGenerateEntity: + return self.generate_entity.entity + + class PauseStatePersistenceLayer(GraphEngineLayer): - def __init__(self, session_factory: Engine | sessionmaker, state_owner_user_id: str): + def __init__( + self, + session_factory: Engine | sessionmaker, + generate_entity: WorkflowAppGenerateEntity | AdvancedChatAppGenerateEntity, + state_owner_user_id: str, + ): """Create a PauseStatePersistenceLayer. The `state_owner_user_id` is used when creating state file for pause. @@ -19,6 +68,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer): session_factory = sessionmaker(session_factory) self._session_maker = session_factory self._state_owner_user_id = state_owner_user_id + self._generate_entity = generate_entity def _get_repo(self) -> APIWorkflowRunRepository: return DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_maker) @@ -49,13 +99,27 @@ class PauseStatePersistenceLayer(GraphEngineLayer): return assert self.graph_runtime_state is not None + + entity_wrapper: _GenerateEntityUnion + if isinstance(self._generate_entity, WorkflowAppGenerateEntity): + entity_wrapper = _WorkflowGenerateEntityWrapper(entity=self._generate_entity) + elif isinstance(self._generate_entity, AdvancedChatAppGenerateEntity): + entity_wrapper = _AdvancedChatAppGenerateEntityWrapper(entity=self._generate_entity) + else: + raise AssertionError(f"unknown entity type: type={type(self._generate_entity)}") + + state = WorkflowResumptionContext( + serialized_graph_runtime_state=self.graph_runtime_state.dumps(), + generate_entity=entity_wrapper, + ) + workflow_run_id: str | None = self.graph_runtime_state.system_variable.workflow_execution_id assert workflow_run_id is not None repo = self._get_repo() repo.create_workflow_pause( workflow_run_id=workflow_run_id, state_owner_user_id=self._state_owner_user_id, - state=self.graph_runtime_state.dumps(), + state=state.dumps(), ) def on_graph_end(self, error: Exception | None) -> None: diff --git a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py index 133e600ca0..bec3517d66 100644 --- a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py @@ -25,7 +25,12 @@ import pytest from sqlalchemy import Engine, delete, select from sqlalchemy.orm import Session -from core.app.layers.pause_state_persist_layer import PauseStatePersistenceLayer +from core.app.app_config.entities import WorkflowUIBasedAppConfig +from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity +from core.app.layers.pause_state_persist_layer import ( + PauseStatePersistenceLayer, + WorkflowResumptionContext, +) from core.model_runtime.entities.llm_entities import LLMUsage from core.workflow.entities.pause_reason import SchedulingPause from core.workflow.enums import WorkflowExecutionStatus @@ -39,7 +44,7 @@ from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from models import Account from models import WorkflowPause as WorkflowPauseModel -from models.model import UploadFile +from models.model import AppMode, UploadFile from models.workflow import Workflow, WorkflowRun from services.file_service import FileService from services.workflow_run_service import WorkflowRunService @@ -226,11 +231,39 @@ class TestPauseStatePersistenceLayerTestContainers: return ReadOnlyGraphRuntimeStateWrapper(graph_runtime_state) + def _create_generate_entity( + self, + workflow_execution_id: str | None = None, + user_id: str | None = None, + workflow_id: str | None = None, + ) -> WorkflowAppGenerateEntity: + execution_id = workflow_execution_id or getattr(self, "test_workflow_run_id", str(uuid.uuid4())) + wf_id = workflow_id or getattr(self, "test_workflow_id", str(uuid.uuid4())) + tenant_id = getattr(self, "test_tenant_id", "tenant-123") + app_id = getattr(self, "test_app_id", "app-123") + app_config = WorkflowUIBasedAppConfig( + tenant_id=str(tenant_id), + app_id=str(app_id), + app_mode=AppMode.WORKFLOW, + workflow_id=str(wf_id), + ) + return WorkflowAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + inputs={}, + files=[], + user_id=user_id or getattr(self, "test_user_id", str(uuid.uuid4())), + stream=False, + invoke_from=InvokeFrom.DEBUGGER, + workflow_execution_id=execution_id, + ) + def _create_pause_state_persistence_layer( self, workflow_run: WorkflowRun | None = None, workflow: Workflow | None = None, state_owner_user_id: str | None = None, + generate_entity: WorkflowAppGenerateEntity | None = None, ) -> PauseStatePersistenceLayer: """Create PauseStatePersistenceLayer with real dependencies.""" owner_id = state_owner_user_id @@ -244,10 +277,23 @@ class TestPauseStatePersistenceLayerTestContainers: assert owner_id is not None owner_id = str(owner_id) + workflow_execution_id = ( + workflow_run.id if workflow_run is not None else getattr(self, "test_workflow_run_id", None) + ) + assert workflow_execution_id is not None + workflow_id = workflow.id if workflow is not None else getattr(self, "test_workflow_id", None) + assert workflow_id is not None + entity_user_id = getattr(self, "test_user_id", owner_id) + entity = generate_entity or self._create_generate_entity( + workflow_execution_id=str(workflow_execution_id), + user_id=entity_user_id, + workflow_id=str(workflow_id), + ) return PauseStatePersistenceLayer( session_factory=self.session.get_bind(), state_owner_user_id=owner_id, + generate_entity=entity, ) def test_complete_pause_flow_with_real_dependencies(self, db_session_with_containers): @@ -297,10 +343,15 @@ class TestPauseStatePersistenceLayerTestContainers: assert pause_model.resumed_at is None storage_content = storage.load(pause_model.state_object_key).decode() + resumption_context = WorkflowResumptionContext.loads(storage_content) + assert resumption_context.version == "1" + assert resumption_context.serialized_graph_runtime_state == graph_runtime_state.dumps() expected_state = json.loads(graph_runtime_state.dumps()) - actual_state = json.loads(storage_content) - + actual_state = json.loads(resumption_context.serialized_graph_runtime_state) assert actual_state == expected_state + persisted_entity = resumption_context.get_generate_entity() + assert isinstance(persisted_entity, WorkflowAppGenerateEntity) + assert persisted_entity.workflow_execution_id == self.test_workflow_run_id def test_state_persistence_and_retrieval(self, db_session_with_containers): """Test that pause state can be persisted and retrieved correctly.""" @@ -341,13 +392,15 @@ class TestPauseStatePersistenceLayerTestContainers: assert pause_entity.workflow_execution_id == self.test_workflow_run_id state_bytes = pause_entity.get_state() - retrieved_state = json.loads(state_bytes.decode()) + resumption_context = WorkflowResumptionContext.loads(state_bytes.decode()) + retrieved_state = json.loads(resumption_context.serialized_graph_runtime_state) expected_state = json.loads(graph_runtime_state.dumps()) assert retrieved_state == expected_state assert retrieved_state["outputs"] == complex_outputs assert retrieved_state["total_tokens"] == 250 assert retrieved_state["node_run_steps"] == 10 + assert resumption_context.get_generate_entity().workflow_execution_id == self.test_workflow_run_id def test_database_transaction_handling(self, db_session_with_containers): """Test that database transactions are handled correctly.""" @@ -410,7 +463,9 @@ class TestPauseStatePersistenceLayerTestContainers: # Verify content in storage storage_content = storage.load(pause_model.state_object_key).decode() - assert storage_content == graph_runtime_state.dumps() + resumption_context = WorkflowResumptionContext.loads(storage_content) + assert resumption_context.serialized_graph_runtime_state == graph_runtime_state.dumps() + assert resumption_context.get_generate_entity().workflow_execution_id == self.test_workflow_run_id def test_workflow_with_different_creators(self, db_session_with_containers): """Test pause state with workflows created by different users.""" @@ -474,6 +529,8 @@ class TestPauseStatePersistenceLayerTestContainers: # Verify the state owner is the workflow creator pause_entity = self.workflow_run_service._workflow_run_repo.get_workflow_pause(different_workflow_run.id) assert pause_entity is not None + resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode()) + assert resumption_context.get_generate_entity().workflow_execution_id == different_workflow_run.id def test_layer_ignores_non_pause_events(self, db_session_with_containers): """Test that layer ignores non-pause events.""" diff --git a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py index 3bd967cbc0..807f5e0fa5 100644 --- a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py @@ -4,7 +4,14 @@ from unittest.mock import Mock import pytest -from core.app.layers.pause_state_persist_layer import PauseStatePersistenceLayer +from core.app.app_config.entities import WorkflowUIBasedAppConfig +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity +from core.app.layers.pause_state_persist_layer import ( + PauseStatePersistenceLayer, + WorkflowResumptionContext, + _AdvancedChatAppGenerateEntityWrapper, + _WorkflowGenerateEntityWrapper, +) from core.variables.segments import Segment from core.workflow.entities.pause_reason import SchedulingPause from core.workflow.graph_engine.entities.commands import GraphEngineCommand @@ -15,6 +22,7 @@ from core.workflow.graph_events.graph import ( GraphRunSucceededEvent, ) from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyVariablePool +from models.model import AppMode from repositories.factory import DifyAPIRepositoryFactory @@ -170,6 +178,25 @@ class MockCommandChannel: class TestPauseStatePersistenceLayer: """Unit tests for PauseStatePersistenceLayer.""" + @staticmethod + def _create_generate_entity(workflow_execution_id: str = "run-123") -> WorkflowAppGenerateEntity: + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant-123", + app_id="app-123", + app_mode=AppMode.WORKFLOW, + workflow_id="workflow-123", + ) + return WorkflowAppGenerateEntity( + task_id="task-123", + app_config=app_config, + inputs={}, + files=[], + user_id="user-123", + stream=False, + invoke_from=InvokeFrom.DEBUGGER, + workflow_execution_id=workflow_execution_id, + ) + def test_init_with_dependency_injection(self): session_factory = Mock(name="session_factory") state_owner_user_id = "user-123" @@ -177,6 +204,7 @@ class TestPauseStatePersistenceLayer: layer = PauseStatePersistenceLayer( session_factory=session_factory, state_owner_user_id=state_owner_user_id, + generate_entity=self._create_generate_entity(), ) assert layer._session_maker is session_factory @@ -186,7 +214,11 @@ class TestPauseStatePersistenceLayer: def test_initialize_sets_dependencies(self): session_factory = Mock(name="session_factory") - layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner") + layer = PauseStatePersistenceLayer( + session_factory=session_factory, + state_owner_user_id="owner", + generate_entity=self._create_generate_entity(), + ) graph_runtime_state = MockReadOnlyGraphRuntimeState() command_channel = MockCommandChannel() @@ -198,7 +230,12 @@ class TestPauseStatePersistenceLayer: def test_on_event_with_graph_run_paused_event(self, monkeypatch: pytest.MonkeyPatch): session_factory = Mock(name="session_factory") - layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123") + generate_entity = self._create_generate_entity(workflow_execution_id="run-123") + layer = PauseStatePersistenceLayer( + session_factory=session_factory, + state_owner_user_id="owner-123", + generate_entity=generate_entity, + ) mock_repo = Mock() mock_factory = Mock(return_value=mock_repo) @@ -221,12 +258,20 @@ class TestPauseStatePersistenceLayer: mock_repo.create_workflow_pause.assert_called_once_with( workflow_run_id="run-123", state_owner_user_id="owner-123", - state=expected_state, + state=mock_repo.create_workflow_pause.call_args.kwargs["state"], ) + serialized_state = mock_repo.create_workflow_pause.call_args.kwargs["state"] + resumption_context = WorkflowResumptionContext.loads(serialized_state) + assert resumption_context.serialized_graph_runtime_state == expected_state + assert resumption_context.get_generate_entity().model_dump() == generate_entity.model_dump() def test_on_event_ignores_non_paused_events(self, monkeypatch: pytest.MonkeyPatch): session_factory = Mock(name="session_factory") - layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123") + layer = PauseStatePersistenceLayer( + session_factory=session_factory, + state_owner_user_id="owner-123", + generate_entity=self._create_generate_entity(), + ) mock_repo = Mock() mock_factory = Mock(return_value=mock_repo) @@ -250,7 +295,11 @@ class TestPauseStatePersistenceLayer: def test_on_event_raises_attribute_error_when_graph_runtime_state_is_none(self): session_factory = Mock(name="session_factory") - layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123") + layer = PauseStatePersistenceLayer( + session_factory=session_factory, + state_owner_user_id="owner-123", + generate_entity=self._create_generate_entity(), + ) event = TestDataFactory.create_graph_run_paused_event() @@ -259,7 +308,11 @@ class TestPauseStatePersistenceLayer: def test_on_event_asserts_when_workflow_execution_id_missing(self, monkeypatch: pytest.MonkeyPatch): session_factory = Mock(name="session_factory") - layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123") + layer = PauseStatePersistenceLayer( + session_factory=session_factory, + state_owner_user_id="owner-123", + generate_entity=self._create_generate_entity(), + ) mock_repo = Mock() mock_factory = Mock(return_value=mock_repo) @@ -276,3 +329,82 @@ class TestPauseStatePersistenceLayer: mock_factory.assert_not_called() mock_repo.create_workflow_pause.assert_not_called() + + +def _build_workflow_generate_entity_for_roundtrip() -> WorkflowResumptionContext: + """Create a WorkflowAppGenerateEntity with realistic data for WorkflowResumptionContext tests.""" + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant-roundtrip", + app_id="app-roundtrip", + app_mode=AppMode.WORKFLOW, + workflow_id="workflow-roundtrip", + ) + serialized_state = json.dumps({"state": "workflow"}) + + return WorkflowResumptionContext( + serialized_graph_runtime_state=serialized_state, + generate_entity=_WorkflowGenerateEntityWrapper( + entity=WorkflowAppGenerateEntity( + task_id="workflow-task", + app_config=app_config, + inputs={"input_key": "input_value"}, + files=[], + user_id="user-roundtrip", + stream=False, + invoke_from=InvokeFrom.DEBUGGER, + workflow_execution_id="workflow-exec-roundtrip", + ) + ), + ) + + +def _build_advanced_chat_generate_entity_for_roundtrip() -> WorkflowResumptionContext: + """Create an AdvancedChatAppGenerateEntity with realistic data for WorkflowResumptionContext tests.""" + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant-advanced", + app_id="app-advanced", + app_mode=AppMode.ADVANCED_CHAT, + workflow_id="workflow-advanced", + ) + serialized_state = json.dumps({"state": "workflow"}) + + return WorkflowResumptionContext( + serialized_graph_runtime_state=serialized_state, + generate_entity=_AdvancedChatAppGenerateEntityWrapper( + entity=AdvancedChatAppGenerateEntity( + task_id="advanced-task", + app_config=app_config, + inputs={"topic": "roundtrip"}, + files=[], + user_id="advanced-user", + stream=False, + invoke_from=InvokeFrom.DEBUGGER, + workflow_run_id="advanced-run-id", + query="Explain serialization behavior", + ) + ), + ) + + +@pytest.mark.parametrize( + "state", + [ + pytest.param( + _build_advanced_chat_generate_entity_for_roundtrip(), + id="advanced_chat", + ), + pytest.param( + _build_workflow_generate_entity_for_roundtrip(), + id="workflow", + ), + ], +) +def test_workflow_resumption_context_dumps_loads_roundtrip(state: WorkflowResumptionContext): + """WorkflowResumptionContext roundtrip preserves workflow generate entity metadata.""" + dumped = state.dumps() + loaded = WorkflowResumptionContext.loads(dumped) + + assert loaded == state + assert loaded.serialized_graph_runtime_state == state.serialized_graph_runtime_state + restored_entity = loaded.get_generate_entity() + assert isinstance(restored_entity, type(state.generate_entity.entity))