diff --git a/api/core/workflow/runtime/graph_runtime_state.py b/api/core/workflow/runtime/graph_runtime_state.py index 486718dc62..4c322c6aa6 100644 --- a/api/core/workflow/runtime/graph_runtime_state.py +++ b/api/core/workflow/runtime/graph_runtime_state.py @@ -5,6 +5,7 @@ import json from collections.abc import Mapping, Sequence from collections.abc import Mapping as TypingMapping from copy import deepcopy +from dataclasses import dataclass from typing import Any, Protocol from pydantic.json import pydantic_encoder @@ -106,6 +107,23 @@ class GraphProtocol(Protocol): def get_outgoing_edges(self, node_id: str) -> Sequence[object]: ... +@dataclass(slots=True) +class _GraphRuntimeStateSnapshot: + """Immutable view of a serialized runtime state snapshot.""" + + start_at: float + total_tokens: int + node_run_steps: int + llm_usage: LLMUsage + outputs: dict[str, Any] + variable_pool: VariablePool + has_variable_pool: bool + ready_queue_dump: str | None + graph_execution_dump: str | None + response_coordinator_dump: str | None + paused_nodes: tuple[str, ...] + + class GraphRuntimeState: """Mutable runtime state shared across graph execution components.""" @@ -293,69 +311,28 @@ class GraphRuntimeState: return json.dumps(snapshot, default=pydantic_encoder) - def loads(self, data: str | Mapping[str, Any]) -> None: + @classmethod + def from_snapshot(cls, data: str | Mapping[str, Any]) -> GraphRuntimeState: """Restore runtime state from a serialized snapshot.""" - payload: dict[str, Any] - if isinstance(data, str): - payload = json.loads(data) - else: - payload = dict(data) + snapshot = cls._parse_snapshot_payload(data) - version = payload.get("version") - if version != "1.0": - raise ValueError(f"Unsupported GraphRuntimeState snapshot version: {version}") + state = cls( + variable_pool=snapshot.variable_pool, + start_at=snapshot.start_at, + total_tokens=snapshot.total_tokens, + llm_usage=snapshot.llm_usage, + outputs=snapshot.outputs, + node_run_steps=snapshot.node_run_steps, + ) + state._apply_snapshot(snapshot) + return state - self._start_at = float(payload.get("start_at", 0.0)) - total_tokens = int(payload.get("total_tokens", 0)) - if total_tokens < 0: - raise ValueError("total_tokens must be non-negative") - self._total_tokens = total_tokens + def loads(self, data: str | Mapping[str, Any]) -> None: + """Restore runtime state from a serialized snapshot (legacy API).""" - node_run_steps = int(payload.get("node_run_steps", 0)) - if node_run_steps < 0: - raise ValueError("node_run_steps must be non-negative") - self._node_run_steps = node_run_steps - - llm_usage_payload = payload.get("llm_usage", {}) - self._llm_usage = LLMUsage.model_validate(llm_usage_payload) - - self._outputs = deepcopy(payload.get("outputs", {})) - - variable_pool_payload = payload.get("variable_pool") - if variable_pool_payload is not None: - self._variable_pool = VariablePool.model_validate(variable_pool_payload) - - ready_queue_payload = payload.get("ready_queue") - if ready_queue_payload is not None: - self._ready_queue = self._build_ready_queue() - self._ready_queue.loads(ready_queue_payload) - else: - self._ready_queue = None - - graph_execution_payload = payload.get("graph_execution") - self._graph_execution = None - self._pending_graph_execution_workflow_id = None - if graph_execution_payload is not None: - try: - execution_payload = json.loads(graph_execution_payload) - self._pending_graph_execution_workflow_id = execution_payload.get("workflow_id") - except (json.JSONDecodeError, TypeError, AttributeError): - self._pending_graph_execution_workflow_id = None - self.graph_execution.loads(graph_execution_payload) - - response_payload = payload.get("response_coordinator") - if response_payload is not None: - if self._graph is not None: - self.response_coordinator.loads(response_payload) - else: - self._pending_response_coordinator_dump = response_payload - else: - self._pending_response_coordinator_dump = None - self._response_coordinator = None - - paused_nodes_payload = payload.get("paused_nodes", []) - self._paused_nodes = set(map(str, paused_nodes_payload)) + snapshot = self._parse_snapshot_payload(data) + self._apply_snapshot(snapshot) def register_paused_node(self, node_id: str) -> None: """Record a node that should resume when execution is continued.""" @@ -391,3 +368,106 @@ class GraphRuntimeState: module = importlib.import_module("core.workflow.graph_engine.response_coordinator") coordinator_cls = module.ResponseStreamCoordinator return coordinator_cls(variable_pool=self.variable_pool, graph=graph) + + # ------------------------------------------------------------------ + # Snapshot helpers + # ------------------------------------------------------------------ + @classmethod + def _parse_snapshot_payload(cls, data: str | Mapping[str, Any]) -> _GraphRuntimeStateSnapshot: + payload: dict[str, Any] + if isinstance(data, str): + payload = json.loads(data) + else: + payload = dict(data) + + version = payload.get("version") + if version != "1.0": + raise ValueError(f"Unsupported GraphRuntimeState snapshot version: {version}") + + start_at = float(payload.get("start_at", 0.0)) + + total_tokens = int(payload.get("total_tokens", 0)) + if total_tokens < 0: + raise ValueError("total_tokens must be non-negative") + + node_run_steps = int(payload.get("node_run_steps", 0)) + if node_run_steps < 0: + raise ValueError("node_run_steps must be non-negative") + + llm_usage_payload = payload.get("llm_usage", {}) + llm_usage = LLMUsage.model_validate(llm_usage_payload) + + outputs_payload = deepcopy(payload.get("outputs", {})) + + variable_pool_payload = payload.get("variable_pool") + has_variable_pool = variable_pool_payload is not None + variable_pool = VariablePool.model_validate(variable_pool_payload) if has_variable_pool else VariablePool() + + ready_queue_payload = payload.get("ready_queue") + graph_execution_payload = payload.get("graph_execution") + response_payload = payload.get("response_coordinator") + paused_nodes_payload = payload.get("paused_nodes", []) + + return _GraphRuntimeStateSnapshot( + start_at=start_at, + total_tokens=total_tokens, + node_run_steps=node_run_steps, + llm_usage=llm_usage, + outputs=outputs_payload, + variable_pool=variable_pool, + has_variable_pool=has_variable_pool, + ready_queue_dump=ready_queue_payload, + graph_execution_dump=graph_execution_payload, + response_coordinator_dump=response_payload, + paused_nodes=tuple(map(str, paused_nodes_payload)), + ) + + def _apply_snapshot(self, snapshot: _GraphRuntimeStateSnapshot) -> None: + self._start_at = snapshot.start_at + self._total_tokens = snapshot.total_tokens + self._node_run_steps = snapshot.node_run_steps + self._llm_usage = snapshot.llm_usage.model_copy() + self._outputs = deepcopy(snapshot.outputs) + if snapshot.has_variable_pool or self._variable_pool is None: + self._variable_pool = snapshot.variable_pool + + self._restore_ready_queue(snapshot.ready_queue_dump) + self._restore_graph_execution(snapshot.graph_execution_dump) + self._restore_response_coordinator(snapshot.response_coordinator_dump) + self._paused_nodes = set(snapshot.paused_nodes) + + def _restore_ready_queue(self, payload: str | None) -> None: + if payload is not None: + self._ready_queue = self._build_ready_queue() + self._ready_queue.loads(payload) + else: + self._ready_queue = None + + def _restore_graph_execution(self, payload: str | None) -> None: + self._graph_execution = None + self._pending_graph_execution_workflow_id = None + + if payload is None: + return + + try: + execution_payload = json.loads(payload) + self._pending_graph_execution_workflow_id = execution_payload.get("workflow_id") + except (json.JSONDecodeError, TypeError, AttributeError): + self._pending_graph_execution_workflow_id = None + + self.graph_execution.loads(payload) + + def _restore_response_coordinator(self, payload: str | None) -> None: + if payload is None: + self._pending_response_coordinator_dump = None + self._response_coordinator = None + return + + if self._graph is not None: + self.response_coordinator.loads(payload) + self._pending_response_coordinator_dump = None + return + + self._pending_response_coordinator_dump = payload + self._response_coordinator = None diff --git a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py index 5ecaeb60ac..deff06fc5d 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py +++ b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py @@ -8,6 +8,18 @@ from core.model_runtime.entities.llm_entities import LLMUsage from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool +class StubCoordinator: + def __init__(self) -> None: + self.state = "initial" + + def dumps(self) -> str: + return json.dumps({"state": self.state}) + + def loads(self, data: str) -> None: + payload = json.loads(data) + self.state = payload["state"] + + class TestGraphRuntimeState: def test_property_getters_and_setters(self): # FIXME(-LAN-): Mock VariablePool if needed @@ -191,17 +203,6 @@ class TestGraphRuntimeState: graph_execution.exceptions_count = 4 graph_execution.started = True - class StubCoordinator: - def __init__(self) -> None: - self.state = "initial" - - def dumps(self) -> str: - return json.dumps({"state": self.state}) - - def loads(self, data: str) -> None: - payload = json.loads(data) - self.state = payload["state"] - mock_graph = MagicMock() stub = StubCoordinator() with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=stub): @@ -211,8 +212,7 @@ class TestGraphRuntimeState: snapshot = state.dumps() - restored = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0) - restored.loads(snapshot) + restored = GraphRuntimeState.from_snapshot(snapshot) assert restored.total_tokens == 10 assert restored.node_run_steps == 3 @@ -235,3 +235,47 @@ class TestGraphRuntimeState: restored.attach_graph(mock_graph) assert new_stub.state == "configured" + + def test_loads_rehydrates_existing_instance(self): + variable_pool = VariablePool() + variable_pool.add(("node", "key"), "value") + + state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) + state.total_tokens = 7 + state.node_run_steps = 2 + state.set_output("foo", "bar") + state.ready_queue.put("node-1") + + execution = state.graph_execution + execution.workflow_id = "wf-456" + execution.started = True + + mock_graph = MagicMock() + original_stub = StubCoordinator() + with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=original_stub): + state.attach_graph(mock_graph) + + original_stub.state = "configured" + snapshot = state.dumps() + + new_stub = StubCoordinator() + with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=new_stub): + restored = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0) + restored.attach_graph(mock_graph) + restored.loads(snapshot) + + assert restored.total_tokens == 7 + assert restored.node_run_steps == 2 + assert restored.get_output("foo") == "bar" + assert restored.ready_queue.qsize() == 1 + assert restored.ready_queue.get(timeout=0.01) == "node-1" + + restored_segment = restored.variable_pool.get(("node", "key")) + assert restored_segment is not None + assert restored_segment.value == "value" + + restored_execution = restored.graph_execution + assert restored_execution.workflow_id == "wf-456" + assert restored_execution.started is True + + assert new_stub.state == "configured"