feat(graph_engine): add ready_queue state persistence to GraphRuntimeState

- Add ReadyQueueState TypedDict for type-safe queue serialization
- Add ready_queue attribute to GraphRuntimeState for initializing with pre-existing queue state
- Update GraphEngine to load ready_queue from GraphRuntimeState on initialization
- Implement proper type hints using ReadyQueueState for better type safety
- Add comprehensive tests for ready_queue loading functionality

The ready_queue is read-only after initialization and allows resuming workflow
execution with a pre-populated queue of nodes ready to execute.
This commit is contained in:
-LAN- 2025-09-15 03:05:10 +08:00
parent 0f15a2baca
commit b4ef1de30f
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF
7 changed files with 159 additions and 16 deletions

View File

@ -1,5 +1,5 @@
from copy import deepcopy
from typing import Any
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel, PrivateAttr
@ -7,6 +7,9 @@ from core.model_runtime.entities.llm_entities import LLMUsage
from .variable_pool import VariablePool
if TYPE_CHECKING:
from core.workflow.graph_engine.ready_queue import ReadyQueueState
class GraphRuntimeState(BaseModel):
# Private attributes to prevent direct modification
@ -16,6 +19,7 @@ class GraphRuntimeState(BaseModel):
_llm_usage: LLMUsage = PrivateAttr(default_factory=LLMUsage.empty_usage)
_outputs: dict[str, Any] = PrivateAttr(default_factory=dict)
_node_run_steps: int = PrivateAttr(default=0)
_ready_queue: "ReadyQueueState | dict[str, object]" = PrivateAttr(default_factory=dict)
def __init__(
self,
@ -25,6 +29,7 @@ class GraphRuntimeState(BaseModel):
llm_usage: LLMUsage | None = None,
outputs: dict[str, Any] | None = None,
node_run_steps: int = 0,
ready_queue: "ReadyQueueState | dict[str, object] | None" = None,
**kwargs: object,
):
"""Initialize the GraphRuntimeState with validation."""
@ -51,6 +56,10 @@ class GraphRuntimeState(BaseModel):
raise ValueError("node_run_steps must be non-negative")
self._node_run_steps = node_run_steps
if ready_queue is None:
ready_queue = {}
self._ready_queue = deepcopy(ready_queue)
@property
def variable_pool(self) -> VariablePool:
"""Get the variable pool."""
@ -133,3 +142,8 @@ class GraphRuntimeState(BaseModel):
if tokens < 0:
raise ValueError("tokens must be non-negative")
self._total_tokens += tokens
@property
def ready_queue(self) -> "ReadyQueueState | dict[str, object]":
"""Get a copy of the ready queue state."""
return deepcopy(self._ready_queue)

View File

@ -106,6 +106,16 @@ class GraphEngine:
# === Execution Queues ===
# Queue for nodes ready to execute
self._ready_queue = InMemoryReadyQueue()
# Load ready queue state from GraphRuntimeState if not empty
ready_queue_state = self._graph_runtime_state.ready_queue
if ready_queue_state:
# Import ReadyQueueState here to avoid circular imports
from .ready_queue import ReadyQueueState
# Ensure we have a ReadyQueueState object
if isinstance(ready_queue_state, dict):
ready_queue_state = ReadyQueueState(**ready_queue_state) # type: ignore
self._ready_queue.loads(ready_queue_state)
# Queue for events generated during execution
self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()

View File

@ -6,6 +6,6 @@ the queue of nodes ready for execution.
"""
from .in_memory import InMemoryReadyQueue
from .protocol import ReadyQueue
from .protocol import ReadyQueue, ReadyQueueState
__all__ = ["InMemoryReadyQueue", "ReadyQueue"]
__all__ = ["InMemoryReadyQueue", "ReadyQueue", "ReadyQueueState"]

View File

@ -8,6 +8,8 @@ serialization capabilities for state storage.
import queue
from typing import final
from .protocol import ReadyQueueState
@final
class InMemoryReadyQueue:
@ -80,12 +82,12 @@ class InMemoryReadyQueue:
"""
return self._queue.qsize()
def dumps(self) -> dict[str, object]:
def dumps(self) -> ReadyQueueState:
"""
Serialize the queue state for storage.
Returns:
A dictionary containing the serialized queue state
A ReadyQueueState dictionary containing the serialized queue state
"""
# Extract all items from the queue without removing them
items: list[str] = []
@ -104,14 +106,14 @@ class InMemoryReadyQueue:
for item in temp_items:
self._queue.put(item)
return {
"type": "InMemoryReadyQueue",
"version": "1.0",
"items": items,
"maxsize": self._queue.maxsize,
}
return ReadyQueueState(
type="InMemoryReadyQueue",
version="1.0",
items=items,
maxsize=self._queue.maxsize,
)
def loads(self, data: dict[str, object]) -> None:
def loads(self, data: ReadyQueueState) -> None:
"""
Restore the queue state from serialized data.

View File

@ -5,7 +5,21 @@ This protocol defines the interface for managing the queue of nodes ready
for execution, supporting both in-memory and persistent storage scenarios.
"""
from typing import Protocol
from typing import Protocol, TypedDict
class ReadyQueueState(TypedDict):
"""
TypedDict for serialized ready queue state.
This defines the structure of the dictionary returned by dumps()
and expected by loads() for ready queue serialization.
"""
type: str # Queue implementation type (e.g., "InMemoryReadyQueue")
version: str # Serialization format version
items: list[str] # List of node IDs in the queue
maxsize: int # Maximum queue size (0 for unlimited)
class ReadyQueue(Protocol):
@ -68,17 +82,17 @@ class ReadyQueue(Protocol):
"""
...
def dumps(self) -> dict[str, object]:
def dumps(self) -> ReadyQueueState:
"""
Serialize the queue state for storage.
Returns:
A dictionary containing the serialized queue state
A ReadyQueueState dictionary containing the serialized queue state
that can be persisted and later restored
"""
...
def loads(self, data: dict[str, object]) -> None:
def loads(self, data: ReadyQueueState) -> None:
"""
Restore the queue state from serialized data.

View File

@ -4,6 +4,7 @@ import pytest
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.ready_queue import ReadyQueueState
class TestGraphRuntimeState:
@ -109,3 +110,30 @@ class TestGraphRuntimeState:
# Original should remain unchanged
assert state.get_output("nested")["level1"]["level2"]["value"] == "test"
def test_ready_queue_property(self):
variable_pool = VariablePool()
# Test default empty ready_queue
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
assert state.ready_queue == {}
# Test initialization with ready_queue data as ReadyQueueState
queue_data = ReadyQueueState(type="InMemoryReadyQueue", version="1.0", items=["node1", "node2"], maxsize=0)
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time(), ready_queue=queue_data)
assert state.ready_queue == queue_data
# Test with different ready_queue data at initialization
another_queue_data = ReadyQueueState(
type="InMemoryReadyQueue",
version="1.0",
items=["node3", "node4", "node5"],
maxsize=0,
)
another_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time(), ready_queue=another_queue_data)
assert another_state.ready_queue == another_queue_data
# Test immutability - modifying retrieved queue doesn't affect internal state
retrieved_queue = state.ready_queue
retrieved_queue["items"].append("node6")
assert len(state.ready_queue["items"]) == 2 # Should still be 2, not 3

View File

@ -744,3 +744,78 @@ def test_event_sequence_validation_with_table_tests():
else:
assert result.event_sequence_match is True
assert result.success, f"Test {i + 1} failed: {result.event_mismatch_details or result.error}"
def test_ready_queue_state_loading():
"""
Test that the ready_queue state is properly loaded from GraphRuntimeState
during GraphEngine initialization.
"""
# Use the TableTestRunner to create a proper workflow instance
runner = TableTestRunner()
# Create a simple workflow
test_case = WorkflowTestCase(
fixture_path="simple_passthrough_workflow",
inputs={"query": "test"},
expected_outputs={"query": "test"},
description="Test ready_queue loading",
)
# Load the workflow fixture
workflow_runner = runner.workflow_runner
fixture_data = workflow_runner.load_fixture("simple_passthrough_workflow")
# Create graph and runtime state with pre-populated ready_queue
ready_queue_data = {
"type": "InMemoryReadyQueue",
"version": "1.0",
"items": ["node1", "node2", "node3"],
"maxsize": 0,
}
# We need to create the graph first, then create a new GraphRuntimeState with ready_queue
graph, original_runtime_state = workflow_runner.create_graph_from_fixture(fixture_data, query="test")
# Create a new GraphRuntimeState with the ready_queue data
from core.workflow.entities import GraphRuntimeState
from core.workflow.graph_engine.ready_queue import ReadyQueueState
# Convert ready_queue_data to ReadyQueueState
ready_queue_state = ReadyQueueState(**ready_queue_data)
graph_runtime_state = GraphRuntimeState(
variable_pool=original_runtime_state.variable_pool,
start_at=original_runtime_state.start_at,
ready_queue=ready_queue_state,
)
# Update all nodes to use the new GraphRuntimeState
for node in graph.nodes.values():
node.graph_runtime_state = graph_runtime_state
# Create GraphEngine
command_channel = InMemoryChannel()
engine = GraphEngine(
tenant_id="test-tenant",
app_id="test-app",
workflow_id="test-workflow",
user_id="test-user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
graph=graph,
graph_config={},
graph_runtime_state=graph_runtime_state,
command_channel=command_channel,
)
# Verify that the ready_queue was loaded from GraphRuntimeState
assert engine._ready_queue.qsize() == 3
# Verify the initial state matches what was provided
initial_queue_state = engine.graph_runtime_state.ready_queue
assert initial_queue_state["type"] == "InMemoryReadyQueue"
assert initial_queue_state["version"] == "1.0"
assert len(initial_queue_state["items"]) == 3
assert initial_queue_state["items"] == ["node1", "node2", "node3"]