mirror of https://github.com/langgenius/dify.git
feat(graph_engine): allow to dumps and loads RSC
This commit is contained in:
parent
02d15ebd5a
commit
73a7756350
|
|
@ -17,6 +17,7 @@ class GraphRuntimeState(BaseModel):
|
|||
_node_run_steps: int = PrivateAttr(default=0)
|
||||
_ready_queue_json: str = PrivateAttr()
|
||||
_graph_execution_json: str = PrivateAttr()
|
||||
_response_coordinator_json: str = PrivateAttr()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -29,6 +30,7 @@ class GraphRuntimeState(BaseModel):
|
|||
node_run_steps: int = 0,
|
||||
ready_queue_json: str = "",
|
||||
graph_execution_json: str = "",
|
||||
response_coordinator_json: str = "",
|
||||
**kwargs: object,
|
||||
):
|
||||
"""Initialize the GraphRuntimeState with validation."""
|
||||
|
|
@ -57,6 +59,7 @@ class GraphRuntimeState(BaseModel):
|
|||
|
||||
self._ready_queue_json = ready_queue_json
|
||||
self._graph_execution_json = graph_execution_json
|
||||
self._response_coordinator_json = response_coordinator_json
|
||||
|
||||
@property
|
||||
def variable_pool(self) -> VariablePool:
|
||||
|
|
@ -151,7 +154,7 @@ class GraphRuntimeState(BaseModel):
|
|||
"""Get a copy of the serialized graph execution state."""
|
||||
return self._graph_execution_json
|
||||
|
||||
@graph_execution_json.setter
|
||||
def graph_execution_json(self, value: str) -> None:
|
||||
"""Set the serialized graph execution state."""
|
||||
self._graph_execution_json = value
|
||||
@property
|
||||
def response_coordinator_json(self) -> str:
|
||||
"""Get a copy of the serialized response coordinator state."""
|
||||
return self._response_coordinator_json
|
||||
|
|
|
|||
|
|
@ -105,6 +105,8 @@ class GraphEngine:
|
|||
self._response_coordinator = ResponseStreamCoordinator(
|
||||
variable_pool=self._graph_runtime_state.variable_pool, graph=self._graph
|
||||
)
|
||||
if graph_runtime_state.response_coordinator_json != "":
|
||||
self._response_coordinator.loads(graph_runtime_state.response_coordinator_json)
|
||||
|
||||
# === Event Management ===
|
||||
# Event manager handles both collection and emission of events
|
||||
|
|
|
|||
|
|
@ -9,9 +9,11 @@ import logging
|
|||
from collections import deque
|
||||
from collections.abc import Sequence
|
||||
from threading import RLock
|
||||
from typing import TypeAlias, final
|
||||
from typing import Literal, TypeAlias, final
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import NodeExecutionType, NodeState
|
||||
from core.workflow.graph import Graph
|
||||
|
|
@ -28,6 +30,43 @@ NodeID: TypeAlias = str
|
|||
EdgeID: TypeAlias = str
|
||||
|
||||
|
||||
class ResponseSessionState(BaseModel):
|
||||
"""Serializable representation of a response session."""
|
||||
|
||||
node_id: str
|
||||
index: int = Field(default=0, ge=0)
|
||||
|
||||
|
||||
class StreamBufferState(BaseModel):
|
||||
"""Serializable representation of buffered stream chunks."""
|
||||
|
||||
selector: tuple[str, ...]
|
||||
events: list[NodeRunStreamChunkEvent] = Field(default_factory=list)
|
||||
|
||||
|
||||
class StreamPositionState(BaseModel):
|
||||
"""Serializable representation for stream read positions."""
|
||||
|
||||
selector: tuple[str, ...]
|
||||
position: int = Field(default=0, ge=0)
|
||||
|
||||
|
||||
class ResponseStreamCoordinatorState(BaseModel):
|
||||
"""Serialized snapshot of ResponseStreamCoordinator."""
|
||||
|
||||
type: Literal["ResponseStreamCoordinator"] = Field(default="ResponseStreamCoordinator")
|
||||
version: str = Field(default="1.0")
|
||||
response_nodes: Sequence[str] = Field(default_factory=list)
|
||||
active_session: ResponseSessionState | None = None
|
||||
waiting_sessions: Sequence[ResponseSessionState] = Field(default_factory=list)
|
||||
pending_sessions: Sequence[ResponseSessionState] = Field(default_factory=list)
|
||||
node_execution_ids: dict[str, str] = Field(default_factory=dict)
|
||||
paths_map: dict[str, list[list[str]]] = Field(default_factory=dict)
|
||||
stream_buffers: Sequence[StreamBufferState] = Field(default_factory=list)
|
||||
stream_positions: Sequence[StreamPositionState] = Field(default_factory=list)
|
||||
closed_streams: Sequence[tuple[str, ...]] = Field(default_factory=list)
|
||||
|
||||
|
||||
@final
|
||||
class ResponseStreamCoordinator:
|
||||
"""
|
||||
|
|
@ -69,6 +108,8 @@ class ResponseStreamCoordinator:
|
|||
|
||||
def register(self, response_node_id: NodeID) -> None:
|
||||
with self._lock:
|
||||
if response_node_id in self._response_nodes:
|
||||
return
|
||||
self._response_nodes.add(response_node_id)
|
||||
|
||||
# Build and save paths map for this response node
|
||||
|
|
@ -558,3 +599,98 @@ class ResponseStreamCoordinator:
|
|||
"""
|
||||
key = tuple(selector)
|
||||
return key in self._closed_streams
|
||||
|
||||
def _serialize_session(self, session: ResponseSession | None) -> ResponseSessionState | None:
|
||||
"""Convert an in-memory session into its serializable form."""
|
||||
|
||||
if session is None:
|
||||
return None
|
||||
return ResponseSessionState(node_id=session.node_id, index=session.index)
|
||||
|
||||
def _session_from_state(self, session_state: ResponseSessionState) -> ResponseSession:
|
||||
"""Rebuild a response session from serialized data."""
|
||||
|
||||
node = self._graph.nodes.get(session_state.node_id)
|
||||
if node is None:
|
||||
raise ValueError(f"Unknown response node '{session_state.node_id}' in serialized state")
|
||||
|
||||
session = ResponseSession.from_node(node)
|
||||
session.index = session_state.index
|
||||
return session
|
||||
|
||||
def dumps(self) -> str:
|
||||
"""Serialize coordinator state to JSON."""
|
||||
|
||||
with self._lock:
|
||||
state = ResponseStreamCoordinatorState(
|
||||
response_nodes=sorted(self._response_nodes),
|
||||
active_session=self._serialize_session(self._active_session),
|
||||
waiting_sessions=[
|
||||
session_state
|
||||
for session in list(self._waiting_sessions)
|
||||
if (session_state := self._serialize_session(session)) is not None
|
||||
],
|
||||
pending_sessions=[
|
||||
session_state
|
||||
for _, session in sorted(self._response_sessions.items())
|
||||
if (session_state := self._serialize_session(session)) is not None
|
||||
],
|
||||
node_execution_ids=dict(sorted(self._node_execution_ids.items())),
|
||||
paths_map={
|
||||
node_id: [path.edges.copy() for path in paths]
|
||||
for node_id, paths in sorted(self._paths_maps.items())
|
||||
},
|
||||
stream_buffers=[
|
||||
StreamBufferState(
|
||||
selector=selector,
|
||||
events=[event.model_copy(deep=True) for event in events],
|
||||
)
|
||||
for selector, events in sorted(self._stream_buffers.items())
|
||||
],
|
||||
stream_positions=[
|
||||
StreamPositionState(selector=selector, position=position)
|
||||
for selector, position in sorted(self._stream_positions.items())
|
||||
],
|
||||
closed_streams=sorted(self._closed_streams),
|
||||
)
|
||||
return state.model_dump_json()
|
||||
|
||||
def loads(self, data: str) -> None:
|
||||
"""Restore coordinator state from JSON."""
|
||||
|
||||
state = ResponseStreamCoordinatorState.model_validate_json(data)
|
||||
|
||||
if state.type != "ResponseStreamCoordinator":
|
||||
raise ValueError(f"Invalid serialized data type: {state.type}")
|
||||
|
||||
if state.version != "1.0":
|
||||
raise ValueError(f"Unsupported serialized version: {state.version}")
|
||||
|
||||
with self._lock:
|
||||
self._response_nodes = set(state.response_nodes)
|
||||
self._paths_maps = {
|
||||
node_id: [Path(edges=list(path_edges)) for path_edges in paths]
|
||||
for node_id, paths in state.paths_map.items()
|
||||
}
|
||||
self._node_execution_ids = dict(state.node_execution_ids)
|
||||
|
||||
self._stream_buffers = {
|
||||
tuple(buffer.selector): [event.model_copy(deep=True) for event in buffer.events]
|
||||
for buffer in state.stream_buffers
|
||||
}
|
||||
self._stream_positions = {
|
||||
tuple(position.selector): position.position for position in state.stream_positions
|
||||
}
|
||||
for selector in self._stream_buffers:
|
||||
self._stream_positions.setdefault(selector, 0)
|
||||
|
||||
self._closed_streams = {tuple(selector) for selector in state.closed_streams}
|
||||
|
||||
self._waiting_sessions = deque(
|
||||
self._session_from_state(session_state) for session_state in state.waiting_sessions
|
||||
)
|
||||
self._response_sessions = {
|
||||
session_state.node_id: self._session_from_state(session_state)
|
||||
for session_state in state.pending_sessions
|
||||
}
|
||||
self._active_session = self._session_from_state(state.active_session) if state.active_session else None
|
||||
|
|
|
|||
|
|
@ -3,12 +3,16 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections import deque
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.enums import NodeExecutionType, NodeState
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
|
||||
from core.workflow.graph_engine.domain import GraphExecution
|
||||
from core.workflow.graph_engine.response_coordinator import ResponseStreamCoordinator
|
||||
from core.workflow.graph_engine.response_coordinator.path import Path
|
||||
from core.workflow.graph_engine.response_coordinator.session import ResponseSession
|
||||
from core.workflow.graph_events import NodeRunStreamChunkEvent
|
||||
from core.workflow.nodes.base.template import Template, TextSegment, VariableSegment
|
||||
|
||||
|
||||
class CustomGraphExecutionError(Exception):
|
||||
|
|
@ -88,78 +92,103 @@ def test_graph_execution_loads_replaces_existing_state() -> None:
|
|||
assert restored_node.error is None
|
||||
|
||||
|
||||
def test_graph_engine_initializes_from_serialized_execution(monkeypatch) -> None:
|
||||
"""GraphEngine restores GraphExecution state from runtime snapshot on init."""
|
||||
def test_response_stream_coordinator_serialization_round_trip(monkeypatch) -> None:
|
||||
"""ResponseStreamCoordinator serialization restores coordinator internals."""
|
||||
|
||||
# Arrange serialized execution state
|
||||
execution = GraphExecution(workflow_id="wf-init")
|
||||
execution.start()
|
||||
node_state = execution.get_or_create_node_execution("serialized-node")
|
||||
node_state.mark_taken()
|
||||
execution.complete()
|
||||
serialized = execution.dumps()
|
||||
|
||||
runtime_state = GraphRuntimeState(
|
||||
variable_pool=MagicMock(),
|
||||
start_at=0.0,
|
||||
graph_execution_json=serialized,
|
||||
)
|
||||
template_main = Template(segments=[TextSegment(text="Hi "), VariableSegment(selector=["node-source", "text"])])
|
||||
template_secondary = Template(segments=[TextSegment(text="secondary")])
|
||||
|
||||
class DummyNode:
|
||||
def __init__(self, graph_runtime_state: GraphRuntimeState) -> None:
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self.execution_type = NodeExecutionType.EXECUTABLE
|
||||
self.id = "dummy-node"
|
||||
def __init__(self, node_id: str, template: Template, execution_type: NodeExecutionType) -> None:
|
||||
self.id = node_id
|
||||
self.node_type = NodeType.ANSWER if execution_type == NodeExecutionType.RESPONSE else NodeType.LLM
|
||||
self.execution_type = execution_type
|
||||
self.state = NodeState.UNKNOWN
|
||||
self.title = "dummy"
|
||||
self.title = node_id
|
||||
self.template = template
|
||||
|
||||
def blocks_variable_output(self, *_args) -> bool:
|
||||
return False
|
||||
|
||||
response_node1 = DummyNode("response-1", template_main, NodeExecutionType.RESPONSE)
|
||||
response_node2 = DummyNode("response-2", template_main, NodeExecutionType.RESPONSE)
|
||||
response_node3 = DummyNode("response-3", template_main, NodeExecutionType.RESPONSE)
|
||||
source_node = DummyNode("node-source", template_secondary, NodeExecutionType.EXECUTABLE)
|
||||
|
||||
class DummyGraph:
|
||||
def __init__(self, graph_runtime_state: GraphRuntimeState) -> None:
|
||||
self.nodes = {"dummy-node": DummyNode(graph_runtime_state)}
|
||||
def __init__(self) -> None:
|
||||
self.nodes = {
|
||||
response_node1.id: response_node1,
|
||||
response_node2.id: response_node2,
|
||||
response_node3.id: response_node3,
|
||||
source_node.id: source_node,
|
||||
}
|
||||
self.edges: dict[str, object] = {}
|
||||
self.root_node = self.nodes["dummy-node"]
|
||||
self.root_node = response_node1
|
||||
|
||||
def get_incoming_edges(self, node_id: str): # pragma: no cover - not exercised
|
||||
def get_outgoing_edges(self, _node_id: str): # pragma: no cover - not exercised
|
||||
return []
|
||||
|
||||
def get_outgoing_edges(self, node_id: str): # pragma: no cover - not exercised
|
||||
def get_incoming_edges(self, _node_id: str): # pragma: no cover - not exercised
|
||||
return []
|
||||
|
||||
dummy_graph = DummyGraph(runtime_state)
|
||||
graph = DummyGraph()
|
||||
|
||||
def _stub(*_args, **_kwargs):
|
||||
return MagicMock()
|
||||
def fake_from_node(cls, node: DummyNode) -> ResponseSession:
|
||||
return ResponseSession(node_id=node.id, template=node.template)
|
||||
|
||||
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.GraphStateManager", _stub)
|
||||
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.ResponseStreamCoordinator", _stub)
|
||||
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.EventManager", _stub)
|
||||
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.ErrorHandler", _stub)
|
||||
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.SkipPropagator", _stub)
|
||||
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.EdgeProcessor", _stub)
|
||||
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.EventHandler", _stub)
|
||||
command_processor = MagicMock()
|
||||
command_processor.register_handler = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
"core.workflow.graph_engine.graph_engine.CommandProcessor",
|
||||
lambda *_args, **_kwargs: command_processor,
|
||||
monkeypatch.setattr(ResponseSession, "from_node", classmethod(fake_from_node))
|
||||
|
||||
coordinator = ResponseStreamCoordinator(variable_pool=MagicMock(), graph=graph) # type: ignore[arg-type]
|
||||
coordinator._response_nodes = {"response-1", "response-2", "response-3"}
|
||||
coordinator._paths_maps = {
|
||||
"response-1": [Path(edges=["edge-1"])],
|
||||
"response-2": [Path(edges=[])],
|
||||
"response-3": [Path(edges=["edge-2", "edge-3"])],
|
||||
}
|
||||
|
||||
active_session = ResponseSession(node_id="response-1", template=response_node1.template)
|
||||
active_session.index = 1
|
||||
coordinator._active_session = active_session
|
||||
waiting_session = ResponseSession(node_id="response-2", template=response_node2.template)
|
||||
coordinator._waiting_sessions = deque([waiting_session])
|
||||
pending_session = ResponseSession(node_id="response-3", template=response_node3.template)
|
||||
pending_session.index = 2
|
||||
coordinator._response_sessions = {"response-3": pending_session}
|
||||
|
||||
coordinator._node_execution_ids = {"response-1": "exec-1"}
|
||||
event = NodeRunStreamChunkEvent(
|
||||
id="exec-1",
|
||||
node_id="response-1",
|
||||
node_type=NodeType.ANSWER,
|
||||
selector=["node-source", "text"],
|
||||
chunk="chunk-1",
|
||||
is_final=False,
|
||||
)
|
||||
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.AbortCommandHandler", _stub)
|
||||
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.WorkerPool", _stub)
|
||||
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.ExecutionCoordinator", _stub)
|
||||
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.Dispatcher", _stub)
|
||||
coordinator._stream_buffers = {("node-source", "text"): [event]}
|
||||
coordinator._stream_positions = {("node-source", "text"): 1}
|
||||
coordinator._closed_streams = {("node-source", "text")}
|
||||
|
||||
# Act
|
||||
engine = GraphEngine(
|
||||
workflow_id="wf-init",
|
||||
graph=dummy_graph, # type: ignore[arg-type]
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=MagicMock(),
|
||||
)
|
||||
serialized = coordinator.dumps()
|
||||
|
||||
# Assert
|
||||
assert engine._graph_execution.started is True
|
||||
assert engine._graph_execution.completed is True
|
||||
assert set(engine._graph_execution.node_executions) == {"serialized-node"}
|
||||
restored_node = engine._graph_execution.node_executions["serialized-node"]
|
||||
assert restored_node.state is NodeState.TAKEN
|
||||
assert restored_node.retry_count == 0
|
||||
restored = ResponseStreamCoordinator(variable_pool=MagicMock(), graph=graph) # type: ignore[arg-type]
|
||||
monkeypatch.setattr(ResponseSession, "from_node", classmethod(fake_from_node))
|
||||
restored.loads(serialized)
|
||||
|
||||
assert restored._response_nodes == {"response-1", "response-2", "response-3"}
|
||||
assert restored._paths_maps["response-1"][0].edges == ["edge-1"]
|
||||
assert restored._active_session is not None
|
||||
assert restored._active_session.node_id == "response-1"
|
||||
assert restored._active_session.index == 1
|
||||
waiting_restored = list(restored._waiting_sessions)
|
||||
assert len(waiting_restored) == 1
|
||||
assert waiting_restored[0].node_id == "response-2"
|
||||
assert waiting_restored[0].index == 0
|
||||
assert set(restored._response_sessions) == {"response-3"}
|
||||
assert restored._response_sessions["response-3"].index == 2
|
||||
assert restored._node_execution_ids == {"response-1": "exec-1"}
|
||||
assert ("node-source", "text") in restored._stream_buffers
|
||||
restored_event = restored._stream_buffers[("node-source", "text")][0]
|
||||
assert restored_event.chunk == "chunk-1"
|
||||
assert restored._stream_positions[("node-source", "text")] == 1
|
||||
assert ("node-source", "text") in restored._closed_streams
|
||||
|
|
|
|||
Loading…
Reference in New Issue