From da87fce751f77b12532135842a3d1063bb70d293 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 16 Sep 2025 03:00:15 +0800 Subject: [PATCH] feat(graph_engine): dump and load ready queue --- api/.importlinter | 1 - .../workflow/entities/graph_runtime_state.py | 29 +++---- .../workflow/graph_engine/graph_engine.py | 20 ++--- .../graph_engine/ready_queue/__init__.py | 3 +- .../graph_engine/ready_queue/factory.py | 35 +++++++++ .../graph_engine/ready_queue/in_memory.py | 34 ++++----- .../graph_engine/ready_queue/protocol.py | 30 ++++---- .../graph_engine/response_coordinator/path.py | 2 +- .../entities/test_graph_runtime_state.py | 42 ----------- .../graph_engine/test_graph_engine.py | 75 ------------------- 10 files changed, 89 insertions(+), 182 deletions(-) create mode 100644 api/core/workflow/graph_engine/ready_queue/factory.py diff --git a/api/.importlinter b/api/.importlinter index c5c4126330..98fe5f50bb 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -31,7 +31,6 @@ ignore_imports = core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine core.workflow.nodes.loop.loop_node -> core.workflow.graph core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels - core.workflow.entities.graph_runtime_state -> core.workflow.graph_engine.ready_queue [importlinter:contract:rsc] name = RSC diff --git a/api/core/workflow/entities/graph_runtime_state.py b/api/core/workflow/entities/graph_runtime_state.py index aefdde5fc7..2b29a36d82 100644 --- a/api/core/workflow/entities/graph_runtime_state.py +++ b/api/core/workflow/entities/graph_runtime_state.py @@ -1,5 +1,4 @@ from copy import deepcopy -from typing import TYPE_CHECKING, Any from pydantic import BaseModel, PrivateAttr @@ -7,9 +6,6 @@ from core.model_runtime.entities.llm_entities import LLMUsage from .variable_pool import VariablePool -if TYPE_CHECKING: - from core.workflow.graph_engine.ready_queue import ReadyQueueState - class GraphRuntimeState(BaseModel): # Private attributes to prevent direct modification @@ -19,17 +15,18 @@ class GraphRuntimeState(BaseModel): _llm_usage: LLMUsage = PrivateAttr(default_factory=LLMUsage.empty_usage) _outputs: dict[str, object] = PrivateAttr(default_factory=dict[str, object]) _node_run_steps: int = PrivateAttr(default=0) - _ready_queue: "ReadyQueueState | dict[str, object]" = PrivateAttr(default_factory=dict) + _ready_queue_json: str = PrivateAttr() def __init__( self, + *, variable_pool: VariablePool, start_at: float, total_tokens: int = 0, llm_usage: LLMUsage | None = None, - outputs: dict[str, Any] | None = None, + outputs: dict[str, object] | None = None, node_run_steps: int = 0, - ready_queue: "ReadyQueueState | dict[str, object] | None" = None, + ready_queue_json: str = "", **kwargs: object, ): """Initialize the GraphRuntimeState with validation.""" @@ -56,9 +53,7 @@ class GraphRuntimeState(BaseModel): raise ValueError("node_run_steps must be non-negative") self._node_run_steps = node_run_steps - if ready_queue is None: - ready_queue = {} - self._ready_queue = deepcopy(ready_queue) + self._ready_queue_json = ready_queue_json @property def variable_pool(self) -> VariablePool: @@ -99,24 +94,24 @@ class GraphRuntimeState(BaseModel): self._llm_usage = value.model_copy() @property - def outputs(self) -> dict[str, Any]: + def outputs(self) -> dict[str, object]: """Get a copy of the outputs dictionary.""" return deepcopy(self._outputs) @outputs.setter - def outputs(self, value: dict[str, Any]) -> None: + def outputs(self, value: dict[str, object]) -> None: """Set the outputs dictionary.""" self._outputs = deepcopy(value) - def set_output(self, key: str, value: Any) -> None: + def set_output(self, key: str, value: object) -> None: """Set a single output value.""" self._outputs[key] = deepcopy(value) - def get_output(self, key: str, default: Any = None) -> Any: + def get_output(self, key: str, default: object = None) -> object: """Get a single output value.""" return deepcopy(self._outputs.get(key, default)) - def update_outputs(self, updates: dict[str, Any]) -> None: + def update_outputs(self, updates: dict[str, object]) -> None: """Update multiple output values.""" for key, value in updates.items(): self._outputs[key] = deepcopy(value) @@ -144,6 +139,6 @@ class GraphRuntimeState(BaseModel): self._total_tokens += tokens @property - def ready_queue(self) -> "ReadyQueueState | dict[str, object]": + def ready_queue_json(self) -> str: """Get a copy of the ready queue state.""" - return deepcopy(self._ready_queue) + return self._ready_queue_json diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index a7b582d803..dc85619421 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -18,6 +18,7 @@ from core.workflow.entities import GraphRuntimeState from core.workflow.enums import NodeExecutionType from core.workflow.graph import Graph from core.workflow.graph.read_only_state_wrapper import ReadOnlyGraphRuntimeStateWrapper +from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue from core.workflow.graph_events import ( GraphEngineEvent, GraphNodeEventBase, @@ -38,7 +39,7 @@ from .graph_traversal import EdgeProcessor, SkipPropagator from .layers.base import GraphEngineLayer from .orchestration import Dispatcher, ExecutionCoordinator from .protocols.command_channel import CommandChannel -from .ready_queue import InMemoryReadyQueue +from .ready_queue import ReadyQueueState, create_ready_queue_from_state from .response_coordinator import ResponseStreamCoordinator from .worker_management import WorkerPool @@ -104,18 +105,13 @@ class GraphEngine: self._scale_down_idle_time = scale_down_idle_time # === Execution Queues === - # Queue for nodes ready to execute - self._ready_queue = InMemoryReadyQueue() - # Load ready queue state from GraphRuntimeState if not empty - ready_queue_state = self._graph_runtime_state.ready_queue - if ready_queue_state: - # Import ReadyQueueState here to avoid circular imports - from .ready_queue import ReadyQueueState + # Create ready queue from saved state or initialize new one + if self._graph_runtime_state.ready_queue_json == "": + self._ready_queue = InMemoryReadyQueue() + else: + ready_queue_state = ReadyQueueState.model_validate_json(self._graph_runtime_state.ready_queue_json) + self._ready_queue = create_ready_queue_from_state(ready_queue_state) - # Ensure we have a ReadyQueueState object - if isinstance(ready_queue_state, dict): - ready_queue_state = ReadyQueueState(**ready_queue_state) # type: ignore - self._ready_queue.loads(ready_queue_state) # Queue for events generated during execution self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue() diff --git a/api/core/workflow/graph_engine/ready_queue/__init__.py b/api/core/workflow/graph_engine/ready_queue/__init__.py index 448abda286..acba0e961c 100644 --- a/api/core/workflow/graph_engine/ready_queue/__init__.py +++ b/api/core/workflow/graph_engine/ready_queue/__init__.py @@ -5,7 +5,8 @@ This package contains the protocol and implementations for managing the queue of nodes ready for execution. """ +from .factory import create_ready_queue_from_state from .in_memory import InMemoryReadyQueue from .protocol import ReadyQueue, ReadyQueueState -__all__ = ["InMemoryReadyQueue", "ReadyQueue", "ReadyQueueState"] +__all__ = ["InMemoryReadyQueue", "ReadyQueue", "ReadyQueueState", "create_ready_queue_from_state"] diff --git a/api/core/workflow/graph_engine/ready_queue/factory.py b/api/core/workflow/graph_engine/ready_queue/factory.py new file mode 100644 index 0000000000..1144e1de69 --- /dev/null +++ b/api/core/workflow/graph_engine/ready_queue/factory.py @@ -0,0 +1,35 @@ +""" +Factory for creating ReadyQueue instances from serialized state. +""" + +from typing import TYPE_CHECKING + +from .in_memory import InMemoryReadyQueue +from .protocol import ReadyQueueState + +if TYPE_CHECKING: + from .protocol import ReadyQueue + + +def create_ready_queue_from_state(state: ReadyQueueState) -> "ReadyQueue": + """ + Create a ReadyQueue instance from a serialized state. + + Args: + state: The serialized queue state (Pydantic model, dict, or JSON string), or None for a new empty queue + + Returns: + A ReadyQueue instance initialized with the given state + + Raises: + ValueError: If the queue type is unknown or version is unsupported + """ + if state.type == "InMemoryReadyQueue": + if state.version != "1.0": + raise ValueError(f"Unsupported InMemoryReadyQueue version: {state.version}") + queue = InMemoryReadyQueue() + # Always pass as JSON string to loads() + queue.loads(state.model_dump_json()) + return queue + else: + raise ValueError(f"Unknown ready queue type: {state.type}") diff --git a/api/core/workflow/graph_engine/ready_queue/in_memory.py b/api/core/workflow/graph_engine/ready_queue/in_memory.py index c3cfbb00ad..e01ecdc160 100644 --- a/api/core/workflow/graph_engine/ready_queue/in_memory.py +++ b/api/core/workflow/graph_engine/ready_queue/in_memory.py @@ -82,12 +82,12 @@ class InMemoryReadyQueue: """ return self._queue.qsize() - def dumps(self) -> ReadyQueueState: + def dumps(self) -> str: """ - Serialize the queue state for storage. + Serialize the queue state to a JSON string for storage. Returns: - A ReadyQueueState dictionary containing the serialized queue state + A JSON string containing the serialized queue state """ # Extract all items from the queue without removing them items: list[str] = [] @@ -106,25 +106,27 @@ class InMemoryReadyQueue: for item in temp_items: self._queue.put(item) - return ReadyQueueState( + state = ReadyQueueState( type="InMemoryReadyQueue", version="1.0", items=items, - maxsize=self._queue.maxsize, ) + return state.model_dump_json() - def loads(self, data: ReadyQueueState) -> None: + def loads(self, data: str) -> None: """ - Restore the queue state from serialized data. + Restore the queue state from a JSON string. Args: - data: The serialized queue state to restore + data: The JSON string containing the serialized queue state to restore """ - if data.get("type") != "InMemoryReadyQueue": - raise ValueError(f"Invalid serialized data type: {data.get('type')}") + state = ReadyQueueState.model_validate_json(data) - if data.get("version") != "1.0": - raise ValueError(f"Unsupported version: {data.get('version')}") + if state.type != "InMemoryReadyQueue": + raise ValueError(f"Invalid serialized data type: {state.type}") + + if state.version != "1.0": + raise ValueError(f"Unsupported version: {state.version}") # Clear the current queue while not self._queue.empty(): @@ -134,11 +136,5 @@ class InMemoryReadyQueue: break # Restore items - items = data.get("items", []) - if not isinstance(items, list): - raise ValueError("Invalid items data: expected list") - - for item in items: - if not isinstance(item, str): - raise ValueError(f"Invalid item type: expected str, got {type(item).__name__}") + for item in state.items: self._queue.put(item) diff --git a/api/core/workflow/graph_engine/ready_queue/protocol.py b/api/core/workflow/graph_engine/ready_queue/protocol.py index d0f66d2955..97d3ea6dd2 100644 --- a/api/core/workflow/graph_engine/ready_queue/protocol.py +++ b/api/core/workflow/graph_engine/ready_queue/protocol.py @@ -5,21 +5,23 @@ This protocol defines the interface for managing the queue of nodes ready for execution, supporting both in-memory and persistent storage scenarios. """ -from typing import Protocol, TypedDict +from collections.abc import Sequence +from typing import Protocol + +from pydantic import BaseModel, Field -class ReadyQueueState(TypedDict): +class ReadyQueueState(BaseModel): """ - TypedDict for serialized ready queue state. + Pydantic model for serialized ready queue state. - This defines the structure of the dictionary returned by dumps() + This defines the structure of the data returned by dumps() and expected by loads() for ready queue serialization. """ - type: str # Queue implementation type (e.g., "InMemoryReadyQueue") - version: str # Serialization format version - items: list[str] # List of node IDs in the queue - maxsize: int # Maximum queue size (0 for unlimited) + type: str = Field(description="Queue implementation type (e.g., 'InMemoryReadyQueue')") + version: str = Field(description="Serialization format version") + items: Sequence[str] = Field(default_factory=list, description="List of node IDs in the queue") class ReadyQueue(Protocol): @@ -82,21 +84,21 @@ class ReadyQueue(Protocol): """ ... - def dumps(self) -> ReadyQueueState: + def dumps(self) -> str: """ - Serialize the queue state for storage. + Serialize the queue state to a JSON string for storage. Returns: - A ReadyQueueState dictionary containing the serialized queue state + A JSON string containing the serialized queue state that can be persisted and later restored """ ... - def loads(self, data: ReadyQueueState) -> None: + def loads(self, data: str) -> None: """ - Restore the queue state from serialized data. + Restore the queue state from a JSON string. Args: - data: The serialized queue state to restore + data: The JSON string containing the serialized queue state to restore """ ... diff --git a/api/core/workflow/graph_engine/response_coordinator/path.py b/api/core/workflow/graph_engine/response_coordinator/path.py index d83dd5e77b..50f2f4eb21 100644 --- a/api/core/workflow/graph_engine/response_coordinator/path.py +++ b/api/core/workflow/graph_engine/response_coordinator/path.py @@ -19,7 +19,7 @@ class Path: Note: This is an internal class not exposed in the public API. """ - edges: list[EdgeID] = field(default_factory=list) + edges: list[EdgeID] = field(default_factory=list[EdgeID]) def contains_edge(self, edge_id: EdgeID) -> bool: """Check if this path contains the given edge.""" diff --git a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py index 067b8d8186..2614424dc7 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py +++ b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py @@ -4,7 +4,6 @@ import pytest from core.workflow.entities.graph_runtime_state import GraphRuntimeState from core.workflow.entities.variable_pool import VariablePool -from core.workflow.graph_engine.ready_queue import ReadyQueueState class TestGraphRuntimeState: @@ -96,44 +95,3 @@ class TestGraphRuntimeState: # Test add_tokens validation with pytest.raises(ValueError): state.add_tokens(-1) - - def test_deep_copy_for_nested_objects(self): - variable_pool = VariablePool() - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - - # Test deep copy for nested dict - nested_data = {"level1": {"level2": {"value": "test"}}} - state.set_output("nested", nested_data) - - retrieved = state.get_output("nested") - retrieved["level1"]["level2"]["value"] = "modified" - - # Original should remain unchanged - assert state.get_output("nested")["level1"]["level2"]["value"] == "test" - - def test_ready_queue_property(self): - variable_pool = VariablePool() - - # Test default empty ready_queue - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - assert state.ready_queue == {} - - # Test initialization with ready_queue data as ReadyQueueState - queue_data = ReadyQueueState(type="InMemoryReadyQueue", version="1.0", items=["node1", "node2"], maxsize=0) - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time(), ready_queue=queue_data) - assert state.ready_queue == queue_data - - # Test with different ready_queue data at initialization - another_queue_data = ReadyQueueState( - type="InMemoryReadyQueue", - version="1.0", - items=["node3", "node4", "node5"], - maxsize=0, - ) - another_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time(), ready_queue=another_queue_data) - assert another_state.ready_queue == another_queue_data - - # Test immutability - modifying retrieved queue doesn't affect internal state - retrieved_queue = state.ready_queue - retrieved_queue["items"].append("node6") - assert len(state.ready_queue["items"]) == 2 # Should still be 2, not 3 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index f03c19ab1c..4aa33bde26 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -744,78 +744,3 @@ def test_event_sequence_validation_with_table_tests(): else: assert result.event_sequence_match is True assert result.success, f"Test {i + 1} failed: {result.event_mismatch_details or result.error}" - - -def test_ready_queue_state_loading(): - """ - Test that the ready_queue state is properly loaded from GraphRuntimeState - during GraphEngine initialization. - """ - # Use the TableTestRunner to create a proper workflow instance - runner = TableTestRunner() - - # Create a simple workflow - test_case = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "test"}, - expected_outputs={"query": "test"}, - description="Test ready_queue loading", - ) - - # Load the workflow fixture - workflow_runner = runner.workflow_runner - fixture_data = workflow_runner.load_fixture("simple_passthrough_workflow") - - # Create graph and runtime state with pre-populated ready_queue - ready_queue_data = { - "type": "InMemoryReadyQueue", - "version": "1.0", - "items": ["node1", "node2", "node3"], - "maxsize": 0, - } - - # We need to create the graph first, then create a new GraphRuntimeState with ready_queue - graph, original_runtime_state = workflow_runner.create_graph_from_fixture(fixture_data, query="test") - - # Create a new GraphRuntimeState with the ready_queue data - from core.workflow.entities import GraphRuntimeState - from core.workflow.graph_engine.ready_queue import ReadyQueueState - - # Convert ready_queue_data to ReadyQueueState - ready_queue_state = ReadyQueueState(**ready_queue_data) - - graph_runtime_state = GraphRuntimeState( - variable_pool=original_runtime_state.variable_pool, - start_at=original_runtime_state.start_at, - ready_queue=ready_queue_state, - ) - - # Update all nodes to use the new GraphRuntimeState - for node in graph.nodes.values(): - node.graph_runtime_state = graph_runtime_state - - # Create GraphEngine - command_channel = InMemoryChannel() - engine = GraphEngine( - tenant_id="test-tenant", - app_id="test-app", - workflow_id="test-workflow", - user_id="test-user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - call_depth=0, - graph=graph, - graph_config={}, - graph_runtime_state=graph_runtime_state, - command_channel=command_channel, - ) - - # Verify that the ready_queue was loaded from GraphRuntimeState - assert engine._ready_queue.qsize() == 3 - - # Verify the initial state matches what was provided - initial_queue_state = engine.graph_runtime_state.ready_queue - assert initial_queue_state["type"] == "InMemoryReadyQueue" - assert initial_queue_state["version"] == "1.0" - assert len(initial_queue_state["items"]) == 3 - assert initial_queue_state["items"] == ["node1", "node2", "node3"]