From b4ef1de30fcdbea9700e0bc8a2ae2474ea7ccbda Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 15 Sep 2025 03:05:10 +0800 Subject: [PATCH] feat(graph_engine): add ready_queue state persistence to GraphRuntimeState - Add ReadyQueueState TypedDict for type-safe queue serialization - Add ready_queue attribute to GraphRuntimeState for initializing with pre-existing queue state - Update GraphEngine to load ready_queue from GraphRuntimeState on initialization - Implement proper type hints using ReadyQueueState for better type safety - Add comprehensive tests for ready_queue loading functionality The ready_queue is read-only after initialization and allows resuming workflow execution with a pre-populated queue of nodes ready to execute. --- .../workflow/entities/graph_runtime_state.py | 16 +++- .../workflow/graph_engine/graph_engine.py | 10 +++ .../graph_engine/ready_queue/__init__.py | 4 +- .../graph_engine/ready_queue/in_memory.py | 20 ++--- .../graph_engine/ready_queue/protocol.py | 22 +++++- .../entities/test_graph_runtime_state.py | 28 +++++++ .../graph_engine/test_graph_engine.py | 75 +++++++++++++++++++ 7 files changed, 159 insertions(+), 16 deletions(-) diff --git a/api/core/workflow/entities/graph_runtime_state.py b/api/core/workflow/entities/graph_runtime_state.py index c06a62d1e7..c9ec426167 100644 --- a/api/core/workflow/entities/graph_runtime_state.py +++ b/api/core/workflow/entities/graph_runtime_state.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import Any +from typing import TYPE_CHECKING, Any from pydantic import BaseModel, PrivateAttr @@ -7,6 +7,9 @@ from core.model_runtime.entities.llm_entities import LLMUsage from .variable_pool import VariablePool +if TYPE_CHECKING: + from core.workflow.graph_engine.ready_queue import ReadyQueueState + class GraphRuntimeState(BaseModel): # Private attributes to prevent direct modification @@ -16,6 +19,7 @@ class GraphRuntimeState(BaseModel): _llm_usage: LLMUsage = PrivateAttr(default_factory=LLMUsage.empty_usage) _outputs: dict[str, Any] = PrivateAttr(default_factory=dict) _node_run_steps: int = PrivateAttr(default=0) + _ready_queue: "ReadyQueueState | dict[str, object]" = PrivateAttr(default_factory=dict) def __init__( self, @@ -25,6 +29,7 @@ class GraphRuntimeState(BaseModel): llm_usage: LLMUsage | None = None, outputs: dict[str, Any] | None = None, node_run_steps: int = 0, + ready_queue: "ReadyQueueState | dict[str, object] | None" = None, **kwargs: object, ): """Initialize the GraphRuntimeState with validation.""" @@ -51,6 +56,10 @@ class GraphRuntimeState(BaseModel): raise ValueError("node_run_steps must be non-negative") self._node_run_steps = node_run_steps + if ready_queue is None: + ready_queue = {} + self._ready_queue = deepcopy(ready_queue) + @property def variable_pool(self) -> VariablePool: """Get the variable pool.""" @@ -133,3 +142,8 @@ class GraphRuntimeState(BaseModel): if tokens < 0: raise ValueError("tokens must be non-negative") self._total_tokens += tokens + + @property + def ready_queue(self) -> "ReadyQueueState | dict[str, object]": + """Get a copy of the ready queue state.""" + return deepcopy(self._ready_queue) diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 6e58d19fd6..a7b582d803 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -106,6 +106,16 @@ class GraphEngine: # === Execution Queues === # Queue for nodes ready to execute self._ready_queue = InMemoryReadyQueue() + # Load ready queue state from GraphRuntimeState if not empty + ready_queue_state = self._graph_runtime_state.ready_queue + if ready_queue_state: + # Import ReadyQueueState here to avoid circular imports + from .ready_queue import ReadyQueueState + + # Ensure we have a ReadyQueueState object + if isinstance(ready_queue_state, dict): + ready_queue_state = ReadyQueueState(**ready_queue_state) # type: ignore + self._ready_queue.loads(ready_queue_state) # Queue for events generated during execution self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue() diff --git a/api/core/workflow/graph_engine/ready_queue/__init__.py b/api/core/workflow/graph_engine/ready_queue/__init__.py index 9b890880f5..448abda286 100644 --- a/api/core/workflow/graph_engine/ready_queue/__init__.py +++ b/api/core/workflow/graph_engine/ready_queue/__init__.py @@ -6,6 +6,6 @@ the queue of nodes ready for execution. """ from .in_memory import InMemoryReadyQueue -from .protocol import ReadyQueue +from .protocol import ReadyQueue, ReadyQueueState -__all__ = ["InMemoryReadyQueue", "ReadyQueue"] +__all__ = ["InMemoryReadyQueue", "ReadyQueue", "ReadyQueueState"] diff --git a/api/core/workflow/graph_engine/ready_queue/in_memory.py b/api/core/workflow/graph_engine/ready_queue/in_memory.py index 90df9a0096..c3cfbb00ad 100644 --- a/api/core/workflow/graph_engine/ready_queue/in_memory.py +++ b/api/core/workflow/graph_engine/ready_queue/in_memory.py @@ -8,6 +8,8 @@ serialization capabilities for state storage. import queue from typing import final +from .protocol import ReadyQueueState + @final class InMemoryReadyQueue: @@ -80,12 +82,12 @@ class InMemoryReadyQueue: """ return self._queue.qsize() - def dumps(self) -> dict[str, object]: + def dumps(self) -> ReadyQueueState: """ Serialize the queue state for storage. Returns: - A dictionary containing the serialized queue state + A ReadyQueueState dictionary containing the serialized queue state """ # Extract all items from the queue without removing them items: list[str] = [] @@ -104,14 +106,14 @@ class InMemoryReadyQueue: for item in temp_items: self._queue.put(item) - return { - "type": "InMemoryReadyQueue", - "version": "1.0", - "items": items, - "maxsize": self._queue.maxsize, - } + return ReadyQueueState( + type="InMemoryReadyQueue", + version="1.0", + items=items, + maxsize=self._queue.maxsize, + ) - def loads(self, data: dict[str, object]) -> None: + def loads(self, data: ReadyQueueState) -> None: """ Restore the queue state from serialized data. diff --git a/api/core/workflow/graph_engine/ready_queue/protocol.py b/api/core/workflow/graph_engine/ready_queue/protocol.py index 0e457bcf05..d0f66d2955 100644 --- a/api/core/workflow/graph_engine/ready_queue/protocol.py +++ b/api/core/workflow/graph_engine/ready_queue/protocol.py @@ -5,7 +5,21 @@ This protocol defines the interface for managing the queue of nodes ready for execution, supporting both in-memory and persistent storage scenarios. """ -from typing import Protocol +from typing import Protocol, TypedDict + + +class ReadyQueueState(TypedDict): + """ + TypedDict for serialized ready queue state. + + This defines the structure of the dictionary returned by dumps() + and expected by loads() for ready queue serialization. + """ + + type: str # Queue implementation type (e.g., "InMemoryReadyQueue") + version: str # Serialization format version + items: list[str] # List of node IDs in the queue + maxsize: int # Maximum queue size (0 for unlimited) class ReadyQueue(Protocol): @@ -68,17 +82,17 @@ class ReadyQueue(Protocol): """ ... - def dumps(self) -> dict[str, object]: + def dumps(self) -> ReadyQueueState: """ Serialize the queue state for storage. Returns: - A dictionary containing the serialized queue state + A ReadyQueueState dictionary containing the serialized queue state that can be persisted and later restored """ ... - def loads(self, data: dict[str, object]) -> None: + def loads(self, data: ReadyQueueState) -> None: """ Restore the queue state from serialized data. diff --git a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py index 4d8483ce0d..067b8d8186 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py +++ b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py @@ -4,6 +4,7 @@ import pytest from core.workflow.entities.graph_runtime_state import GraphRuntimeState from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine.ready_queue import ReadyQueueState class TestGraphRuntimeState: @@ -109,3 +110,30 @@ class TestGraphRuntimeState: # Original should remain unchanged assert state.get_output("nested")["level1"]["level2"]["value"] == "test" + + def test_ready_queue_property(self): + variable_pool = VariablePool() + + # Test default empty ready_queue + state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) + assert state.ready_queue == {} + + # Test initialization with ready_queue data as ReadyQueueState + queue_data = ReadyQueueState(type="InMemoryReadyQueue", version="1.0", items=["node1", "node2"], maxsize=0) + state = GraphRuntimeState(variable_pool=variable_pool, start_at=time(), ready_queue=queue_data) + assert state.ready_queue == queue_data + + # Test with different ready_queue data at initialization + another_queue_data = ReadyQueueState( + type="InMemoryReadyQueue", + version="1.0", + items=["node3", "node4", "node5"], + maxsize=0, + ) + another_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time(), ready_queue=another_queue_data) + assert another_state.ready_queue == another_queue_data + + # Test immutability - modifying retrieved queue doesn't affect internal state + retrieved_queue = state.ready_queue + retrieved_queue["items"].append("node6") + assert len(state.ready_queue["items"]) == 2 # Should still be 2, not 3 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index 4aa33bde26..f03c19ab1c 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -744,3 +744,78 @@ def test_event_sequence_validation_with_table_tests(): else: assert result.event_sequence_match is True assert result.success, f"Test {i + 1} failed: {result.event_mismatch_details or result.error}" + + +def test_ready_queue_state_loading(): + """ + Test that the ready_queue state is properly loaded from GraphRuntimeState + during GraphEngine initialization. + """ + # Use the TableTestRunner to create a proper workflow instance + runner = TableTestRunner() + + # Create a simple workflow + test_case = WorkflowTestCase( + fixture_path="simple_passthrough_workflow", + inputs={"query": "test"}, + expected_outputs={"query": "test"}, + description="Test ready_queue loading", + ) + + # Load the workflow fixture + workflow_runner = runner.workflow_runner + fixture_data = workflow_runner.load_fixture("simple_passthrough_workflow") + + # Create graph and runtime state with pre-populated ready_queue + ready_queue_data = { + "type": "InMemoryReadyQueue", + "version": "1.0", + "items": ["node1", "node2", "node3"], + "maxsize": 0, + } + + # We need to create the graph first, then create a new GraphRuntimeState with ready_queue + graph, original_runtime_state = workflow_runner.create_graph_from_fixture(fixture_data, query="test") + + # Create a new GraphRuntimeState with the ready_queue data + from core.workflow.entities import GraphRuntimeState + from core.workflow.graph_engine.ready_queue import ReadyQueueState + + # Convert ready_queue_data to ReadyQueueState + ready_queue_state = ReadyQueueState(**ready_queue_data) + + graph_runtime_state = GraphRuntimeState( + variable_pool=original_runtime_state.variable_pool, + start_at=original_runtime_state.start_at, + ready_queue=ready_queue_state, + ) + + # Update all nodes to use the new GraphRuntimeState + for node in graph.nodes.values(): + node.graph_runtime_state = graph_runtime_state + + # Create GraphEngine + command_channel = InMemoryChannel() + engine = GraphEngine( + tenant_id="test-tenant", + app_id="test-app", + workflow_id="test-workflow", + user_id="test-user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + graph=graph, + graph_config={}, + graph_runtime_state=graph_runtime_state, + command_channel=command_channel, + ) + + # Verify that the ready_queue was loaded from GraphRuntimeState + assert engine._ready_queue.qsize() == 3 + + # Verify the initial state matches what was provided + initial_queue_state = engine.graph_runtime_state.ready_queue + assert initial_queue_state["type"] == "InMemoryReadyQueue" + assert initial_queue_state["version"] == "1.0" + assert len(initial_queue_state["items"]) == 3 + assert initial_queue_state["items"] == ["node1", "node2", "node3"]