chore: improve typing

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2025-08-30 16:35:57 +08:00
parent 1fd27cf3ad
commit 82193580de
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF
12 changed files with 37 additions and 16 deletions

View File

@ -7,7 +7,7 @@ Each instance uses a unique key for its command queue.
"""
import json
from typing import TYPE_CHECKING, final
from typing import TYPE_CHECKING, Any, final
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand
@ -87,7 +87,7 @@ class RedisChannel:
pipe.expire(self._key, self._command_ttl)
pipe.execute()
def _deserialize_command(self, data: dict) -> GraphEngineCommand | None:
def _deserialize_command(self, data: dict[str, Any]) -> GraphEngineCommand | None:
"""
Deserialize a command from dictionary data.

View File

@ -5,6 +5,8 @@ Command handler implementations.
import logging
from typing import final
from typing_extensions import override
from ..domain.graph_execution import GraphExecution
from ..entities.commands import AbortCommand, GraphEngineCommand
from .command_processor import CommandHandler
@ -16,6 +18,7 @@ logger = logging.getLogger(__name__)
class AbortCommandHandler(CommandHandler):
"""Handles abort commands."""
@override
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
"""
Handle an abort command.

View File

@ -73,7 +73,7 @@ class CommandProcessor:
if handler:
try:
handler.handle(command, self.graph_execution)
except Exception as e:
except Exception:
logger.exception("Error handling command %s", command.__class__.__name__)
else:
logger.warning("No handler registered for command: %s", command.__class__.__name__)

View File

@ -213,7 +213,9 @@ class GraphEngine:
# Capture context for workers
flask_app: Flask | None = None
try:
flask_app = current_app._get_current_object() # type: ignore
app = current_app._get_current_object() # type: ignore
if isinstance(app, Flask):
flask_app = app
except RuntimeError:
pass

View File

@ -9,6 +9,8 @@ import logging
from collections.abc import Mapping
from typing import Any, final
from typing_extensions import override
from core.workflow.graph_events import (
GraphEngineEvent,
GraphRunAbortedEvent,
@ -93,13 +95,14 @@ class DebugLoggingLayer(Layer):
if not data:
return "{}"
formatted_items = []
formatted_items: list[str] = []
for key, value in data.items():
formatted_value = self._truncate_value(value)
formatted_items.append(f" {key}: {formatted_value}")
return "{\n" + ",\n".join(formatted_items) + "\n}"
@override
def on_graph_start(self) -> None:
"""Log graph execution start."""
self.logger.info("=" * 80)
@ -112,7 +115,7 @@ class DebugLoggingLayer(Layer):
# Log inputs if available
if self.graph_runtime_state.variable_pool:
initial_vars = {}
initial_vars: dict[str, Any] = {}
# Access the variable dictionary directly
for node_id, variables in self.graph_runtime_state.variable_pool.variable_dictionary.items():
for var_key, var in variables.items():
@ -121,6 +124,7 @@ class DebugLoggingLayer(Layer):
if initial_vars:
self.logger.info(" Initial variables: %s", self._format_dict(initial_vars))
@override
def on_event(self, event: GraphEngineEvent) -> None:
"""Log individual events based on their type."""
event_class = event.__class__.__name__
@ -222,6 +226,7 @@ class DebugLoggingLayer(Layer):
# Log unknown events at debug level
self.logger.debug("Event: %s", event_class)
@override
def on_graph_end(self, error: Exception | None) -> None:
"""Log graph execution end with summary statistics."""
self.logger.info("=" * 80)

View File

@ -13,6 +13,8 @@ import time
from enum import Enum
from typing import final
from typing_extensions import override
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType
from core.workflow.graph_engine.layers import Layer
from core.workflow.graph_events import (
@ -63,6 +65,7 @@ class ExecutionLimitsLayer(Layer):
self._execution_ended = False
self._abort_sent = False # Track if abort command has been sent
@override
def on_graph_start(self) -> None:
"""Called when graph execution starts."""
self.start_time = time.time()
@ -73,6 +76,7 @@ class ExecutionLimitsLayer(Layer):
self.logger.debug("Execution limits monitoring started")
@override
def on_event(self, event: GraphEngineEvent) -> None:
"""
Called for every event emitted by the engine.
@ -95,6 +99,7 @@ class ExecutionLimitsLayer(Layer):
if self._reached_time_limitation():
self._send_abort_command(LimitType.TIME_LIMIT)
@override
def on_graph_end(self, error: Exception | None) -> None:
"""Called when graph execution ends."""
if self._execution_started and not self._execution_ended:

View File

@ -7,7 +7,7 @@ thread-safe storage for node outputs.
from collections.abc import Sequence
from threading import RLock
from typing import TYPE_CHECKING, Union, final
from typing import TYPE_CHECKING, Any, Union, final
from core.variables import Segment
from core.workflow.entities.variable_pool import VariablePool
@ -31,13 +31,15 @@ class OutputRegistry:
"""Initialize empty registry with thread-safe storage."""
self._lock = RLock()
self._scalars = variable_pool
self._streams: dict[tuple, Stream] = {}
self._streams: dict[tuple[str, ...], Stream] = {}
def _selector_to_key(self, selector: Sequence[str]) -> tuple:
def _selector_to_key(self, selector: Sequence[str]) -> tuple[str, ...]:
"""Convert selector list to tuple key for internal storage."""
return tuple(selector)
def set_scalar(self, selector: Sequence[str], value: Union[str, int, float, bool, dict, list]) -> None:
def set_scalar(
self, selector: Sequence[str], value: Union[str, int, float, bool, dict[str, Any], list[Any]]
) -> None:
"""
Set a scalar value for the given selector.

View File

@ -161,7 +161,7 @@ class ResponseStreamCoordinator:
# Step 2: For each complete path, filter edges based on node blocking behavior
filtered_paths: list[Path] = []
for path in all_complete_paths:
blocking_edges = []
blocking_edges: list[str] = []
for edge_id in path:
edge = self.graph.edges[edge_id]
source_node = self.graph.nodes[edge.tail]
@ -260,7 +260,7 @@ class ResponseStreamCoordinator:
if event.is_final:
self.registry.close_stream(event.selector)
return self.try_flush()
elif isinstance(event, NodeRunSucceededEvent):
else:
# Skip cause we share the same variable pool.
#
# for variable_name, variable_value in event.node_run_result.outputs.items():
@ -426,7 +426,7 @@ class ResponseStreamCoordinator:
# Wait for more data
break
elif isinstance(segment, TextSegment):
else:
segment_events = self._process_text_segment(segment)
events.extend(segment_events)
self.active_session.index += 1

View File

@ -15,6 +15,7 @@ from typing import final
from uuid import uuid4
from flask import Flask
from typing_extensions import override
from core.workflow.enums import NodeType
from core.workflow.graph import Graph
@ -73,6 +74,7 @@ class Worker(threading.Thread):
"""Signal the worker to stop processing."""
self._stop_event.set()
@override
def run(self) -> None:
"""
Main worker loop.

View File

@ -46,7 +46,7 @@ class ActivityTracker:
List of idle worker IDs
"""
current_time = time.time()
idle_workers = []
idle_workers: list[int] = []
with self._lock:
for worker_id, (is_active, last_change) in self._worker_activity.items():

View File

@ -10,6 +10,7 @@ from typing import final
from flask import Flask
from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase
from ..worker import Worker
@ -42,7 +43,7 @@ class WorkerFactory:
def create_worker(
self,
ready_queue: queue.Queue[str],
event_queue: queue.Queue,
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
on_idle_callback: Callable[[int], None] | None = None,
on_active_callback: Callable[[int], None] | None = None,

View File

@ -7,6 +7,7 @@ import threading
from typing import final
from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase
from ..worker import Worker
from .activity_tracker import ActivityTracker
@ -26,7 +27,7 @@ class WorkerPool:
def __init__(
self,
ready_queue: queue.Queue[str],
event_queue: queue.Queue,
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
worker_factory: WorkerFactory,
dynamic_scaler: DynamicScaler,