From 02d15ebd5a05ba21abb706a80c7ba90ee91e9b45 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 16 Sep 2025 19:38:10 +0800 Subject: [PATCH] feat(graph_engine): support dumps and loads in GraphExecution --- .../workflow/entities/graph_runtime_state.py | 13 ++ .../graph_engine/domain/graph_execution.py | 142 ++++++++++++++- .../workflow/graph_engine/graph_engine.py | 2 + .../test_graph_execution_serialization.py | 165 ++++++++++++++++++ 4 files changed, 319 insertions(+), 3 deletions(-) create mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py diff --git a/api/core/workflow/entities/graph_runtime_state.py b/api/core/workflow/entities/graph_runtime_state.py index 2b29a36d82..c8fb1de20e 100644 --- a/api/core/workflow/entities/graph_runtime_state.py +++ b/api/core/workflow/entities/graph_runtime_state.py @@ -16,6 +16,7 @@ class GraphRuntimeState(BaseModel): _outputs: dict[str, object] = PrivateAttr(default_factory=dict[str, object]) _node_run_steps: int = PrivateAttr(default=0) _ready_queue_json: str = PrivateAttr() + _graph_execution_json: str = PrivateAttr() def __init__( self, @@ -27,6 +28,7 @@ class GraphRuntimeState(BaseModel): outputs: dict[str, object] | None = None, node_run_steps: int = 0, ready_queue_json: str = "", + graph_execution_json: str = "", **kwargs: object, ): """Initialize the GraphRuntimeState with validation.""" @@ -54,6 +56,7 @@ class GraphRuntimeState(BaseModel): self._node_run_steps = node_run_steps self._ready_queue_json = ready_queue_json + self._graph_execution_json = graph_execution_json @property def variable_pool(self) -> VariablePool: @@ -142,3 +145,13 @@ class GraphRuntimeState(BaseModel): def ready_queue_json(self) -> str: """Get a copy of the ready queue state.""" return self._ready_queue_json + + @property + def graph_execution_json(self) -> str: + """Get a copy of the serialized graph execution state.""" + return self._graph_execution_json + + @graph_execution_json.setter + def graph_execution_json(self, value: str) -> None: + """Set the serialized graph execution state.""" + self._graph_execution_json = value diff --git a/api/core/workflow/graph_engine/domain/graph_execution.py b/api/core/workflow/graph_engine/domain/graph_execution.py index c375b08fe0..5951af1087 100644 --- a/api/core/workflow/graph_engine/domain/graph_execution.py +++ b/api/core/workflow/graph_engine/domain/graph_execution.py @@ -1,12 +1,94 @@ -""" -GraphExecution aggregate root managing the overall graph execution state. -""" +"""GraphExecution aggregate root managing the overall graph execution state.""" + +from __future__ import annotations from dataclasses import dataclass, field +from importlib import import_module +from typing import Literal + +from pydantic import BaseModel, Field + +from core.workflow.enums import NodeState from .node_execution import NodeExecution +class GraphExecutionErrorState(BaseModel): + """Serializable representation of an execution error.""" + + module: str = Field(description="Module containing the exception class") + qualname: str = Field(description="Qualified name of the exception class") + message: str | None = Field(default=None, description="Exception message string") + + +class NodeExecutionState(BaseModel): + """Serializable representation of a node execution entity.""" + + node_id: str + state: NodeState = Field(default=NodeState.UNKNOWN) + retry_count: int = Field(default=0) + execution_id: str | None = Field(default=None) + error: str | None = Field(default=None) + + +class GraphExecutionState(BaseModel): + """Pydantic model describing serialized GraphExecution state.""" + + type: Literal["GraphExecution"] = Field(default="GraphExecution") + version: str = Field(default="1.0") + workflow_id: str + started: bool = Field(default=False) + completed: bool = Field(default=False) + aborted: bool = Field(default=False) + error: GraphExecutionErrorState | None = Field(default=None) + node_executions: list[NodeExecutionState] = Field(default_factory=list) + + +def _serialize_error(error: Exception | None) -> GraphExecutionErrorState | None: + """Convert an exception into its serializable representation.""" + + if error is None: + return None + + return GraphExecutionErrorState( + module=error.__class__.__module__, + qualname=error.__class__.__qualname__, + message=str(error), + ) + + +def _resolve_exception_class(module_name: str, qualname: str) -> type[Exception]: + """Locate an exception class from its module and qualified name.""" + + module = import_module(module_name) + attr: object = module + for part in qualname.split("."): + attr = getattr(attr, part) + + if isinstance(attr, type) and issubclass(attr, Exception): + return attr + + raise TypeError(f"{qualname} in {module_name} is not an Exception subclass") + + +def _deserialize_error(state: GraphExecutionErrorState | None) -> Exception | None: + """Reconstruct an exception instance from serialized data.""" + + if state is None: + return None + + try: + exception_class = _resolve_exception_class(state.module, state.qualname) + if state.message is None: + return exception_class() + return exception_class(state.message) + except Exception: + # Fallback to RuntimeError when reconstruction fails + if state.message is None: + return RuntimeError(state.qualname) + return RuntimeError(state.message) + + @dataclass class GraphExecution: """ @@ -69,3 +151,57 @@ class GraphExecution: if not self.error: return None return str(self.error) + + def dumps(self) -> str: + """Serialize the aggregate state into a JSON string.""" + + node_states = [ + NodeExecutionState( + node_id=node_id, + state=node_execution.state, + retry_count=node_execution.retry_count, + execution_id=node_execution.execution_id, + error=node_execution.error, + ) + for node_id, node_execution in sorted(self.node_executions.items()) + ] + + state = GraphExecutionState( + workflow_id=self.workflow_id, + started=self.started, + completed=self.completed, + aborted=self.aborted, + error=_serialize_error(self.error), + node_executions=node_states, + ) + + return state.model_dump_json() + + def loads(self, data: str) -> None: + """Restore aggregate state from a serialized JSON string.""" + + state = GraphExecutionState.model_validate_json(data) + + if state.type != "GraphExecution": + raise ValueError(f"Invalid serialized data type: {state.type}") + + if state.version != "1.0": + raise ValueError(f"Unsupported serialized version: {state.version}") + + if self.workflow_id != state.workflow_id: + raise ValueError("Serialized workflow_id does not match aggregate identity") + + self.started = state.started + self.completed = state.completed + self.aborted = state.aborted + self.error = _deserialize_error(state.error) + self.node_executions = { + item.node_id: NodeExecution( + node_id=item.node_id, + state=item.state, + retry_count=item.retry_count, + execution_id=item.execution_id, + error=item.error, + ) + for item in state.node_executions + } diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index b0daf694ce..1a136d4365 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -68,6 +68,8 @@ class GraphEngine: # Graph execution tracks the overall execution state self._graph_execution = GraphExecution(workflow_id=workflow_id) + if graph_runtime_state.graph_execution_json != "": + self._graph_execution.loads(graph_runtime_state.graph_execution_json) # === Core Dependencies === # Graph structure and configuration diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py new file mode 100644 index 0000000000..2388e4d57b --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py @@ -0,0 +1,165 @@ +"""Unit tests for GraphExecution serialization helpers.""" + +from __future__ import annotations + +import json +from unittest.mock import MagicMock + +from core.workflow.entities import GraphRuntimeState +from core.workflow.enums import NodeExecutionType, NodeState +from core.workflow.graph_engine import GraphEngine +from core.workflow.graph_engine.domain import GraphExecution + + +class CustomGraphExecutionError(Exception): + """Custom exception used to verify error serialization.""" + + +def test_graph_execution_serialization_round_trip() -> None: + """GraphExecution serialization restores full aggregate state.""" + # Arrange + execution = GraphExecution(workflow_id="wf-1") + execution.start() + node_a = execution.get_or_create_node_execution("node-a") + node_a.mark_started(execution_id="exec-1") + node_a.increment_retry() + node_a.mark_failed("boom") + node_b = execution.get_or_create_node_execution("node-b") + node_b.mark_skipped() + execution.fail(CustomGraphExecutionError("serialization failure")) + + # Act + serialized = execution.dumps() + payload = json.loads(serialized) + restored = GraphExecution(workflow_id="wf-1") + restored.loads(serialized) + + # Assert + assert payload["type"] == "GraphExecution" + assert payload["version"] == "1.0" + assert restored.workflow_id == "wf-1" + assert restored.started is True + assert restored.completed is True + assert restored.aborted is False + assert isinstance(restored.error, CustomGraphExecutionError) + assert str(restored.error) == "serialization failure" + assert set(restored.node_executions) == {"node-a", "node-b"} + restored_node_a = restored.node_executions["node-a"] + assert restored_node_a.state is NodeState.TAKEN + assert restored_node_a.retry_count == 1 + assert restored_node_a.execution_id == "exec-1" + assert restored_node_a.error == "boom" + restored_node_b = restored.node_executions["node-b"] + assert restored_node_b.state is NodeState.SKIPPED + assert restored_node_b.retry_count == 0 + assert restored_node_b.execution_id is None + assert restored_node_b.error is None + + +def test_graph_execution_loads_replaces_existing_state() -> None: + """loads replaces existing runtime data with serialized snapshot.""" + # Arrange + source = GraphExecution(workflow_id="wf-2") + source.start() + source_node = source.get_or_create_node_execution("node-source") + source_node.mark_taken() + serialized = source.dumps() + + target = GraphExecution(workflow_id="wf-2") + target.start() + target.abort("pre-existing abort") + temp_node = target.get_or_create_node_execution("node-temp") + temp_node.increment_retry() + temp_node.mark_failed("temp error") + + # Act + target.loads(serialized) + + # Assert + assert target.aborted is False + assert target.error is None + assert target.started is True + assert target.completed is False + assert set(target.node_executions) == {"node-source"} + restored_node = target.node_executions["node-source"] + assert restored_node.state is NodeState.TAKEN + assert restored_node.retry_count == 0 + assert restored_node.execution_id is None + assert restored_node.error is None + + +def test_graph_engine_initializes_from_serialized_execution(monkeypatch) -> None: + """GraphEngine restores GraphExecution state from runtime snapshot on init.""" + + # Arrange serialized execution state + execution = GraphExecution(workflow_id="wf-init") + execution.start() + node_state = execution.get_or_create_node_execution("serialized-node") + node_state.mark_taken() + execution.complete() + serialized = execution.dumps() + + runtime_state = GraphRuntimeState( + variable_pool=MagicMock(), + start_at=0.0, + graph_execution_json=serialized, + ) + + class DummyNode: + def __init__(self, graph_runtime_state: GraphRuntimeState) -> None: + self.graph_runtime_state = graph_runtime_state + self.execution_type = NodeExecutionType.EXECUTABLE + self.id = "dummy-node" + self.state = NodeState.UNKNOWN + self.title = "dummy" + + class DummyGraph: + def __init__(self, graph_runtime_state: GraphRuntimeState) -> None: + self.nodes = {"dummy-node": DummyNode(graph_runtime_state)} + self.edges: dict[str, object] = {} + self.root_node = self.nodes["dummy-node"] + + def get_incoming_edges(self, node_id: str): # pragma: no cover - not exercised + return [] + + def get_outgoing_edges(self, node_id: str): # pragma: no cover - not exercised + return [] + + dummy_graph = DummyGraph(runtime_state) + + def _stub(*_args, **_kwargs): + return MagicMock() + + monkeypatch.setattr("core.workflow.graph_engine.graph_engine.GraphStateManager", _stub) + monkeypatch.setattr("core.workflow.graph_engine.graph_engine.ResponseStreamCoordinator", _stub) + monkeypatch.setattr("core.workflow.graph_engine.graph_engine.EventManager", _stub) + monkeypatch.setattr("core.workflow.graph_engine.graph_engine.ErrorHandler", _stub) + monkeypatch.setattr("core.workflow.graph_engine.graph_engine.SkipPropagator", _stub) + monkeypatch.setattr("core.workflow.graph_engine.graph_engine.EdgeProcessor", _stub) + monkeypatch.setattr("core.workflow.graph_engine.graph_engine.EventHandler", _stub) + command_processor = MagicMock() + command_processor.register_handler = MagicMock() + monkeypatch.setattr( + "core.workflow.graph_engine.graph_engine.CommandProcessor", + lambda *_args, **_kwargs: command_processor, + ) + monkeypatch.setattr("core.workflow.graph_engine.graph_engine.AbortCommandHandler", _stub) + monkeypatch.setattr("core.workflow.graph_engine.graph_engine.WorkerPool", _stub) + monkeypatch.setattr("core.workflow.graph_engine.graph_engine.ExecutionCoordinator", _stub) + monkeypatch.setattr("core.workflow.graph_engine.graph_engine.Dispatcher", _stub) + + # Act + engine = GraphEngine( + workflow_id="wf-init", + graph=dummy_graph, # type: ignore[arg-type] + graph_runtime_state=runtime_state, + command_channel=MagicMock(), + ) + + # Assert + assert engine._graph_execution.started is True + assert engine._graph_execution.completed is True + assert set(engine._graph_execution.node_executions) == {"serialized-node"} + restored_node = engine._graph_execution.node_executions["serialized-node"] + assert restored_node.state is NodeState.TAKEN + assert restored_node.retry_count == 0