diff --git a/api/core/app/layers/pause_state_persist_layer.py b/api/core/app/layers/pause_state_persist_layer.py index 61a3e1baca..bf76ae8178 100644 --- a/api/core/app/layers/pause_state_persist_layer.py +++ b/api/core/app/layers/pause_state_persist_layer.py @@ -66,6 +66,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer): """ if isinstance(session_factory, Engine): session_factory = sessionmaker(session_factory) + super().__init__() self._session_maker = session_factory self._state_owner_user_id = state_owner_user_id self._generate_entity = generate_entity @@ -98,8 +99,6 @@ class PauseStatePersistenceLayer(GraphEngineLayer): if not isinstance(event, GraphRunPausedEvent): return - assert self.graph_runtime_state is not None - entity_wrapper: _GenerateEntityUnion if isinstance(self._generate_entity, WorkflowAppGenerateEntity): entity_wrapper = _WorkflowGenerateEntityWrapper(entity=self._generate_entity) diff --git a/api/core/app/layers/trigger_post_layer.py b/api/core/app/layers/trigger_post_layer.py index fe1a46a945..225b758fcb 100644 --- a/api/core/app/layers/trigger_post_layer.py +++ b/api/core/app/layers/trigger_post_layer.py @@ -33,6 +33,7 @@ class TriggerPostLayer(GraphEngineLayer): trigger_log_id: str, session_maker: sessionmaker[Session], ): + super().__init__() self.trigger_log_id = trigger_log_id self.start_time = start_time self.cfs_plan_scheduler_entity = cfs_plan_scheduler_entity @@ -57,10 +58,6 @@ class TriggerPostLayer(GraphEngineLayer): elapsed_time = (datetime.now(UTC) - self.start_time).total_seconds() # Extract relevant data from result - if not self.graph_runtime_state: - logger.exception("Graph runtime state is not set") - return - outputs = self.graph_runtime_state.outputs # BASICLY, workflow_execution_id is the same as workflow_run_id diff --git a/api/core/workflow/README.md b/api/core/workflow/README.md index 72f5dbe1e2..9a39f976a6 100644 --- a/api/core/workflow/README.md +++ b/api/core/workflow/README.md @@ -64,6 +64,9 @@ engine.layer(DebugLoggingLayer(level="INFO")) engine.layer(ExecutionLimitsLayer(max_nodes=100)) ``` +`engine.layer()` binds the read-only runtime state before execution, so layer hooks +can assume `graph_runtime_state` is available. + ### Event-Driven Architecture All node executions emit events for monitoring and integration: diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 2e8b8f345f..e4816afaf8 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -212,9 +212,16 @@ class GraphEngine: if id(node.graph_runtime_state) != expected_state_id: raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance") + def _bind_layer_context( + self, + layer: GraphEngineLayer, + ) -> None: + layer.initialize(ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state), self._command_channel) + def layer(self, layer: GraphEngineLayer) -> "GraphEngine": """Add a layer for extending functionality.""" self._layers.append(layer) + self._bind_layer_context(layer) return self def run(self) -> Generator[GraphEngineEvent, None, None]: @@ -301,14 +308,7 @@ class GraphEngine: def _initialize_layers(self) -> None: """Initialize layers with context.""" self._event_manager.set_layers(self._layers) - # Create a read-only wrapper for the runtime state - read_only_state = ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state) for layer in self._layers: - try: - layer.initialize(read_only_state, self._command_channel) - except Exception as e: - logger.warning("Failed to initialize layer %s: %s", layer.__class__.__name__, e) - try: layer.on_graph_start() except Exception as e: diff --git a/api/core/workflow/graph_engine/layers/README.md b/api/core/workflow/graph_engine/layers/README.md index 17845ee1f0..b0f295037c 100644 --- a/api/core/workflow/graph_engine/layers/README.md +++ b/api/core/workflow/graph_engine/layers/README.md @@ -8,7 +8,7 @@ Pluggable middleware for engine extensions. Abstract base class for layers. -- `initialize()` - Receive runtime context +- `initialize()` - Receive runtime context (runtime state is bound here and always available to hooks) - `on_graph_start()` - Execution start hook - `on_event()` - Process all events - `on_graph_end()` - Execution end hook @@ -34,6 +34,9 @@ engine.layer(debug_layer) engine.run() ``` +`engine.layer()` binds the read-only runtime state before execution, so +`graph_runtime_state` is always available inside layer hooks. + ## Custom Layers ```python diff --git a/api/core/workflow/graph_engine/layers/base.py b/api/core/workflow/graph_engine/layers/base.py index 780f92a0f4..89293b9b30 100644 --- a/api/core/workflow/graph_engine/layers/base.py +++ b/api/core/workflow/graph_engine/layers/base.py @@ -13,6 +13,14 @@ from core.workflow.nodes.base.node import Node from core.workflow.runtime import ReadOnlyGraphRuntimeState +class GraphEngineLayerNotInitializedError(Exception): + """Raised when a layer's runtime state is accessed before initialization.""" + + def __init__(self, layer_name: str | None = None) -> None: + name = layer_name or "GraphEngineLayer" + super().__init__(f"{name} runtime state is not initialized. Bind the layer to a GraphEngine before access.") + + class GraphEngineLayer(ABC): """ Abstract base class for GraphEngine layers. @@ -28,22 +36,27 @@ class GraphEngineLayer(ABC): def __init__(self) -> None: """Initialize the layer. Subclasses can override with custom parameters.""" - self.graph_runtime_state: ReadOnlyGraphRuntimeState | None = None + self._graph_runtime_state: ReadOnlyGraphRuntimeState | None = None self.command_channel: CommandChannel | None = None + @property + def graph_runtime_state(self) -> ReadOnlyGraphRuntimeState: + if self._graph_runtime_state is None: + raise GraphEngineLayerNotInitializedError(type(self).__name__) + return self._graph_runtime_state + def initialize(self, graph_runtime_state: ReadOnlyGraphRuntimeState, command_channel: CommandChannel) -> None: """ Initialize the layer with engine dependencies. - Called by GraphEngine before execution starts to inject the read-only runtime state - and command channel. This allows layers to observe engine context and send - commands, but prevents direct state modification. - + Called by GraphEngine to inject the read-only runtime state and command channel. + This is invoked when the layer is registered with a `GraphEngine` instance. + Implementations should be idempotent. Args: graph_runtime_state: Read-only view of the runtime state command_channel: Channel for sending commands to the engine """ - self.graph_runtime_state = graph_runtime_state + self._graph_runtime_state = graph_runtime_state self.command_channel = command_channel @abstractmethod diff --git a/api/core/workflow/graph_engine/layers/debug_logging.py b/api/core/workflow/graph_engine/layers/debug_logging.py index 034ebcf54f..e0402cd09c 100644 --- a/api/core/workflow/graph_engine/layers/debug_logging.py +++ b/api/core/workflow/graph_engine/layers/debug_logging.py @@ -109,10 +109,8 @@ class DebugLoggingLayer(GraphEngineLayer): self.logger.info("=" * 80) self.logger.info("🚀 GRAPH EXECUTION STARTED") self.logger.info("=" * 80) - - if self.graph_runtime_state: - # Log initial state - self.logger.info("Initial State:") + # Log initial state + self.logger.info("Initial State:") @override def on_event(self, event: GraphEngineEvent) -> None: @@ -243,8 +241,7 @@ class DebugLoggingLayer(GraphEngineLayer): self.logger.info(" Node retries: %s", self.retry_count) # Log final state if available - if self.graph_runtime_state and self.include_outputs: - if self.graph_runtime_state.outputs: - self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs)) + if self.include_outputs and self.graph_runtime_state.outputs: + self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs)) self.logger.info("=" * 80) diff --git a/api/core/workflow/graph_engine/layers/persistence.py b/api/core/workflow/graph_engine/layers/persistence.py index b70f36ec9e..e81df4f3b7 100644 --- a/api/core/workflow/graph_engine/layers/persistence.py +++ b/api/core/workflow/graph_engine/layers/persistence.py @@ -337,8 +337,6 @@ class WorkflowPersistenceLayer(GraphEngineLayer): if update_finished: execution.finished_at = naive_utc_now() runtime_state = self.graph_runtime_state - if runtime_state is None: - return execution.total_tokens = runtime_state.total_tokens execution.total_steps = runtime_state.node_run_steps execution.outputs = execution.outputs or runtime_state.outputs @@ -404,6 +402,4 @@ class WorkflowPersistenceLayer(GraphEngineLayer): def _system_variables(self) -> Mapping[str, Any]: runtime_state = self.graph_runtime_state - if runtime_state is None: - return {} return runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID) diff --git a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py index 72469ad646..dcf31aeca7 100644 --- a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py @@ -35,6 +35,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage from core.workflow.entities.pause_reason import SchedulingPause from core.workflow.enums import WorkflowExecutionStatus from core.workflow.graph_engine.entities.commands import GraphEngineCommand +from core.workflow.graph_engine.layers.base import GraphEngineLayerNotInitializedError from core.workflow.graph_events.graph import GraphRunPausedEvent from core.workflow.runtime.graph_runtime_state import GraphRuntimeState from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState @@ -569,10 +570,10 @@ class TestPauseStatePersistenceLayerTestContainers: """Test that layer requires proper initialization before handling events.""" # Arrange layer = self._create_pause_state_persistence_layer() - # Don't initialize - graph_runtime_state should not be set + # Don't initialize - graph_runtime_state should be uninitialized event = GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")]) - # Act & Assert - Should raise AttributeError - with pytest.raises(AttributeError): + # Act & Assert - Should raise GraphEngineLayerNotInitializedError + with pytest.raises(GraphEngineLayerNotInitializedError): layer.on_event(event) diff --git a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py index 534420f21e..5a5386ee57 100644 --- a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py @@ -15,6 +15,7 @@ from core.app.layers.pause_state_persist_layer import ( from core.variables.segments import Segment from core.workflow.entities.pause_reason import SchedulingPause from core.workflow.graph_engine.entities.commands import GraphEngineCommand +from core.workflow.graph_engine.layers.base import GraphEngineLayerNotInitializedError from core.workflow.graph_events.graph import ( GraphRunFailedEvent, GraphRunPausedEvent, @@ -209,8 +210,9 @@ class TestPauseStatePersistenceLayer: assert layer._session_maker is session_factory assert layer._state_owner_user_id == state_owner_user_id - assert not hasattr(layer, "graph_runtime_state") - assert not hasattr(layer, "command_channel") + with pytest.raises(GraphEngineLayerNotInitializedError): + _ = layer.graph_runtime_state + assert layer.command_channel is None def test_initialize_sets_dependencies(self): session_factory = Mock(name="session_factory") @@ -295,7 +297,7 @@ class TestPauseStatePersistenceLayer: mock_factory.assert_not_called() mock_repo.create_workflow_pause.assert_not_called() - def test_on_event_raises_attribute_error_when_graph_runtime_state_is_none(self): + def test_on_event_raises_when_graph_runtime_state_is_uninitialized(self): session_factory = Mock(name="session_factory") layer = PauseStatePersistenceLayer( session_factory=session_factory, @@ -305,7 +307,7 @@ class TestPauseStatePersistenceLayer: event = TestDataFactory.create_graph_run_paused_event() - with pytest.raises(AttributeError): + with pytest.raises(GraphEngineLayerNotInitializedError): layer.on_event(event) def test_on_event_asserts_when_workflow_execution_id_missing(self, monkeypatch: pytest.MonkeyPatch): diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py new file mode 100644 index 0000000000..d6ba61c50c --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +import pytest + +from core.workflow.graph_engine import GraphEngine +from core.workflow.graph_engine.command_channels import InMemoryChannel +from core.workflow.graph_engine.layers.base import ( + GraphEngineLayer, + GraphEngineLayerNotInitializedError, +) +from core.workflow.graph_events import GraphEngineEvent + +from ..test_table_runner import WorkflowRunner + + +class LayerForTest(GraphEngineLayer): + def on_graph_start(self) -> None: + pass + + def on_event(self, event: GraphEngineEvent) -> None: + pass + + def on_graph_end(self, error: Exception | None) -> None: + pass + + +def test_layer_runtime_state_raises_when_uninitialized() -> None: + layer = LayerForTest() + + with pytest.raises(GraphEngineLayerNotInitializedError): + _ = layer.graph_runtime_state + + +def test_layer_runtime_state_available_after_engine_layer() -> None: + runner = WorkflowRunner() + fixture_data = runner.load_fixture("simple_passthrough_workflow") + graph, graph_runtime_state = runner.create_graph_from_fixture( + fixture_data, + inputs={"query": "test layer state"}, + ) + engine = GraphEngine( + workflow_id="test_workflow", + graph=graph, + graph_runtime_state=graph_runtime_state, + command_channel=InMemoryChannel(), + ) + + layer = LayerForTest() + engine.layer(layer) + + outputs = layer.graph_runtime_state.outputs + ready_queue_size = layer.graph_runtime_state.ready_queue_size + + assert outputs == {} + assert isinstance(ready_queue_size, int) + assert ready_queue_size >= 0