mirror of
https://github.com/langgenius/dify.git
synced 2026-04-29 12:37:20 +08:00
chore: improve typing
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
1fd27cf3ad
commit
82193580de
@ -7,7 +7,7 @@ Each instance uses a unique key for its command queue.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import TYPE_CHECKING, final
|
from typing import TYPE_CHECKING, Any, final
|
||||||
|
|
||||||
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand
|
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand
|
||||||
|
|
||||||
@ -87,7 +87,7 @@ class RedisChannel:
|
|||||||
pipe.expire(self._key, self._command_ttl)
|
pipe.expire(self._key, self._command_ttl)
|
||||||
pipe.execute()
|
pipe.execute()
|
||||||
|
|
||||||
def _deserialize_command(self, data: dict) -> GraphEngineCommand | None:
|
def _deserialize_command(self, data: dict[str, Any]) -> GraphEngineCommand | None:
|
||||||
"""
|
"""
|
||||||
Deserialize a command from dictionary data.
|
Deserialize a command from dictionary data.
|
||||||
|
|
||||||
|
|||||||
@ -5,6 +5,8 @@ Command handler implementations.
|
|||||||
import logging
|
import logging
|
||||||
from typing import final
|
from typing import final
|
||||||
|
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from ..domain.graph_execution import GraphExecution
|
from ..domain.graph_execution import GraphExecution
|
||||||
from ..entities.commands import AbortCommand, GraphEngineCommand
|
from ..entities.commands import AbortCommand, GraphEngineCommand
|
||||||
from .command_processor import CommandHandler
|
from .command_processor import CommandHandler
|
||||||
@ -16,6 +18,7 @@ logger = logging.getLogger(__name__)
|
|||||||
class AbortCommandHandler(CommandHandler):
|
class AbortCommandHandler(CommandHandler):
|
||||||
"""Handles abort commands."""
|
"""Handles abort commands."""
|
||||||
|
|
||||||
|
@override
|
||||||
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
|
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
|
||||||
"""
|
"""
|
||||||
Handle an abort command.
|
Handle an abort command.
|
||||||
|
|||||||
@ -73,7 +73,7 @@ class CommandProcessor:
|
|||||||
if handler:
|
if handler:
|
||||||
try:
|
try:
|
||||||
handler.handle(command, self.graph_execution)
|
handler.handle(command, self.graph_execution)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.exception("Error handling command %s", command.__class__.__name__)
|
logger.exception("Error handling command %s", command.__class__.__name__)
|
||||||
else:
|
else:
|
||||||
logger.warning("No handler registered for command: %s", command.__class__.__name__)
|
logger.warning("No handler registered for command: %s", command.__class__.__name__)
|
||||||
|
|||||||
@ -213,7 +213,9 @@ class GraphEngine:
|
|||||||
# Capture context for workers
|
# Capture context for workers
|
||||||
flask_app: Flask | None = None
|
flask_app: Flask | None = None
|
||||||
try:
|
try:
|
||||||
flask_app = current_app._get_current_object() # type: ignore
|
app = current_app._get_current_object() # type: ignore
|
||||||
|
if isinstance(app, Flask):
|
||||||
|
flask_app = app
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@ -9,6 +9,8 @@ import logging
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any, final
|
from typing import Any, final
|
||||||
|
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from core.workflow.graph_events import (
|
from core.workflow.graph_events import (
|
||||||
GraphEngineEvent,
|
GraphEngineEvent,
|
||||||
GraphRunAbortedEvent,
|
GraphRunAbortedEvent,
|
||||||
@ -93,13 +95,14 @@ class DebugLoggingLayer(Layer):
|
|||||||
if not data:
|
if not data:
|
||||||
return "{}"
|
return "{}"
|
||||||
|
|
||||||
formatted_items = []
|
formatted_items: list[str] = []
|
||||||
for key, value in data.items():
|
for key, value in data.items():
|
||||||
formatted_value = self._truncate_value(value)
|
formatted_value = self._truncate_value(value)
|
||||||
formatted_items.append(f" {key}: {formatted_value}")
|
formatted_items.append(f" {key}: {formatted_value}")
|
||||||
|
|
||||||
return "{\n" + ",\n".join(formatted_items) + "\n}"
|
return "{\n" + ",\n".join(formatted_items) + "\n}"
|
||||||
|
|
||||||
|
@override
|
||||||
def on_graph_start(self) -> None:
|
def on_graph_start(self) -> None:
|
||||||
"""Log graph execution start."""
|
"""Log graph execution start."""
|
||||||
self.logger.info("=" * 80)
|
self.logger.info("=" * 80)
|
||||||
@ -112,7 +115,7 @@ class DebugLoggingLayer(Layer):
|
|||||||
|
|
||||||
# Log inputs if available
|
# Log inputs if available
|
||||||
if self.graph_runtime_state.variable_pool:
|
if self.graph_runtime_state.variable_pool:
|
||||||
initial_vars = {}
|
initial_vars: dict[str, Any] = {}
|
||||||
# Access the variable dictionary directly
|
# Access the variable dictionary directly
|
||||||
for node_id, variables in self.graph_runtime_state.variable_pool.variable_dictionary.items():
|
for node_id, variables in self.graph_runtime_state.variable_pool.variable_dictionary.items():
|
||||||
for var_key, var in variables.items():
|
for var_key, var in variables.items():
|
||||||
@ -121,6 +124,7 @@ class DebugLoggingLayer(Layer):
|
|||||||
if initial_vars:
|
if initial_vars:
|
||||||
self.logger.info(" Initial variables: %s", self._format_dict(initial_vars))
|
self.logger.info(" Initial variables: %s", self._format_dict(initial_vars))
|
||||||
|
|
||||||
|
@override
|
||||||
def on_event(self, event: GraphEngineEvent) -> None:
|
def on_event(self, event: GraphEngineEvent) -> None:
|
||||||
"""Log individual events based on their type."""
|
"""Log individual events based on their type."""
|
||||||
event_class = event.__class__.__name__
|
event_class = event.__class__.__name__
|
||||||
@ -222,6 +226,7 @@ class DebugLoggingLayer(Layer):
|
|||||||
# Log unknown events at debug level
|
# Log unknown events at debug level
|
||||||
self.logger.debug("Event: %s", event_class)
|
self.logger.debug("Event: %s", event_class)
|
||||||
|
|
||||||
|
@override
|
||||||
def on_graph_end(self, error: Exception | None) -> None:
|
def on_graph_end(self, error: Exception | None) -> None:
|
||||||
"""Log graph execution end with summary statistics."""
|
"""Log graph execution end with summary statistics."""
|
||||||
self.logger.info("=" * 80)
|
self.logger.info("=" * 80)
|
||||||
|
|||||||
@ -13,6 +13,8 @@ import time
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import final
|
from typing import final
|
||||||
|
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType
|
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType
|
||||||
from core.workflow.graph_engine.layers import Layer
|
from core.workflow.graph_engine.layers import Layer
|
||||||
from core.workflow.graph_events import (
|
from core.workflow.graph_events import (
|
||||||
@ -63,6 +65,7 @@ class ExecutionLimitsLayer(Layer):
|
|||||||
self._execution_ended = False
|
self._execution_ended = False
|
||||||
self._abort_sent = False # Track if abort command has been sent
|
self._abort_sent = False # Track if abort command has been sent
|
||||||
|
|
||||||
|
@override
|
||||||
def on_graph_start(self) -> None:
|
def on_graph_start(self) -> None:
|
||||||
"""Called when graph execution starts."""
|
"""Called when graph execution starts."""
|
||||||
self.start_time = time.time()
|
self.start_time = time.time()
|
||||||
@ -73,6 +76,7 @@ class ExecutionLimitsLayer(Layer):
|
|||||||
|
|
||||||
self.logger.debug("Execution limits monitoring started")
|
self.logger.debug("Execution limits monitoring started")
|
||||||
|
|
||||||
|
@override
|
||||||
def on_event(self, event: GraphEngineEvent) -> None:
|
def on_event(self, event: GraphEngineEvent) -> None:
|
||||||
"""
|
"""
|
||||||
Called for every event emitted by the engine.
|
Called for every event emitted by the engine.
|
||||||
@ -95,6 +99,7 @@ class ExecutionLimitsLayer(Layer):
|
|||||||
if self._reached_time_limitation():
|
if self._reached_time_limitation():
|
||||||
self._send_abort_command(LimitType.TIME_LIMIT)
|
self._send_abort_command(LimitType.TIME_LIMIT)
|
||||||
|
|
||||||
|
@override
|
||||||
def on_graph_end(self, error: Exception | None) -> None:
|
def on_graph_end(self, error: Exception | None) -> None:
|
||||||
"""Called when graph execution ends."""
|
"""Called when graph execution ends."""
|
||||||
if self._execution_started and not self._execution_ended:
|
if self._execution_started and not self._execution_ended:
|
||||||
|
|||||||
@ -7,7 +7,7 @@ thread-safe storage for node outputs.
|
|||||||
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from threading import RLock
|
from threading import RLock
|
||||||
from typing import TYPE_CHECKING, Union, final
|
from typing import TYPE_CHECKING, Any, Union, final
|
||||||
|
|
||||||
from core.variables import Segment
|
from core.variables import Segment
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
@ -31,13 +31,15 @@ class OutputRegistry:
|
|||||||
"""Initialize empty registry with thread-safe storage."""
|
"""Initialize empty registry with thread-safe storage."""
|
||||||
self._lock = RLock()
|
self._lock = RLock()
|
||||||
self._scalars = variable_pool
|
self._scalars = variable_pool
|
||||||
self._streams: dict[tuple, Stream] = {}
|
self._streams: dict[tuple[str, ...], Stream] = {}
|
||||||
|
|
||||||
def _selector_to_key(self, selector: Sequence[str]) -> tuple:
|
def _selector_to_key(self, selector: Sequence[str]) -> tuple[str, ...]:
|
||||||
"""Convert selector list to tuple key for internal storage."""
|
"""Convert selector list to tuple key for internal storage."""
|
||||||
return tuple(selector)
|
return tuple(selector)
|
||||||
|
|
||||||
def set_scalar(self, selector: Sequence[str], value: Union[str, int, float, bool, dict, list]) -> None:
|
def set_scalar(
|
||||||
|
self, selector: Sequence[str], value: Union[str, int, float, bool, dict[str, Any], list[Any]]
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Set a scalar value for the given selector.
|
Set a scalar value for the given selector.
|
||||||
|
|
||||||
|
|||||||
@ -161,7 +161,7 @@ class ResponseStreamCoordinator:
|
|||||||
# Step 2: For each complete path, filter edges based on node blocking behavior
|
# Step 2: For each complete path, filter edges based on node blocking behavior
|
||||||
filtered_paths: list[Path] = []
|
filtered_paths: list[Path] = []
|
||||||
for path in all_complete_paths:
|
for path in all_complete_paths:
|
||||||
blocking_edges = []
|
blocking_edges: list[str] = []
|
||||||
for edge_id in path:
|
for edge_id in path:
|
||||||
edge = self.graph.edges[edge_id]
|
edge = self.graph.edges[edge_id]
|
||||||
source_node = self.graph.nodes[edge.tail]
|
source_node = self.graph.nodes[edge.tail]
|
||||||
@ -260,7 +260,7 @@ class ResponseStreamCoordinator:
|
|||||||
if event.is_final:
|
if event.is_final:
|
||||||
self.registry.close_stream(event.selector)
|
self.registry.close_stream(event.selector)
|
||||||
return self.try_flush()
|
return self.try_flush()
|
||||||
elif isinstance(event, NodeRunSucceededEvent):
|
else:
|
||||||
# Skip cause we share the same variable pool.
|
# Skip cause we share the same variable pool.
|
||||||
#
|
#
|
||||||
# for variable_name, variable_value in event.node_run_result.outputs.items():
|
# for variable_name, variable_value in event.node_run_result.outputs.items():
|
||||||
@ -426,7 +426,7 @@ class ResponseStreamCoordinator:
|
|||||||
# Wait for more data
|
# Wait for more data
|
||||||
break
|
break
|
||||||
|
|
||||||
elif isinstance(segment, TextSegment):
|
else:
|
||||||
segment_events = self._process_text_segment(segment)
|
segment_events = self._process_text_segment(segment)
|
||||||
events.extend(segment_events)
|
events.extend(segment_events)
|
||||||
self.active_session.index += 1
|
self.active_session.index += 1
|
||||||
|
|||||||
@ -15,6 +15,7 @@ from typing import final
|
|||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from flask import Flask
|
from flask import Flask
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from core.workflow.enums import NodeType
|
from core.workflow.enums import NodeType
|
||||||
from core.workflow.graph import Graph
|
from core.workflow.graph import Graph
|
||||||
@ -73,6 +74,7 @@ class Worker(threading.Thread):
|
|||||||
"""Signal the worker to stop processing."""
|
"""Signal the worker to stop processing."""
|
||||||
self._stop_event.set()
|
self._stop_event.set()
|
||||||
|
|
||||||
|
@override
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
"""
|
"""
|
||||||
Main worker loop.
|
Main worker loop.
|
||||||
|
|||||||
@ -46,7 +46,7 @@ class ActivityTracker:
|
|||||||
List of idle worker IDs
|
List of idle worker IDs
|
||||||
"""
|
"""
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
idle_workers = []
|
idle_workers: list[int] = []
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
for worker_id, (is_active, last_change) in self._worker_activity.items():
|
for worker_id, (is_active, last_change) in self._worker_activity.items():
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from typing import final
|
|||||||
from flask import Flask
|
from flask import Flask
|
||||||
|
|
||||||
from core.workflow.graph import Graph
|
from core.workflow.graph import Graph
|
||||||
|
from core.workflow.graph_events import GraphNodeEventBase
|
||||||
|
|
||||||
from ..worker import Worker
|
from ..worker import Worker
|
||||||
|
|
||||||
@ -42,7 +43,7 @@ class WorkerFactory:
|
|||||||
def create_worker(
|
def create_worker(
|
||||||
self,
|
self,
|
||||||
ready_queue: queue.Queue[str],
|
ready_queue: queue.Queue[str],
|
||||||
event_queue: queue.Queue,
|
event_queue: queue.Queue[GraphNodeEventBase],
|
||||||
graph: Graph,
|
graph: Graph,
|
||||||
on_idle_callback: Callable[[int], None] | None = None,
|
on_idle_callback: Callable[[int], None] | None = None,
|
||||||
on_active_callback: Callable[[int], None] | None = None,
|
on_active_callback: Callable[[int], None] | None = None,
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import threading
|
|||||||
from typing import final
|
from typing import final
|
||||||
|
|
||||||
from core.workflow.graph import Graph
|
from core.workflow.graph import Graph
|
||||||
|
from core.workflow.graph_events import GraphNodeEventBase
|
||||||
|
|
||||||
from ..worker import Worker
|
from ..worker import Worker
|
||||||
from .activity_tracker import ActivityTracker
|
from .activity_tracker import ActivityTracker
|
||||||
@ -26,7 +27,7 @@ class WorkerPool:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
ready_queue: queue.Queue[str],
|
ready_queue: queue.Queue[str],
|
||||||
event_queue: queue.Queue,
|
event_queue: queue.Queue[GraphNodeEventBase],
|
||||||
graph: Graph,
|
graph: Graph,
|
||||||
worker_factory: WorkerFactory,
|
worker_factory: WorkerFactory,
|
||||||
dynamic_scaler: DynamicScaler,
|
dynamic_scaler: DynamicScaler,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user