feat(graph_engine): dump and load ready queue

This commit is contained in:
-LAN- 2025-09-16 03:00:15 +08:00
parent d5342927d0
commit da87fce751
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF
10 changed files with 89 additions and 182 deletions

View File

@ -31,7 +31,6 @@ ignore_imports =
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine
core.workflow.nodes.loop.loop_node -> core.workflow.graph
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels
core.workflow.entities.graph_runtime_state -> core.workflow.graph_engine.ready_queue
[importlinter:contract:rsc]
name = RSC

View File

@ -1,5 +1,4 @@
from copy import deepcopy
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel, PrivateAttr
@ -7,9 +6,6 @@ from core.model_runtime.entities.llm_entities import LLMUsage
from .variable_pool import VariablePool
if TYPE_CHECKING:
from core.workflow.graph_engine.ready_queue import ReadyQueueState
class GraphRuntimeState(BaseModel):
# Private attributes to prevent direct modification
@ -19,17 +15,18 @@ class GraphRuntimeState(BaseModel):
_llm_usage: LLMUsage = PrivateAttr(default_factory=LLMUsage.empty_usage)
_outputs: dict[str, object] = PrivateAttr(default_factory=dict[str, object])
_node_run_steps: int = PrivateAttr(default=0)
_ready_queue: "ReadyQueueState | dict[str, object]" = PrivateAttr(default_factory=dict)
_ready_queue_json: str = PrivateAttr()
def __init__(
self,
*,
variable_pool: VariablePool,
start_at: float,
total_tokens: int = 0,
llm_usage: LLMUsage | None = None,
outputs: dict[str, Any] | None = None,
outputs: dict[str, object] | None = None,
node_run_steps: int = 0,
ready_queue: "ReadyQueueState | dict[str, object] | None" = None,
ready_queue_json: str = "",
**kwargs: object,
):
"""Initialize the GraphRuntimeState with validation."""
@ -56,9 +53,7 @@ class GraphRuntimeState(BaseModel):
raise ValueError("node_run_steps must be non-negative")
self._node_run_steps = node_run_steps
if ready_queue is None:
ready_queue = {}
self._ready_queue = deepcopy(ready_queue)
self._ready_queue_json = ready_queue_json
@property
def variable_pool(self) -> VariablePool:
@ -99,24 +94,24 @@ class GraphRuntimeState(BaseModel):
self._llm_usage = value.model_copy()
@property
def outputs(self) -> dict[str, Any]:
def outputs(self) -> dict[str, object]:
"""Get a copy of the outputs dictionary."""
return deepcopy(self._outputs)
@outputs.setter
def outputs(self, value: dict[str, Any]) -> None:
def outputs(self, value: dict[str, object]) -> None:
"""Set the outputs dictionary."""
self._outputs = deepcopy(value)
def set_output(self, key: str, value: Any) -> None:
def set_output(self, key: str, value: object) -> None:
"""Set a single output value."""
self._outputs[key] = deepcopy(value)
def get_output(self, key: str, default: Any = None) -> Any:
def get_output(self, key: str, default: object = None) -> object:
"""Get a single output value."""
return deepcopy(self._outputs.get(key, default))
def update_outputs(self, updates: dict[str, Any]) -> None:
def update_outputs(self, updates: dict[str, object]) -> None:
"""Update multiple output values."""
for key, value in updates.items():
self._outputs[key] = deepcopy(value)
@ -144,6 +139,6 @@ class GraphRuntimeState(BaseModel):
self._total_tokens += tokens
@property
def ready_queue(self) -> "ReadyQueueState | dict[str, object]":
def ready_queue_json(self) -> str:
"""Get a copy of the ready queue state."""
return deepcopy(self._ready_queue)
return self._ready_queue_json

View File

@ -18,6 +18,7 @@ from core.workflow.entities import GraphRuntimeState
from core.workflow.enums import NodeExecutionType
from core.workflow.graph import Graph
from core.workflow.graph.read_only_state_wrapper import ReadOnlyGraphRuntimeStateWrapper
from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue
from core.workflow.graph_events import (
GraphEngineEvent,
GraphNodeEventBase,
@ -38,7 +39,7 @@ from .graph_traversal import EdgeProcessor, SkipPropagator
from .layers.base import GraphEngineLayer
from .orchestration import Dispatcher, ExecutionCoordinator
from .protocols.command_channel import CommandChannel
from .ready_queue import InMemoryReadyQueue
from .ready_queue import ReadyQueueState, create_ready_queue_from_state
from .response_coordinator import ResponseStreamCoordinator
from .worker_management import WorkerPool
@ -104,18 +105,13 @@ class GraphEngine:
self._scale_down_idle_time = scale_down_idle_time
# === Execution Queues ===
# Queue for nodes ready to execute
self._ready_queue = InMemoryReadyQueue()
# Load ready queue state from GraphRuntimeState if not empty
ready_queue_state = self._graph_runtime_state.ready_queue
if ready_queue_state:
# Import ReadyQueueState here to avoid circular imports
from .ready_queue import ReadyQueueState
# Create ready queue from saved state or initialize new one
if self._graph_runtime_state.ready_queue_json == "":
self._ready_queue = InMemoryReadyQueue()
else:
ready_queue_state = ReadyQueueState.model_validate_json(self._graph_runtime_state.ready_queue_json)
self._ready_queue = create_ready_queue_from_state(ready_queue_state)
# Ensure we have a ReadyQueueState object
if isinstance(ready_queue_state, dict):
ready_queue_state = ReadyQueueState(**ready_queue_state) # type: ignore
self._ready_queue.loads(ready_queue_state)
# Queue for events generated during execution
self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()

View File

@ -5,7 +5,8 @@ This package contains the protocol and implementations for managing
the queue of nodes ready for execution.
"""
from .factory import create_ready_queue_from_state
from .in_memory import InMemoryReadyQueue
from .protocol import ReadyQueue, ReadyQueueState
__all__ = ["InMemoryReadyQueue", "ReadyQueue", "ReadyQueueState"]
__all__ = ["InMemoryReadyQueue", "ReadyQueue", "ReadyQueueState", "create_ready_queue_from_state"]

View File

@ -0,0 +1,35 @@
"""
Factory for creating ReadyQueue instances from serialized state.
"""
from typing import TYPE_CHECKING
from .in_memory import InMemoryReadyQueue
from .protocol import ReadyQueueState
if TYPE_CHECKING:
from .protocol import ReadyQueue
def create_ready_queue_from_state(state: ReadyQueueState) -> "ReadyQueue":
"""
Create a ReadyQueue instance from a serialized state.
Args:
state: The serialized queue state (Pydantic model, dict, or JSON string), or None for a new empty queue
Returns:
A ReadyQueue instance initialized with the given state
Raises:
ValueError: If the queue type is unknown or version is unsupported
"""
if state.type == "InMemoryReadyQueue":
if state.version != "1.0":
raise ValueError(f"Unsupported InMemoryReadyQueue version: {state.version}")
queue = InMemoryReadyQueue()
# Always pass as JSON string to loads()
queue.loads(state.model_dump_json())
return queue
else:
raise ValueError(f"Unknown ready queue type: {state.type}")

View File

@ -82,12 +82,12 @@ class InMemoryReadyQueue:
"""
return self._queue.qsize()
def dumps(self) -> ReadyQueueState:
def dumps(self) -> str:
"""
Serialize the queue state for storage.
Serialize the queue state to a JSON string for storage.
Returns:
A ReadyQueueState dictionary containing the serialized queue state
A JSON string containing the serialized queue state
"""
# Extract all items from the queue without removing them
items: list[str] = []
@ -106,25 +106,27 @@ class InMemoryReadyQueue:
for item in temp_items:
self._queue.put(item)
return ReadyQueueState(
state = ReadyQueueState(
type="InMemoryReadyQueue",
version="1.0",
items=items,
maxsize=self._queue.maxsize,
)
return state.model_dump_json()
def loads(self, data: ReadyQueueState) -> None:
def loads(self, data: str) -> None:
"""
Restore the queue state from serialized data.
Restore the queue state from a JSON string.
Args:
data: The serialized queue state to restore
data: The JSON string containing the serialized queue state to restore
"""
if data.get("type") != "InMemoryReadyQueue":
raise ValueError(f"Invalid serialized data type: {data.get('type')}")
state = ReadyQueueState.model_validate_json(data)
if data.get("version") != "1.0":
raise ValueError(f"Unsupported version: {data.get('version')}")
if state.type != "InMemoryReadyQueue":
raise ValueError(f"Invalid serialized data type: {state.type}")
if state.version != "1.0":
raise ValueError(f"Unsupported version: {state.version}")
# Clear the current queue
while not self._queue.empty():
@ -134,11 +136,5 @@ class InMemoryReadyQueue:
break
# Restore items
items = data.get("items", [])
if not isinstance(items, list):
raise ValueError("Invalid items data: expected list")
for item in items:
if not isinstance(item, str):
raise ValueError(f"Invalid item type: expected str, got {type(item).__name__}")
for item in state.items:
self._queue.put(item)

View File

@ -5,21 +5,23 @@ This protocol defines the interface for managing the queue of nodes ready
for execution, supporting both in-memory and persistent storage scenarios.
"""
from typing import Protocol, TypedDict
from collections.abc import Sequence
from typing import Protocol
from pydantic import BaseModel, Field
class ReadyQueueState(TypedDict):
class ReadyQueueState(BaseModel):
"""
TypedDict for serialized ready queue state.
Pydantic model for serialized ready queue state.
This defines the structure of the dictionary returned by dumps()
This defines the structure of the data returned by dumps()
and expected by loads() for ready queue serialization.
"""
type: str # Queue implementation type (e.g., "InMemoryReadyQueue")
version: str # Serialization format version
items: list[str] # List of node IDs in the queue
maxsize: int # Maximum queue size (0 for unlimited)
type: str = Field(description="Queue implementation type (e.g., 'InMemoryReadyQueue')")
version: str = Field(description="Serialization format version")
items: Sequence[str] = Field(default_factory=list, description="List of node IDs in the queue")
class ReadyQueue(Protocol):
@ -82,21 +84,21 @@ class ReadyQueue(Protocol):
"""
...
def dumps(self) -> ReadyQueueState:
def dumps(self) -> str:
"""
Serialize the queue state for storage.
Serialize the queue state to a JSON string for storage.
Returns:
A ReadyQueueState dictionary containing the serialized queue state
A JSON string containing the serialized queue state
that can be persisted and later restored
"""
...
def loads(self, data: ReadyQueueState) -> None:
def loads(self, data: str) -> None:
"""
Restore the queue state from serialized data.
Restore the queue state from a JSON string.
Args:
data: The serialized queue state to restore
data: The JSON string containing the serialized queue state to restore
"""
...

View File

@ -19,7 +19,7 @@ class Path:
Note: This is an internal class not exposed in the public API.
"""
edges: list[EdgeID] = field(default_factory=list)
edges: list[EdgeID] = field(default_factory=list[EdgeID])
def contains_edge(self, edge_id: EdgeID) -> bool:
"""Check if this path contains the given edge."""

View File

@ -4,7 +4,6 @@ import pytest
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.ready_queue import ReadyQueueState
class TestGraphRuntimeState:
@ -96,44 +95,3 @@ class TestGraphRuntimeState:
# Test add_tokens validation
with pytest.raises(ValueError):
state.add_tokens(-1)
def test_deep_copy_for_nested_objects(self):
variable_pool = VariablePool()
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
# Test deep copy for nested dict
nested_data = {"level1": {"level2": {"value": "test"}}}
state.set_output("nested", nested_data)
retrieved = state.get_output("nested")
retrieved["level1"]["level2"]["value"] = "modified"
# Original should remain unchanged
assert state.get_output("nested")["level1"]["level2"]["value"] == "test"
def test_ready_queue_property(self):
variable_pool = VariablePool()
# Test default empty ready_queue
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
assert state.ready_queue == {}
# Test initialization with ready_queue data as ReadyQueueState
queue_data = ReadyQueueState(type="InMemoryReadyQueue", version="1.0", items=["node1", "node2"], maxsize=0)
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time(), ready_queue=queue_data)
assert state.ready_queue == queue_data
# Test with different ready_queue data at initialization
another_queue_data = ReadyQueueState(
type="InMemoryReadyQueue",
version="1.0",
items=["node3", "node4", "node5"],
maxsize=0,
)
another_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time(), ready_queue=another_queue_data)
assert another_state.ready_queue == another_queue_data
# Test immutability - modifying retrieved queue doesn't affect internal state
retrieved_queue = state.ready_queue
retrieved_queue["items"].append("node6")
assert len(state.ready_queue["items"]) == 2 # Should still be 2, not 3

View File

@ -744,78 +744,3 @@ def test_event_sequence_validation_with_table_tests():
else:
assert result.event_sequence_match is True
assert result.success, f"Test {i + 1} failed: {result.event_mismatch_details or result.error}"
def test_ready_queue_state_loading():
"""
Test that the ready_queue state is properly loaded from GraphRuntimeState
during GraphEngine initialization.
"""
# Use the TableTestRunner to create a proper workflow instance
runner = TableTestRunner()
# Create a simple workflow
test_case = WorkflowTestCase(
fixture_path="simple_passthrough_workflow",
inputs={"query": "test"},
expected_outputs={"query": "test"},
description="Test ready_queue loading",
)
# Load the workflow fixture
workflow_runner = runner.workflow_runner
fixture_data = workflow_runner.load_fixture("simple_passthrough_workflow")
# Create graph and runtime state with pre-populated ready_queue
ready_queue_data = {
"type": "InMemoryReadyQueue",
"version": "1.0",
"items": ["node1", "node2", "node3"],
"maxsize": 0,
}
# We need to create the graph first, then create a new GraphRuntimeState with ready_queue
graph, original_runtime_state = workflow_runner.create_graph_from_fixture(fixture_data, query="test")
# Create a new GraphRuntimeState with the ready_queue data
from core.workflow.entities import GraphRuntimeState
from core.workflow.graph_engine.ready_queue import ReadyQueueState
# Convert ready_queue_data to ReadyQueueState
ready_queue_state = ReadyQueueState(**ready_queue_data)
graph_runtime_state = GraphRuntimeState(
variable_pool=original_runtime_state.variable_pool,
start_at=original_runtime_state.start_at,
ready_queue=ready_queue_state,
)
# Update all nodes to use the new GraphRuntimeState
for node in graph.nodes.values():
node.graph_runtime_state = graph_runtime_state
# Create GraphEngine
command_channel = InMemoryChannel()
engine = GraphEngine(
tenant_id="test-tenant",
app_id="test-app",
workflow_id="test-workflow",
user_id="test-user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
graph=graph,
graph_config={},
graph_runtime_state=graph_runtime_state,
command_channel=command_channel,
)
# Verify that the ready_queue was loaded from GraphRuntimeState
assert engine._ready_queue.qsize() == 3
# Verify the initial state matches what was provided
initial_queue_state = engine.graph_runtime_state.ready_queue
assert initial_queue_state["type"] == "InMemoryReadyQueue"
assert initial_queue_state["version"] == "1.0"
assert len(initial_queue_state["items"]) == 3
assert initial_queue_state["items"] == ["node1", "node2", "node3"]