From 73a77563509d2ab279c50ce7e031596e47c4d4f6 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 17 Sep 2025 12:45:51 +0800 Subject: [PATCH] feat(graph_engine): allow to dumps and loads RSC --- .../workflow/entities/graph_runtime_state.py | 11 +- .../workflow/graph_engine/graph_engine.py | 2 + .../response_coordinator/coordinator.py | 138 +++++++++++++++- .../test_graph_execution_serialization.py | 151 +++++++++++------- 4 files changed, 236 insertions(+), 66 deletions(-) diff --git a/api/core/workflow/entities/graph_runtime_state.py b/api/core/workflow/entities/graph_runtime_state.py index c8fb1de20e..6362f291ea 100644 --- a/api/core/workflow/entities/graph_runtime_state.py +++ b/api/core/workflow/entities/graph_runtime_state.py @@ -17,6 +17,7 @@ class GraphRuntimeState(BaseModel): _node_run_steps: int = PrivateAttr(default=0) _ready_queue_json: str = PrivateAttr() _graph_execution_json: str = PrivateAttr() + _response_coordinator_json: str = PrivateAttr() def __init__( self, @@ -29,6 +30,7 @@ class GraphRuntimeState(BaseModel): node_run_steps: int = 0, ready_queue_json: str = "", graph_execution_json: str = "", + response_coordinator_json: str = "", **kwargs: object, ): """Initialize the GraphRuntimeState with validation.""" @@ -57,6 +59,7 @@ class GraphRuntimeState(BaseModel): self._ready_queue_json = ready_queue_json self._graph_execution_json = graph_execution_json + self._response_coordinator_json = response_coordinator_json @property def variable_pool(self) -> VariablePool: @@ -151,7 +154,7 @@ class GraphRuntimeState(BaseModel): """Get a copy of the serialized graph execution state.""" return self._graph_execution_json - @graph_execution_json.setter - def graph_execution_json(self, value: str) -> None: - """Set the serialized graph execution state.""" - self._graph_execution_json = value + @property + def response_coordinator_json(self) -> str: + """Get a copy of the serialized response coordinator state.""" + return self._response_coordinator_json diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 1a136d4365..164ae41cca 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -105,6 +105,8 @@ class GraphEngine: self._response_coordinator = ResponseStreamCoordinator( variable_pool=self._graph_runtime_state.variable_pool, graph=self._graph ) + if graph_runtime_state.response_coordinator_json != "": + self._response_coordinator.loads(graph_runtime_state.response_coordinator_json) # === Event Management === # Event manager handles both collection and emission of events diff --git a/api/core/workflow/graph_engine/response_coordinator/coordinator.py b/api/core/workflow/graph_engine/response_coordinator/coordinator.py index b5224cbc22..985992f3f1 100644 --- a/api/core/workflow/graph_engine/response_coordinator/coordinator.py +++ b/api/core/workflow/graph_engine/response_coordinator/coordinator.py @@ -9,9 +9,11 @@ import logging from collections import deque from collections.abc import Sequence from threading import RLock -from typing import TypeAlias, final +from typing import Literal, TypeAlias, final from uuid import uuid4 +from pydantic import BaseModel, Field + from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import NodeExecutionType, NodeState from core.workflow.graph import Graph @@ -28,6 +30,43 @@ NodeID: TypeAlias = str EdgeID: TypeAlias = str +class ResponseSessionState(BaseModel): + """Serializable representation of a response session.""" + + node_id: str + index: int = Field(default=0, ge=0) + + +class StreamBufferState(BaseModel): + """Serializable representation of buffered stream chunks.""" + + selector: tuple[str, ...] + events: list[NodeRunStreamChunkEvent] = Field(default_factory=list) + + +class StreamPositionState(BaseModel): + """Serializable representation for stream read positions.""" + + selector: tuple[str, ...] + position: int = Field(default=0, ge=0) + + +class ResponseStreamCoordinatorState(BaseModel): + """Serialized snapshot of ResponseStreamCoordinator.""" + + type: Literal["ResponseStreamCoordinator"] = Field(default="ResponseStreamCoordinator") + version: str = Field(default="1.0") + response_nodes: Sequence[str] = Field(default_factory=list) + active_session: ResponseSessionState | None = None + waiting_sessions: Sequence[ResponseSessionState] = Field(default_factory=list) + pending_sessions: Sequence[ResponseSessionState] = Field(default_factory=list) + node_execution_ids: dict[str, str] = Field(default_factory=dict) + paths_map: dict[str, list[list[str]]] = Field(default_factory=dict) + stream_buffers: Sequence[StreamBufferState] = Field(default_factory=list) + stream_positions: Sequence[StreamPositionState] = Field(default_factory=list) + closed_streams: Sequence[tuple[str, ...]] = Field(default_factory=list) + + @final class ResponseStreamCoordinator: """ @@ -69,6 +108,8 @@ class ResponseStreamCoordinator: def register(self, response_node_id: NodeID) -> None: with self._lock: + if response_node_id in self._response_nodes: + return self._response_nodes.add(response_node_id) # Build and save paths map for this response node @@ -558,3 +599,98 @@ class ResponseStreamCoordinator: """ key = tuple(selector) return key in self._closed_streams + + def _serialize_session(self, session: ResponseSession | None) -> ResponseSessionState | None: + """Convert an in-memory session into its serializable form.""" + + if session is None: + return None + return ResponseSessionState(node_id=session.node_id, index=session.index) + + def _session_from_state(self, session_state: ResponseSessionState) -> ResponseSession: + """Rebuild a response session from serialized data.""" + + node = self._graph.nodes.get(session_state.node_id) + if node is None: + raise ValueError(f"Unknown response node '{session_state.node_id}' in serialized state") + + session = ResponseSession.from_node(node) + session.index = session_state.index + return session + + def dumps(self) -> str: + """Serialize coordinator state to JSON.""" + + with self._lock: + state = ResponseStreamCoordinatorState( + response_nodes=sorted(self._response_nodes), + active_session=self._serialize_session(self._active_session), + waiting_sessions=[ + session_state + for session in list(self._waiting_sessions) + if (session_state := self._serialize_session(session)) is not None + ], + pending_sessions=[ + session_state + for _, session in sorted(self._response_sessions.items()) + if (session_state := self._serialize_session(session)) is not None + ], + node_execution_ids=dict(sorted(self._node_execution_ids.items())), + paths_map={ + node_id: [path.edges.copy() for path in paths] + for node_id, paths in sorted(self._paths_maps.items()) + }, + stream_buffers=[ + StreamBufferState( + selector=selector, + events=[event.model_copy(deep=True) for event in events], + ) + for selector, events in sorted(self._stream_buffers.items()) + ], + stream_positions=[ + StreamPositionState(selector=selector, position=position) + for selector, position in sorted(self._stream_positions.items()) + ], + closed_streams=sorted(self._closed_streams), + ) + return state.model_dump_json() + + def loads(self, data: str) -> None: + """Restore coordinator state from JSON.""" + + state = ResponseStreamCoordinatorState.model_validate_json(data) + + if state.type != "ResponseStreamCoordinator": + raise ValueError(f"Invalid serialized data type: {state.type}") + + if state.version != "1.0": + raise ValueError(f"Unsupported serialized version: {state.version}") + + with self._lock: + self._response_nodes = set(state.response_nodes) + self._paths_maps = { + node_id: [Path(edges=list(path_edges)) for path_edges in paths] + for node_id, paths in state.paths_map.items() + } + self._node_execution_ids = dict(state.node_execution_ids) + + self._stream_buffers = { + tuple(buffer.selector): [event.model_copy(deep=True) for event in buffer.events] + for buffer in state.stream_buffers + } + self._stream_positions = { + tuple(position.selector): position.position for position in state.stream_positions + } + for selector in self._stream_buffers: + self._stream_positions.setdefault(selector, 0) + + self._closed_streams = {tuple(selector) for selector in state.closed_streams} + + self._waiting_sessions = deque( + self._session_from_state(session_state) for session_state in state.waiting_sessions + ) + self._response_sessions = { + session_state.node_id: self._session_from_state(session_state) + for session_state in state.pending_sessions + } + self._active_session = self._session_from_state(state.active_session) if state.active_session else None diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py index 2388e4d57b..6385b0b91f 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py @@ -3,12 +3,16 @@ from __future__ import annotations import json +from collections import deque from unittest.mock import MagicMock -from core.workflow.entities import GraphRuntimeState -from core.workflow.enums import NodeExecutionType, NodeState -from core.workflow.graph_engine import GraphEngine +from core.workflow.enums import NodeExecutionType, NodeState, NodeType from core.workflow.graph_engine.domain import GraphExecution +from core.workflow.graph_engine.response_coordinator import ResponseStreamCoordinator +from core.workflow.graph_engine.response_coordinator.path import Path +from core.workflow.graph_engine.response_coordinator.session import ResponseSession +from core.workflow.graph_events import NodeRunStreamChunkEvent +from core.workflow.nodes.base.template import Template, TextSegment, VariableSegment class CustomGraphExecutionError(Exception): @@ -88,78 +92,103 @@ def test_graph_execution_loads_replaces_existing_state() -> None: assert restored_node.error is None -def test_graph_engine_initializes_from_serialized_execution(monkeypatch) -> None: - """GraphEngine restores GraphExecution state from runtime snapshot on init.""" +def test_response_stream_coordinator_serialization_round_trip(monkeypatch) -> None: + """ResponseStreamCoordinator serialization restores coordinator internals.""" - # Arrange serialized execution state - execution = GraphExecution(workflow_id="wf-init") - execution.start() - node_state = execution.get_or_create_node_execution("serialized-node") - node_state.mark_taken() - execution.complete() - serialized = execution.dumps() - - runtime_state = GraphRuntimeState( - variable_pool=MagicMock(), - start_at=0.0, - graph_execution_json=serialized, - ) + template_main = Template(segments=[TextSegment(text="Hi "), VariableSegment(selector=["node-source", "text"])]) + template_secondary = Template(segments=[TextSegment(text="secondary")]) class DummyNode: - def __init__(self, graph_runtime_state: GraphRuntimeState) -> None: - self.graph_runtime_state = graph_runtime_state - self.execution_type = NodeExecutionType.EXECUTABLE - self.id = "dummy-node" + def __init__(self, node_id: str, template: Template, execution_type: NodeExecutionType) -> None: + self.id = node_id + self.node_type = NodeType.ANSWER if execution_type == NodeExecutionType.RESPONSE else NodeType.LLM + self.execution_type = execution_type self.state = NodeState.UNKNOWN - self.title = "dummy" + self.title = node_id + self.template = template + + def blocks_variable_output(self, *_args) -> bool: + return False + + response_node1 = DummyNode("response-1", template_main, NodeExecutionType.RESPONSE) + response_node2 = DummyNode("response-2", template_main, NodeExecutionType.RESPONSE) + response_node3 = DummyNode("response-3", template_main, NodeExecutionType.RESPONSE) + source_node = DummyNode("node-source", template_secondary, NodeExecutionType.EXECUTABLE) class DummyGraph: - def __init__(self, graph_runtime_state: GraphRuntimeState) -> None: - self.nodes = {"dummy-node": DummyNode(graph_runtime_state)} + def __init__(self) -> None: + self.nodes = { + response_node1.id: response_node1, + response_node2.id: response_node2, + response_node3.id: response_node3, + source_node.id: source_node, + } self.edges: dict[str, object] = {} - self.root_node = self.nodes["dummy-node"] + self.root_node = response_node1 - def get_incoming_edges(self, node_id: str): # pragma: no cover - not exercised + def get_outgoing_edges(self, _node_id: str): # pragma: no cover - not exercised return [] - def get_outgoing_edges(self, node_id: str): # pragma: no cover - not exercised + def get_incoming_edges(self, _node_id: str): # pragma: no cover - not exercised return [] - dummy_graph = DummyGraph(runtime_state) + graph = DummyGraph() - def _stub(*_args, **_kwargs): - return MagicMock() + def fake_from_node(cls, node: DummyNode) -> ResponseSession: + return ResponseSession(node_id=node.id, template=node.template) - monkeypatch.setattr("core.workflow.graph_engine.graph_engine.GraphStateManager", _stub) - monkeypatch.setattr("core.workflow.graph_engine.graph_engine.ResponseStreamCoordinator", _stub) - monkeypatch.setattr("core.workflow.graph_engine.graph_engine.EventManager", _stub) - monkeypatch.setattr("core.workflow.graph_engine.graph_engine.ErrorHandler", _stub) - monkeypatch.setattr("core.workflow.graph_engine.graph_engine.SkipPropagator", _stub) - monkeypatch.setattr("core.workflow.graph_engine.graph_engine.EdgeProcessor", _stub) - monkeypatch.setattr("core.workflow.graph_engine.graph_engine.EventHandler", _stub) - command_processor = MagicMock() - command_processor.register_handler = MagicMock() - monkeypatch.setattr( - "core.workflow.graph_engine.graph_engine.CommandProcessor", - lambda *_args, **_kwargs: command_processor, + monkeypatch.setattr(ResponseSession, "from_node", classmethod(fake_from_node)) + + coordinator = ResponseStreamCoordinator(variable_pool=MagicMock(), graph=graph) # type: ignore[arg-type] + coordinator._response_nodes = {"response-1", "response-2", "response-3"} + coordinator._paths_maps = { + "response-1": [Path(edges=["edge-1"])], + "response-2": [Path(edges=[])], + "response-3": [Path(edges=["edge-2", "edge-3"])], + } + + active_session = ResponseSession(node_id="response-1", template=response_node1.template) + active_session.index = 1 + coordinator._active_session = active_session + waiting_session = ResponseSession(node_id="response-2", template=response_node2.template) + coordinator._waiting_sessions = deque([waiting_session]) + pending_session = ResponseSession(node_id="response-3", template=response_node3.template) + pending_session.index = 2 + coordinator._response_sessions = {"response-3": pending_session} + + coordinator._node_execution_ids = {"response-1": "exec-1"} + event = NodeRunStreamChunkEvent( + id="exec-1", + node_id="response-1", + node_type=NodeType.ANSWER, + selector=["node-source", "text"], + chunk="chunk-1", + is_final=False, ) - monkeypatch.setattr("core.workflow.graph_engine.graph_engine.AbortCommandHandler", _stub) - monkeypatch.setattr("core.workflow.graph_engine.graph_engine.WorkerPool", _stub) - monkeypatch.setattr("core.workflow.graph_engine.graph_engine.ExecutionCoordinator", _stub) - monkeypatch.setattr("core.workflow.graph_engine.graph_engine.Dispatcher", _stub) + coordinator._stream_buffers = {("node-source", "text"): [event]} + coordinator._stream_positions = {("node-source", "text"): 1} + coordinator._closed_streams = {("node-source", "text")} - # Act - engine = GraphEngine( - workflow_id="wf-init", - graph=dummy_graph, # type: ignore[arg-type] - graph_runtime_state=runtime_state, - command_channel=MagicMock(), - ) + serialized = coordinator.dumps() - # Assert - assert engine._graph_execution.started is True - assert engine._graph_execution.completed is True - assert set(engine._graph_execution.node_executions) == {"serialized-node"} - restored_node = engine._graph_execution.node_executions["serialized-node"] - assert restored_node.state is NodeState.TAKEN - assert restored_node.retry_count == 0 + restored = ResponseStreamCoordinator(variable_pool=MagicMock(), graph=graph) # type: ignore[arg-type] + monkeypatch.setattr(ResponseSession, "from_node", classmethod(fake_from_node)) + restored.loads(serialized) + + assert restored._response_nodes == {"response-1", "response-2", "response-3"} + assert restored._paths_maps["response-1"][0].edges == ["edge-1"] + assert restored._active_session is not None + assert restored._active_session.node_id == "response-1" + assert restored._active_session.index == 1 + waiting_restored = list(restored._waiting_sessions) + assert len(waiting_restored) == 1 + assert waiting_restored[0].node_id == "response-2" + assert waiting_restored[0].index == 0 + assert set(restored._response_sessions) == {"response-3"} + assert restored._response_sessions["response-3"].index == 2 + assert restored._node_execution_ids == {"response-1": "exec-1"} + assert ("node-source", "text") in restored._stream_buffers + restored_event = restored._stream_buffers[("node-source", "text")][0] + assert restored_event.chunk == "chunk-1" + assert restored._stream_positions[("node-source", "text")] == 1 + assert ("node-source", "text") in restored._closed_streams