chore(graph_events): Improve type hints

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2025-08-28 04:41:48 +08:00
parent 7cbf4093f4
commit 1cd0792606
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF
7 changed files with 74 additions and 85 deletions

View File

@ -2,7 +2,7 @@
Main error handler that coordinates error strategies.
"""
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, final
from core.workflow.enums import ErrorStrategy as ErrorStrategyEnum
from core.workflow.graph import Graph
@ -17,6 +17,7 @@ if TYPE_CHECKING:
from ..domain import GraphExecution
@final
class ErrorHandler:
"""
Coordinates error handling strategies for node failures.
@ -43,7 +44,7 @@ class ErrorHandler:
self.fail_branch_strategy = FailBranchStrategy()
self.default_value_strategy = DefaultValueStrategy()
def handle_node_failure(self, event: NodeRunFailedEvent) -> Optional[GraphNodeEventBase]:
def handle_node_failure(self, event: NodeRunFailedEvent) -> GraphNodeEventBase | None:
"""
Handle a node failure event.

View File

@ -3,7 +3,7 @@ Event handler implementations for different event types.
"""
import logging
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING
from core.workflow.entities import GraphRuntimeState
from core.workflow.enums import NodeExecutionType
@ -52,12 +52,12 @@ class EventHandlerRegistry:
graph_runtime_state: GraphRuntimeState,
graph_execution: GraphExecution,
response_coordinator: ResponseStreamCoordinator,
event_collector: Optional["EventCollector"] = None,
branch_handler: Optional["BranchHandler"] = None,
edge_processor: Optional["EdgeProcessor"] = None,
node_state_manager: Optional["NodeStateManager"] = None,
execution_tracker: Optional["ExecutionTracker"] = None,
error_handler: Optional["ErrorHandler"] = None,
event_collector: "EventCollector",
branch_handler: "BranchHandler",
edge_processor: "EdgeProcessor",
node_state_manager: "NodeStateManager",
execution_tracker: "ExecutionTracker",
error_handler: "ErrorHandler",
) -> None:
"""
Initialize the event handler registry.
@ -67,12 +67,12 @@ class EventHandlerRegistry:
graph_runtime_state: Runtime state with variable pool
graph_execution: Graph execution aggregate
response_coordinator: Response stream coordinator
event_collector: Optional event collector for collecting events
branch_handler: Optional branch handler for branch node processing
edge_processor: Optional edge processor for edge traversal
node_state_manager: Optional node state manager
execution_tracker: Optional execution tracker
error_handler: Optional error handler
event_collector: Event collector for collecting events
branch_handler: Branch handler for branch node processing
edge_processor: Edge processor for edge traversal
node_state_manager: Node state manager
execution_tracker: Execution tracker
error_handler: Error handler
"""
self.graph = graph
self.graph_runtime_state = graph_runtime_state
@ -93,9 +93,8 @@ class EventHandlerRegistry:
event: The event to handle
"""
# Events in loops or iterations are always collected
if isinstance(event, GraphNodeEventBase) and (event.in_loop_id or event.in_iteration_id):
if self.event_collector:
self.event_collector.collect(event)
if event.in_loop_id or event.in_iteration_id:
self.event_collector.collect(event)
return
# Handle specific event types
@ -125,12 +124,10 @@ class EventHandlerRegistry:
),
):
# Iteration and loop events are collected directly
if self.event_collector:
self.event_collector.collect(event)
self.event_collector.collect(event)
else:
# Collect unhandled events
if self.event_collector:
self.event_collector.collect(event)
self.event_collector.collect(event)
logger.warning("Unhandled event type: %s", type(event).__name__)
def _handle_node_started(self, event: NodeRunStartedEvent) -> None:
@ -148,8 +145,7 @@ class EventHandlerRegistry:
self.response_coordinator.track_node_execution(event.node_id, event.id)
# Collect the event
if self.event_collector:
self.event_collector.collect(event)
self.event_collector.collect(event)
def _handle_stream_chunk(self, event: NodeRunStreamChunkEvent) -> None:
"""
@ -162,9 +158,8 @@ class EventHandlerRegistry:
streaming_events = list(self.response_coordinator.intercept_event(event))
# Collect all events
if self.event_collector:
for stream_event in streaming_events:
self.event_collector.collect(stream_event)
for stream_event in streaming_events:
self.event_collector.collect(stream_event)
def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None:
"""
@ -184,48 +179,37 @@ class EventHandlerRegistry:
self._store_node_outputs(event)
# Forward to response coordinator and emit streaming events
streaming_events = list(self.response_coordinator.intercept_event(event))
if self.event_collector:
for stream_event in streaming_events:
self.event_collector.collect(stream_event)
streaming_events = self.response_coordinator.intercept_event(event)
for stream_event in streaming_events:
self.event_collector.collect(stream_event)
# Process edges and get ready nodes
node = self.graph.nodes[event.node_id]
if node.execution_type == NodeExecutionType.BRANCH:
if self.branch_handler:
ready_nodes, edge_streaming_events = self.branch_handler.handle_branch_completion(
event.node_id, event.node_run_result.edge_source_handle
)
else:
ready_nodes, edge_streaming_events = [], []
ready_nodes, edge_streaming_events = self.branch_handler.handle_branch_completion(
event.node_id, event.node_run_result.edge_source_handle
)
else:
if self.edge_processor:
ready_nodes, edge_streaming_events = self.edge_processor.process_node_success(event.node_id)
else:
ready_nodes, edge_streaming_events = [], []
ready_nodes, edge_streaming_events = self.edge_processor.process_node_success(event.node_id)
# Collect streaming events from edge processing
if self.event_collector:
for edge_event in edge_streaming_events:
self.event_collector.collect(edge_event)
for edge_event in edge_streaming_events:
self.event_collector.collect(edge_event)
# Enqueue ready nodes
if self.node_state_manager and self.execution_tracker:
for node_id in ready_nodes:
self.node_state_manager.enqueue_node(node_id)
self.execution_tracker.add(node_id)
for node_id in ready_nodes:
self.node_state_manager.enqueue_node(node_id)
self.execution_tracker.add(node_id)
# Update execution tracking
if self.execution_tracker:
self.execution_tracker.remove(event.node_id)
self.execution_tracker.remove(event.node_id)
# Handle response node outputs
if node.execution_type == NodeExecutionType.RESPONSE:
self._update_response_outputs(event)
# Collect the event
if self.event_collector:
self.event_collector.collect(event)
self.event_collector.collect(event)
def _handle_node_failed(self, event: NodeRunFailedEvent) -> None:
"""
@ -238,26 +222,16 @@ class EventHandlerRegistry:
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_failed(event.error)
if self.error_handler:
result = self.error_handler.handle_node_failure(event)
result = self.error_handler.handle_node_failure(event)
if result:
# Process the resulting event (retry, exception, etc.)
self.handle_event(result)
else:
# Abort execution
self.graph_execution.fail(RuntimeError(event.error))
if self.event_collector:
self.event_collector.collect(event)
if self.execution_tracker:
self.execution_tracker.remove(event.node_id)
if result:
# Process the resulting event (retry, exception, etc.)
self.handle_event(result)
else:
# Without error handler, just fail
# Abort execution
self.graph_execution.fail(RuntimeError(event.error))
if self.event_collector:
self.event_collector.collect(event)
if self.execution_tracker:
self.execution_tracker.remove(event.node_id)
self.event_collector.collect(event)
self.execution_tracker.remove(event.node_id)
def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None:
"""

View File

@ -2,9 +2,11 @@
Branch node handling for graph traversal.
"""
from collections.abc import Sequence
from typing import Optional
from core.workflow.graph import Graph
from core.workflow.graph_events.node import NodeRunStreamChunkEvent
from ..state_management import EdgeStateManager
from .edge_processor import EdgeProcessor
@ -40,7 +42,9 @@ class BranchHandler:
self.skip_propagator = skip_propagator
self.edge_state_manager = edge_state_manager
def handle_branch_completion(self, node_id: str, selected_handle: Optional[str]) -> tuple[list[str], list]:
def handle_branch_completion(
self, node_id: str, selected_handle: Optional[str]
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
"""
Handle completion of a branch node.

View File

@ -2,8 +2,11 @@
Edge processing logic for graph traversal.
"""
from collections.abc import Sequence
from core.workflow.enums import NodeExecutionType
from core.workflow.graph import Edge, Graph
from core.workflow.graph_events import NodeRunStreamChunkEvent
from ..response_coordinator import ResponseStreamCoordinator
from ..state_management import EdgeStateManager, NodeStateManager
@ -38,7 +41,9 @@ class EdgeProcessor:
self.node_state_manager = node_state_manager
self.response_coordinator = response_coordinator
def process_node_success(self, node_id: str, selected_handle: str | None = None) -> tuple[list[str], list]:
def process_node_success(
self, node_id: str, selected_handle: str | None = None
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
"""
Process edges after a node succeeds.
@ -56,7 +61,7 @@ class EdgeProcessor:
else:
return self._process_non_branch_node_edges(node_id)
def _process_non_branch_node_edges(self, node_id: str) -> tuple[list[str], list]:
def _process_non_branch_node_edges(self, node_id: str) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
"""
Process edges for non-branch nodes (mark all as TAKEN).
@ -66,8 +71,8 @@ class EdgeProcessor:
Returns:
Tuple of (list of downstream nodes ready for execution, list of streaming events)
"""
ready_nodes = []
all_streaming_events = []
ready_nodes: list[str] = []
all_streaming_events: list[NodeRunStreamChunkEvent] = []
outgoing_edges = self.graph.get_outgoing_edges(node_id)
for edge in outgoing_edges:
@ -77,7 +82,9 @@ class EdgeProcessor:
return ready_nodes, all_streaming_events
def _process_branch_node_edges(self, node_id: str, selected_handle: str | None) -> tuple[list[str], list]:
def _process_branch_node_edges(
self, node_id: str, selected_handle: str | None
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
"""
Process edges for branch nodes.
@ -94,8 +101,8 @@ class EdgeProcessor:
if not selected_handle:
raise ValueError(f"Branch node {node_id} did not select any edge")
ready_nodes = []
all_streaming_events = []
ready_nodes: list[str] = []
all_streaming_events: list[NodeRunStreamChunkEvent] = []
# Categorize edges
selected_edges, unselected_edges = self.edge_state_manager.categorize_branch_edges(node_id, selected_handle)
@ -112,7 +119,7 @@ class EdgeProcessor:
return ready_nodes, all_streaming_events
def _process_taken_edge(self, edge: Edge) -> tuple[list[str], list]:
def _process_taken_edge(self, edge: Edge) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
"""
Mark edge as taken and check downstream node.
@ -129,11 +136,11 @@ class EdgeProcessor:
streaming_events = self.response_coordinator.on_edge_taken(edge.id)
# Check if downstream node is ready
ready_nodes = []
ready_nodes: list[str] = []
if self.node_state_manager.is_node_ready(edge.head):
ready_nodes.append(edge.head)
return ready_nodes, list(streaming_events)
return ready_nodes, streaming_events
def _process_skipped_edge(self, edge: Edge) -> None:
"""

View File

@ -71,7 +71,7 @@ class NodeReadinessChecker:
Returns:
List of node IDs that are now ready
"""
ready_nodes = []
ready_nodes: list[str] = []
outgoing_edges = self.graph.get_outgoing_edges(from_node_id)
for edge in outgoing_edges:

View File

@ -8,6 +8,8 @@ import threading
import time
from typing import TYPE_CHECKING, Optional
from core.workflow.graph_events.base import GraphNodeEventBase
from ..event_management import EventCollector, EventEmitter
from .execution_coordinator import ExecutionCoordinator
@ -27,7 +29,7 @@ class Dispatcher:
def __init__(
self,
event_queue: queue.Queue,
event_queue: queue.Queue[GraphNodeEventBase],
event_handler: "EventHandlerRegistry",
event_collector: EventCollector,
execution_coordinator: ExecutionCoordinator,

View File

@ -3,6 +3,7 @@ Manager for edge states during graph execution.
"""
import threading
from collections.abc import Sequence
from typing import TypedDict
from core.workflow.enums import NodeState
@ -87,7 +88,7 @@ class EdgeStateManager:
with self._lock:
return self.graph.edges[edge_id].state
def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[list[Edge], list[Edge]]:
def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[Sequence[Edge], Sequence[Edge]]:
"""
Categorize branch edges into selected and unselected.
@ -100,8 +101,8 @@ class EdgeStateManager:
"""
with self._lock:
outgoing_edges = self.graph.get_outgoing_edges(node_id)
selected_edges = []
unselected_edges = []
selected_edges: list[Edge] = []
unselected_edges: list[Edge] = []
for edge in outgoing_edges:
if edge.source_handle == selected_handle: