diff --git a/api/core/workflow/graph_engine/event_management/event_handlers.py b/api/core/workflow/graph_engine/event_management/event_handlers.py index 63929381de..244f4a4d86 100644 --- a/api/core/workflow/graph_engine/event_management/event_handlers.py +++ b/api/core/workflow/graph_engine/event_management/event_handlers.py @@ -3,6 +3,7 @@ Event handler implementations for different event types. """ import logging +from functools import singledispatchmethod from typing import TYPE_CHECKING, final from core.workflow.entities import GraphRuntimeState @@ -81,7 +82,7 @@ class EventHandler: self._state_manager = state_manager self._error_handler = error_handler - def handle_event(self, event: GraphNodeEventBase) -> None: + def dispatch(self, event: GraphNodeEventBase) -> None: """ Handle any node event by dispatching to the appropriate handler. @@ -92,42 +93,27 @@ class EventHandler: if event.in_loop_id or event.in_iteration_id: self._event_collector.collect(event) return + return self._dispatch(event) - # Handle specific event types - if isinstance(event, NodeRunStartedEvent): - self._handle_node_started(event) - elif isinstance(event, NodeRunStreamChunkEvent): - self._handle_stream_chunk(event) - elif isinstance(event, NodeRunSucceededEvent): - self._handle_node_succeeded(event) - elif isinstance(event, NodeRunFailedEvent): - self._handle_node_failed(event) - elif isinstance(event, NodeRunExceptionEvent): - self._handle_node_exception(event) - elif isinstance(event, NodeRunRetryEvent): - self._handle_node_retry(event) - elif isinstance( - event, - ( - NodeRunIterationStartedEvent, - NodeRunIterationNextEvent, - NodeRunIterationSucceededEvent, - NodeRunIterationFailedEvent, - NodeRunLoopStartedEvent, - NodeRunLoopNextEvent, - NodeRunLoopSucceededEvent, - NodeRunLoopFailedEvent, - NodeRunAgentLogEvent, - ), - ): - # Iteration and loop events are collected directly - self._event_collector.collect(event) - else: - # Collect unhandled events - self._event_collector.collect(event) - logger.warning("Unhandled event type: %s", type(event).__name__) + @singledispatchmethod + def _dispatch(self, event: GraphNodeEventBase) -> None: + self._event_collector.collect(event) + logger.warning("Unhandled event type: %s", type(event).__name__) - def _handle_node_started(self, event: NodeRunStartedEvent) -> None: + @_dispatch.register(NodeRunIterationStartedEvent) + @_dispatch.register(NodeRunIterationNextEvent) + @_dispatch.register(NodeRunIterationSucceededEvent) + @_dispatch.register(NodeRunIterationFailedEvent) + @_dispatch.register(NodeRunLoopStartedEvent) + @_dispatch.register(NodeRunLoopNextEvent) + @_dispatch.register(NodeRunLoopSucceededEvent) + @_dispatch.register(NodeRunLoopFailedEvent) + @_dispatch.register(NodeRunAgentLogEvent) + def _(self, event: GraphNodeEventBase) -> None: + self._event_collector.collect(event) + + @_dispatch.register + def _(self, event: NodeRunStartedEvent) -> None: """ Handle node started event. @@ -144,7 +130,8 @@ class EventHandler: # Collect the event self._event_collector.collect(event) - def _handle_stream_chunk(self, event: NodeRunStreamChunkEvent) -> None: + @_dispatch.register + def _(self, event: NodeRunStreamChunkEvent) -> None: """ Handle stream chunk event with full processing. @@ -158,7 +145,8 @@ class EventHandler: for stream_event in streaming_events: self._event_collector.collect(stream_event) - def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None: + @_dispatch.register + def _(self, event: NodeRunSucceededEvent) -> None: """ Handle node success by coordinating subsystems. @@ -208,7 +196,8 @@ class EventHandler: # Collect the event self._event_collector.collect(event) - def _handle_node_failed(self, event: NodeRunFailedEvent) -> None: + @_dispatch.register + def _(self, event: NodeRunFailedEvent) -> None: """ Handle node failure using error handler. @@ -223,14 +212,15 @@ class EventHandler: if result: # Process the resulting event (retry, exception, etc.) - self.handle_event(result) + self.dispatch(result) else: # Abort execution self._graph_execution.fail(RuntimeError(event.error)) self._event_collector.collect(event) self._state_manager.finish_execution(event.node_id) - def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None: + @_dispatch.register + def _(self, event: NodeRunExceptionEvent) -> None: """ Handle node exception event (fail-branch strategy). @@ -241,7 +231,8 @@ class EventHandler: node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) node_execution.mark_taken() - def _handle_node_retry(self, event: NodeRunRetryEvent) -> None: + @_dispatch.register + def _(self, event: NodeRunRetryEvent) -> None: """ Handle node retry event. diff --git a/api/core/workflow/graph_engine/orchestration/dispatcher.py b/api/core/workflow/graph_engine/orchestration/dispatcher.py index bb4720a684..a7229ce4e8 100644 --- a/api/core/workflow/graph_engine/orchestration/dispatcher.py +++ b/api/core/workflow/graph_engine/orchestration/dispatcher.py @@ -86,7 +86,7 @@ class Dispatcher: try: event = self._event_queue.get(timeout=0.1) # Route to the event handler - self._event_handler.handle_event(event) + self._event_handler.dispatch(event) self._event_queue.task_done() except queue.Empty: # Check if execution is complete