mirror of https://github.com/langgenius/dify.git
fix: type hints
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
8aab7f49c3
commit
e3a7b1f691
|
|
@ -3,7 +3,7 @@ Event handler implementations for different event types.
|
|||
"""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.enums import NodeExecutionType
|
||||
|
|
@ -38,6 +38,7 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class EventHandlerRegistry:
|
||||
"""
|
||||
Registry of event handlers for different event types.
|
||||
|
|
@ -74,16 +75,16 @@ class EventHandlerRegistry:
|
|||
execution_tracker: Execution tracker
|
||||
error_handler: Error handler
|
||||
"""
|
||||
self.graph = graph
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self.graph_execution = graph_execution
|
||||
self.response_coordinator = response_coordinator
|
||||
self.event_collector = event_collector
|
||||
self.branch_handler = branch_handler
|
||||
self.edge_processor = edge_processor
|
||||
self.node_state_manager = node_state_manager
|
||||
self.execution_tracker = execution_tracker
|
||||
self.error_handler = error_handler
|
||||
self._graph = graph
|
||||
self._graph_runtime_state = graph_runtime_state
|
||||
self._graph_execution = graph_execution
|
||||
self._response_coordinator = response_coordinator
|
||||
self._event_collector = event_collector
|
||||
self._branch_handler = branch_handler
|
||||
self._edge_processor = edge_processor
|
||||
self._node_state_manager = node_state_manager
|
||||
self._execution_tracker = execution_tracker
|
||||
self._error_handler = error_handler
|
||||
|
||||
def handle_event(self, event: GraphNodeEventBase) -> None:
|
||||
"""
|
||||
|
|
@ -94,7 +95,7 @@ class EventHandlerRegistry:
|
|||
"""
|
||||
# Events in loops or iterations are always collected
|
||||
if event.in_loop_id or event.in_iteration_id:
|
||||
self.event_collector.collect(event)
|
||||
self._event_collector.collect(event)
|
||||
return
|
||||
|
||||
# Handle specific event types
|
||||
|
|
@ -124,10 +125,10 @@ class EventHandlerRegistry:
|
|||
),
|
||||
):
|
||||
# Iteration and loop events are collected directly
|
||||
self.event_collector.collect(event)
|
||||
self._event_collector.collect(event)
|
||||
else:
|
||||
# Collect unhandled events
|
||||
self.event_collector.collect(event)
|
||||
self._event_collector.collect(event)
|
||||
logger.warning("Unhandled event type: %s", type(event).__name__)
|
||||
|
||||
def _handle_node_started(self, event: NodeRunStartedEvent) -> None:
|
||||
|
|
@ -138,14 +139,14 @@ class EventHandlerRegistry:
|
|||
event: The node started event
|
||||
"""
|
||||
# Track execution in domain model
|
||||
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_started(event.id)
|
||||
|
||||
# Track in response coordinator for stream ordering
|
||||
self.response_coordinator.track_node_execution(event.node_id, event.id)
|
||||
self._response_coordinator.track_node_execution(event.node_id, event.id)
|
||||
|
||||
# Collect the event
|
||||
self.event_collector.collect(event)
|
||||
self._event_collector.collect(event)
|
||||
|
||||
def _handle_stream_chunk(self, event: NodeRunStreamChunkEvent) -> None:
|
||||
"""
|
||||
|
|
@ -155,11 +156,11 @@ class EventHandlerRegistry:
|
|||
event: The stream chunk event
|
||||
"""
|
||||
# Process with response coordinator
|
||||
streaming_events = list(self.response_coordinator.intercept_event(event))
|
||||
streaming_events = list(self._response_coordinator.intercept_event(event))
|
||||
|
||||
# Collect all events
|
||||
for stream_event in streaming_events:
|
||||
self.event_collector.collect(stream_event)
|
||||
self._event_collector.collect(stream_event)
|
||||
|
||||
def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None:
|
||||
"""
|
||||
|
|
@ -172,44 +173,44 @@ class EventHandlerRegistry:
|
|||
event: The node succeeded event
|
||||
"""
|
||||
# Update domain model
|
||||
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_taken()
|
||||
|
||||
# Store outputs in variable pool
|
||||
self._store_node_outputs(event)
|
||||
|
||||
# Forward to response coordinator and emit streaming events
|
||||
streaming_events = self.response_coordinator.intercept_event(event)
|
||||
streaming_events = self._response_coordinator.intercept_event(event)
|
||||
for stream_event in streaming_events:
|
||||
self.event_collector.collect(stream_event)
|
||||
self._event_collector.collect(stream_event)
|
||||
|
||||
# Process edges and get ready nodes
|
||||
node = self.graph.nodes[event.node_id]
|
||||
node = self._graph.nodes[event.node_id]
|
||||
if node.execution_type == NodeExecutionType.BRANCH:
|
||||
ready_nodes, edge_streaming_events = self.branch_handler.handle_branch_completion(
|
||||
ready_nodes, edge_streaming_events = self._branch_handler.handle_branch_completion(
|
||||
event.node_id, event.node_run_result.edge_source_handle
|
||||
)
|
||||
else:
|
||||
ready_nodes, edge_streaming_events = self.edge_processor.process_node_success(event.node_id)
|
||||
ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id)
|
||||
|
||||
# Collect streaming events from edge processing
|
||||
for edge_event in edge_streaming_events:
|
||||
self.event_collector.collect(edge_event)
|
||||
self._event_collector.collect(edge_event)
|
||||
|
||||
# Enqueue ready nodes
|
||||
for node_id in ready_nodes:
|
||||
self.node_state_manager.enqueue_node(node_id)
|
||||
self.execution_tracker.add(node_id)
|
||||
self._node_state_manager.enqueue_node(node_id)
|
||||
self._execution_tracker.add(node_id)
|
||||
|
||||
# Update execution tracking
|
||||
self.execution_tracker.remove(event.node_id)
|
||||
self._execution_tracker.remove(event.node_id)
|
||||
|
||||
# Handle response node outputs
|
||||
if node.execution_type == NodeExecutionType.RESPONSE:
|
||||
self._update_response_outputs(event)
|
||||
|
||||
# Collect the event
|
||||
self.event_collector.collect(event)
|
||||
self._event_collector.collect(event)
|
||||
|
||||
def _handle_node_failed(self, event: NodeRunFailedEvent) -> None:
|
||||
"""
|
||||
|
|
@ -219,19 +220,19 @@ class EventHandlerRegistry:
|
|||
event: The node failed event
|
||||
"""
|
||||
# Update domain model
|
||||
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_failed(event.error)
|
||||
|
||||
result = self.error_handler.handle_node_failure(event)
|
||||
result = self._error_handler.handle_node_failure(event)
|
||||
|
||||
if result:
|
||||
# Process the resulting event (retry, exception, etc.)
|
||||
self.handle_event(result)
|
||||
else:
|
||||
# Abort execution
|
||||
self.graph_execution.fail(RuntimeError(event.error))
|
||||
self.event_collector.collect(event)
|
||||
self.execution_tracker.remove(event.node_id)
|
||||
self._graph_execution.fail(RuntimeError(event.error))
|
||||
self._event_collector.collect(event)
|
||||
self._execution_tracker.remove(event.node_id)
|
||||
|
||||
def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None:
|
||||
"""
|
||||
|
|
@ -241,7 +242,7 @@ class EventHandlerRegistry:
|
|||
event: The node exception event
|
||||
"""
|
||||
# Node continues via fail-branch, so it's technically "succeeded"
|
||||
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_taken()
|
||||
|
||||
def _handle_node_retry(self, event: NodeRunRetryEvent) -> None:
|
||||
|
|
@ -251,7 +252,7 @@ class EventHandlerRegistry:
|
|||
Args:
|
||||
event: The node retry event
|
||||
"""
|
||||
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.increment_retry()
|
||||
|
||||
def _store_node_outputs(self, event: NodeRunSucceededEvent) -> None:
|
||||
|
|
@ -262,16 +263,16 @@ class EventHandlerRegistry:
|
|||
event: The node succeeded event containing outputs
|
||||
"""
|
||||
for variable_name, variable_value in event.node_run_result.outputs.items():
|
||||
self.graph_runtime_state.variable_pool.add((event.node_id, variable_name), variable_value)
|
||||
self._graph_runtime_state.variable_pool.add((event.node_id, variable_name), variable_value)
|
||||
|
||||
def _update_response_outputs(self, event: NodeRunSucceededEvent) -> None:
|
||||
"""Update response outputs for response nodes."""
|
||||
for key, value in event.node_run_result.outputs.items():
|
||||
if key == "answer":
|
||||
existing = self.graph_runtime_state.outputs.get("answer", "")
|
||||
existing = self._graph_runtime_state.outputs.get("answer", "")
|
||||
if existing:
|
||||
self.graph_runtime_state.outputs["answer"] = f"{existing}{value}"
|
||||
self._graph_runtime_state.outputs["answer"] = f"{existing}{value}"
|
||||
else:
|
||||
self.graph_runtime_state.outputs["answer"] = value
|
||||
self._graph_runtime_state.outputs["answer"] = value
|
||||
else:
|
||||
self.graph_runtime_state.outputs[key] = value
|
||||
self._graph_runtime_state.outputs[key] = value
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ class OutputRegistry:
|
|||
with self._lock:
|
||||
self._scalars.add(selector, value)
|
||||
|
||||
def get_scalar(self, selector: Sequence[str]) -> "Segment" | None:
|
||||
def get_scalar(self, selector: Sequence[str]) -> "Segment | None":
|
||||
"""
|
||||
Get a scalar value for the given selector.
|
||||
|
||||
|
|
@ -81,7 +81,7 @@ class OutputRegistry:
|
|||
except ValueError:
|
||||
raise ValueError(f"Stream {'.'.join(selector)} is already closed")
|
||||
|
||||
def pop_chunk(self, selector: Sequence[str]) -> "NodeRunStreamChunkEvent" | None:
|
||||
def pop_chunk(self, selector: Sequence[str]) -> "NodeRunStreamChunkEvent | None":
|
||||
"""
|
||||
Pop the next unread NodeRunStreamChunkEvent from the stream.
|
||||
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ class Stream:
|
|||
raise ValueError("Cannot append to a closed stream")
|
||||
self.events.append(event)
|
||||
|
||||
def pop_next(self) -> "NodeRunStreamChunkEvent" | None:
|
||||
def pop_next(self) -> "NodeRunStreamChunkEvent | None":
|
||||
"""
|
||||
Pop the next unread NodeRunStreamChunkEvent from the stream.
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue