feat(graph_engine): allow to dumps and loads RSC

This commit is contained in:
-LAN- 2025-09-17 12:45:51 +08:00
parent 02d15ebd5a
commit 73a7756350
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF
4 changed files with 236 additions and 66 deletions

View File

@ -17,6 +17,7 @@ class GraphRuntimeState(BaseModel):
_node_run_steps: int = PrivateAttr(default=0)
_ready_queue_json: str = PrivateAttr()
_graph_execution_json: str = PrivateAttr()
_response_coordinator_json: str = PrivateAttr()
def __init__(
self,
@ -29,6 +30,7 @@ class GraphRuntimeState(BaseModel):
node_run_steps: int = 0,
ready_queue_json: str = "",
graph_execution_json: str = "",
response_coordinator_json: str = "",
**kwargs: object,
):
"""Initialize the GraphRuntimeState with validation."""
@ -57,6 +59,7 @@ class GraphRuntimeState(BaseModel):
self._ready_queue_json = ready_queue_json
self._graph_execution_json = graph_execution_json
self._response_coordinator_json = response_coordinator_json
@property
def variable_pool(self) -> VariablePool:
@ -151,7 +154,7 @@ class GraphRuntimeState(BaseModel):
"""Get a copy of the serialized graph execution state."""
return self._graph_execution_json
@graph_execution_json.setter
def graph_execution_json(self, value: str) -> None:
"""Set the serialized graph execution state."""
self._graph_execution_json = value
@property
def response_coordinator_json(self) -> str:
"""Get a copy of the serialized response coordinator state."""
return self._response_coordinator_json

View File

@ -105,6 +105,8 @@ class GraphEngine:
self._response_coordinator = ResponseStreamCoordinator(
variable_pool=self._graph_runtime_state.variable_pool, graph=self._graph
)
if graph_runtime_state.response_coordinator_json != "":
self._response_coordinator.loads(graph_runtime_state.response_coordinator_json)
# === Event Management ===
# Event manager handles both collection and emission of events

View File

@ -9,9 +9,11 @@ import logging
from collections import deque
from collections.abc import Sequence
from threading import RLock
from typing import TypeAlias, final
from typing import Literal, TypeAlias, final
from uuid import uuid4
from pydantic import BaseModel, Field
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import NodeExecutionType, NodeState
from core.workflow.graph import Graph
@ -28,6 +30,43 @@ NodeID: TypeAlias = str
EdgeID: TypeAlias = str
class ResponseSessionState(BaseModel):
"""Serializable representation of a response session."""
node_id: str
index: int = Field(default=0, ge=0)
class StreamBufferState(BaseModel):
"""Serializable representation of buffered stream chunks."""
selector: tuple[str, ...]
events: list[NodeRunStreamChunkEvent] = Field(default_factory=list)
class StreamPositionState(BaseModel):
"""Serializable representation for stream read positions."""
selector: tuple[str, ...]
position: int = Field(default=0, ge=0)
class ResponseStreamCoordinatorState(BaseModel):
"""Serialized snapshot of ResponseStreamCoordinator."""
type: Literal["ResponseStreamCoordinator"] = Field(default="ResponseStreamCoordinator")
version: str = Field(default="1.0")
response_nodes: Sequence[str] = Field(default_factory=list)
active_session: ResponseSessionState | None = None
waiting_sessions: Sequence[ResponseSessionState] = Field(default_factory=list)
pending_sessions: Sequence[ResponseSessionState] = Field(default_factory=list)
node_execution_ids: dict[str, str] = Field(default_factory=dict)
paths_map: dict[str, list[list[str]]] = Field(default_factory=dict)
stream_buffers: Sequence[StreamBufferState] = Field(default_factory=list)
stream_positions: Sequence[StreamPositionState] = Field(default_factory=list)
closed_streams: Sequence[tuple[str, ...]] = Field(default_factory=list)
@final
class ResponseStreamCoordinator:
"""
@ -69,6 +108,8 @@ class ResponseStreamCoordinator:
def register(self, response_node_id: NodeID) -> None:
with self._lock:
if response_node_id in self._response_nodes:
return
self._response_nodes.add(response_node_id)
# Build and save paths map for this response node
@ -558,3 +599,98 @@ class ResponseStreamCoordinator:
"""
key = tuple(selector)
return key in self._closed_streams
def _serialize_session(self, session: ResponseSession | None) -> ResponseSessionState | None:
"""Convert an in-memory session into its serializable form."""
if session is None:
return None
return ResponseSessionState(node_id=session.node_id, index=session.index)
def _session_from_state(self, session_state: ResponseSessionState) -> ResponseSession:
"""Rebuild a response session from serialized data."""
node = self._graph.nodes.get(session_state.node_id)
if node is None:
raise ValueError(f"Unknown response node '{session_state.node_id}' in serialized state")
session = ResponseSession.from_node(node)
session.index = session_state.index
return session
def dumps(self) -> str:
"""Serialize coordinator state to JSON."""
with self._lock:
state = ResponseStreamCoordinatorState(
response_nodes=sorted(self._response_nodes),
active_session=self._serialize_session(self._active_session),
waiting_sessions=[
session_state
for session in list(self._waiting_sessions)
if (session_state := self._serialize_session(session)) is not None
],
pending_sessions=[
session_state
for _, session in sorted(self._response_sessions.items())
if (session_state := self._serialize_session(session)) is not None
],
node_execution_ids=dict(sorted(self._node_execution_ids.items())),
paths_map={
node_id: [path.edges.copy() for path in paths]
for node_id, paths in sorted(self._paths_maps.items())
},
stream_buffers=[
StreamBufferState(
selector=selector,
events=[event.model_copy(deep=True) for event in events],
)
for selector, events in sorted(self._stream_buffers.items())
],
stream_positions=[
StreamPositionState(selector=selector, position=position)
for selector, position in sorted(self._stream_positions.items())
],
closed_streams=sorted(self._closed_streams),
)
return state.model_dump_json()
def loads(self, data: str) -> None:
"""Restore coordinator state from JSON."""
state = ResponseStreamCoordinatorState.model_validate_json(data)
if state.type != "ResponseStreamCoordinator":
raise ValueError(f"Invalid serialized data type: {state.type}")
if state.version != "1.0":
raise ValueError(f"Unsupported serialized version: {state.version}")
with self._lock:
self._response_nodes = set(state.response_nodes)
self._paths_maps = {
node_id: [Path(edges=list(path_edges)) for path_edges in paths]
for node_id, paths in state.paths_map.items()
}
self._node_execution_ids = dict(state.node_execution_ids)
self._stream_buffers = {
tuple(buffer.selector): [event.model_copy(deep=True) for event in buffer.events]
for buffer in state.stream_buffers
}
self._stream_positions = {
tuple(position.selector): position.position for position in state.stream_positions
}
for selector in self._stream_buffers:
self._stream_positions.setdefault(selector, 0)
self._closed_streams = {tuple(selector) for selector in state.closed_streams}
self._waiting_sessions = deque(
self._session_from_state(session_state) for session_state in state.waiting_sessions
)
self._response_sessions = {
session_state.node_id: self._session_from_state(session_state)
for session_state in state.pending_sessions
}
self._active_session = self._session_from_state(state.active_session) if state.active_session else None

View File

@ -3,12 +3,16 @@
from __future__ import annotations
import json
from collections import deque
from unittest.mock import MagicMock
from core.workflow.entities import GraphRuntimeState
from core.workflow.enums import NodeExecutionType, NodeState
from core.workflow.graph_engine import GraphEngine
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
from core.workflow.graph_engine.domain import GraphExecution
from core.workflow.graph_engine.response_coordinator import ResponseStreamCoordinator
from core.workflow.graph_engine.response_coordinator.path import Path
from core.workflow.graph_engine.response_coordinator.session import ResponseSession
from core.workflow.graph_events import NodeRunStreamChunkEvent
from core.workflow.nodes.base.template import Template, TextSegment, VariableSegment
class CustomGraphExecutionError(Exception):
@ -88,78 +92,103 @@ def test_graph_execution_loads_replaces_existing_state() -> None:
assert restored_node.error is None
def test_graph_engine_initializes_from_serialized_execution(monkeypatch) -> None:
"""GraphEngine restores GraphExecution state from runtime snapshot on init."""
def test_response_stream_coordinator_serialization_round_trip(monkeypatch) -> None:
"""ResponseStreamCoordinator serialization restores coordinator internals."""
# Arrange serialized execution state
execution = GraphExecution(workflow_id="wf-init")
execution.start()
node_state = execution.get_or_create_node_execution("serialized-node")
node_state.mark_taken()
execution.complete()
serialized = execution.dumps()
runtime_state = GraphRuntimeState(
variable_pool=MagicMock(),
start_at=0.0,
graph_execution_json=serialized,
)
template_main = Template(segments=[TextSegment(text="Hi "), VariableSegment(selector=["node-source", "text"])])
template_secondary = Template(segments=[TextSegment(text="secondary")])
class DummyNode:
def __init__(self, graph_runtime_state: GraphRuntimeState) -> None:
self.graph_runtime_state = graph_runtime_state
self.execution_type = NodeExecutionType.EXECUTABLE
self.id = "dummy-node"
def __init__(self, node_id: str, template: Template, execution_type: NodeExecutionType) -> None:
self.id = node_id
self.node_type = NodeType.ANSWER if execution_type == NodeExecutionType.RESPONSE else NodeType.LLM
self.execution_type = execution_type
self.state = NodeState.UNKNOWN
self.title = "dummy"
self.title = node_id
self.template = template
def blocks_variable_output(self, *_args) -> bool:
return False
response_node1 = DummyNode("response-1", template_main, NodeExecutionType.RESPONSE)
response_node2 = DummyNode("response-2", template_main, NodeExecutionType.RESPONSE)
response_node3 = DummyNode("response-3", template_main, NodeExecutionType.RESPONSE)
source_node = DummyNode("node-source", template_secondary, NodeExecutionType.EXECUTABLE)
class DummyGraph:
def __init__(self, graph_runtime_state: GraphRuntimeState) -> None:
self.nodes = {"dummy-node": DummyNode(graph_runtime_state)}
def __init__(self) -> None:
self.nodes = {
response_node1.id: response_node1,
response_node2.id: response_node2,
response_node3.id: response_node3,
source_node.id: source_node,
}
self.edges: dict[str, object] = {}
self.root_node = self.nodes["dummy-node"]
self.root_node = response_node1
def get_incoming_edges(self, node_id: str): # pragma: no cover - not exercised
def get_outgoing_edges(self, _node_id: str): # pragma: no cover - not exercised
return []
def get_outgoing_edges(self, node_id: str): # pragma: no cover - not exercised
def get_incoming_edges(self, _node_id: str): # pragma: no cover - not exercised
return []
dummy_graph = DummyGraph(runtime_state)
graph = DummyGraph()
def _stub(*_args, **_kwargs):
return MagicMock()
def fake_from_node(cls, node: DummyNode) -> ResponseSession:
return ResponseSession(node_id=node.id, template=node.template)
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.GraphStateManager", _stub)
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.ResponseStreamCoordinator", _stub)
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.EventManager", _stub)
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.ErrorHandler", _stub)
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.SkipPropagator", _stub)
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.EdgeProcessor", _stub)
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.EventHandler", _stub)
command_processor = MagicMock()
command_processor.register_handler = MagicMock()
monkeypatch.setattr(
"core.workflow.graph_engine.graph_engine.CommandProcessor",
lambda *_args, **_kwargs: command_processor,
monkeypatch.setattr(ResponseSession, "from_node", classmethod(fake_from_node))
coordinator = ResponseStreamCoordinator(variable_pool=MagicMock(), graph=graph) # type: ignore[arg-type]
coordinator._response_nodes = {"response-1", "response-2", "response-3"}
coordinator._paths_maps = {
"response-1": [Path(edges=["edge-1"])],
"response-2": [Path(edges=[])],
"response-3": [Path(edges=["edge-2", "edge-3"])],
}
active_session = ResponseSession(node_id="response-1", template=response_node1.template)
active_session.index = 1
coordinator._active_session = active_session
waiting_session = ResponseSession(node_id="response-2", template=response_node2.template)
coordinator._waiting_sessions = deque([waiting_session])
pending_session = ResponseSession(node_id="response-3", template=response_node3.template)
pending_session.index = 2
coordinator._response_sessions = {"response-3": pending_session}
coordinator._node_execution_ids = {"response-1": "exec-1"}
event = NodeRunStreamChunkEvent(
id="exec-1",
node_id="response-1",
node_type=NodeType.ANSWER,
selector=["node-source", "text"],
chunk="chunk-1",
is_final=False,
)
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.AbortCommandHandler", _stub)
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.WorkerPool", _stub)
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.ExecutionCoordinator", _stub)
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.Dispatcher", _stub)
coordinator._stream_buffers = {("node-source", "text"): [event]}
coordinator._stream_positions = {("node-source", "text"): 1}
coordinator._closed_streams = {("node-source", "text")}
# Act
engine = GraphEngine(
workflow_id="wf-init",
graph=dummy_graph, # type: ignore[arg-type]
graph_runtime_state=runtime_state,
command_channel=MagicMock(),
)
serialized = coordinator.dumps()
# Assert
assert engine._graph_execution.started is True
assert engine._graph_execution.completed is True
assert set(engine._graph_execution.node_executions) == {"serialized-node"}
restored_node = engine._graph_execution.node_executions["serialized-node"]
assert restored_node.state is NodeState.TAKEN
assert restored_node.retry_count == 0
restored = ResponseStreamCoordinator(variable_pool=MagicMock(), graph=graph) # type: ignore[arg-type]
monkeypatch.setattr(ResponseSession, "from_node", classmethod(fake_from_node))
restored.loads(serialized)
assert restored._response_nodes == {"response-1", "response-2", "response-3"}
assert restored._paths_maps["response-1"][0].edges == ["edge-1"]
assert restored._active_session is not None
assert restored._active_session.node_id == "response-1"
assert restored._active_session.index == 1
waiting_restored = list(restored._waiting_sessions)
assert len(waiting_restored) == 1
assert waiting_restored[0].node_id == "response-2"
assert waiting_restored[0].index == 0
assert set(restored._response_sessions) == {"response-3"}
assert restored._response_sessions["response-3"].index == 2
assert restored._node_execution_ids == {"response-1": "exec-1"}
assert ("node-source", "text") in restored._stream_buffers
restored_event = restored._stream_buffers[("node-source", "text")][0]
assert restored_event.chunk == "chunk-1"
assert restored._stream_positions[("node-source", "text")] == 1
assert ("node-source", "text") in restored._closed_streams