From 82193580decd6be9f35c249f655064fb93c387f5 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sat, 30 Aug 2025 16:35:57 +0800 Subject: [PATCH] chore: improve typing Signed-off-by: -LAN- --- .../graph_engine/command_channels/redis_channel.py | 4 ++-- .../command_processing/command_handlers.py | 3 +++ .../command_processing/command_processor.py | 2 +- api/core/workflow/graph_engine/graph_engine.py | 4 +++- api/core/workflow/graph_engine/layers/debug_logging.py | 9 +++++++-- .../workflow/graph_engine/layers/execution_limits.py | 5 +++++ .../workflow/graph_engine/output_registry/registry.py | 10 ++++++---- .../graph_engine/response_coordinator/coordinator.py | 6 +++--- api/core/workflow/graph_engine/worker.py | 2 ++ .../graph_engine/worker_management/activity_tracker.py | 2 +- .../graph_engine/worker_management/worker_factory.py | 3 ++- .../graph_engine/worker_management/worker_pool.py | 3 ++- 12 files changed, 37 insertions(+), 16 deletions(-) diff --git a/api/core/workflow/graph_engine/command_channels/redis_channel.py b/api/core/workflow/graph_engine/command_channels/redis_channel.py index 7809e43e32..ad0aa9402c 100644 --- a/api/core/workflow/graph_engine/command_channels/redis_channel.py +++ b/api/core/workflow/graph_engine/command_channels/redis_channel.py @@ -7,7 +7,7 @@ Each instance uses a unique key for its command queue. """ import json -from typing import TYPE_CHECKING, final +from typing import TYPE_CHECKING, Any, final from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand @@ -87,7 +87,7 @@ class RedisChannel: pipe.expire(self._key, self._command_ttl) pipe.execute() - def _deserialize_command(self, data: dict) -> GraphEngineCommand | None: + def _deserialize_command(self, data: dict[str, Any]) -> GraphEngineCommand | None: """ Deserialize a command from dictionary data. diff --git a/api/core/workflow/graph_engine/command_processing/command_handlers.py b/api/core/workflow/graph_engine/command_processing/command_handlers.py index 9f8d20b1b9..3c51de99f3 100644 --- a/api/core/workflow/graph_engine/command_processing/command_handlers.py +++ b/api/core/workflow/graph_engine/command_processing/command_handlers.py @@ -5,6 +5,8 @@ Command handler implementations. import logging from typing import final +from typing_extensions import override + from ..domain.graph_execution import GraphExecution from ..entities.commands import AbortCommand, GraphEngineCommand from .command_processor import CommandHandler @@ -16,6 +18,7 @@ logger = logging.getLogger(__name__) class AbortCommandHandler(CommandHandler): """Handles abort commands.""" + @override def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: """ Handle an abort command. diff --git a/api/core/workflow/graph_engine/command_processing/command_processor.py b/api/core/workflow/graph_engine/command_processing/command_processor.py index 2521058ef2..7051ece735 100644 --- a/api/core/workflow/graph_engine/command_processing/command_processor.py +++ b/api/core/workflow/graph_engine/command_processing/command_processor.py @@ -73,7 +73,7 @@ class CommandProcessor: if handler: try: handler.handle(command, self.graph_execution) - except Exception as e: + except Exception: logger.exception("Error handling command %s", command.__class__.__name__) else: logger.warning("No handler registered for command: %s", command.__class__.__name__) diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index dd98536fba..828e9b329f 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -213,7 +213,9 @@ class GraphEngine: # Capture context for workers flask_app: Flask | None = None try: - flask_app = current_app._get_current_object() # type: ignore + app = current_app._get_current_object() # type: ignore + if isinstance(app, Flask): + flask_app = app except RuntimeError: pass diff --git a/api/core/workflow/graph_engine/layers/debug_logging.py b/api/core/workflow/graph_engine/layers/debug_logging.py index 3052600161..42bacfa474 100644 --- a/api/core/workflow/graph_engine/layers/debug_logging.py +++ b/api/core/workflow/graph_engine/layers/debug_logging.py @@ -9,6 +9,8 @@ import logging from collections.abc import Mapping from typing import Any, final +from typing_extensions import override + from core.workflow.graph_events import ( GraphEngineEvent, GraphRunAbortedEvent, @@ -93,13 +95,14 @@ class DebugLoggingLayer(Layer): if not data: return "{}" - formatted_items = [] + formatted_items: list[str] = [] for key, value in data.items(): formatted_value = self._truncate_value(value) formatted_items.append(f" {key}: {formatted_value}") return "{\n" + ",\n".join(formatted_items) + "\n}" + @override def on_graph_start(self) -> None: """Log graph execution start.""" self.logger.info("=" * 80) @@ -112,7 +115,7 @@ class DebugLoggingLayer(Layer): # Log inputs if available if self.graph_runtime_state.variable_pool: - initial_vars = {} + initial_vars: dict[str, Any] = {} # Access the variable dictionary directly for node_id, variables in self.graph_runtime_state.variable_pool.variable_dictionary.items(): for var_key, var in variables.items(): @@ -121,6 +124,7 @@ class DebugLoggingLayer(Layer): if initial_vars: self.logger.info(" Initial variables: %s", self._format_dict(initial_vars)) + @override def on_event(self, event: GraphEngineEvent) -> None: """Log individual events based on their type.""" event_class = event.__class__.__name__ @@ -222,6 +226,7 @@ class DebugLoggingLayer(Layer): # Log unknown events at debug level self.logger.debug("Event: %s", event_class) + @override def on_graph_end(self, error: Exception | None) -> None: """Log graph execution end with summary statistics.""" self.logger.info("=" * 80) diff --git a/api/core/workflow/graph_engine/layers/execution_limits.py b/api/core/workflow/graph_engine/layers/execution_limits.py index efda0bacbe..6cc0c1305a 100644 --- a/api/core/workflow/graph_engine/layers/execution_limits.py +++ b/api/core/workflow/graph_engine/layers/execution_limits.py @@ -13,6 +13,8 @@ import time from enum import Enum from typing import final +from typing_extensions import override + from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType from core.workflow.graph_engine.layers import Layer from core.workflow.graph_events import ( @@ -63,6 +65,7 @@ class ExecutionLimitsLayer(Layer): self._execution_ended = False self._abort_sent = False # Track if abort command has been sent + @override def on_graph_start(self) -> None: """Called when graph execution starts.""" self.start_time = time.time() @@ -73,6 +76,7 @@ class ExecutionLimitsLayer(Layer): self.logger.debug("Execution limits monitoring started") + @override def on_event(self, event: GraphEngineEvent) -> None: """ Called for every event emitted by the engine. @@ -95,6 +99,7 @@ class ExecutionLimitsLayer(Layer): if self._reached_time_limitation(): self._send_abort_command(LimitType.TIME_LIMIT) + @override def on_graph_end(self, error: Exception | None) -> None: """Called when graph execution ends.""" if self._execution_started and not self._execution_ended: diff --git a/api/core/workflow/graph_engine/output_registry/registry.py b/api/core/workflow/graph_engine/output_registry/registry.py index 4df7da207c..29eefa5abe 100644 --- a/api/core/workflow/graph_engine/output_registry/registry.py +++ b/api/core/workflow/graph_engine/output_registry/registry.py @@ -7,7 +7,7 @@ thread-safe storage for node outputs. from collections.abc import Sequence from threading import RLock -from typing import TYPE_CHECKING, Union, final +from typing import TYPE_CHECKING, Any, Union, final from core.variables import Segment from core.workflow.entities.variable_pool import VariablePool @@ -31,13 +31,15 @@ class OutputRegistry: """Initialize empty registry with thread-safe storage.""" self._lock = RLock() self._scalars = variable_pool - self._streams: dict[tuple, Stream] = {} + self._streams: dict[tuple[str, ...], Stream] = {} - def _selector_to_key(self, selector: Sequence[str]) -> tuple: + def _selector_to_key(self, selector: Sequence[str]) -> tuple[str, ...]: """Convert selector list to tuple key for internal storage.""" return tuple(selector) - def set_scalar(self, selector: Sequence[str], value: Union[str, int, float, bool, dict, list]) -> None: + def set_scalar( + self, selector: Sequence[str], value: Union[str, int, float, bool, dict[str, Any], list[Any]] + ) -> None: """ Set a scalar value for the given selector. diff --git a/api/core/workflow/graph_engine/response_coordinator/coordinator.py b/api/core/workflow/graph_engine/response_coordinator/coordinator.py index 4c3cc167fa..1fb58852d2 100644 --- a/api/core/workflow/graph_engine/response_coordinator/coordinator.py +++ b/api/core/workflow/graph_engine/response_coordinator/coordinator.py @@ -161,7 +161,7 @@ class ResponseStreamCoordinator: # Step 2: For each complete path, filter edges based on node blocking behavior filtered_paths: list[Path] = [] for path in all_complete_paths: - blocking_edges = [] + blocking_edges: list[str] = [] for edge_id in path: edge = self.graph.edges[edge_id] source_node = self.graph.nodes[edge.tail] @@ -260,7 +260,7 @@ class ResponseStreamCoordinator: if event.is_final: self.registry.close_stream(event.selector) return self.try_flush() - elif isinstance(event, NodeRunSucceededEvent): + else: # Skip cause we share the same variable pool. # # for variable_name, variable_value in event.node_run_result.outputs.items(): @@ -426,7 +426,7 @@ class ResponseStreamCoordinator: # Wait for more data break - elif isinstance(segment, TextSegment): + else: segment_events = self._process_text_segment(segment) events.extend(segment_events) self.active_session.index += 1 diff --git a/api/core/workflow/graph_engine/worker.py b/api/core/workflow/graph_engine/worker.py index dacf6f0435..1fb0824e63 100644 --- a/api/core/workflow/graph_engine/worker.py +++ b/api/core/workflow/graph_engine/worker.py @@ -15,6 +15,7 @@ from typing import final from uuid import uuid4 from flask import Flask +from typing_extensions import override from core.workflow.enums import NodeType from core.workflow.graph import Graph @@ -73,6 +74,7 @@ class Worker(threading.Thread): """Signal the worker to stop processing.""" self._stop_event.set() + @override def run(self) -> None: """ Main worker loop. diff --git a/api/core/workflow/graph_engine/worker_management/activity_tracker.py b/api/core/workflow/graph_engine/worker_management/activity_tracker.py index b2125a0158..19c4ddaeb5 100644 --- a/api/core/workflow/graph_engine/worker_management/activity_tracker.py +++ b/api/core/workflow/graph_engine/worker_management/activity_tracker.py @@ -46,7 +46,7 @@ class ActivityTracker: List of idle worker IDs """ current_time = time.time() - idle_workers = [] + idle_workers: list[int] = [] with self._lock: for worker_id, (is_active, last_change) in self._worker_activity.items(): diff --git a/api/core/workflow/graph_engine/worker_management/worker_factory.py b/api/core/workflow/graph_engine/worker_management/worker_factory.py index 673ca11f26..cbb8e0b68e 100644 --- a/api/core/workflow/graph_engine/worker_management/worker_factory.py +++ b/api/core/workflow/graph_engine/worker_management/worker_factory.py @@ -10,6 +10,7 @@ from typing import final from flask import Flask from core.workflow.graph import Graph +from core.workflow.graph_events import GraphNodeEventBase from ..worker import Worker @@ -42,7 +43,7 @@ class WorkerFactory: def create_worker( self, ready_queue: queue.Queue[str], - event_queue: queue.Queue, + event_queue: queue.Queue[GraphNodeEventBase], graph: Graph, on_idle_callback: Callable[[int], None] | None = None, on_active_callback: Callable[[int], None] | None = None, diff --git a/api/core/workflow/graph_engine/worker_management/worker_pool.py b/api/core/workflow/graph_engine/worker_management/worker_pool.py index 55250809cd..bdec3e5323 100644 --- a/api/core/workflow/graph_engine/worker_management/worker_pool.py +++ b/api/core/workflow/graph_engine/worker_management/worker_pool.py @@ -7,6 +7,7 @@ import threading from typing import final from core.workflow.graph import Graph +from core.workflow.graph_events import GraphNodeEventBase from ..worker import Worker from .activity_tracker import ActivityTracker @@ -26,7 +27,7 @@ class WorkerPool: def __init__( self, ready_queue: queue.Queue[str], - event_queue: queue.Queue, + event_queue: queue.Queue[GraphNodeEventBase], graph: Graph, worker_factory: WorkerFactory, dynamic_scaler: DynamicScaler,