mirror of https://github.com/langgenius/dify.git
feat(graph_engine): support dumps and loads in GraphExecution
This commit is contained in:
parent
976b3b5e83
commit
02d15ebd5a
|
|
@ -16,6 +16,7 @@ class GraphRuntimeState(BaseModel):
|
|||
_outputs: dict[str, object] = PrivateAttr(default_factory=dict[str, object])
|
||||
_node_run_steps: int = PrivateAttr(default=0)
|
||||
_ready_queue_json: str = PrivateAttr()
|
||||
_graph_execution_json: str = PrivateAttr()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -27,6 +28,7 @@ class GraphRuntimeState(BaseModel):
|
|||
outputs: dict[str, object] | None = None,
|
||||
node_run_steps: int = 0,
|
||||
ready_queue_json: str = "",
|
||||
graph_execution_json: str = "",
|
||||
**kwargs: object,
|
||||
):
|
||||
"""Initialize the GraphRuntimeState with validation."""
|
||||
|
|
@ -54,6 +56,7 @@ class GraphRuntimeState(BaseModel):
|
|||
self._node_run_steps = node_run_steps
|
||||
|
||||
self._ready_queue_json = ready_queue_json
|
||||
self._graph_execution_json = graph_execution_json
|
||||
|
||||
@property
|
||||
def variable_pool(self) -> VariablePool:
|
||||
|
|
@ -142,3 +145,13 @@ class GraphRuntimeState(BaseModel):
|
|||
def ready_queue_json(self) -> str:
|
||||
"""Get a copy of the ready queue state."""
|
||||
return self._ready_queue_json
|
||||
|
||||
@property
|
||||
def graph_execution_json(self) -> str:
|
||||
"""Get a copy of the serialized graph execution state."""
|
||||
return self._graph_execution_json
|
||||
|
||||
@graph_execution_json.setter
|
||||
def graph_execution_json(self, value: str) -> None:
|
||||
"""Set the serialized graph execution state."""
|
||||
self._graph_execution_json = value
|
||||
|
|
|
|||
|
|
@ -1,12 +1,94 @@
|
|||
"""
|
||||
GraphExecution aggregate root managing the overall graph execution state.
|
||||
"""
|
||||
"""GraphExecution aggregate root managing the overall graph execution state."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from importlib import import_module
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.enums import NodeState
|
||||
|
||||
from .node_execution import NodeExecution
|
||||
|
||||
|
||||
class GraphExecutionErrorState(BaseModel):
|
||||
"""Serializable representation of an execution error."""
|
||||
|
||||
module: str = Field(description="Module containing the exception class")
|
||||
qualname: str = Field(description="Qualified name of the exception class")
|
||||
message: str | None = Field(default=None, description="Exception message string")
|
||||
|
||||
|
||||
class NodeExecutionState(BaseModel):
|
||||
"""Serializable representation of a node execution entity."""
|
||||
|
||||
node_id: str
|
||||
state: NodeState = Field(default=NodeState.UNKNOWN)
|
||||
retry_count: int = Field(default=0)
|
||||
execution_id: str | None = Field(default=None)
|
||||
error: str | None = Field(default=None)
|
||||
|
||||
|
||||
class GraphExecutionState(BaseModel):
|
||||
"""Pydantic model describing serialized GraphExecution state."""
|
||||
|
||||
type: Literal["GraphExecution"] = Field(default="GraphExecution")
|
||||
version: str = Field(default="1.0")
|
||||
workflow_id: str
|
||||
started: bool = Field(default=False)
|
||||
completed: bool = Field(default=False)
|
||||
aborted: bool = Field(default=False)
|
||||
error: GraphExecutionErrorState | None = Field(default=None)
|
||||
node_executions: list[NodeExecutionState] = Field(default_factory=list)
|
||||
|
||||
|
||||
def _serialize_error(error: Exception | None) -> GraphExecutionErrorState | None:
|
||||
"""Convert an exception into its serializable representation."""
|
||||
|
||||
if error is None:
|
||||
return None
|
||||
|
||||
return GraphExecutionErrorState(
|
||||
module=error.__class__.__module__,
|
||||
qualname=error.__class__.__qualname__,
|
||||
message=str(error),
|
||||
)
|
||||
|
||||
|
||||
def _resolve_exception_class(module_name: str, qualname: str) -> type[Exception]:
|
||||
"""Locate an exception class from its module and qualified name."""
|
||||
|
||||
module = import_module(module_name)
|
||||
attr: object = module
|
||||
for part in qualname.split("."):
|
||||
attr = getattr(attr, part)
|
||||
|
||||
if isinstance(attr, type) and issubclass(attr, Exception):
|
||||
return attr
|
||||
|
||||
raise TypeError(f"{qualname} in {module_name} is not an Exception subclass")
|
||||
|
||||
|
||||
def _deserialize_error(state: GraphExecutionErrorState | None) -> Exception | None:
|
||||
"""Reconstruct an exception instance from serialized data."""
|
||||
|
||||
if state is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
exception_class = _resolve_exception_class(state.module, state.qualname)
|
||||
if state.message is None:
|
||||
return exception_class()
|
||||
return exception_class(state.message)
|
||||
except Exception:
|
||||
# Fallback to RuntimeError when reconstruction fails
|
||||
if state.message is None:
|
||||
return RuntimeError(state.qualname)
|
||||
return RuntimeError(state.message)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphExecution:
|
||||
"""
|
||||
|
|
@ -69,3 +151,57 @@ class GraphExecution:
|
|||
if not self.error:
|
||||
return None
|
||||
return str(self.error)
|
||||
|
||||
def dumps(self) -> str:
|
||||
"""Serialize the aggregate state into a JSON string."""
|
||||
|
||||
node_states = [
|
||||
NodeExecutionState(
|
||||
node_id=node_id,
|
||||
state=node_execution.state,
|
||||
retry_count=node_execution.retry_count,
|
||||
execution_id=node_execution.execution_id,
|
||||
error=node_execution.error,
|
||||
)
|
||||
for node_id, node_execution in sorted(self.node_executions.items())
|
||||
]
|
||||
|
||||
state = GraphExecutionState(
|
||||
workflow_id=self.workflow_id,
|
||||
started=self.started,
|
||||
completed=self.completed,
|
||||
aborted=self.aborted,
|
||||
error=_serialize_error(self.error),
|
||||
node_executions=node_states,
|
||||
)
|
||||
|
||||
return state.model_dump_json()
|
||||
|
||||
def loads(self, data: str) -> None:
|
||||
"""Restore aggregate state from a serialized JSON string."""
|
||||
|
||||
state = GraphExecutionState.model_validate_json(data)
|
||||
|
||||
if state.type != "GraphExecution":
|
||||
raise ValueError(f"Invalid serialized data type: {state.type}")
|
||||
|
||||
if state.version != "1.0":
|
||||
raise ValueError(f"Unsupported serialized version: {state.version}")
|
||||
|
||||
if self.workflow_id != state.workflow_id:
|
||||
raise ValueError("Serialized workflow_id does not match aggregate identity")
|
||||
|
||||
self.started = state.started
|
||||
self.completed = state.completed
|
||||
self.aborted = state.aborted
|
||||
self.error = _deserialize_error(state.error)
|
||||
self.node_executions = {
|
||||
item.node_id: NodeExecution(
|
||||
node_id=item.node_id,
|
||||
state=item.state,
|
||||
retry_count=item.retry_count,
|
||||
execution_id=item.execution_id,
|
||||
error=item.error,
|
||||
)
|
||||
for item in state.node_executions
|
||||
}
|
||||
|
|
|
|||
|
|
@ -68,6 +68,8 @@ class GraphEngine:
|
|||
|
||||
# Graph execution tracks the overall execution state
|
||||
self._graph_execution = GraphExecution(workflow_id=workflow_id)
|
||||
if graph_runtime_state.graph_execution_json != "":
|
||||
self._graph_execution.loads(graph_runtime_state.graph_execution_json)
|
||||
|
||||
# === Core Dependencies ===
|
||||
# Graph structure and configuration
|
||||
|
|
|
|||
|
|
@ -0,0 +1,165 @@
|
|||
"""Unit tests for GraphExecution serialization helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.enums import NodeExecutionType, NodeState
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.domain import GraphExecution
|
||||
|
||||
|
||||
class CustomGraphExecutionError(Exception):
|
||||
"""Custom exception used to verify error serialization."""
|
||||
|
||||
|
||||
def test_graph_execution_serialization_round_trip() -> None:
|
||||
"""GraphExecution serialization restores full aggregate state."""
|
||||
# Arrange
|
||||
execution = GraphExecution(workflow_id="wf-1")
|
||||
execution.start()
|
||||
node_a = execution.get_or_create_node_execution("node-a")
|
||||
node_a.mark_started(execution_id="exec-1")
|
||||
node_a.increment_retry()
|
||||
node_a.mark_failed("boom")
|
||||
node_b = execution.get_or_create_node_execution("node-b")
|
||||
node_b.mark_skipped()
|
||||
execution.fail(CustomGraphExecutionError("serialization failure"))
|
||||
|
||||
# Act
|
||||
serialized = execution.dumps()
|
||||
payload = json.loads(serialized)
|
||||
restored = GraphExecution(workflow_id="wf-1")
|
||||
restored.loads(serialized)
|
||||
|
||||
# Assert
|
||||
assert payload["type"] == "GraphExecution"
|
||||
assert payload["version"] == "1.0"
|
||||
assert restored.workflow_id == "wf-1"
|
||||
assert restored.started is True
|
||||
assert restored.completed is True
|
||||
assert restored.aborted is False
|
||||
assert isinstance(restored.error, CustomGraphExecutionError)
|
||||
assert str(restored.error) == "serialization failure"
|
||||
assert set(restored.node_executions) == {"node-a", "node-b"}
|
||||
restored_node_a = restored.node_executions["node-a"]
|
||||
assert restored_node_a.state is NodeState.TAKEN
|
||||
assert restored_node_a.retry_count == 1
|
||||
assert restored_node_a.execution_id == "exec-1"
|
||||
assert restored_node_a.error == "boom"
|
||||
restored_node_b = restored.node_executions["node-b"]
|
||||
assert restored_node_b.state is NodeState.SKIPPED
|
||||
assert restored_node_b.retry_count == 0
|
||||
assert restored_node_b.execution_id is None
|
||||
assert restored_node_b.error is None
|
||||
|
||||
|
||||
def test_graph_execution_loads_replaces_existing_state() -> None:
|
||||
"""loads replaces existing runtime data with serialized snapshot."""
|
||||
# Arrange
|
||||
source = GraphExecution(workflow_id="wf-2")
|
||||
source.start()
|
||||
source_node = source.get_or_create_node_execution("node-source")
|
||||
source_node.mark_taken()
|
||||
serialized = source.dumps()
|
||||
|
||||
target = GraphExecution(workflow_id="wf-2")
|
||||
target.start()
|
||||
target.abort("pre-existing abort")
|
||||
temp_node = target.get_or_create_node_execution("node-temp")
|
||||
temp_node.increment_retry()
|
||||
temp_node.mark_failed("temp error")
|
||||
|
||||
# Act
|
||||
target.loads(serialized)
|
||||
|
||||
# Assert
|
||||
assert target.aborted is False
|
||||
assert target.error is None
|
||||
assert target.started is True
|
||||
assert target.completed is False
|
||||
assert set(target.node_executions) == {"node-source"}
|
||||
restored_node = target.node_executions["node-source"]
|
||||
assert restored_node.state is NodeState.TAKEN
|
||||
assert restored_node.retry_count == 0
|
||||
assert restored_node.execution_id is None
|
||||
assert restored_node.error is None
|
||||
|
||||
|
||||
def test_graph_engine_initializes_from_serialized_execution(monkeypatch) -> None:
|
||||
"""GraphEngine restores GraphExecution state from runtime snapshot on init."""
|
||||
|
||||
# Arrange serialized execution state
|
||||
execution = GraphExecution(workflow_id="wf-init")
|
||||
execution.start()
|
||||
node_state = execution.get_or_create_node_execution("serialized-node")
|
||||
node_state.mark_taken()
|
||||
execution.complete()
|
||||
serialized = execution.dumps()
|
||||
|
||||
runtime_state = GraphRuntimeState(
|
||||
variable_pool=MagicMock(),
|
||||
start_at=0.0,
|
||||
graph_execution_json=serialized,
|
||||
)
|
||||
|
||||
class DummyNode:
|
||||
def __init__(self, graph_runtime_state: GraphRuntimeState) -> None:
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self.execution_type = NodeExecutionType.EXECUTABLE
|
||||
self.id = "dummy-node"
|
||||
self.state = NodeState.UNKNOWN
|
||||
self.title = "dummy"
|
||||
|
||||
class DummyGraph:
|
||||
def __init__(self, graph_runtime_state: GraphRuntimeState) -> None:
|
||||
self.nodes = {"dummy-node": DummyNode(graph_runtime_state)}
|
||||
self.edges: dict[str, object] = {}
|
||||
self.root_node = self.nodes["dummy-node"]
|
||||
|
||||
def get_incoming_edges(self, node_id: str): # pragma: no cover - not exercised
|
||||
return []
|
||||
|
||||
def get_outgoing_edges(self, node_id: str): # pragma: no cover - not exercised
|
||||
return []
|
||||
|
||||
dummy_graph = DummyGraph(runtime_state)
|
||||
|
||||
def _stub(*_args, **_kwargs):
|
||||
return MagicMock()
|
||||
|
||||
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.GraphStateManager", _stub)
|
||||
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.ResponseStreamCoordinator", _stub)
|
||||
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.EventManager", _stub)
|
||||
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.ErrorHandler", _stub)
|
||||
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.SkipPropagator", _stub)
|
||||
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.EdgeProcessor", _stub)
|
||||
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.EventHandler", _stub)
|
||||
command_processor = MagicMock()
|
||||
command_processor.register_handler = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
"core.workflow.graph_engine.graph_engine.CommandProcessor",
|
||||
lambda *_args, **_kwargs: command_processor,
|
||||
)
|
||||
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.AbortCommandHandler", _stub)
|
||||
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.WorkerPool", _stub)
|
||||
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.ExecutionCoordinator", _stub)
|
||||
monkeypatch.setattr("core.workflow.graph_engine.graph_engine.Dispatcher", _stub)
|
||||
|
||||
# Act
|
||||
engine = GraphEngine(
|
||||
workflow_id="wf-init",
|
||||
graph=dummy_graph, # type: ignore[arg-type]
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=MagicMock(),
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert engine._graph_execution.started is True
|
||||
assert engine._graph_execution.completed is True
|
||||
assert set(engine._graph_execution.node_executions) == {"serialized-node"}
|
||||
restored_node = engine._graph_execution.node_executions["serialized-node"]
|
||||
assert restored_node.state is NodeState.TAKEN
|
||||
assert restored_node.retry_count == 0
|
||||
Loading…
Reference in New Issue