fix: type hints

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2025-08-28 05:24:18 +08:00
parent 8aab7f49c3
commit e3a7b1f691
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF
3 changed files with 46 additions and 45 deletions

View File

@ -3,7 +3,7 @@ Event handler implementations for different event types.
"""
import logging
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, final
from core.workflow.entities import GraphRuntimeState
from core.workflow.enums import NodeExecutionType
@ -38,6 +38,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
@final
class EventHandlerRegistry:
"""
Registry of event handlers for different event types.
@ -74,16 +75,16 @@ class EventHandlerRegistry:
execution_tracker: Execution tracker
error_handler: Error handler
"""
self.graph = graph
self.graph_runtime_state = graph_runtime_state
self.graph_execution = graph_execution
self.response_coordinator = response_coordinator
self.event_collector = event_collector
self.branch_handler = branch_handler
self.edge_processor = edge_processor
self.node_state_manager = node_state_manager
self.execution_tracker = execution_tracker
self.error_handler = error_handler
self._graph = graph
self._graph_runtime_state = graph_runtime_state
self._graph_execution = graph_execution
self._response_coordinator = response_coordinator
self._event_collector = event_collector
self._branch_handler = branch_handler
self._edge_processor = edge_processor
self._node_state_manager = node_state_manager
self._execution_tracker = execution_tracker
self._error_handler = error_handler
def handle_event(self, event: GraphNodeEventBase) -> None:
"""
@ -94,7 +95,7 @@ class EventHandlerRegistry:
"""
# Events in loops or iterations are always collected
if event.in_loop_id or event.in_iteration_id:
self.event_collector.collect(event)
self._event_collector.collect(event)
return
# Handle specific event types
@ -124,10 +125,10 @@ class EventHandlerRegistry:
),
):
# Iteration and loop events are collected directly
self.event_collector.collect(event)
self._event_collector.collect(event)
else:
# Collect unhandled events
self.event_collector.collect(event)
self._event_collector.collect(event)
logger.warning("Unhandled event type: %s", type(event).__name__)
def _handle_node_started(self, event: NodeRunStartedEvent) -> None:
@ -138,14 +139,14 @@ class EventHandlerRegistry:
event: The node started event
"""
# Track execution in domain model
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_started(event.id)
# Track in response coordinator for stream ordering
self.response_coordinator.track_node_execution(event.node_id, event.id)
self._response_coordinator.track_node_execution(event.node_id, event.id)
# Collect the event
self.event_collector.collect(event)
self._event_collector.collect(event)
def _handle_stream_chunk(self, event: NodeRunStreamChunkEvent) -> None:
"""
@ -155,11 +156,11 @@ class EventHandlerRegistry:
event: The stream chunk event
"""
# Process with response coordinator
streaming_events = list(self.response_coordinator.intercept_event(event))
streaming_events = list(self._response_coordinator.intercept_event(event))
# Collect all events
for stream_event in streaming_events:
self.event_collector.collect(stream_event)
self._event_collector.collect(stream_event)
def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None:
"""
@ -172,44 +173,44 @@ class EventHandlerRegistry:
event: The node succeeded event
"""
# Update domain model
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_taken()
# Store outputs in variable pool
self._store_node_outputs(event)
# Forward to response coordinator and emit streaming events
streaming_events = self.response_coordinator.intercept_event(event)
streaming_events = self._response_coordinator.intercept_event(event)
for stream_event in streaming_events:
self.event_collector.collect(stream_event)
self._event_collector.collect(stream_event)
# Process edges and get ready nodes
node = self.graph.nodes[event.node_id]
node = self._graph.nodes[event.node_id]
if node.execution_type == NodeExecutionType.BRANCH:
ready_nodes, edge_streaming_events = self.branch_handler.handle_branch_completion(
ready_nodes, edge_streaming_events = self._branch_handler.handle_branch_completion(
event.node_id, event.node_run_result.edge_source_handle
)
else:
ready_nodes, edge_streaming_events = self.edge_processor.process_node_success(event.node_id)
ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id)
# Collect streaming events from edge processing
for edge_event in edge_streaming_events:
self.event_collector.collect(edge_event)
self._event_collector.collect(edge_event)
# Enqueue ready nodes
for node_id in ready_nodes:
self.node_state_manager.enqueue_node(node_id)
self.execution_tracker.add(node_id)
self._node_state_manager.enqueue_node(node_id)
self._execution_tracker.add(node_id)
# Update execution tracking
self.execution_tracker.remove(event.node_id)
self._execution_tracker.remove(event.node_id)
# Handle response node outputs
if node.execution_type == NodeExecutionType.RESPONSE:
self._update_response_outputs(event)
# Collect the event
self.event_collector.collect(event)
self._event_collector.collect(event)
def _handle_node_failed(self, event: NodeRunFailedEvent) -> None:
"""
@ -219,19 +220,19 @@ class EventHandlerRegistry:
event: The node failed event
"""
# Update domain model
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_failed(event.error)
result = self.error_handler.handle_node_failure(event)
result = self._error_handler.handle_node_failure(event)
if result:
# Process the resulting event (retry, exception, etc.)
self.handle_event(result)
else:
# Abort execution
self.graph_execution.fail(RuntimeError(event.error))
self.event_collector.collect(event)
self.execution_tracker.remove(event.node_id)
self._graph_execution.fail(RuntimeError(event.error))
self._event_collector.collect(event)
self._execution_tracker.remove(event.node_id)
def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None:
"""
@ -241,7 +242,7 @@ class EventHandlerRegistry:
event: The node exception event
"""
# Node continues via fail-branch, so it's technically "succeeded"
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_taken()
def _handle_node_retry(self, event: NodeRunRetryEvent) -> None:
@ -251,7 +252,7 @@ class EventHandlerRegistry:
Args:
event: The node retry event
"""
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.increment_retry()
def _store_node_outputs(self, event: NodeRunSucceededEvent) -> None:
@ -262,16 +263,16 @@ class EventHandlerRegistry:
event: The node succeeded event containing outputs
"""
for variable_name, variable_value in event.node_run_result.outputs.items():
self.graph_runtime_state.variable_pool.add((event.node_id, variable_name), variable_value)
self._graph_runtime_state.variable_pool.add((event.node_id, variable_name), variable_value)
def _update_response_outputs(self, event: NodeRunSucceededEvent) -> None:
"""Update response outputs for response nodes."""
for key, value in event.node_run_result.outputs.items():
if key == "answer":
existing = self.graph_runtime_state.outputs.get("answer", "")
existing = self._graph_runtime_state.outputs.get("answer", "")
if existing:
self.graph_runtime_state.outputs["answer"] = f"{existing}{value}"
self._graph_runtime_state.outputs["answer"] = f"{existing}{value}"
else:
self.graph_runtime_state.outputs["answer"] = value
self._graph_runtime_state.outputs["answer"] = value
else:
self.graph_runtime_state.outputs[key] = value
self._graph_runtime_state.outputs[key] = value

View File

@ -47,7 +47,7 @@ class OutputRegistry:
with self._lock:
self._scalars.add(selector, value)
def get_scalar(self, selector: Sequence[str]) -> "Segment" | None:
def get_scalar(self, selector: Sequence[str]) -> "Segment | None":
"""
Get a scalar value for the given selector.
@ -81,7 +81,7 @@ class OutputRegistry:
except ValueError:
raise ValueError(f"Stream {'.'.join(selector)} is already closed")
def pop_chunk(self, selector: Sequence[str]) -> "NodeRunStreamChunkEvent" | None:
def pop_chunk(self, selector: Sequence[str]) -> "NodeRunStreamChunkEvent | None":
"""
Pop the next unread NodeRunStreamChunkEvent from the stream.

View File

@ -41,7 +41,7 @@ class Stream:
raise ValueError("Cannot append to a closed stream")
self.events.append(event)
def pop_next(self) -> "NodeRunStreamChunkEvent" | None:
def pop_next(self) -> "NodeRunStreamChunkEvent | None":
"""
Pop the next unread NodeRunStreamChunkEvent from the stream.