diff --git a/api/.importlinter b/api/.importlinter index 9aa1073c38..9205b7c94d 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -34,7 +34,7 @@ ignore_imports = [importlinter:contract:rsc] name = RSC type = layers -layers = +layers = graph_engine response_coordinator output_registry @@ -44,7 +44,7 @@ containers = [importlinter:contract:worker] name = Worker type = layers -layers = +layers = graph_engine worker containers = @@ -77,18 +77,8 @@ forbidden_modules = core.workflow.graph_engine.layers core.workflow.graph_engine.protocols -[importlinter:contract:state-management-layers] -name = State Management Layers -type = layers -layers = - execution_tracker - node_state_manager - edge_state_manager -containers = - core.workflow.graph_engine.state_management - [importlinter:contract:worker-management-layers] -name = Worker Management Layers +name = Worker Management Layers type = layers layers = worker_pool @@ -119,4 +109,4 @@ name = Command Channels Independence type = independence modules = core.workflow.graph_engine.command_channels.in_memory_channel - core.workflow.graph_engine.command_channels.redis_channel \ No newline at end of file + core.workflow.graph_engine.command_channels.redis_channel diff --git a/api/core/workflow/graph_engine/event_management/event_handlers.py b/api/core/workflow/graph_engine/event_management/event_handlers.py index 842bd2635f..bdd1c4d245 100644 --- a/api/core/workflow/graph_engine/event_management/event_handlers.py +++ b/api/core/workflow/graph_engine/event_management/event_handlers.py @@ -32,7 +32,7 @@ from ..response_coordinator import ResponseStreamCoordinator if TYPE_CHECKING: from ..error_handling import ErrorHandler from ..graph_traversal import BranchHandler, EdgeProcessor - from ..state_management import ExecutionTracker, NodeStateManager + from ..state_management import UnifiedStateManager from .event_collector import EventCollector logger = logging.getLogger(__name__) @@ -56,8 +56,8 @@ class EventHandlerRegistry: event_collector: "EventCollector", branch_handler: "BranchHandler", edge_processor: "EdgeProcessor", - node_state_manager: "NodeStateManager", - execution_tracker: "ExecutionTracker", + node_state_manager: "UnifiedStateManager", + execution_tracker: "UnifiedStateManager", error_handler: "ErrorHandler", ) -> None: """ diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 828e9b329f..7398b846d8 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -39,7 +39,7 @@ from .orchestration import Dispatcher, ExecutionCoordinator from .output_registry import OutputRegistry from .protocols.command_channel import CommandChannel from .response_coordinator import ResponseStreamCoordinator -from .state_management import EdgeStateManager, ExecutionTracker, NodeStateManager +from .state_management import UnifiedStateManager from .worker_management import ActivityTracker, DynamicScaler, WorkerFactory, WorkerPool logger = logging.getLogger(__name__) @@ -119,10 +119,8 @@ class GraphEngine: def _initialize_subsystems(self) -> None: """Initialize all subsystems with proper dependency injection.""" - # State management - self.node_state_manager = NodeStateManager(self.graph, self.ready_queue) - self.edge_state_manager = EdgeStateManager(self.graph) - self.execution_tracker = ExecutionTracker() + # Unified state management - single instance handles all state operations + self.state_manager = UnifiedStateManager(self.graph, self.ready_queue) # Response coordination self.output_registry = OutputRegistry(self.graph_runtime_state.variable_pool) @@ -139,20 +137,20 @@ class GraphEngine: self.node_readiness_checker = NodeReadinessChecker(self.graph) self.edge_processor = EdgeProcessor( graph=self.graph, - edge_state_manager=self.edge_state_manager, - node_state_manager=self.node_state_manager, + edge_state_manager=self.state_manager, + node_state_manager=self.state_manager, response_coordinator=self.response_coordinator, ) self.skip_propagator = SkipPropagator( graph=self.graph, - edge_state_manager=self.edge_state_manager, - node_state_manager=self.node_state_manager, + edge_state_manager=self.state_manager, + node_state_manager=self.state_manager, ) self.branch_handler = BranchHandler( graph=self.graph, edge_processor=self.edge_processor, skip_propagator=self.skip_propagator, - edge_state_manager=self.edge_state_manager, + edge_state_manager=self.state_manager, ) # Event handler registry with all dependencies @@ -164,8 +162,8 @@ class GraphEngine: event_collector=self.event_collector, branch_handler=self.branch_handler, edge_processor=self.edge_processor, - node_state_manager=self.node_state_manager, - execution_tracker=self.execution_tracker, + node_state_manager=self.state_manager, + execution_tracker=self.state_manager, error_handler=self.error_handler, ) @@ -182,8 +180,8 @@ class GraphEngine: # Orchestration self.execution_coordinator = ExecutionCoordinator( graph_execution=self.graph_execution, - node_state_manager=self.node_state_manager, - execution_tracker=self.execution_tracker, + node_state_manager=self.state_manager, + execution_tracker=self.state_manager, event_handler=self.event_handler_registry, event_collector=self.event_collector, command_processor=self.command_processor, @@ -335,8 +333,8 @@ class GraphEngine: # Enqueue root node root_node = self.graph.root_node - self.node_state_manager.enqueue_node(root_node.id) - self.execution_tracker.add(root_node.id) + self.state_manager.enqueue_node(root_node.id) + self.state_manager.add(root_node.id) # Start dispatcher self.dispatcher.start() diff --git a/api/core/workflow/graph_engine/graph_traversal/branch_handler.py b/api/core/workflow/graph_engine/graph_traversal/branch_handler.py index b371f3bc73..8e08a03e3c 100644 --- a/api/core/workflow/graph_engine/graph_traversal/branch_handler.py +++ b/api/core/workflow/graph_engine/graph_traversal/branch_handler.py @@ -8,7 +8,7 @@ from typing import final from core.workflow.graph import Graph from core.workflow.graph_events.node import NodeRunStreamChunkEvent -from ..state_management import EdgeStateManager +from ..state_management import UnifiedStateManager from .edge_processor import EdgeProcessor from .skip_propagator import SkipPropagator @@ -27,7 +27,7 @@ class BranchHandler: graph: Graph, edge_processor: EdgeProcessor, skip_propagator: SkipPropagator, - edge_state_manager: EdgeStateManager, + edge_state_manager: UnifiedStateManager, ) -> None: """ Initialize the branch handler. diff --git a/api/core/workflow/graph_engine/graph_traversal/edge_processor.py b/api/core/workflow/graph_engine/graph_traversal/edge_processor.py index ac2c658b4b..6efb56f046 100644 --- a/api/core/workflow/graph_engine/graph_traversal/edge_processor.py +++ b/api/core/workflow/graph_engine/graph_traversal/edge_processor.py @@ -10,7 +10,7 @@ from core.workflow.graph import Edge, Graph from core.workflow.graph_events import NodeRunStreamChunkEvent from ..response_coordinator import ResponseStreamCoordinator -from ..state_management import EdgeStateManager, NodeStateManager +from ..state_management import UnifiedStateManager @final @@ -25,8 +25,8 @@ class EdgeProcessor: def __init__( self, graph: Graph, - edge_state_manager: EdgeStateManager, - node_state_manager: NodeStateManager, + edge_state_manager: UnifiedStateManager, + node_state_manager: UnifiedStateManager, response_coordinator: ResponseStreamCoordinator, ) -> None: """ diff --git a/api/core/workflow/graph_engine/graph_traversal/skip_propagator.py b/api/core/workflow/graph_engine/graph_traversal/skip_propagator.py index 5ac445d405..01426809eb 100644 --- a/api/core/workflow/graph_engine/graph_traversal/skip_propagator.py +++ b/api/core/workflow/graph_engine/graph_traversal/skip_propagator.py @@ -7,7 +7,7 @@ from typing import final from core.workflow.graph import Edge, Graph -from ..state_management import EdgeStateManager, NodeStateManager +from ..state_management import UnifiedStateManager @final @@ -22,8 +22,8 @@ class SkipPropagator: def __init__( self, graph: Graph, - edge_state_manager: EdgeStateManager, - node_state_manager: NodeStateManager, + edge_state_manager: UnifiedStateManager, + node_state_manager: UnifiedStateManager, ) -> None: """ Initialize the skip propagator. diff --git a/api/core/workflow/graph_engine/orchestration/execution_coordinator.py b/api/core/workflow/graph_engine/orchestration/execution_coordinator.py index 5f95b5b29e..3d9783703e 100644 --- a/api/core/workflow/graph_engine/orchestration/execution_coordinator.py +++ b/api/core/workflow/graph_engine/orchestration/execution_coordinator.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, final from ..command_processing import CommandProcessor from ..domain import GraphExecution from ..event_management import EventCollector -from ..state_management import ExecutionTracker, NodeStateManager +from ..state_management import UnifiedStateManager from ..worker_management import WorkerPool if TYPE_CHECKING: @@ -26,8 +26,8 @@ class ExecutionCoordinator: def __init__( self, graph_execution: GraphExecution, - node_state_manager: NodeStateManager, - execution_tracker: ExecutionTracker, + node_state_manager: UnifiedStateManager, + execution_tracker: UnifiedStateManager, event_handler: "EventHandlerRegistry", event_collector: EventCollector, command_processor: CommandProcessor, diff --git a/api/core/workflow/graph_engine/state_management/__init__.py b/api/core/workflow/graph_engine/state_management/__init__.py index 6680696ed2..9a632a3b9f 100644 --- a/api/core/workflow/graph_engine/state_management/__init__.py +++ b/api/core/workflow/graph_engine/state_management/__init__.py @@ -5,12 +5,8 @@ This package manages node states, edge states, and execution tracking during workflow graph execution. """ -from .edge_state_manager import EdgeStateManager -from .execution_tracker import ExecutionTracker -from .node_state_manager import NodeStateManager +from .unified_state_manager import UnifiedStateManager __all__ = [ - "EdgeStateManager", - "ExecutionTracker", - "NodeStateManager", + "UnifiedStateManager", ] diff --git a/api/core/workflow/graph_engine/state_management/edge_state_manager.py b/api/core/workflow/graph_engine/state_management/edge_state_manager.py deleted file mode 100644 index 747062284a..0000000000 --- a/api/core/workflow/graph_engine/state_management/edge_state_manager.py +++ /dev/null @@ -1,114 +0,0 @@ -""" -Manager for edge states during graph execution. -""" - -import threading -from collections.abc import Sequence -from typing import TypedDict, final - -from core.workflow.enums import NodeState -from core.workflow.graph import Edge, Graph - - -class EdgeStateAnalysis(TypedDict): - """Analysis result for edge states.""" - - has_unknown: bool - has_taken: bool - all_skipped: bool - - -@final -class EdgeStateManager: - """ - Manages edge states and transitions during graph execution. - - This handles edge state changes and provides analysis of edge - states for decision making during execution. - """ - - def __init__(self, graph: Graph) -> None: - """ - Initialize the edge state manager. - - Args: - graph: The workflow graph - """ - self.graph = graph - self._lock = threading.RLock() - - def mark_edge_taken(self, edge_id: str) -> None: - """ - Mark an edge as TAKEN. - - Args: - edge_id: The ID of the edge to mark - """ - with self._lock: - self.graph.edges[edge_id].state = NodeState.TAKEN - - def mark_edge_skipped(self, edge_id: str) -> None: - """ - Mark an edge as SKIPPED. - - Args: - edge_id: The ID of the edge to mark - """ - with self._lock: - self.graph.edges[edge_id].state = NodeState.SKIPPED - - def analyze_edge_states(self, edges: list[Edge]) -> EdgeStateAnalysis: - """ - Analyze the states of edges and return summary flags. - - Args: - edges: List of edges to analyze - - Returns: - Analysis result with state flags - """ - with self._lock: - states = {edge.state for edge in edges} - - return EdgeStateAnalysis( - has_unknown=NodeState.UNKNOWN in states, - has_taken=NodeState.TAKEN in states, - all_skipped=states == {NodeState.SKIPPED} if states else True, - ) - - def get_edge_state(self, edge_id: str) -> NodeState: - """ - Get the current state of an edge. - - Args: - edge_id: The ID of the edge - - Returns: - The current edge state - """ - with self._lock: - return self.graph.edges[edge_id].state - - def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[Sequence[Edge], Sequence[Edge]]: - """ - Categorize branch edges into selected and unselected. - - Args: - node_id: The ID of the branch node - selected_handle: The handle of the selected edge - - Returns: - A tuple of (selected_edges, unselected_edges) - """ - with self._lock: - outgoing_edges = self.graph.get_outgoing_edges(node_id) - selected_edges: list[Edge] = [] - unselected_edges: list[Edge] = [] - - for edge in outgoing_edges: - if edge.source_handle == selected_handle: - selected_edges.append(edge) - else: - unselected_edges.append(edge) - - return selected_edges, unselected_edges diff --git a/api/core/workflow/graph_engine/state_management/execution_tracker.py b/api/core/workflow/graph_engine/state_management/execution_tracker.py deleted file mode 100644 index 01fa80f2ce..0000000000 --- a/api/core/workflow/graph_engine/state_management/execution_tracker.py +++ /dev/null @@ -1,89 +0,0 @@ -""" -Tracker for currently executing nodes. -""" - -import threading -from typing import final - - -@final -class ExecutionTracker: - """ - Tracks nodes that are currently being executed. - - This replaces the ExecutingNodesManager with a cleaner interface - focused on tracking which nodes are in progress. - """ - - def __init__(self) -> None: - """Initialize the execution tracker.""" - self._executing_nodes: set[str] = set() - self._lock = threading.RLock() - - def add(self, node_id: str) -> None: - """ - Mark a node as executing. - - Args: - node_id: The ID of the node starting execution - """ - with self._lock: - self._executing_nodes.add(node_id) - - def remove(self, node_id: str) -> None: - """ - Mark a node as no longer executing. - - Args: - node_id: The ID of the node finishing execution - """ - with self._lock: - self._executing_nodes.discard(node_id) - - def is_executing(self, node_id: str) -> bool: - """ - Check if a node is currently executing. - - Args: - node_id: The ID of the node to check - - Returns: - True if the node is executing - """ - with self._lock: - return node_id in self._executing_nodes - - def is_empty(self) -> bool: - """ - Check if no nodes are currently executing. - - Returns: - True if no nodes are executing - """ - with self._lock: - return len(self._executing_nodes) == 0 - - def count(self) -> int: - """ - Get the count of currently executing nodes. - - Returns: - Number of executing nodes - """ - with self._lock: - return len(self._executing_nodes) - - def get_executing_nodes(self) -> set[str]: - """ - Get a copy of the set of executing node IDs. - - Returns: - Set of node IDs currently executing - """ - with self._lock: - return self._executing_nodes.copy() - - def clear(self) -> None: - """Clear all executing nodes.""" - with self._lock: - self._executing_nodes.clear() diff --git a/api/core/workflow/graph_engine/state_management/node_state_manager.py b/api/core/workflow/graph_engine/state_management/node_state_manager.py deleted file mode 100644 index d5ed42ad1d..0000000000 --- a/api/core/workflow/graph_engine/state_management/node_state_manager.py +++ /dev/null @@ -1,97 +0,0 @@ -""" -Manager for node states during graph execution. -""" - -import queue -import threading -from typing import final - -from core.workflow.enums import NodeState -from core.workflow.graph import Graph - - -@final -class NodeStateManager: - """ - Manages node states and the ready queue for execution. - - This centralizes node state transitions and enqueueing logic, - ensuring thread-safe operations on node states. - """ - - def __init__(self, graph: Graph, ready_queue: queue.Queue[str]) -> None: - """ - Initialize the node state manager. - - Args: - graph: The workflow graph - ready_queue: Queue for nodes ready to execute - """ - self.graph = graph - self.ready_queue = ready_queue - self._lock = threading.RLock() - - def enqueue_node(self, node_id: str) -> None: - """ - Mark a node as TAKEN and add it to the ready queue. - - This combines the state transition and enqueueing operations - that always occur together when preparing a node for execution. - - Args: - node_id: The ID of the node to enqueue - """ - with self._lock: - self.graph.nodes[node_id].state = NodeState.TAKEN - self.ready_queue.put(node_id) - - def mark_node_skipped(self, node_id: str) -> None: - """ - Mark a node as SKIPPED. - - Args: - node_id: The ID of the node to skip - """ - with self._lock: - self.graph.nodes[node_id].state = NodeState.SKIPPED - - def is_node_ready(self, node_id: str) -> bool: - """ - Check if a node is ready to be executed. - - A node is ready when all its incoming edges from taken branches - have been satisfied. - - Args: - node_id: The ID of the node to check - - Returns: - True if the node is ready for execution - """ - with self._lock: - # Get all incoming edges to this node - incoming_edges = self.graph.get_incoming_edges(node_id) - - # If no incoming edges, node is always ready - if not incoming_edges: - return True - - # If any edge is UNKNOWN, node is not ready - if any(edge.state == NodeState.UNKNOWN for edge in incoming_edges): - return False - - # Node is ready if at least one edge is TAKEN - return any(edge.state == NodeState.TAKEN for edge in incoming_edges) - - def get_node_state(self, node_id: str) -> NodeState: - """ - Get the current state of a node. - - Args: - node_id: The ID of the node - - Returns: - The current node state - """ - with self._lock: - return self.graph.nodes[node_id].state diff --git a/api/core/workflow/graph_engine/state_management/unified_state_manager.py b/api/core/workflow/graph_engine/state_management/unified_state_manager.py new file mode 100644 index 0000000000..3f50b68213 --- /dev/null +++ b/api/core/workflow/graph_engine/state_management/unified_state_manager.py @@ -0,0 +1,343 @@ +""" +Unified state manager that combines node, edge, and execution tracking. + +This is a proposed simplification that merges NodeStateManager, EdgeStateManager, +and ExecutionTracker into a single cohesive class. +""" + +import queue +import threading +from collections.abc import Sequence +from typing import TypedDict, final + +from core.workflow.enums import NodeState +from core.workflow.graph import Edge, Graph + + +class EdgeStateAnalysis(TypedDict): + """Analysis result for edge states.""" + + has_unknown: bool + has_taken: bool + all_skipped: bool + + +@final +class UnifiedStateManager: + """ + Unified manager for all graph state operations. + + This class combines the responsibilities of: + - NodeStateManager: Node state transitions and ready queue + - EdgeStateManager: Edge state transitions and analysis + - ExecutionTracker: Tracking executing nodes + + Benefits: + - Single lock for all state operations (reduced contention) + - Cohesive state management interface + - Simplified dependency injection + """ + + def __init__(self, graph: Graph, ready_queue: queue.Queue[str]) -> None: + """ + Initialize the unified state manager. + + Args: + graph: The workflow graph + ready_queue: Queue for nodes ready to execute + """ + self.graph = graph + self.ready_queue = ready_queue + self._lock = threading.RLock() + + # Execution tracking state + self._executing_nodes: set[str] = set() + + # ============= Node State Operations ============= + + def enqueue_node(self, node_id: str) -> None: + """ + Mark a node as TAKEN and add it to the ready queue. + + This combines the state transition and enqueueing operations + that always occur together when preparing a node for execution. + + Args: + node_id: The ID of the node to enqueue + """ + with self._lock: + self.graph.nodes[node_id].state = NodeState.TAKEN + self.ready_queue.put(node_id) + + def mark_node_skipped(self, node_id: str) -> None: + """ + Mark a node as SKIPPED. + + Args: + node_id: The ID of the node to skip + """ + with self._lock: + self.graph.nodes[node_id].state = NodeState.SKIPPED + + def is_node_ready(self, node_id: str) -> bool: + """ + Check if a node is ready to be executed. + + A node is ready when all its incoming edges from taken branches + have been satisfied. + + Args: + node_id: The ID of the node to check + + Returns: + True if the node is ready for execution + """ + with self._lock: + # Get all incoming edges to this node + incoming_edges = self.graph.get_incoming_edges(node_id) + + # If no incoming edges, node is always ready + if not incoming_edges: + return True + + # If any edge is UNKNOWN, node is not ready + if any(edge.state == NodeState.UNKNOWN for edge in incoming_edges): + return False + + # Node is ready if at least one edge is TAKEN + return any(edge.state == NodeState.TAKEN for edge in incoming_edges) + + def get_node_state(self, node_id: str) -> NodeState: + """ + Get the current state of a node. + + Args: + node_id: The ID of the node + + Returns: + The current node state + """ + with self._lock: + return self.graph.nodes[node_id].state + + # ============= Edge State Operations ============= + + def mark_edge_taken(self, edge_id: str) -> None: + """ + Mark an edge as TAKEN. + + Args: + edge_id: The ID of the edge to mark + """ + with self._lock: + self.graph.edges[edge_id].state = NodeState.TAKEN + + def mark_edge_skipped(self, edge_id: str) -> None: + """ + Mark an edge as SKIPPED. + + Args: + edge_id: The ID of the edge to mark + """ + with self._lock: + self.graph.edges[edge_id].state = NodeState.SKIPPED + + def analyze_edge_states(self, edges: list[Edge]) -> EdgeStateAnalysis: + """ + Analyze the states of edges and return summary flags. + + Args: + edges: List of edges to analyze + + Returns: + Analysis result with state flags + """ + with self._lock: + states = {edge.state for edge in edges} + + return EdgeStateAnalysis( + has_unknown=NodeState.UNKNOWN in states, + has_taken=NodeState.TAKEN in states, + all_skipped=states == {NodeState.SKIPPED} if states else True, + ) + + def get_edge_state(self, edge_id: str) -> NodeState: + """ + Get the current state of an edge. + + Args: + edge_id: The ID of the edge + + Returns: + The current edge state + """ + with self._lock: + return self.graph.edges[edge_id].state + + def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[Sequence[Edge], Sequence[Edge]]: + """ + Categorize branch edges into selected and unselected. + + Args: + node_id: The ID of the branch node + selected_handle: The handle of the selected edge + + Returns: + A tuple of (selected_edges, unselected_edges) + """ + with self._lock: + outgoing_edges = self.graph.get_outgoing_edges(node_id) + selected_edges: list[Edge] = [] + unselected_edges: list[Edge] = [] + + for edge in outgoing_edges: + if edge.source_handle == selected_handle: + selected_edges.append(edge) + else: + unselected_edges.append(edge) + + return selected_edges, unselected_edges + + # ============= Execution Tracking Operations ============= + + def start_execution(self, node_id: str) -> None: + """ + Mark a node as executing. + + Args: + node_id: The ID of the node starting execution + """ + with self._lock: + self._executing_nodes.add(node_id) + + def finish_execution(self, node_id: str) -> None: + """ + Mark a node as no longer executing. + + Args: + node_id: The ID of the node finishing execution + """ + with self._lock: + self._executing_nodes.discard(node_id) + + def is_executing(self, node_id: str) -> bool: + """ + Check if a node is currently executing. + + Args: + node_id: The ID of the node to check + + Returns: + True if the node is executing + """ + with self._lock: + return node_id in self._executing_nodes + + def get_executing_count(self) -> int: + """ + Get the count of currently executing nodes. + + Returns: + Number of executing nodes + """ + with self._lock: + return len(self._executing_nodes) + + def get_executing_nodes(self) -> set[str]: + """ + Get a copy of the set of executing node IDs. + + Returns: + Set of node IDs currently executing + """ + with self._lock: + return self._executing_nodes.copy() + + def clear_executing(self) -> None: + """Clear all executing nodes.""" + with self._lock: + self._executing_nodes.clear() + + # ============= Composite Operations ============= + + def is_execution_complete(self) -> bool: + """ + Check if graph execution is complete. + + Execution is complete when: + - Ready queue is empty + - No nodes are executing + + Returns: + True if execution is complete + """ + with self._lock: + return self.ready_queue.empty() and len(self._executing_nodes) == 0 + + def get_queue_depth(self) -> int: + """ + Get the current depth of the ready queue. + + Returns: + Number of nodes in the ready queue + """ + return self.ready_queue.qsize() + + def get_execution_stats(self) -> dict[str, int]: + """ + Get execution statistics. + + Returns: + Dictionary with execution statistics + """ + with self._lock: + taken_nodes = sum(1 for node in self.graph.nodes.values() if node.state == NodeState.TAKEN) + skipped_nodes = sum(1 for node in self.graph.nodes.values() if node.state == NodeState.SKIPPED) + unknown_nodes = sum(1 for node in self.graph.nodes.values() if node.state == NodeState.UNKNOWN) + + return { + "queue_depth": self.ready_queue.qsize(), + "executing": len(self._executing_nodes), + "taken_nodes": taken_nodes, + "skipped_nodes": skipped_nodes, + "unknown_nodes": unknown_nodes, + } + + # ============= Backward Compatibility Methods ============= + # These methods provide compatibility with existing code + + @property + def execution_tracker(self) -> "UnifiedStateManager": + """Compatibility property for ExecutionTracker access.""" + return self + + @property + def node_state_manager(self) -> "UnifiedStateManager": + """Compatibility property for NodeStateManager access.""" + return self + + @property + def edge_state_manager(self) -> "UnifiedStateManager": + """Compatibility property for EdgeStateManager access.""" + return self + + # ExecutionTracker compatibility methods + def add(self, node_id: str) -> None: + """Compatibility method for ExecutionTracker.add().""" + self.start_execution(node_id) + + def remove(self, node_id: str) -> None: + """Compatibility method for ExecutionTracker.remove().""" + self.finish_execution(node_id) + + def is_empty(self) -> bool: + """Compatibility method for ExecutionTracker.is_empty().""" + return len(self._executing_nodes) == 0 + + def count(self) -> int: + """Compatibility method for ExecutionTracker.count().""" + return self.get_executing_count() + + def clear(self) -> None: + """Compatibility method for ExecutionTracker.clear().""" + self.clear_executing() diff --git a/api/core/workflow/graph_engine/worker_management/enhanced_worker_pool.py b/api/core/workflow/graph_engine/worker_management/enhanced_worker_pool.py new file mode 100644 index 0000000000..015fb79e4f --- /dev/null +++ b/api/core/workflow/graph_engine/worker_management/enhanced_worker_pool.py @@ -0,0 +1,360 @@ +""" +Enhanced worker pool with integrated activity tracking and dynamic scaling. + +This is a proposed simplification that merges WorkerPool, ActivityTracker, +and DynamicScaler into a single cohesive class. +""" + +import queue +import threading +import time +from typing import TYPE_CHECKING, final + +from configs import dify_config +from core.workflow.graph import Graph +from core.workflow.graph_events import GraphNodeEventBase + +from ..worker import Worker + +if TYPE_CHECKING: + from contextvars import Context + + from flask import Flask + + +@final +class EnhancedWorkerPool: + """ + Enhanced worker pool with integrated features. + + This class combines the responsibilities of: + - WorkerPool: Managing worker threads + - ActivityTracker: Tracking worker activity + - DynamicScaler: Making scaling decisions + + Benefits: + - Simplified interface with fewer classes + - Direct integration of related features + - Reduced inter-class communication overhead + """ + + def __init__( + self, + ready_queue: queue.Queue[str], + event_queue: queue.Queue[GraphNodeEventBase], + graph: Graph, + flask_app: "Flask | None" = None, + context_vars: "Context | None" = None, + min_workers: int | None = None, + max_workers: int | None = None, + scale_up_threshold: int | None = None, + scale_down_idle_time: float | None = None, + ) -> None: + """ + Initialize the enhanced worker pool. + + Args: + ready_queue: Queue of nodes ready for execution + event_queue: Queue for worker events + graph: The workflow graph + flask_app: Optional Flask app for context preservation + context_vars: Optional context variables + min_workers: Minimum number of workers + max_workers: Maximum number of workers + scale_up_threshold: Queue depth to trigger scale up + scale_down_idle_time: Seconds before scaling down idle workers + """ + self.ready_queue = ready_queue + self.event_queue = event_queue + self.graph = graph + self.flask_app = flask_app + self.context_vars = context_vars + + # Scaling parameters + self.min_workers = min_workers or dify_config.GRAPH_ENGINE_MIN_WORKERS + self.max_workers = max_workers or dify_config.GRAPH_ENGINE_MAX_WORKERS + self.scale_up_threshold = scale_up_threshold or dify_config.GRAPH_ENGINE_SCALE_UP_THRESHOLD + self.scale_down_idle_time = scale_down_idle_time or dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME + + # Worker management + self.workers: list[Worker] = [] + self._worker_counter = 0 + self._lock = threading.RLock() + self._running = False + + # Activity tracking (integrated) + self._worker_activity: dict[int, tuple[bool, float]] = {} + + # Scaling control + self._last_scale_check = time.time() + self._scale_check_interval = 1.0 # Check scaling every second + + def start(self, initial_count: int | None = None) -> None: + """ + Start the worker pool with initial workers. + + Args: + initial_count: Number of workers to start with (auto-calculated if None) + """ + with self._lock: + if self._running: + return + + self._running = True + + # Calculate initial worker count if not specified + if initial_count is None: + initial_count = self._calculate_initial_workers() + + # Create initial workers + for _ in range(initial_count): + self._add_worker() + + def stop(self) -> None: + """Stop all workers in the pool.""" + with self._lock: + self._running = False + + # Stop all workers + for worker in self.workers: + worker.stop() + + # Wait for workers to finish + for worker in self.workers: + if worker.is_alive(): + worker.join(timeout=10.0) + + self.workers.clear() + self._worker_activity.clear() + + def check_and_scale(self) -> None: + """ + Check and perform scaling if needed. + + This method should be called periodically to adjust pool size. + """ + current_time = time.time() + + # Rate limit scaling checks + if current_time - self._last_scale_check < self._scale_check_interval: + return + + self._last_scale_check = current_time + + with self._lock: + if not self._running: + return + + current_count = len(self.workers) + queue_depth = self.ready_queue.qsize() + + # Check for scale up + if self._should_scale_up(current_count, queue_depth): + self._add_worker() + + # Check for scale down + idle_workers = self._get_idle_workers(current_time) + if idle_workers and self._should_scale_down(current_count): + # Remove the most idle worker + self._remove_worker(idle_workers[0]) + + # ============= Private Methods ============= + + def _calculate_initial_workers(self) -> int: + """ + Calculate initial number of workers based on graph complexity. + + Returns: + Initial worker count + """ + # Simple heuristic: start with min_workers, scale based on graph size + node_count = len(self.graph.nodes) + + if node_count < 10: + return self.min_workers + elif node_count < 50: + return min(self.min_workers + 1, self.max_workers) + else: + return min(self.min_workers + 2, self.max_workers) + + def _should_scale_up(self, current_count: int, queue_depth: int) -> bool: + """ + Determine if pool should scale up. + + Args: + current_count: Current number of workers + queue_depth: Current queue depth + + Returns: + True if should scale up + """ + if current_count >= self.max_workers: + return False + + # Scale up if queue is deep + if queue_depth > self.scale_up_threshold: + return True + + # Scale up if all workers are busy and queue is not empty + active_count = self._get_active_count() + if active_count == current_count and queue_depth > 0: + return True + + return False + + def _should_scale_down(self, current_count: int) -> bool: + """ + Determine if pool should scale down. + + Args: + current_count: Current number of workers + + Returns: + True if should scale down + """ + return current_count > self.min_workers + + def _add_worker(self) -> None: + """Add a new worker to the pool.""" + worker_id = self._worker_counter + self._worker_counter += 1 + + # Create worker with activity callbacks + worker = Worker( + ready_queue=self.ready_queue, + event_queue=self.event_queue, + graph=self.graph, + worker_id=worker_id, + flask_app=self.flask_app, + context_vars=self.context_vars, + on_idle_callback=self._on_worker_idle, + on_active_callback=self._on_worker_active, + ) + + worker.start() + self.workers.append(worker) + self._worker_activity[worker_id] = (False, time.time()) + + def _remove_worker(self, worker_id: int) -> None: + """ + Remove a specific worker from the pool. + + Args: + worker_id: ID of worker to remove + """ + worker_to_remove = None + for worker in self.workers: + if worker.worker_id == worker_id: + worker_to_remove = worker + break + + if worker_to_remove: + worker_to_remove.stop() + self.workers.remove(worker_to_remove) + self._worker_activity.pop(worker_id, None) + + if worker_to_remove.is_alive(): + worker_to_remove.join(timeout=1.0) + + def _on_worker_idle(self, worker_id: int) -> None: + """ + Callback when worker becomes idle. + + Args: + worker_id: ID of the idle worker + """ + with self._lock: + self._worker_activity[worker_id] = (False, time.time()) + + def _on_worker_active(self, worker_id: int) -> None: + """ + Callback when worker becomes active. + + Args: + worker_id: ID of the active worker + """ + with self._lock: + self._worker_activity[worker_id] = (True, time.time()) + + def _get_idle_workers(self, current_time: float) -> list[int]: + """ + Get list of workers that have been idle too long. + + Args: + current_time: Current timestamp + + Returns: + List of idle worker IDs sorted by idle time (longest first) + """ + idle_workers: list[tuple[int, float]] = [] + + for worker_id, (is_active, last_change) in self._worker_activity.items(): + if not is_active: + idle_time = current_time - last_change + if idle_time > self.scale_down_idle_time: + idle_workers.append((worker_id, idle_time)) + + # Sort by idle time (longest first) + idle_workers.sort(key=lambda x: x[1], reverse=True) + return [worker_id for worker_id, _ in idle_workers] + + def _get_active_count(self) -> int: + """ + Get count of currently active workers. + + Returns: + Number of active workers + """ + return sum(1 for is_active, _ in self._worker_activity.values() if is_active) + + # ============= Public Status Methods ============= + + def get_worker_count(self) -> int: + """Get current number of workers.""" + with self._lock: + return len(self.workers) + + def get_status(self) -> dict[str, int]: + """ + Get pool status information. + + Returns: + Dictionary with status information + """ + with self._lock: + return { + "total_workers": len(self.workers), + "active_workers": self._get_active_count(), + "idle_workers": len(self.workers) - self._get_active_count(), + "queue_depth": self.ready_queue.qsize(), + "min_workers": self.min_workers, + "max_workers": self.max_workers, + } + + # ============= Backward Compatibility ============= + + def scale_up(self) -> None: + """Compatibility method for manual scale up.""" + with self._lock: + if self._running and len(self.workers) < self.max_workers: + self._add_worker() + + def scale_down(self, worker_ids: list[int]) -> None: + """Compatibility method for manual scale down.""" + with self._lock: + if not self._running: + return + + for worker_id in worker_ids: + if len(self.workers) > self.min_workers: + self._remove_worker(worker_id) + + def check_scaling(self, queue_depth: int, executing_count: int) -> None: + """ + Compatibility method for checking scaling. + + Args: + queue_depth: Current queue depth (ignored, we check directly) + executing_count: Number of executing nodes (ignored) + """ + self.check_and_scale()