mirror of https://github.com/langgenius/dify.git
chore(graph_events): Improve type hints
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
7cbf4093f4
commit
1cd0792606
|
|
@ -2,7 +2,7 @@
|
|||
Main error handler that coordinates error strategies.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from core.workflow.enums import ErrorStrategy as ErrorStrategyEnum
|
||||
from core.workflow.graph import Graph
|
||||
|
|
@ -17,6 +17,7 @@ if TYPE_CHECKING:
|
|||
from ..domain import GraphExecution
|
||||
|
||||
|
||||
@final
|
||||
class ErrorHandler:
|
||||
"""
|
||||
Coordinates error handling strategies for node failures.
|
||||
|
|
@ -43,7 +44,7 @@ class ErrorHandler:
|
|||
self.fail_branch_strategy = FailBranchStrategy()
|
||||
self.default_value_strategy = DefaultValueStrategy()
|
||||
|
||||
def handle_node_failure(self, event: NodeRunFailedEvent) -> Optional[GraphNodeEventBase]:
|
||||
def handle_node_failure(self, event: NodeRunFailedEvent) -> GraphNodeEventBase | None:
|
||||
"""
|
||||
Handle a node failure event.
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ Event handler implementations for different event types.
|
|||
"""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.enums import NodeExecutionType
|
||||
|
|
@ -52,12 +52,12 @@ class EventHandlerRegistry:
|
|||
graph_runtime_state: GraphRuntimeState,
|
||||
graph_execution: GraphExecution,
|
||||
response_coordinator: ResponseStreamCoordinator,
|
||||
event_collector: Optional["EventCollector"] = None,
|
||||
branch_handler: Optional["BranchHandler"] = None,
|
||||
edge_processor: Optional["EdgeProcessor"] = None,
|
||||
node_state_manager: Optional["NodeStateManager"] = None,
|
||||
execution_tracker: Optional["ExecutionTracker"] = None,
|
||||
error_handler: Optional["ErrorHandler"] = None,
|
||||
event_collector: "EventCollector",
|
||||
branch_handler: "BranchHandler",
|
||||
edge_processor: "EdgeProcessor",
|
||||
node_state_manager: "NodeStateManager",
|
||||
execution_tracker: "ExecutionTracker",
|
||||
error_handler: "ErrorHandler",
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the event handler registry.
|
||||
|
|
@ -67,12 +67,12 @@ class EventHandlerRegistry:
|
|||
graph_runtime_state: Runtime state with variable pool
|
||||
graph_execution: Graph execution aggregate
|
||||
response_coordinator: Response stream coordinator
|
||||
event_collector: Optional event collector for collecting events
|
||||
branch_handler: Optional branch handler for branch node processing
|
||||
edge_processor: Optional edge processor for edge traversal
|
||||
node_state_manager: Optional node state manager
|
||||
execution_tracker: Optional execution tracker
|
||||
error_handler: Optional error handler
|
||||
event_collector: Event collector for collecting events
|
||||
branch_handler: Branch handler for branch node processing
|
||||
edge_processor: Edge processor for edge traversal
|
||||
node_state_manager: Node state manager
|
||||
execution_tracker: Execution tracker
|
||||
error_handler: Error handler
|
||||
"""
|
||||
self.graph = graph
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
|
|
@ -93,9 +93,8 @@ class EventHandlerRegistry:
|
|||
event: The event to handle
|
||||
"""
|
||||
# Events in loops or iterations are always collected
|
||||
if isinstance(event, GraphNodeEventBase) and (event.in_loop_id or event.in_iteration_id):
|
||||
if self.event_collector:
|
||||
self.event_collector.collect(event)
|
||||
if event.in_loop_id or event.in_iteration_id:
|
||||
self.event_collector.collect(event)
|
||||
return
|
||||
|
||||
# Handle specific event types
|
||||
|
|
@ -125,12 +124,10 @@ class EventHandlerRegistry:
|
|||
),
|
||||
):
|
||||
# Iteration and loop events are collected directly
|
||||
if self.event_collector:
|
||||
self.event_collector.collect(event)
|
||||
self.event_collector.collect(event)
|
||||
else:
|
||||
# Collect unhandled events
|
||||
if self.event_collector:
|
||||
self.event_collector.collect(event)
|
||||
self.event_collector.collect(event)
|
||||
logger.warning("Unhandled event type: %s", type(event).__name__)
|
||||
|
||||
def _handle_node_started(self, event: NodeRunStartedEvent) -> None:
|
||||
|
|
@ -148,8 +145,7 @@ class EventHandlerRegistry:
|
|||
self.response_coordinator.track_node_execution(event.node_id, event.id)
|
||||
|
||||
# Collect the event
|
||||
if self.event_collector:
|
||||
self.event_collector.collect(event)
|
||||
self.event_collector.collect(event)
|
||||
|
||||
def _handle_stream_chunk(self, event: NodeRunStreamChunkEvent) -> None:
|
||||
"""
|
||||
|
|
@ -162,9 +158,8 @@ class EventHandlerRegistry:
|
|||
streaming_events = list(self.response_coordinator.intercept_event(event))
|
||||
|
||||
# Collect all events
|
||||
if self.event_collector:
|
||||
for stream_event in streaming_events:
|
||||
self.event_collector.collect(stream_event)
|
||||
for stream_event in streaming_events:
|
||||
self.event_collector.collect(stream_event)
|
||||
|
||||
def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None:
|
||||
"""
|
||||
|
|
@ -184,48 +179,37 @@ class EventHandlerRegistry:
|
|||
self._store_node_outputs(event)
|
||||
|
||||
# Forward to response coordinator and emit streaming events
|
||||
streaming_events = list(self.response_coordinator.intercept_event(event))
|
||||
if self.event_collector:
|
||||
for stream_event in streaming_events:
|
||||
self.event_collector.collect(stream_event)
|
||||
streaming_events = self.response_coordinator.intercept_event(event)
|
||||
for stream_event in streaming_events:
|
||||
self.event_collector.collect(stream_event)
|
||||
|
||||
# Process edges and get ready nodes
|
||||
node = self.graph.nodes[event.node_id]
|
||||
if node.execution_type == NodeExecutionType.BRANCH:
|
||||
if self.branch_handler:
|
||||
ready_nodes, edge_streaming_events = self.branch_handler.handle_branch_completion(
|
||||
event.node_id, event.node_run_result.edge_source_handle
|
||||
)
|
||||
else:
|
||||
ready_nodes, edge_streaming_events = [], []
|
||||
ready_nodes, edge_streaming_events = self.branch_handler.handle_branch_completion(
|
||||
event.node_id, event.node_run_result.edge_source_handle
|
||||
)
|
||||
else:
|
||||
if self.edge_processor:
|
||||
ready_nodes, edge_streaming_events = self.edge_processor.process_node_success(event.node_id)
|
||||
else:
|
||||
ready_nodes, edge_streaming_events = [], []
|
||||
ready_nodes, edge_streaming_events = self.edge_processor.process_node_success(event.node_id)
|
||||
|
||||
# Collect streaming events from edge processing
|
||||
if self.event_collector:
|
||||
for edge_event in edge_streaming_events:
|
||||
self.event_collector.collect(edge_event)
|
||||
for edge_event in edge_streaming_events:
|
||||
self.event_collector.collect(edge_event)
|
||||
|
||||
# Enqueue ready nodes
|
||||
if self.node_state_manager and self.execution_tracker:
|
||||
for node_id in ready_nodes:
|
||||
self.node_state_manager.enqueue_node(node_id)
|
||||
self.execution_tracker.add(node_id)
|
||||
for node_id in ready_nodes:
|
||||
self.node_state_manager.enqueue_node(node_id)
|
||||
self.execution_tracker.add(node_id)
|
||||
|
||||
# Update execution tracking
|
||||
if self.execution_tracker:
|
||||
self.execution_tracker.remove(event.node_id)
|
||||
self.execution_tracker.remove(event.node_id)
|
||||
|
||||
# Handle response node outputs
|
||||
if node.execution_type == NodeExecutionType.RESPONSE:
|
||||
self._update_response_outputs(event)
|
||||
|
||||
# Collect the event
|
||||
if self.event_collector:
|
||||
self.event_collector.collect(event)
|
||||
self.event_collector.collect(event)
|
||||
|
||||
def _handle_node_failed(self, event: NodeRunFailedEvent) -> None:
|
||||
"""
|
||||
|
|
@ -238,26 +222,16 @@ class EventHandlerRegistry:
|
|||
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_failed(event.error)
|
||||
|
||||
if self.error_handler:
|
||||
result = self.error_handler.handle_node_failure(event)
|
||||
result = self.error_handler.handle_node_failure(event)
|
||||
|
||||
if result:
|
||||
# Process the resulting event (retry, exception, etc.)
|
||||
self.handle_event(result)
|
||||
else:
|
||||
# Abort execution
|
||||
self.graph_execution.fail(RuntimeError(event.error))
|
||||
if self.event_collector:
|
||||
self.event_collector.collect(event)
|
||||
if self.execution_tracker:
|
||||
self.execution_tracker.remove(event.node_id)
|
||||
if result:
|
||||
# Process the resulting event (retry, exception, etc.)
|
||||
self.handle_event(result)
|
||||
else:
|
||||
# Without error handler, just fail
|
||||
# Abort execution
|
||||
self.graph_execution.fail(RuntimeError(event.error))
|
||||
if self.event_collector:
|
||||
self.event_collector.collect(event)
|
||||
if self.execution_tracker:
|
||||
self.execution_tracker.remove(event.node_id)
|
||||
self.event_collector.collect(event)
|
||||
self.execution_tracker.remove(event.node_id)
|
||||
|
||||
def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -2,9 +2,11 @@
|
|||
Branch node handling for graph traversal.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events.node import NodeRunStreamChunkEvent
|
||||
|
||||
from ..state_management import EdgeStateManager
|
||||
from .edge_processor import EdgeProcessor
|
||||
|
|
@ -40,7 +42,9 @@ class BranchHandler:
|
|||
self.skip_propagator = skip_propagator
|
||||
self.edge_state_manager = edge_state_manager
|
||||
|
||||
def handle_branch_completion(self, node_id: str, selected_handle: Optional[str]) -> tuple[list[str], list]:
|
||||
def handle_branch_completion(
|
||||
self, node_id: str, selected_handle: Optional[str]
|
||||
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Handle completion of a branch node.
|
||||
|
||||
|
|
|
|||
|
|
@ -2,8 +2,11 @@
|
|||
Edge processing logic for graph traversal.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from core.workflow.enums import NodeExecutionType
|
||||
from core.workflow.graph import Edge, Graph
|
||||
from core.workflow.graph_events import NodeRunStreamChunkEvent
|
||||
|
||||
from ..response_coordinator import ResponseStreamCoordinator
|
||||
from ..state_management import EdgeStateManager, NodeStateManager
|
||||
|
|
@ -38,7 +41,9 @@ class EdgeProcessor:
|
|||
self.node_state_manager = node_state_manager
|
||||
self.response_coordinator = response_coordinator
|
||||
|
||||
def process_node_success(self, node_id: str, selected_handle: str | None = None) -> tuple[list[str], list]:
|
||||
def process_node_success(
|
||||
self, node_id: str, selected_handle: str | None = None
|
||||
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Process edges after a node succeeds.
|
||||
|
||||
|
|
@ -56,7 +61,7 @@ class EdgeProcessor:
|
|||
else:
|
||||
return self._process_non_branch_node_edges(node_id)
|
||||
|
||||
def _process_non_branch_node_edges(self, node_id: str) -> tuple[list[str], list]:
|
||||
def _process_non_branch_node_edges(self, node_id: str) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Process edges for non-branch nodes (mark all as TAKEN).
|
||||
|
||||
|
|
@ -66,8 +71,8 @@ class EdgeProcessor:
|
|||
Returns:
|
||||
Tuple of (list of downstream nodes ready for execution, list of streaming events)
|
||||
"""
|
||||
ready_nodes = []
|
||||
all_streaming_events = []
|
||||
ready_nodes: list[str] = []
|
||||
all_streaming_events: list[NodeRunStreamChunkEvent] = []
|
||||
outgoing_edges = self.graph.get_outgoing_edges(node_id)
|
||||
|
||||
for edge in outgoing_edges:
|
||||
|
|
@ -77,7 +82,9 @@ class EdgeProcessor:
|
|||
|
||||
return ready_nodes, all_streaming_events
|
||||
|
||||
def _process_branch_node_edges(self, node_id: str, selected_handle: str | None) -> tuple[list[str], list]:
|
||||
def _process_branch_node_edges(
|
||||
self, node_id: str, selected_handle: str | None
|
||||
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Process edges for branch nodes.
|
||||
|
||||
|
|
@ -94,8 +101,8 @@ class EdgeProcessor:
|
|||
if not selected_handle:
|
||||
raise ValueError(f"Branch node {node_id} did not select any edge")
|
||||
|
||||
ready_nodes = []
|
||||
all_streaming_events = []
|
||||
ready_nodes: list[str] = []
|
||||
all_streaming_events: list[NodeRunStreamChunkEvent] = []
|
||||
|
||||
# Categorize edges
|
||||
selected_edges, unselected_edges = self.edge_state_manager.categorize_branch_edges(node_id, selected_handle)
|
||||
|
|
@ -112,7 +119,7 @@ class EdgeProcessor:
|
|||
|
||||
return ready_nodes, all_streaming_events
|
||||
|
||||
def _process_taken_edge(self, edge: Edge) -> tuple[list[str], list]:
|
||||
def _process_taken_edge(self, edge: Edge) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Mark edge as taken and check downstream node.
|
||||
|
||||
|
|
@ -129,11 +136,11 @@ class EdgeProcessor:
|
|||
streaming_events = self.response_coordinator.on_edge_taken(edge.id)
|
||||
|
||||
# Check if downstream node is ready
|
||||
ready_nodes = []
|
||||
ready_nodes: list[str] = []
|
||||
if self.node_state_manager.is_node_ready(edge.head):
|
||||
ready_nodes.append(edge.head)
|
||||
|
||||
return ready_nodes, list(streaming_events)
|
||||
return ready_nodes, streaming_events
|
||||
|
||||
def _process_skipped_edge(self, edge: Edge) -> None:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ class NodeReadinessChecker:
|
|||
Returns:
|
||||
List of node IDs that are now ready
|
||||
"""
|
||||
ready_nodes = []
|
||||
ready_nodes: list[str] = []
|
||||
outgoing_edges = self.graph.get_outgoing_edges(from_node_id)
|
||||
|
||||
for edge in outgoing_edges:
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ import threading
|
|||
import time
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from core.workflow.graph_events.base import GraphNodeEventBase
|
||||
|
||||
from ..event_management import EventCollector, EventEmitter
|
||||
from .execution_coordinator import ExecutionCoordinator
|
||||
|
||||
|
|
@ -27,7 +29,7 @@ class Dispatcher:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
event_queue: queue.Queue,
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
event_handler: "EventHandlerRegistry",
|
||||
event_collector: EventCollector,
|
||||
execution_coordinator: ExecutionCoordinator,
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ Manager for edge states during graph execution.
|
|||
"""
|
||||
|
||||
import threading
|
||||
from collections.abc import Sequence
|
||||
from typing import TypedDict
|
||||
|
||||
from core.workflow.enums import NodeState
|
||||
|
|
@ -87,7 +88,7 @@ class EdgeStateManager:
|
|||
with self._lock:
|
||||
return self.graph.edges[edge_id].state
|
||||
|
||||
def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[list[Edge], list[Edge]]:
|
||||
def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[Sequence[Edge], Sequence[Edge]]:
|
||||
"""
|
||||
Categorize branch edges into selected and unselected.
|
||||
|
||||
|
|
@ -100,8 +101,8 @@ class EdgeStateManager:
|
|||
"""
|
||||
with self._lock:
|
||||
outgoing_edges = self.graph.get_outgoing_edges(node_id)
|
||||
selected_edges = []
|
||||
unselected_edges = []
|
||||
selected_edges: list[Edge] = []
|
||||
unselected_edges: list[Edge] = []
|
||||
|
||||
for edge in outgoing_edges:
|
||||
if edge.source_handle == selected_handle:
|
||||
|
|
|
|||
Loading…
Reference in New Issue