diff --git a/api/core/workflow/graph_engine/event_management/event_handlers.py b/api/core/workflow/graph_engine/event_management/event_handlers.py index 9124fb3a45..842bd2635f 100644 --- a/api/core/workflow/graph_engine/event_management/event_handlers.py +++ b/api/core/workflow/graph_engine/event_management/event_handlers.py @@ -3,7 +3,7 @@ Event handler implementations for different event types. """ import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, final from core.workflow.entities import GraphRuntimeState from core.workflow.enums import NodeExecutionType @@ -38,6 +38,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +@final class EventHandlerRegistry: """ Registry of event handlers for different event types. @@ -74,16 +75,16 @@ class EventHandlerRegistry: execution_tracker: Execution tracker error_handler: Error handler """ - self.graph = graph - self.graph_runtime_state = graph_runtime_state - self.graph_execution = graph_execution - self.response_coordinator = response_coordinator - self.event_collector = event_collector - self.branch_handler = branch_handler - self.edge_processor = edge_processor - self.node_state_manager = node_state_manager - self.execution_tracker = execution_tracker - self.error_handler = error_handler + self._graph = graph + self._graph_runtime_state = graph_runtime_state + self._graph_execution = graph_execution + self._response_coordinator = response_coordinator + self._event_collector = event_collector + self._branch_handler = branch_handler + self._edge_processor = edge_processor + self._node_state_manager = node_state_manager + self._execution_tracker = execution_tracker + self._error_handler = error_handler def handle_event(self, event: GraphNodeEventBase) -> None: """ @@ -94,7 +95,7 @@ class EventHandlerRegistry: """ # Events in loops or iterations are always collected if event.in_loop_id or event.in_iteration_id: - self.event_collector.collect(event) + self._event_collector.collect(event) return # Handle specific event types @@ -124,10 +125,10 @@ class EventHandlerRegistry: ), ): # Iteration and loop events are collected directly - self.event_collector.collect(event) + self._event_collector.collect(event) else: # Collect unhandled events - self.event_collector.collect(event) + self._event_collector.collect(event) logger.warning("Unhandled event type: %s", type(event).__name__) def _handle_node_started(self, event: NodeRunStartedEvent) -> None: @@ -138,14 +139,14 @@ class EventHandlerRegistry: event: The node started event """ # Track execution in domain model - node_execution = self.graph_execution.get_or_create_node_execution(event.node_id) + node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) node_execution.mark_started(event.id) # Track in response coordinator for stream ordering - self.response_coordinator.track_node_execution(event.node_id, event.id) + self._response_coordinator.track_node_execution(event.node_id, event.id) # Collect the event - self.event_collector.collect(event) + self._event_collector.collect(event) def _handle_stream_chunk(self, event: NodeRunStreamChunkEvent) -> None: """ @@ -155,11 +156,11 @@ class EventHandlerRegistry: event: The stream chunk event """ # Process with response coordinator - streaming_events = list(self.response_coordinator.intercept_event(event)) + streaming_events = list(self._response_coordinator.intercept_event(event)) # Collect all events for stream_event in streaming_events: - self.event_collector.collect(stream_event) + self._event_collector.collect(stream_event) def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None: """ @@ -172,44 +173,44 @@ class EventHandlerRegistry: event: The node succeeded event """ # Update domain model - node_execution = self.graph_execution.get_or_create_node_execution(event.node_id) + node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) node_execution.mark_taken() # Store outputs in variable pool self._store_node_outputs(event) # Forward to response coordinator and emit streaming events - streaming_events = self.response_coordinator.intercept_event(event) + streaming_events = self._response_coordinator.intercept_event(event) for stream_event in streaming_events: - self.event_collector.collect(stream_event) + self._event_collector.collect(stream_event) # Process edges and get ready nodes - node = self.graph.nodes[event.node_id] + node = self._graph.nodes[event.node_id] if node.execution_type == NodeExecutionType.BRANCH: - ready_nodes, edge_streaming_events = self.branch_handler.handle_branch_completion( + ready_nodes, edge_streaming_events = self._branch_handler.handle_branch_completion( event.node_id, event.node_run_result.edge_source_handle ) else: - ready_nodes, edge_streaming_events = self.edge_processor.process_node_success(event.node_id) + ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id) # Collect streaming events from edge processing for edge_event in edge_streaming_events: - self.event_collector.collect(edge_event) + self._event_collector.collect(edge_event) # Enqueue ready nodes for node_id in ready_nodes: - self.node_state_manager.enqueue_node(node_id) - self.execution_tracker.add(node_id) + self._node_state_manager.enqueue_node(node_id) + self._execution_tracker.add(node_id) # Update execution tracking - self.execution_tracker.remove(event.node_id) + self._execution_tracker.remove(event.node_id) # Handle response node outputs if node.execution_type == NodeExecutionType.RESPONSE: self._update_response_outputs(event) # Collect the event - self.event_collector.collect(event) + self._event_collector.collect(event) def _handle_node_failed(self, event: NodeRunFailedEvent) -> None: """ @@ -219,19 +220,19 @@ class EventHandlerRegistry: event: The node failed event """ # Update domain model - node_execution = self.graph_execution.get_or_create_node_execution(event.node_id) + node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) node_execution.mark_failed(event.error) - result = self.error_handler.handle_node_failure(event) + result = self._error_handler.handle_node_failure(event) if result: # Process the resulting event (retry, exception, etc.) self.handle_event(result) else: # Abort execution - self.graph_execution.fail(RuntimeError(event.error)) - self.event_collector.collect(event) - self.execution_tracker.remove(event.node_id) + self._graph_execution.fail(RuntimeError(event.error)) + self._event_collector.collect(event) + self._execution_tracker.remove(event.node_id) def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None: """ @@ -241,7 +242,7 @@ class EventHandlerRegistry: event: The node exception event """ # Node continues via fail-branch, so it's technically "succeeded" - node_execution = self.graph_execution.get_or_create_node_execution(event.node_id) + node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) node_execution.mark_taken() def _handle_node_retry(self, event: NodeRunRetryEvent) -> None: @@ -251,7 +252,7 @@ class EventHandlerRegistry: Args: event: The node retry event """ - node_execution = self.graph_execution.get_or_create_node_execution(event.node_id) + node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) node_execution.increment_retry() def _store_node_outputs(self, event: NodeRunSucceededEvent) -> None: @@ -262,16 +263,16 @@ class EventHandlerRegistry: event: The node succeeded event containing outputs """ for variable_name, variable_value in event.node_run_result.outputs.items(): - self.graph_runtime_state.variable_pool.add((event.node_id, variable_name), variable_value) + self._graph_runtime_state.variable_pool.add((event.node_id, variable_name), variable_value) def _update_response_outputs(self, event: NodeRunSucceededEvent) -> None: """Update response outputs for response nodes.""" for key, value in event.node_run_result.outputs.items(): if key == "answer": - existing = self.graph_runtime_state.outputs.get("answer", "") + existing = self._graph_runtime_state.outputs.get("answer", "") if existing: - self.graph_runtime_state.outputs["answer"] = f"{existing}{value}" + self._graph_runtime_state.outputs["answer"] = f"{existing}{value}" else: - self.graph_runtime_state.outputs["answer"] = value + self._graph_runtime_state.outputs["answer"] = value else: - self.graph_runtime_state.outputs[key] = value + self._graph_runtime_state.outputs[key] = value diff --git a/api/core/workflow/graph_engine/output_registry/registry.py b/api/core/workflow/graph_engine/output_registry/registry.py index 42ccf51d62..6ffc6b178a 100644 --- a/api/core/workflow/graph_engine/output_registry/registry.py +++ b/api/core/workflow/graph_engine/output_registry/registry.py @@ -47,7 +47,7 @@ class OutputRegistry: with self._lock: self._scalars.add(selector, value) - def get_scalar(self, selector: Sequence[str]) -> "Segment" | None: + def get_scalar(self, selector: Sequence[str]) -> "Segment | None": """ Get a scalar value for the given selector. @@ -81,7 +81,7 @@ class OutputRegistry: except ValueError: raise ValueError(f"Stream {'.'.join(selector)} is already closed") - def pop_chunk(self, selector: Sequence[str]) -> "NodeRunStreamChunkEvent" | None: + def pop_chunk(self, selector: Sequence[str]) -> "NodeRunStreamChunkEvent | None": """ Pop the next unread NodeRunStreamChunkEvent from the stream. diff --git a/api/core/workflow/graph_engine/output_registry/stream.py b/api/core/workflow/graph_engine/output_registry/stream.py index 1e52d4efaa..e9a097d85f 100644 --- a/api/core/workflow/graph_engine/output_registry/stream.py +++ b/api/core/workflow/graph_engine/output_registry/stream.py @@ -41,7 +41,7 @@ class Stream: raise ValueError("Cannot append to a closed stream") self.events.append(event) - def pop_next(self) -> "NodeRunStreamChunkEvent" | None: + def pop_next(self) -> "NodeRunStreamChunkEvent | None": """ Pop the next unread NodeRunStreamChunkEvent from the stream.