feat(graph_engine): support dumps and loads in GraphExecution

This commit is contained in:
-LAN- 2025-09-16 19:38:10 +08:00
parent 976b3b5e83
commit 02d15ebd5a
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF
4 changed files with 319 additions and 3 deletions

View File

@ -16,6 +16,7 @@ class GraphRuntimeState(BaseModel):
_outputs: dict[str, object] = PrivateAttr(default_factory=dict[str, object])
_node_run_steps: int = PrivateAttr(default=0)
_ready_queue_json: str = PrivateAttr()
_graph_execution_json: str = PrivateAttr()
def __init__(
self,
@ -27,6 +28,7 @@ class GraphRuntimeState(BaseModel):
outputs: dict[str, object] | None = None,
node_run_steps: int = 0,
ready_queue_json: str = "",
graph_execution_json: str = "",
**kwargs: object,
):
"""Initialize the GraphRuntimeState with validation."""
@ -54,6 +56,7 @@ class GraphRuntimeState(BaseModel):
self._node_run_steps = node_run_steps
self._ready_queue_json = ready_queue_json
self._graph_execution_json = graph_execution_json
@property
def variable_pool(self) -> VariablePool:
@ -142,3 +145,13 @@ class GraphRuntimeState(BaseModel):
def ready_queue_json(self) -> str:
"""Get a copy of the ready queue state."""
return self._ready_queue_json
@property
def graph_execution_json(self) -> str:
"""Get a copy of the serialized graph execution state."""
return self._graph_execution_json
@graph_execution_json.setter
def graph_execution_json(self, value: str) -> None:
"""Set the serialized graph execution state."""
self._graph_execution_json = value

View File

@ -1,12 +1,94 @@
"""
GraphExecution aggregate root managing the overall graph execution state.
"""
"""GraphExecution aggregate root managing the overall graph execution state."""
from __future__ import annotations
from dataclasses import dataclass, field
from importlib import import_module
from typing import Literal
from pydantic import BaseModel, Field
from core.workflow.enums import NodeState
from .node_execution import NodeExecution
class GraphExecutionErrorState(BaseModel):
"""Serializable representation of an execution error."""
module: str = Field(description="Module containing the exception class")
qualname: str = Field(description="Qualified name of the exception class")
message: str | None = Field(default=None, description="Exception message string")
class NodeExecutionState(BaseModel):
"""Serializable representation of a node execution entity."""
node_id: str
state: NodeState = Field(default=NodeState.UNKNOWN)
retry_count: int = Field(default=0)
execution_id: str | None = Field(default=None)
error: str | None = Field(default=None)
class GraphExecutionState(BaseModel):
"""Pydantic model describing serialized GraphExecution state."""
type: Literal["GraphExecution"] = Field(default="GraphExecution")
version: str = Field(default="1.0")
workflow_id: str
started: bool = Field(default=False)
completed: bool = Field(default=False)
aborted: bool = Field(default=False)
error: GraphExecutionErrorState | None = Field(default=None)
node_executions: list[NodeExecutionState] = Field(default_factory=list)
def _serialize_error(error: Exception | None) -> GraphExecutionErrorState | None:
"""Convert an exception into its serializable representation."""
if error is None:
return None
return GraphExecutionErrorState(
module=error.__class__.__module__,
qualname=error.__class__.__qualname__,
message=str(error),
)
def _resolve_exception_class(module_name: str, qualname: str) -> type[Exception]:
"""Locate an exception class from its module and qualified name."""
module = import_module(module_name)
attr: object = module
for part in qualname.split("."):
attr = getattr(attr, part)
if isinstance(attr, type) and issubclass(attr, Exception):
return attr
raise TypeError(f"{qualname} in {module_name} is not an Exception subclass")
def _deserialize_error(state: GraphExecutionErrorState | None) -> Exception | None:
"""Reconstruct an exception instance from serialized data."""
if state is None:
return None
try:
exception_class = _resolve_exception_class(state.module, state.qualname)
if state.message is None:
return exception_class()
return exception_class(state.message)
except Exception:
# Fallback to RuntimeError when reconstruction fails
if state.message is None:
return RuntimeError(state.qualname)
return RuntimeError(state.message)
@dataclass
class GraphExecution:
"""
@ -69,3 +151,57 @@ class GraphExecution:
if not self.error:
return None
return str(self.error)
def dumps(self) -> str:
"""Serialize the aggregate state into a JSON string."""
node_states = [
NodeExecutionState(
node_id=node_id,
state=node_execution.state,
retry_count=node_execution.retry_count,
execution_id=node_execution.execution_id,
error=node_execution.error,
)
for node_id, node_execution in sorted(self.node_executions.items())
]
state = GraphExecutionState(
workflow_id=self.workflow_id,
started=self.started,
completed=self.completed,
aborted=self.aborted,
error=_serialize_error(self.error),
node_executions=node_states,
)
return state.model_dump_json()
def loads(self, data: str) -> None:
"""Restore aggregate state from a serialized JSON string."""
state = GraphExecutionState.model_validate_json(data)
if state.type != "GraphExecution":
raise ValueError(f"Invalid serialized data type: {state.type}")
if state.version != "1.0":
raise ValueError(f"Unsupported serialized version: {state.version}")
if self.workflow_id != state.workflow_id:
raise ValueError("Serialized workflow_id does not match aggregate identity")
self.started = state.started
self.completed = state.completed
self.aborted = state.aborted
self.error = _deserialize_error(state.error)
self.node_executions = {
item.node_id: NodeExecution(
node_id=item.node_id,
state=item.state,
retry_count=item.retry_count,
execution_id=item.execution_id,
error=item.error,
)
for item in state.node_executions
}

View File

@ -68,6 +68,8 @@ class GraphEngine:
# Graph execution tracks the overall execution state
self._graph_execution = GraphExecution(workflow_id=workflow_id)
if graph_runtime_state.graph_execution_json != "":
self._graph_execution.loads(graph_runtime_state.graph_execution_json)
# === Core Dependencies ===
# Graph structure and configuration

View File

@ -0,0 +1,165 @@
"""Unit tests for GraphExecution serialization helpers."""
from __future__ import annotations
import json
from unittest.mock import MagicMock
from core.workflow.entities import GraphRuntimeState
from core.workflow.enums import NodeExecutionType, NodeState
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.domain import GraphExecution
class CustomGraphExecutionError(Exception):
"""Custom exception used to verify error serialization."""
def test_graph_execution_serialization_round_trip() -> None:
"""GraphExecution serialization restores full aggregate state."""
# Arrange
execution = GraphExecution(workflow_id="wf-1")
execution.start()
node_a = execution.get_or_create_node_execution("node-a")
node_a.mark_started(execution_id="exec-1")
node_a.increment_retry()
node_a.mark_failed("boom")
node_b = execution.get_or_create_node_execution("node-b")
node_b.mark_skipped()
execution.fail(CustomGraphExecutionError("serialization failure"))
# Act
serialized = execution.dumps()
payload = json.loads(serialized)
restored = GraphExecution(workflow_id="wf-1")
restored.loads(serialized)
# Assert
assert payload["type"] == "GraphExecution"
assert payload["version"] == "1.0"
assert restored.workflow_id == "wf-1"
assert restored.started is True
assert restored.completed is True
assert restored.aborted is False
assert isinstance(restored.error, CustomGraphExecutionError)
assert str(restored.error) == "serialization failure"
assert set(restored.node_executions) == {"node-a", "node-b"}
restored_node_a = restored.node_executions["node-a"]
assert restored_node_a.state is NodeState.TAKEN
assert restored_node_a.retry_count == 1
assert restored_node_a.execution_id == "exec-1"
assert restored_node_a.error == "boom"
restored_node_b = restored.node_executions["node-b"]
assert restored_node_b.state is NodeState.SKIPPED
assert restored_node_b.retry_count == 0
assert restored_node_b.execution_id is None
assert restored_node_b.error is None
def test_graph_execution_loads_replaces_existing_state() -> None:
"""loads replaces existing runtime data with serialized snapshot."""
# Arrange
source = GraphExecution(workflow_id="wf-2")
source.start()
source_node = source.get_or_create_node_execution("node-source")
source_node.mark_taken()
serialized = source.dumps()
target = GraphExecution(workflow_id="wf-2")
target.start()
target.abort("pre-existing abort")
temp_node = target.get_or_create_node_execution("node-temp")
temp_node.increment_retry()
temp_node.mark_failed("temp error")
# Act
target.loads(serialized)
# Assert
assert target.aborted is False
assert target.error is None
assert target.started is True
assert target.completed is False
assert set(target.node_executions) == {"node-source"}
restored_node = target.node_executions["node-source"]
assert restored_node.state is NodeState.TAKEN
assert restored_node.retry_count == 0
assert restored_node.execution_id is None
assert restored_node.error is None
def test_graph_engine_initializes_from_serialized_execution(monkeypatch) -> None:
"""GraphEngine restores GraphExecution state from runtime snapshot on init."""
# Arrange serialized execution state
execution = GraphExecution(workflow_id="wf-init")
execution.start()
node_state = execution.get_or_create_node_execution("serialized-node")
node_state.mark_taken()
execution.complete()
serialized = execution.dumps()
runtime_state = GraphRuntimeState(
variable_pool=MagicMock(),
start_at=0.0,
graph_execution_json=serialized,
)
class DummyNode:
def __init__(self, graph_runtime_state: GraphRuntimeState) -> None:
self.graph_runtime_state = graph_runtime_state
self.execution_type = NodeExecutionType.EXECUTABLE
self.id = "dummy-node"
self.state = NodeState.UNKNOWN
self.title = "dummy"
class DummyGraph:
def __init__(self, graph_runtime_state: GraphRuntimeState) -> None:
self.nodes = {"dummy-node": DummyNode(graph_runtime_state)}
self.edges: dict[str, object] = {}
self.root_node = self.nodes["dummy-node"]
def get_incoming_edges(self, node_id: str): # pragma: no cover - not exercised
return []
def get_outgoing_edges(self, node_id: str): # pragma: no cover - not exercised
return []
dummy_graph = DummyGraph(runtime_state)
def _stub(*_args, **_kwargs):
return MagicMock()
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.GraphStateManager", _stub)
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.ResponseStreamCoordinator", _stub)
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.EventManager", _stub)
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.ErrorHandler", _stub)
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.SkipPropagator", _stub)
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.EdgeProcessor", _stub)
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.EventHandler", _stub)
command_processor = MagicMock()
command_processor.register_handler = MagicMock()
monkeypatch.setattr(
"core.workflow.graph_engine.graph_engine.CommandProcessor",
lambda *_args, **_kwargs: command_processor,
)
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.AbortCommandHandler", _stub)
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.WorkerPool", _stub)
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.ExecutionCoordinator", _stub)
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.Dispatcher", _stub)
# Act
engine = GraphEngine(
workflow_id="wf-init",
graph=dummy_graph, # type: ignore[arg-type]
graph_runtime_state=runtime_state,
command_channel=MagicMock(),
)
# Assert
assert engine._graph_execution.started is True
assert engine._graph_execution.completed is True
assert set(engine._graph_execution.node_executions) == {"serialized-node"}
restored_node = engine._graph_execution.node_executions["serialized-node"]
assert restored_node.state is NodeState.TAKEN
assert restored_node.retry_count == 0