diff --git a/api/core/workflow/graph_engine/error_handling/error_handler.py b/api/core/workflow/graph_engine/error_handling/error_handler.py index 7f6abb146c..d99115b75b 100644 --- a/api/core/workflow/graph_engine/error_handling/error_handler.py +++ b/api/core/workflow/graph_engine/error_handling/error_handler.py @@ -2,7 +2,7 @@ Main error handler that coordinates error strategies. """ -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, final from core.workflow.enums import ErrorStrategy as ErrorStrategyEnum from core.workflow.graph import Graph @@ -17,6 +17,7 @@ if TYPE_CHECKING: from ..domain import GraphExecution +@final class ErrorHandler: """ Coordinates error handling strategies for node failures. @@ -43,7 +44,7 @@ class ErrorHandler: self.fail_branch_strategy = FailBranchStrategy() self.default_value_strategy = DefaultValueStrategy() - def handle_node_failure(self, event: NodeRunFailedEvent) -> Optional[GraphNodeEventBase]: + def handle_node_failure(self, event: NodeRunFailedEvent) -> GraphNodeEventBase | None: """ Handle a node failure event. diff --git a/api/core/workflow/graph_engine/event_management/event_handlers.py b/api/core/workflow/graph_engine/event_management/event_handlers.py index db3137e99a..9124fb3a45 100644 --- a/api/core/workflow/graph_engine/event_management/event_handlers.py +++ b/api/core/workflow/graph_engine/event_management/event_handlers.py @@ -3,7 +3,7 @@ Event handler implementations for different event types. """ import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from core.workflow.entities import GraphRuntimeState from core.workflow.enums import NodeExecutionType @@ -52,12 +52,12 @@ class EventHandlerRegistry: graph_runtime_state: GraphRuntimeState, graph_execution: GraphExecution, response_coordinator: ResponseStreamCoordinator, - event_collector: Optional["EventCollector"] = None, - branch_handler: Optional["BranchHandler"] = None, - edge_processor: Optional["EdgeProcessor"] = None, - node_state_manager: Optional["NodeStateManager"] = None, - execution_tracker: Optional["ExecutionTracker"] = None, - error_handler: Optional["ErrorHandler"] = None, + event_collector: "EventCollector", + branch_handler: "BranchHandler", + edge_processor: "EdgeProcessor", + node_state_manager: "NodeStateManager", + execution_tracker: "ExecutionTracker", + error_handler: "ErrorHandler", ) -> None: """ Initialize the event handler registry. @@ -67,12 +67,12 @@ class EventHandlerRegistry: graph_runtime_state: Runtime state with variable pool graph_execution: Graph execution aggregate response_coordinator: Response stream coordinator - event_collector: Optional event collector for collecting events - branch_handler: Optional branch handler for branch node processing - edge_processor: Optional edge processor for edge traversal - node_state_manager: Optional node state manager - execution_tracker: Optional execution tracker - error_handler: Optional error handler + event_collector: Event collector for collecting events + branch_handler: Branch handler for branch node processing + edge_processor: Edge processor for edge traversal + node_state_manager: Node state manager + execution_tracker: Execution tracker + error_handler: Error handler """ self.graph = graph self.graph_runtime_state = graph_runtime_state @@ -93,9 +93,8 @@ class EventHandlerRegistry: event: The event to handle """ # Events in loops or iterations are always collected - if isinstance(event, GraphNodeEventBase) and (event.in_loop_id or event.in_iteration_id): - if self.event_collector: - self.event_collector.collect(event) + if event.in_loop_id or event.in_iteration_id: + self.event_collector.collect(event) return # Handle specific event types @@ -125,12 +124,10 @@ class EventHandlerRegistry: ), ): # Iteration and loop events are collected directly - if self.event_collector: - self.event_collector.collect(event) + self.event_collector.collect(event) else: # Collect unhandled events - if self.event_collector: - self.event_collector.collect(event) + self.event_collector.collect(event) logger.warning("Unhandled event type: %s", type(event).__name__) def _handle_node_started(self, event: NodeRunStartedEvent) -> None: @@ -148,8 +145,7 @@ class EventHandlerRegistry: self.response_coordinator.track_node_execution(event.node_id, event.id) # Collect the event - if self.event_collector: - self.event_collector.collect(event) + self.event_collector.collect(event) def _handle_stream_chunk(self, event: NodeRunStreamChunkEvent) -> None: """ @@ -162,9 +158,8 @@ class EventHandlerRegistry: streaming_events = list(self.response_coordinator.intercept_event(event)) # Collect all events - if self.event_collector: - for stream_event in streaming_events: - self.event_collector.collect(stream_event) + for stream_event in streaming_events: + self.event_collector.collect(stream_event) def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None: """ @@ -184,48 +179,37 @@ class EventHandlerRegistry: self._store_node_outputs(event) # Forward to response coordinator and emit streaming events - streaming_events = list(self.response_coordinator.intercept_event(event)) - if self.event_collector: - for stream_event in streaming_events: - self.event_collector.collect(stream_event) + streaming_events = self.response_coordinator.intercept_event(event) + for stream_event in streaming_events: + self.event_collector.collect(stream_event) # Process edges and get ready nodes node = self.graph.nodes[event.node_id] if node.execution_type == NodeExecutionType.BRANCH: - if self.branch_handler: - ready_nodes, edge_streaming_events = self.branch_handler.handle_branch_completion( - event.node_id, event.node_run_result.edge_source_handle - ) - else: - ready_nodes, edge_streaming_events = [], [] + ready_nodes, edge_streaming_events = self.branch_handler.handle_branch_completion( + event.node_id, event.node_run_result.edge_source_handle + ) else: - if self.edge_processor: - ready_nodes, edge_streaming_events = self.edge_processor.process_node_success(event.node_id) - else: - ready_nodes, edge_streaming_events = [], [] + ready_nodes, edge_streaming_events = self.edge_processor.process_node_success(event.node_id) # Collect streaming events from edge processing - if self.event_collector: - for edge_event in edge_streaming_events: - self.event_collector.collect(edge_event) + for edge_event in edge_streaming_events: + self.event_collector.collect(edge_event) # Enqueue ready nodes - if self.node_state_manager and self.execution_tracker: - for node_id in ready_nodes: - self.node_state_manager.enqueue_node(node_id) - self.execution_tracker.add(node_id) + for node_id in ready_nodes: + self.node_state_manager.enqueue_node(node_id) + self.execution_tracker.add(node_id) # Update execution tracking - if self.execution_tracker: - self.execution_tracker.remove(event.node_id) + self.execution_tracker.remove(event.node_id) # Handle response node outputs if node.execution_type == NodeExecutionType.RESPONSE: self._update_response_outputs(event) # Collect the event - if self.event_collector: - self.event_collector.collect(event) + self.event_collector.collect(event) def _handle_node_failed(self, event: NodeRunFailedEvent) -> None: """ @@ -238,26 +222,16 @@ class EventHandlerRegistry: node_execution = self.graph_execution.get_or_create_node_execution(event.node_id) node_execution.mark_failed(event.error) - if self.error_handler: - result = self.error_handler.handle_node_failure(event) + result = self.error_handler.handle_node_failure(event) - if result: - # Process the resulting event (retry, exception, etc.) - self.handle_event(result) - else: - # Abort execution - self.graph_execution.fail(RuntimeError(event.error)) - if self.event_collector: - self.event_collector.collect(event) - if self.execution_tracker: - self.execution_tracker.remove(event.node_id) + if result: + # Process the resulting event (retry, exception, etc.) + self.handle_event(result) else: - # Without error handler, just fail + # Abort execution self.graph_execution.fail(RuntimeError(event.error)) - if self.event_collector: - self.event_collector.collect(event) - if self.execution_tracker: - self.execution_tracker.remove(event.node_id) + self.event_collector.collect(event) + self.execution_tracker.remove(event.node_id) def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None: """ diff --git a/api/core/workflow/graph_engine/graph_traversal/branch_handler.py b/api/core/workflow/graph_engine/graph_traversal/branch_handler.py index 685867a02d..deddd86eb8 100644 --- a/api/core/workflow/graph_engine/graph_traversal/branch_handler.py +++ b/api/core/workflow/graph_engine/graph_traversal/branch_handler.py @@ -2,9 +2,11 @@ Branch node handling for graph traversal. """ +from collections.abc import Sequence from typing import Optional from core.workflow.graph import Graph +from core.workflow.graph_events.node import NodeRunStreamChunkEvent from ..state_management import EdgeStateManager from .edge_processor import EdgeProcessor @@ -40,7 +42,9 @@ class BranchHandler: self.skip_propagator = skip_propagator self.edge_state_manager = edge_state_manager - def handle_branch_completion(self, node_id: str, selected_handle: Optional[str]) -> tuple[list[str], list]: + def handle_branch_completion( + self, node_id: str, selected_handle: Optional[str] + ) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]: """ Handle completion of a branch node. diff --git a/api/core/workflow/graph_engine/graph_traversal/edge_processor.py b/api/core/workflow/graph_engine/graph_traversal/edge_processor.py index 79a7952282..76e6d819bf 100644 --- a/api/core/workflow/graph_engine/graph_traversal/edge_processor.py +++ b/api/core/workflow/graph_engine/graph_traversal/edge_processor.py @@ -2,8 +2,11 @@ Edge processing logic for graph traversal. """ +from collections.abc import Sequence + from core.workflow.enums import NodeExecutionType from core.workflow.graph import Edge, Graph +from core.workflow.graph_events import NodeRunStreamChunkEvent from ..response_coordinator import ResponseStreamCoordinator from ..state_management import EdgeStateManager, NodeStateManager @@ -38,7 +41,9 @@ class EdgeProcessor: self.node_state_manager = node_state_manager self.response_coordinator = response_coordinator - def process_node_success(self, node_id: str, selected_handle: str | None = None) -> tuple[list[str], list]: + def process_node_success( + self, node_id: str, selected_handle: str | None = None + ) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]: """ Process edges after a node succeeds. @@ -56,7 +61,7 @@ class EdgeProcessor: else: return self._process_non_branch_node_edges(node_id) - def _process_non_branch_node_edges(self, node_id: str) -> tuple[list[str], list]: + def _process_non_branch_node_edges(self, node_id: str) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]: """ Process edges for non-branch nodes (mark all as TAKEN). @@ -66,8 +71,8 @@ class EdgeProcessor: Returns: Tuple of (list of downstream nodes ready for execution, list of streaming events) """ - ready_nodes = [] - all_streaming_events = [] + ready_nodes: list[str] = [] + all_streaming_events: list[NodeRunStreamChunkEvent] = [] outgoing_edges = self.graph.get_outgoing_edges(node_id) for edge in outgoing_edges: @@ -77,7 +82,9 @@ class EdgeProcessor: return ready_nodes, all_streaming_events - def _process_branch_node_edges(self, node_id: str, selected_handle: str | None) -> tuple[list[str], list]: + def _process_branch_node_edges( + self, node_id: str, selected_handle: str | None + ) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]: """ Process edges for branch nodes. @@ -94,8 +101,8 @@ class EdgeProcessor: if not selected_handle: raise ValueError(f"Branch node {node_id} did not select any edge") - ready_nodes = [] - all_streaming_events = [] + ready_nodes: list[str] = [] + all_streaming_events: list[NodeRunStreamChunkEvent] = [] # Categorize edges selected_edges, unselected_edges = self.edge_state_manager.categorize_branch_edges(node_id, selected_handle) @@ -112,7 +119,7 @@ class EdgeProcessor: return ready_nodes, all_streaming_events - def _process_taken_edge(self, edge: Edge) -> tuple[list[str], list]: + def _process_taken_edge(self, edge: Edge) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]: """ Mark edge as taken and check downstream node. @@ -129,11 +136,11 @@ class EdgeProcessor: streaming_events = self.response_coordinator.on_edge_taken(edge.id) # Check if downstream node is ready - ready_nodes = [] + ready_nodes: list[str] = [] if self.node_state_manager.is_node_ready(edge.head): ready_nodes.append(edge.head) - return ready_nodes, list(streaming_events) + return ready_nodes, streaming_events def _process_skipped_edge(self, edge: Edge) -> None: """ diff --git a/api/core/workflow/graph_engine/graph_traversal/node_readiness.py b/api/core/workflow/graph_engine/graph_traversal/node_readiness.py index 93f9935a90..29e74e2f3f 100644 --- a/api/core/workflow/graph_engine/graph_traversal/node_readiness.py +++ b/api/core/workflow/graph_engine/graph_traversal/node_readiness.py @@ -71,7 +71,7 @@ class NodeReadinessChecker: Returns: List of node IDs that are now ready """ - ready_nodes = [] + ready_nodes: list[str] = [] outgoing_edges = self.graph.get_outgoing_edges(from_node_id) for edge in outgoing_edges: diff --git a/api/core/workflow/graph_engine/orchestration/dispatcher.py b/api/core/workflow/graph_engine/orchestration/dispatcher.py index 7fc441f194..bee4651def 100644 --- a/api/core/workflow/graph_engine/orchestration/dispatcher.py +++ b/api/core/workflow/graph_engine/orchestration/dispatcher.py @@ -8,6 +8,8 @@ import threading import time from typing import TYPE_CHECKING, Optional +from core.workflow.graph_events.base import GraphNodeEventBase + from ..event_management import EventCollector, EventEmitter from .execution_coordinator import ExecutionCoordinator @@ -27,7 +29,7 @@ class Dispatcher: def __init__( self, - event_queue: queue.Queue, + event_queue: queue.Queue[GraphNodeEventBase], event_handler: "EventHandlerRegistry", event_collector: EventCollector, execution_coordinator: ExecutionCoordinator, diff --git a/api/core/workflow/graph_engine/state_management/edge_state_manager.py b/api/core/workflow/graph_engine/state_management/edge_state_manager.py index 9e238a6fdd..32d6ca5780 100644 --- a/api/core/workflow/graph_engine/state_management/edge_state_manager.py +++ b/api/core/workflow/graph_engine/state_management/edge_state_manager.py @@ -3,6 +3,7 @@ Manager for edge states during graph execution. """ import threading +from collections.abc import Sequence from typing import TypedDict from core.workflow.enums import NodeState @@ -87,7 +88,7 @@ class EdgeStateManager: with self._lock: return self.graph.edges[edge_id].state - def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[list[Edge], list[Edge]]: + def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[Sequence[Edge], Sequence[Edge]]: """ Categorize branch edges into selected and unselected. @@ -100,8 +101,8 @@ class EdgeStateManager: """ with self._lock: outgoing_edges = self.graph.get_outgoing_edges(node_id) - selected_edges = [] - unselected_edges = [] + selected_edges: list[Edge] = [] + unselected_edges: list[Edge] = [] for edge in outgoing_edges: if edge.source_handle == selected_handle: