chore: improve typing

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2025-08-30 16:35:57 +08:00
parent 1fd27cf3ad
commit 82193580de
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF
12 changed files with 37 additions and 16 deletions

View File

@ -7,7 +7,7 @@ Each instance uses a unique key for its command queue.
""" """
import json import json
from typing import TYPE_CHECKING, final from typing import TYPE_CHECKING, Any, final
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand
@ -87,7 +87,7 @@ class RedisChannel:
pipe.expire(self._key, self._command_ttl) pipe.expire(self._key, self._command_ttl)
pipe.execute() pipe.execute()
def _deserialize_command(self, data: dict) -> GraphEngineCommand | None: def _deserialize_command(self, data: dict[str, Any]) -> GraphEngineCommand | None:
""" """
Deserialize a command from dictionary data. Deserialize a command from dictionary data.

View File

@ -5,6 +5,8 @@ Command handler implementations.
import logging import logging
from typing import final from typing import final
from typing_extensions import override
from ..domain.graph_execution import GraphExecution from ..domain.graph_execution import GraphExecution
from ..entities.commands import AbortCommand, GraphEngineCommand from ..entities.commands import AbortCommand, GraphEngineCommand
from .command_processor import CommandHandler from .command_processor import CommandHandler
@ -16,6 +18,7 @@ logger = logging.getLogger(__name__)
class AbortCommandHandler(CommandHandler): class AbortCommandHandler(CommandHandler):
"""Handles abort commands.""" """Handles abort commands."""
@override
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
""" """
Handle an abort command. Handle an abort command.

View File

@ -73,7 +73,7 @@ class CommandProcessor:
if handler: if handler:
try: try:
handler.handle(command, self.graph_execution) handler.handle(command, self.graph_execution)
except Exception as e: except Exception:
logger.exception("Error handling command %s", command.__class__.__name__) logger.exception("Error handling command %s", command.__class__.__name__)
else: else:
logger.warning("No handler registered for command: %s", command.__class__.__name__) logger.warning("No handler registered for command: %s", command.__class__.__name__)

View File

@ -213,7 +213,9 @@ class GraphEngine:
# Capture context for workers # Capture context for workers
flask_app: Flask | None = None flask_app: Flask | None = None
try: try:
flask_app = current_app._get_current_object() # type: ignore app = current_app._get_current_object() # type: ignore
if isinstance(app, Flask):
flask_app = app
except RuntimeError: except RuntimeError:
pass pass

View File

@ -9,6 +9,8 @@ import logging
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, final from typing import Any, final
from typing_extensions import override
from core.workflow.graph_events import ( from core.workflow.graph_events import (
GraphEngineEvent, GraphEngineEvent,
GraphRunAbortedEvent, GraphRunAbortedEvent,
@ -93,13 +95,14 @@ class DebugLoggingLayer(Layer):
if not data: if not data:
return "{}" return "{}"
formatted_items = [] formatted_items: list[str] = []
for key, value in data.items(): for key, value in data.items():
formatted_value = self._truncate_value(value) formatted_value = self._truncate_value(value)
formatted_items.append(f" {key}: {formatted_value}") formatted_items.append(f" {key}: {formatted_value}")
return "{\n" + ",\n".join(formatted_items) + "\n}" return "{\n" + ",\n".join(formatted_items) + "\n}"
@override
def on_graph_start(self) -> None: def on_graph_start(self) -> None:
"""Log graph execution start.""" """Log graph execution start."""
self.logger.info("=" * 80) self.logger.info("=" * 80)
@ -112,7 +115,7 @@ class DebugLoggingLayer(Layer):
# Log inputs if available # Log inputs if available
if self.graph_runtime_state.variable_pool: if self.graph_runtime_state.variable_pool:
initial_vars = {} initial_vars: dict[str, Any] = {}
# Access the variable dictionary directly # Access the variable dictionary directly
for node_id, variables in self.graph_runtime_state.variable_pool.variable_dictionary.items(): for node_id, variables in self.graph_runtime_state.variable_pool.variable_dictionary.items():
for var_key, var in variables.items(): for var_key, var in variables.items():
@ -121,6 +124,7 @@ class DebugLoggingLayer(Layer):
if initial_vars: if initial_vars:
self.logger.info(" Initial variables: %s", self._format_dict(initial_vars)) self.logger.info(" Initial variables: %s", self._format_dict(initial_vars))
@override
def on_event(self, event: GraphEngineEvent) -> None: def on_event(self, event: GraphEngineEvent) -> None:
"""Log individual events based on their type.""" """Log individual events based on their type."""
event_class = event.__class__.__name__ event_class = event.__class__.__name__
@ -222,6 +226,7 @@ class DebugLoggingLayer(Layer):
# Log unknown events at debug level # Log unknown events at debug level
self.logger.debug("Event: %s", event_class) self.logger.debug("Event: %s", event_class)
@override
def on_graph_end(self, error: Exception | None) -> None: def on_graph_end(self, error: Exception | None) -> None:
"""Log graph execution end with summary statistics.""" """Log graph execution end with summary statistics."""
self.logger.info("=" * 80) self.logger.info("=" * 80)

View File

@ -13,6 +13,8 @@ import time
from enum import Enum from enum import Enum
from typing import final from typing import final
from typing_extensions import override
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType
from core.workflow.graph_engine.layers import Layer from core.workflow.graph_engine.layers import Layer
from core.workflow.graph_events import ( from core.workflow.graph_events import (
@ -63,6 +65,7 @@ class ExecutionLimitsLayer(Layer):
self._execution_ended = False self._execution_ended = False
self._abort_sent = False # Track if abort command has been sent self._abort_sent = False # Track if abort command has been sent
@override
def on_graph_start(self) -> None: def on_graph_start(self) -> None:
"""Called when graph execution starts.""" """Called when graph execution starts."""
self.start_time = time.time() self.start_time = time.time()
@ -73,6 +76,7 @@ class ExecutionLimitsLayer(Layer):
self.logger.debug("Execution limits monitoring started") self.logger.debug("Execution limits monitoring started")
@override
def on_event(self, event: GraphEngineEvent) -> None: def on_event(self, event: GraphEngineEvent) -> None:
""" """
Called for every event emitted by the engine. Called for every event emitted by the engine.
@ -95,6 +99,7 @@ class ExecutionLimitsLayer(Layer):
if self._reached_time_limitation(): if self._reached_time_limitation():
self._send_abort_command(LimitType.TIME_LIMIT) self._send_abort_command(LimitType.TIME_LIMIT)
@override
def on_graph_end(self, error: Exception | None) -> None: def on_graph_end(self, error: Exception | None) -> None:
"""Called when graph execution ends.""" """Called when graph execution ends."""
if self._execution_started and not self._execution_ended: if self._execution_started and not self._execution_ended:

View File

@ -7,7 +7,7 @@ thread-safe storage for node outputs.
from collections.abc import Sequence from collections.abc import Sequence
from threading import RLock from threading import RLock
from typing import TYPE_CHECKING, Union, final from typing import TYPE_CHECKING, Any, Union, final
from core.variables import Segment from core.variables import Segment
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
@ -31,13 +31,15 @@ class OutputRegistry:
"""Initialize empty registry with thread-safe storage.""" """Initialize empty registry with thread-safe storage."""
self._lock = RLock() self._lock = RLock()
self._scalars = variable_pool self._scalars = variable_pool
self._streams: dict[tuple, Stream] = {} self._streams: dict[tuple[str, ...], Stream] = {}
def _selector_to_key(self, selector: Sequence[str]) -> tuple: def _selector_to_key(self, selector: Sequence[str]) -> tuple[str, ...]:
"""Convert selector list to tuple key for internal storage.""" """Convert selector list to tuple key for internal storage."""
return tuple(selector) return tuple(selector)
def set_scalar(self, selector: Sequence[str], value: Union[str, int, float, bool, dict, list]) -> None: def set_scalar(
self, selector: Sequence[str], value: Union[str, int, float, bool, dict[str, Any], list[Any]]
) -> None:
""" """
Set a scalar value for the given selector. Set a scalar value for the given selector.

View File

@ -161,7 +161,7 @@ class ResponseStreamCoordinator:
# Step 2: For each complete path, filter edges based on node blocking behavior # Step 2: For each complete path, filter edges based on node blocking behavior
filtered_paths: list[Path] = [] filtered_paths: list[Path] = []
for path in all_complete_paths: for path in all_complete_paths:
blocking_edges = [] blocking_edges: list[str] = []
for edge_id in path: for edge_id in path:
edge = self.graph.edges[edge_id] edge = self.graph.edges[edge_id]
source_node = self.graph.nodes[edge.tail] source_node = self.graph.nodes[edge.tail]
@ -260,7 +260,7 @@ class ResponseStreamCoordinator:
if event.is_final: if event.is_final:
self.registry.close_stream(event.selector) self.registry.close_stream(event.selector)
return self.try_flush() return self.try_flush()
elif isinstance(event, NodeRunSucceededEvent): else:
# Skip cause we share the same variable pool. # Skip cause we share the same variable pool.
# #
# for variable_name, variable_value in event.node_run_result.outputs.items(): # for variable_name, variable_value in event.node_run_result.outputs.items():
@ -426,7 +426,7 @@ class ResponseStreamCoordinator:
# Wait for more data # Wait for more data
break break
elif isinstance(segment, TextSegment): else:
segment_events = self._process_text_segment(segment) segment_events = self._process_text_segment(segment)
events.extend(segment_events) events.extend(segment_events)
self.active_session.index += 1 self.active_session.index += 1

View File

@ -15,6 +15,7 @@ from typing import final
from uuid import uuid4 from uuid import uuid4
from flask import Flask from flask import Flask
from typing_extensions import override
from core.workflow.enums import NodeType from core.workflow.enums import NodeType
from core.workflow.graph import Graph from core.workflow.graph import Graph
@ -73,6 +74,7 @@ class Worker(threading.Thread):
"""Signal the worker to stop processing.""" """Signal the worker to stop processing."""
self._stop_event.set() self._stop_event.set()
@override
def run(self) -> None: def run(self) -> None:
""" """
Main worker loop. Main worker loop.

View File

@ -46,7 +46,7 @@ class ActivityTracker:
List of idle worker IDs List of idle worker IDs
""" """
current_time = time.time() current_time = time.time()
idle_workers = [] idle_workers: list[int] = []
with self._lock: with self._lock:
for worker_id, (is_active, last_change) in self._worker_activity.items(): for worker_id, (is_active, last_change) in self._worker_activity.items():

View File

@ -10,6 +10,7 @@ from typing import final
from flask import Flask from flask import Flask
from core.workflow.graph import Graph from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase
from ..worker import Worker from ..worker import Worker
@ -42,7 +43,7 @@ class WorkerFactory:
def create_worker( def create_worker(
self, self,
ready_queue: queue.Queue[str], ready_queue: queue.Queue[str],
event_queue: queue.Queue, event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph, graph: Graph,
on_idle_callback: Callable[[int], None] | None = None, on_idle_callback: Callable[[int], None] | None = None,
on_active_callback: Callable[[int], None] | None = None, on_active_callback: Callable[[int], None] | None = None,

View File

@ -7,6 +7,7 @@ import threading
from typing import final from typing import final
from core.workflow.graph import Graph from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase
from ..worker import Worker from ..worker import Worker
from .activity_tracker import ActivityTracker from .activity_tracker import ActivityTracker
@ -26,7 +27,7 @@ class WorkerPool:
def __init__( def __init__(
self, self,
ready_queue: queue.Queue[str], ready_queue: queue.Queue[str],
event_queue: queue.Queue, event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph, graph: Graph,
worker_factory: WorkerFactory, worker_factory: WorkerFactory,
dynamic_scaler: DynamicScaler, dynamic_scaler: DynamicScaler,