mirror of https://github.com/langgenius/dify.git
chore: improve typing
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
1fd27cf3ad
commit
82193580de
|
|
@ -7,7 +7,7 @@ Each instance uses a unique key for its command queue.
|
|||
"""
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, final
|
||||
from typing import TYPE_CHECKING, Any, final
|
||||
|
||||
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand
|
||||
|
||||
|
|
@ -87,7 +87,7 @@ class RedisChannel:
|
|||
pipe.expire(self._key, self._command_ttl)
|
||||
pipe.execute()
|
||||
|
||||
def _deserialize_command(self, data: dict) -> GraphEngineCommand | None:
|
||||
def _deserialize_command(self, data: dict[str, Any]) -> GraphEngineCommand | None:
|
||||
"""
|
||||
Deserialize a command from dictionary data.
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ Command handler implementations.
|
|||
import logging
|
||||
from typing import final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from ..domain.graph_execution import GraphExecution
|
||||
from ..entities.commands import AbortCommand, GraphEngineCommand
|
||||
from .command_processor import CommandHandler
|
||||
|
|
@ -16,6 +18,7 @@ logger = logging.getLogger(__name__)
|
|||
class AbortCommandHandler(CommandHandler):
|
||||
"""Handles abort commands."""
|
||||
|
||||
@override
|
||||
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
|
||||
"""
|
||||
Handle an abort command.
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ class CommandProcessor:
|
|||
if handler:
|
||||
try:
|
||||
handler.handle(command, self.graph_execution)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("Error handling command %s", command.__class__.__name__)
|
||||
else:
|
||||
logger.warning("No handler registered for command: %s", command.__class__.__name__)
|
||||
|
|
|
|||
|
|
@ -213,7 +213,9 @@ class GraphEngine:
|
|||
# Capture context for workers
|
||||
flask_app: Flask | None = None
|
||||
try:
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
app = current_app._get_current_object() # type: ignore
|
||||
if isinstance(app, Flask):
|
||||
flask_app = app
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ import logging
|
|||
from collections.abc import Mapping
|
||||
from typing import Any, final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunAbortedEvent,
|
||||
|
|
@ -93,13 +95,14 @@ class DebugLoggingLayer(Layer):
|
|||
if not data:
|
||||
return "{}"
|
||||
|
||||
formatted_items = []
|
||||
formatted_items: list[str] = []
|
||||
for key, value in data.items():
|
||||
formatted_value = self._truncate_value(value)
|
||||
formatted_items.append(f" {key}: {formatted_value}")
|
||||
|
||||
return "{\n" + ",\n".join(formatted_items) + "\n}"
|
||||
|
||||
@override
|
||||
def on_graph_start(self) -> None:
|
||||
"""Log graph execution start."""
|
||||
self.logger.info("=" * 80)
|
||||
|
|
@ -112,7 +115,7 @@ class DebugLoggingLayer(Layer):
|
|||
|
||||
# Log inputs if available
|
||||
if self.graph_runtime_state.variable_pool:
|
||||
initial_vars = {}
|
||||
initial_vars: dict[str, Any] = {}
|
||||
# Access the variable dictionary directly
|
||||
for node_id, variables in self.graph_runtime_state.variable_pool.variable_dictionary.items():
|
||||
for var_key, var in variables.items():
|
||||
|
|
@ -121,6 +124,7 @@ class DebugLoggingLayer(Layer):
|
|||
if initial_vars:
|
||||
self.logger.info(" Initial variables: %s", self._format_dict(initial_vars))
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
"""Log individual events based on their type."""
|
||||
event_class = event.__class__.__name__
|
||||
|
|
@ -222,6 +226,7 @@ class DebugLoggingLayer(Layer):
|
|||
# Log unknown events at debug level
|
||||
self.logger.debug("Event: %s", event_class)
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
"""Log graph execution end with summary statistics."""
|
||||
self.logger.info("=" * 80)
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@ import time
|
|||
from enum import Enum
|
||||
from typing import final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType
|
||||
from core.workflow.graph_engine.layers import Layer
|
||||
from core.workflow.graph_events import (
|
||||
|
|
@ -63,6 +65,7 @@ class ExecutionLimitsLayer(Layer):
|
|||
self._execution_ended = False
|
||||
self._abort_sent = False # Track if abort command has been sent
|
||||
|
||||
@override
|
||||
def on_graph_start(self) -> None:
|
||||
"""Called when graph execution starts."""
|
||||
self.start_time = time.time()
|
||||
|
|
@ -73,6 +76,7 @@ class ExecutionLimitsLayer(Layer):
|
|||
|
||||
self.logger.debug("Execution limits monitoring started")
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Called for every event emitted by the engine.
|
||||
|
|
@ -95,6 +99,7 @@ class ExecutionLimitsLayer(Layer):
|
|||
if self._reached_time_limitation():
|
||||
self._send_abort_command(LimitType.TIME_LIMIT)
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
"""Called when graph execution ends."""
|
||||
if self._execution_started and not self._execution_ended:
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ thread-safe storage for node outputs.
|
|||
|
||||
from collections.abc import Sequence
|
||||
from threading import RLock
|
||||
from typing import TYPE_CHECKING, Union, final
|
||||
from typing import TYPE_CHECKING, Any, Union, final
|
||||
|
||||
from core.variables import Segment
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
|
@ -31,13 +31,15 @@ class OutputRegistry:
|
|||
"""Initialize empty registry with thread-safe storage."""
|
||||
self._lock = RLock()
|
||||
self._scalars = variable_pool
|
||||
self._streams: dict[tuple, Stream] = {}
|
||||
self._streams: dict[tuple[str, ...], Stream] = {}
|
||||
|
||||
def _selector_to_key(self, selector: Sequence[str]) -> tuple:
|
||||
def _selector_to_key(self, selector: Sequence[str]) -> tuple[str, ...]:
|
||||
"""Convert selector list to tuple key for internal storage."""
|
||||
return tuple(selector)
|
||||
|
||||
def set_scalar(self, selector: Sequence[str], value: Union[str, int, float, bool, dict, list]) -> None:
|
||||
def set_scalar(
|
||||
self, selector: Sequence[str], value: Union[str, int, float, bool, dict[str, Any], list[Any]]
|
||||
) -> None:
|
||||
"""
|
||||
Set a scalar value for the given selector.
|
||||
|
||||
|
|
|
|||
|
|
@ -161,7 +161,7 @@ class ResponseStreamCoordinator:
|
|||
# Step 2: For each complete path, filter edges based on node blocking behavior
|
||||
filtered_paths: list[Path] = []
|
||||
for path in all_complete_paths:
|
||||
blocking_edges = []
|
||||
blocking_edges: list[str] = []
|
||||
for edge_id in path:
|
||||
edge = self.graph.edges[edge_id]
|
||||
source_node = self.graph.nodes[edge.tail]
|
||||
|
|
@ -260,7 +260,7 @@ class ResponseStreamCoordinator:
|
|||
if event.is_final:
|
||||
self.registry.close_stream(event.selector)
|
||||
return self.try_flush()
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
else:
|
||||
# Skip cause we share the same variable pool.
|
||||
#
|
||||
# for variable_name, variable_value in event.node_run_result.outputs.items():
|
||||
|
|
@ -426,7 +426,7 @@ class ResponseStreamCoordinator:
|
|||
# Wait for more data
|
||||
break
|
||||
|
||||
elif isinstance(segment, TextSegment):
|
||||
else:
|
||||
segment_events = self._process_text_segment(segment)
|
||||
events.extend(segment_events)
|
||||
self.active_session.index += 1
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from typing import final
|
|||
from uuid import uuid4
|
||||
|
||||
from flask import Flask
|
||||
from typing_extensions import override
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph import Graph
|
||||
|
|
@ -73,6 +74,7 @@ class Worker(threading.Thread):
|
|||
"""Signal the worker to stop processing."""
|
||||
self._stop_event.set()
|
||||
|
||||
@override
|
||||
def run(self) -> None:
|
||||
"""
|
||||
Main worker loop.
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ class ActivityTracker:
|
|||
List of idle worker IDs
|
||||
"""
|
||||
current_time = time.time()
|
||||
idle_workers = []
|
||||
idle_workers: list[int] = []
|
||||
|
||||
with self._lock:
|
||||
for worker_id, (is_active, last_change) in self._worker_activity.items():
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from typing import final
|
|||
from flask import Flask
|
||||
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import GraphNodeEventBase
|
||||
|
||||
from ..worker import Worker
|
||||
|
||||
|
|
@ -42,7 +43,7 @@ class WorkerFactory:
|
|||
def create_worker(
|
||||
self,
|
||||
ready_queue: queue.Queue[str],
|
||||
event_queue: queue.Queue,
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
graph: Graph,
|
||||
on_idle_callback: Callable[[int], None] | None = None,
|
||||
on_active_callback: Callable[[int], None] | None = None,
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import threading
|
|||
from typing import final
|
||||
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import GraphNodeEventBase
|
||||
|
||||
from ..worker import Worker
|
||||
from .activity_tracker import ActivityTracker
|
||||
|
|
@ -26,7 +27,7 @@ class WorkerPool:
|
|||
def __init__(
|
||||
self,
|
||||
ready_queue: queue.Queue[str],
|
||||
event_queue: queue.Queue,
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
graph: Graph,
|
||||
worker_factory: WorkerFactory,
|
||||
dynamic_scaler: DynamicScaler,
|
||||
|
|
|
|||
Loading…
Reference in New Issue