chore(graph_events): Improve type hints

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2025-08-28 04:41:48 +08:00
parent 7cbf4093f4
commit 1cd0792606
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF
7 changed files with 74 additions and 85 deletions

View File

@ -2,7 +2,7 @@
Main error handler that coordinates error strategies. Main error handler that coordinates error strategies.
""" """
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, final
from core.workflow.enums import ErrorStrategy as ErrorStrategyEnum from core.workflow.enums import ErrorStrategy as ErrorStrategyEnum
from core.workflow.graph import Graph from core.workflow.graph import Graph
@ -17,6 +17,7 @@ if TYPE_CHECKING:
from ..domain import GraphExecution from ..domain import GraphExecution
@final
class ErrorHandler: class ErrorHandler:
""" """
Coordinates error handling strategies for node failures. Coordinates error handling strategies for node failures.
@ -43,7 +44,7 @@ class ErrorHandler:
self.fail_branch_strategy = FailBranchStrategy() self.fail_branch_strategy = FailBranchStrategy()
self.default_value_strategy = DefaultValueStrategy() self.default_value_strategy = DefaultValueStrategy()
def handle_node_failure(self, event: NodeRunFailedEvent) -> Optional[GraphNodeEventBase]: def handle_node_failure(self, event: NodeRunFailedEvent) -> GraphNodeEventBase | None:
""" """
Handle a node failure event. Handle a node failure event.

View File

@ -3,7 +3,7 @@ Event handler implementations for different event types.
""" """
import logging import logging
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING
from core.workflow.entities import GraphRuntimeState from core.workflow.entities import GraphRuntimeState
from core.workflow.enums import NodeExecutionType from core.workflow.enums import NodeExecutionType
@ -52,12 +52,12 @@ class EventHandlerRegistry:
graph_runtime_state: GraphRuntimeState, graph_runtime_state: GraphRuntimeState,
graph_execution: GraphExecution, graph_execution: GraphExecution,
response_coordinator: ResponseStreamCoordinator, response_coordinator: ResponseStreamCoordinator,
event_collector: Optional["EventCollector"] = None, event_collector: "EventCollector",
branch_handler: Optional["BranchHandler"] = None, branch_handler: "BranchHandler",
edge_processor: Optional["EdgeProcessor"] = None, edge_processor: "EdgeProcessor",
node_state_manager: Optional["NodeStateManager"] = None, node_state_manager: "NodeStateManager",
execution_tracker: Optional["ExecutionTracker"] = None, execution_tracker: "ExecutionTracker",
error_handler: Optional["ErrorHandler"] = None, error_handler: "ErrorHandler",
) -> None: ) -> None:
""" """
Initialize the event handler registry. Initialize the event handler registry.
@ -67,12 +67,12 @@ class EventHandlerRegistry:
graph_runtime_state: Runtime state with variable pool graph_runtime_state: Runtime state with variable pool
graph_execution: Graph execution aggregate graph_execution: Graph execution aggregate
response_coordinator: Response stream coordinator response_coordinator: Response stream coordinator
event_collector: Optional event collector for collecting events event_collector: Event collector for collecting events
branch_handler: Optional branch handler for branch node processing branch_handler: Branch handler for branch node processing
edge_processor: Optional edge processor for edge traversal edge_processor: Edge processor for edge traversal
node_state_manager: Optional node state manager node_state_manager: Node state manager
execution_tracker: Optional execution tracker execution_tracker: Execution tracker
error_handler: Optional error handler error_handler: Error handler
""" """
self.graph = graph self.graph = graph
self.graph_runtime_state = graph_runtime_state self.graph_runtime_state = graph_runtime_state
@ -93,9 +93,8 @@ class EventHandlerRegistry:
event: The event to handle event: The event to handle
""" """
# Events in loops or iterations are always collected # Events in loops or iterations are always collected
if isinstance(event, GraphNodeEventBase) and (event.in_loop_id or event.in_iteration_id): if event.in_loop_id or event.in_iteration_id:
if self.event_collector: self.event_collector.collect(event)
self.event_collector.collect(event)
return return
# Handle specific event types # Handle specific event types
@ -125,12 +124,10 @@ class EventHandlerRegistry:
), ),
): ):
# Iteration and loop events are collected directly # Iteration and loop events are collected directly
if self.event_collector: self.event_collector.collect(event)
self.event_collector.collect(event)
else: else:
# Collect unhandled events # Collect unhandled events
if self.event_collector: self.event_collector.collect(event)
self.event_collector.collect(event)
logger.warning("Unhandled event type: %s", type(event).__name__) logger.warning("Unhandled event type: %s", type(event).__name__)
def _handle_node_started(self, event: NodeRunStartedEvent) -> None: def _handle_node_started(self, event: NodeRunStartedEvent) -> None:
@ -148,8 +145,7 @@ class EventHandlerRegistry:
self.response_coordinator.track_node_execution(event.node_id, event.id) self.response_coordinator.track_node_execution(event.node_id, event.id)
# Collect the event # Collect the event
if self.event_collector: self.event_collector.collect(event)
self.event_collector.collect(event)
def _handle_stream_chunk(self, event: NodeRunStreamChunkEvent) -> None: def _handle_stream_chunk(self, event: NodeRunStreamChunkEvent) -> None:
""" """
@ -162,9 +158,8 @@ class EventHandlerRegistry:
streaming_events = list(self.response_coordinator.intercept_event(event)) streaming_events = list(self.response_coordinator.intercept_event(event))
# Collect all events # Collect all events
if self.event_collector: for stream_event in streaming_events:
for stream_event in streaming_events: self.event_collector.collect(stream_event)
self.event_collector.collect(stream_event)
def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None: def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None:
""" """
@ -184,48 +179,37 @@ class EventHandlerRegistry:
self._store_node_outputs(event) self._store_node_outputs(event)
# Forward to response coordinator and emit streaming events # Forward to response coordinator and emit streaming events
streaming_events = list(self.response_coordinator.intercept_event(event)) streaming_events = self.response_coordinator.intercept_event(event)
if self.event_collector: for stream_event in streaming_events:
for stream_event in streaming_events: self.event_collector.collect(stream_event)
self.event_collector.collect(stream_event)
# Process edges and get ready nodes # Process edges and get ready nodes
node = self.graph.nodes[event.node_id] node = self.graph.nodes[event.node_id]
if node.execution_type == NodeExecutionType.BRANCH: if node.execution_type == NodeExecutionType.BRANCH:
if self.branch_handler: ready_nodes, edge_streaming_events = self.branch_handler.handle_branch_completion(
ready_nodes, edge_streaming_events = self.branch_handler.handle_branch_completion( event.node_id, event.node_run_result.edge_source_handle
event.node_id, event.node_run_result.edge_source_handle )
)
else:
ready_nodes, edge_streaming_events = [], []
else: else:
if self.edge_processor: ready_nodes, edge_streaming_events = self.edge_processor.process_node_success(event.node_id)
ready_nodes, edge_streaming_events = self.edge_processor.process_node_success(event.node_id)
else:
ready_nodes, edge_streaming_events = [], []
# Collect streaming events from edge processing # Collect streaming events from edge processing
if self.event_collector: for edge_event in edge_streaming_events:
for edge_event in edge_streaming_events: self.event_collector.collect(edge_event)
self.event_collector.collect(edge_event)
# Enqueue ready nodes # Enqueue ready nodes
if self.node_state_manager and self.execution_tracker: for node_id in ready_nodes:
for node_id in ready_nodes: self.node_state_manager.enqueue_node(node_id)
self.node_state_manager.enqueue_node(node_id) self.execution_tracker.add(node_id)
self.execution_tracker.add(node_id)
# Update execution tracking # Update execution tracking
if self.execution_tracker: self.execution_tracker.remove(event.node_id)
self.execution_tracker.remove(event.node_id)
# Handle response node outputs # Handle response node outputs
if node.execution_type == NodeExecutionType.RESPONSE: if node.execution_type == NodeExecutionType.RESPONSE:
self._update_response_outputs(event) self._update_response_outputs(event)
# Collect the event # Collect the event
if self.event_collector: self.event_collector.collect(event)
self.event_collector.collect(event)
def _handle_node_failed(self, event: NodeRunFailedEvent) -> None: def _handle_node_failed(self, event: NodeRunFailedEvent) -> None:
""" """
@ -238,26 +222,16 @@ class EventHandlerRegistry:
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id) node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_failed(event.error) node_execution.mark_failed(event.error)
if self.error_handler: result = self.error_handler.handle_node_failure(event)
result = self.error_handler.handle_node_failure(event)
if result: if result:
# Process the resulting event (retry, exception, etc.) # Process the resulting event (retry, exception, etc.)
self.handle_event(result) self.handle_event(result)
else:
# Abort execution
self.graph_execution.fail(RuntimeError(event.error))
if self.event_collector:
self.event_collector.collect(event)
if self.execution_tracker:
self.execution_tracker.remove(event.node_id)
else: else:
# Without error handler, just fail # Abort execution
self.graph_execution.fail(RuntimeError(event.error)) self.graph_execution.fail(RuntimeError(event.error))
if self.event_collector: self.event_collector.collect(event)
self.event_collector.collect(event) self.execution_tracker.remove(event.node_id)
if self.execution_tracker:
self.execution_tracker.remove(event.node_id)
def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None: def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None:
""" """

View File

@ -2,9 +2,11 @@
Branch node handling for graph traversal. Branch node handling for graph traversal.
""" """
from collections.abc import Sequence
from typing import Optional from typing import Optional
from core.workflow.graph import Graph from core.workflow.graph import Graph
from core.workflow.graph_events.node import NodeRunStreamChunkEvent
from ..state_management import EdgeStateManager from ..state_management import EdgeStateManager
from .edge_processor import EdgeProcessor from .edge_processor import EdgeProcessor
@ -40,7 +42,9 @@ class BranchHandler:
self.skip_propagator = skip_propagator self.skip_propagator = skip_propagator
self.edge_state_manager = edge_state_manager self.edge_state_manager = edge_state_manager
def handle_branch_completion(self, node_id: str, selected_handle: Optional[str]) -> tuple[list[str], list]: def handle_branch_completion(
self, node_id: str, selected_handle: Optional[str]
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
""" """
Handle completion of a branch node. Handle completion of a branch node.

View File

@ -2,8 +2,11 @@
Edge processing logic for graph traversal. Edge processing logic for graph traversal.
""" """
from collections.abc import Sequence
from core.workflow.enums import NodeExecutionType from core.workflow.enums import NodeExecutionType
from core.workflow.graph import Edge, Graph from core.workflow.graph import Edge, Graph
from core.workflow.graph_events import NodeRunStreamChunkEvent
from ..response_coordinator import ResponseStreamCoordinator from ..response_coordinator import ResponseStreamCoordinator
from ..state_management import EdgeStateManager, NodeStateManager from ..state_management import EdgeStateManager, NodeStateManager
@ -38,7 +41,9 @@ class EdgeProcessor:
self.node_state_manager = node_state_manager self.node_state_manager = node_state_manager
self.response_coordinator = response_coordinator self.response_coordinator = response_coordinator
def process_node_success(self, node_id: str, selected_handle: str | None = None) -> tuple[list[str], list]: def process_node_success(
self, node_id: str, selected_handle: str | None = None
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
""" """
Process edges after a node succeeds. Process edges after a node succeeds.
@ -56,7 +61,7 @@ class EdgeProcessor:
else: else:
return self._process_non_branch_node_edges(node_id) return self._process_non_branch_node_edges(node_id)
def _process_non_branch_node_edges(self, node_id: str) -> tuple[list[str], list]: def _process_non_branch_node_edges(self, node_id: str) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
""" """
Process edges for non-branch nodes (mark all as TAKEN). Process edges for non-branch nodes (mark all as TAKEN).
@ -66,8 +71,8 @@ class EdgeProcessor:
Returns: Returns:
Tuple of (list of downstream nodes ready for execution, list of streaming events) Tuple of (list of downstream nodes ready for execution, list of streaming events)
""" """
ready_nodes = [] ready_nodes: list[str] = []
all_streaming_events = [] all_streaming_events: list[NodeRunStreamChunkEvent] = []
outgoing_edges = self.graph.get_outgoing_edges(node_id) outgoing_edges = self.graph.get_outgoing_edges(node_id)
for edge in outgoing_edges: for edge in outgoing_edges:
@ -77,7 +82,9 @@ class EdgeProcessor:
return ready_nodes, all_streaming_events return ready_nodes, all_streaming_events
def _process_branch_node_edges(self, node_id: str, selected_handle: str | None) -> tuple[list[str], list]: def _process_branch_node_edges(
self, node_id: str, selected_handle: str | None
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
""" """
Process edges for branch nodes. Process edges for branch nodes.
@ -94,8 +101,8 @@ class EdgeProcessor:
if not selected_handle: if not selected_handle:
raise ValueError(f"Branch node {node_id} did not select any edge") raise ValueError(f"Branch node {node_id} did not select any edge")
ready_nodes = [] ready_nodes: list[str] = []
all_streaming_events = [] all_streaming_events: list[NodeRunStreamChunkEvent] = []
# Categorize edges # Categorize edges
selected_edges, unselected_edges = self.edge_state_manager.categorize_branch_edges(node_id, selected_handle) selected_edges, unselected_edges = self.edge_state_manager.categorize_branch_edges(node_id, selected_handle)
@ -112,7 +119,7 @@ class EdgeProcessor:
return ready_nodes, all_streaming_events return ready_nodes, all_streaming_events
def _process_taken_edge(self, edge: Edge) -> tuple[list[str], list]: def _process_taken_edge(self, edge: Edge) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
""" """
Mark edge as taken and check downstream node. Mark edge as taken and check downstream node.
@ -129,11 +136,11 @@ class EdgeProcessor:
streaming_events = self.response_coordinator.on_edge_taken(edge.id) streaming_events = self.response_coordinator.on_edge_taken(edge.id)
# Check if downstream node is ready # Check if downstream node is ready
ready_nodes = [] ready_nodes: list[str] = []
if self.node_state_manager.is_node_ready(edge.head): if self.node_state_manager.is_node_ready(edge.head):
ready_nodes.append(edge.head) ready_nodes.append(edge.head)
return ready_nodes, list(streaming_events) return ready_nodes, streaming_events
def _process_skipped_edge(self, edge: Edge) -> None: def _process_skipped_edge(self, edge: Edge) -> None:
""" """

View File

@ -71,7 +71,7 @@ class NodeReadinessChecker:
Returns: Returns:
List of node IDs that are now ready List of node IDs that are now ready
""" """
ready_nodes = [] ready_nodes: list[str] = []
outgoing_edges = self.graph.get_outgoing_edges(from_node_id) outgoing_edges = self.graph.get_outgoing_edges(from_node_id)
for edge in outgoing_edges: for edge in outgoing_edges:

View File

@ -8,6 +8,8 @@ import threading
import time import time
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from core.workflow.graph_events.base import GraphNodeEventBase
from ..event_management import EventCollector, EventEmitter from ..event_management import EventCollector, EventEmitter
from .execution_coordinator import ExecutionCoordinator from .execution_coordinator import ExecutionCoordinator
@ -27,7 +29,7 @@ class Dispatcher:
def __init__( def __init__(
self, self,
event_queue: queue.Queue, event_queue: queue.Queue[GraphNodeEventBase],
event_handler: "EventHandlerRegistry", event_handler: "EventHandlerRegistry",
event_collector: EventCollector, event_collector: EventCollector,
execution_coordinator: ExecutionCoordinator, execution_coordinator: ExecutionCoordinator,

View File

@ -3,6 +3,7 @@ Manager for edge states during graph execution.
""" """
import threading import threading
from collections.abc import Sequence
from typing import TypedDict from typing import TypedDict
from core.workflow.enums import NodeState from core.workflow.enums import NodeState
@ -87,7 +88,7 @@ class EdgeStateManager:
with self._lock: with self._lock:
return self.graph.edges[edge_id].state return self.graph.edges[edge_id].state
def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[list[Edge], list[Edge]]: def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[Sequence[Edge], Sequence[Edge]]:
""" """
Categorize branch edges into selected and unselected. Categorize branch edges into selected and unselected.
@ -100,8 +101,8 @@ class EdgeStateManager:
""" """
with self._lock: with self._lock:
outgoing_edges = self.graph.get_outgoing_edges(node_id) outgoing_edges = self.graph.get_outgoing_edges(node_id)
selected_edges = [] selected_edges: list[Edge] = []
unselected_edges = [] unselected_edges: list[Edge] = []
for edge in outgoing_edges: for edge in outgoing_edges:
if edge.source_handle == selected_handle: if edge.source_handle == selected_handle: