refactor(graph_engine): Remove backward compatibility code

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2025-09-01 02:41:16 +08:00
parent e2f4c9ba8d
commit 202fdfcb81
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF
8 changed files with 38 additions and 121 deletions

View File

@ -56,8 +56,7 @@ class EventHandlerRegistry:
event_collector: "EventCollector",
branch_handler: "BranchHandler",
edge_processor: "EdgeProcessor",
node_state_manager: "UnifiedStateManager",
execution_tracker: "UnifiedStateManager",
state_manager: "UnifiedStateManager",
error_handler: "ErrorHandler",
) -> None:
"""
@ -71,8 +70,7 @@ class EventHandlerRegistry:
event_collector: Event collector for collecting events
branch_handler: Branch handler for branch node processing
edge_processor: Edge processor for edge traversal
node_state_manager: Node state manager
execution_tracker: Execution tracker
state_manager: Unified state manager
error_handler: Error handler
"""
self._graph = graph
@ -82,8 +80,7 @@ class EventHandlerRegistry:
self._event_collector = event_collector
self._branch_handler = branch_handler
self._edge_processor = edge_processor
self._node_state_manager = node_state_manager
self._execution_tracker = execution_tracker
self._state_manager = state_manager
self._error_handler = error_handler
def handle_event(self, event: GraphNodeEventBase) -> None:
@ -199,11 +196,11 @@ class EventHandlerRegistry:
# Enqueue ready nodes
for node_id in ready_nodes:
self._node_state_manager.enqueue_node(node_id)
self._execution_tracker.add(node_id)
self._state_manager.enqueue_node(node_id)
self._state_manager.start_execution(node_id)
# Update execution tracking
self._execution_tracker.remove(event.node_id)
self._state_manager.finish_execution(event.node_id)
# Handle response node outputs
if node.execution_type == NodeExecutionType.RESPONSE:
@ -232,7 +229,7 @@ class EventHandlerRegistry:
# Abort execution
self._graph_execution.fail(RuntimeError(event.error))
self._event_collector.collect(event)
self._execution_tracker.remove(event.node_id)
self._state_manager.finish_execution(event.node_id)
def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None:
"""

View File

@ -137,20 +137,18 @@ class GraphEngine:
self.node_readiness_checker = NodeReadinessChecker(self.graph)
self.edge_processor = EdgeProcessor(
graph=self.graph,
edge_state_manager=self.state_manager,
node_state_manager=self.state_manager,
state_manager=self.state_manager,
response_coordinator=self.response_coordinator,
)
self.skip_propagator = SkipPropagator(
graph=self.graph,
edge_state_manager=self.state_manager,
node_state_manager=self.state_manager,
state_manager=self.state_manager,
)
self.branch_handler = BranchHandler(
graph=self.graph,
edge_processor=self.edge_processor,
skip_propagator=self.skip_propagator,
edge_state_manager=self.state_manager,
state_manager=self.state_manager,
)
# Event handler registry with all dependencies
@ -162,8 +160,7 @@ class GraphEngine:
event_collector=self.event_collector,
branch_handler=self.branch_handler,
edge_processor=self.edge_processor,
node_state_manager=self.state_manager,
execution_tracker=self.state_manager,
state_manager=self.state_manager,
error_handler=self.error_handler,
)
@ -180,8 +177,7 @@ class GraphEngine:
# Orchestration
self.execution_coordinator = ExecutionCoordinator(
graph_execution=self.graph_execution,
node_state_manager=self.state_manager,
execution_tracker=self.state_manager,
state_manager=self.state_manager,
event_handler=self.event_handler_registry,
event_collector=self.event_collector,
command_processor=self.command_processor,
@ -334,7 +330,7 @@ class GraphEngine:
# Enqueue root node
root_node = self.graph.root_node
self.state_manager.enqueue_node(root_node.id)
self.state_manager.add(root_node.id)
self.state_manager.start_execution(root_node.id)
# Start dispatcher
self.dispatcher.start()

View File

@ -27,7 +27,7 @@ class BranchHandler:
graph: Graph,
edge_processor: EdgeProcessor,
skip_propagator: SkipPropagator,
edge_state_manager: UnifiedStateManager,
state_manager: UnifiedStateManager,
) -> None:
"""
Initialize the branch handler.
@ -36,12 +36,12 @@ class BranchHandler:
graph: The workflow graph
edge_processor: Processor for edges
skip_propagator: Propagator for skip states
edge_state_manager: Manager for edge states
state_manager: Unified state manager
"""
self.graph = graph
self.edge_processor = edge_processor
self.skip_propagator = skip_propagator
self.edge_state_manager = edge_state_manager
self.state_manager = state_manager
def handle_branch_completion(
self, node_id: str, selected_handle: str | None
@ -63,7 +63,7 @@ class BranchHandler:
raise ValueError(f"Branch node {node_id} completed without selecting a branch")
# Categorize edges into selected and unselected
_, unselected_edges = self.edge_state_manager.categorize_branch_edges(node_id, selected_handle)
_, unselected_edges = self.state_manager.categorize_branch_edges(node_id, selected_handle)
# Skip all unselected paths
self.skip_propagator.skip_branch_paths(unselected_edges)

View File

@ -25,8 +25,7 @@ class EdgeProcessor:
def __init__(
self,
graph: Graph,
edge_state_manager: UnifiedStateManager,
node_state_manager: UnifiedStateManager,
state_manager: UnifiedStateManager,
response_coordinator: ResponseStreamCoordinator,
) -> None:
"""
@ -34,13 +33,11 @@ class EdgeProcessor:
Args:
graph: The workflow graph
edge_state_manager: Manager for edge states
node_state_manager: Manager for node states
state_manager: Unified state manager
response_coordinator: Response stream coordinator
"""
self.graph = graph
self.edge_state_manager = edge_state_manager
self.node_state_manager = node_state_manager
self.state_manager = state_manager
self.response_coordinator = response_coordinator
def process_node_success(
@ -107,7 +104,7 @@ class EdgeProcessor:
all_streaming_events: list[NodeRunStreamChunkEvent] = []
# Categorize edges
selected_edges, unselected_edges = self.edge_state_manager.categorize_branch_edges(node_id, selected_handle)
selected_edges, unselected_edges = self.state_manager.categorize_branch_edges(node_id, selected_handle)
# Process unselected edges first (mark as skipped)
for edge in unselected_edges:
@ -132,14 +129,14 @@ class EdgeProcessor:
Tuple of (list containing downstream node ID if it's ready, list of streaming events)
"""
# Mark edge as taken
self.edge_state_manager.mark_edge_taken(edge.id)
self.state_manager.mark_edge_taken(edge.id)
# Notify response coordinator and get streaming events
streaming_events = self.response_coordinator.on_edge_taken(edge.id)
# Check if downstream node is ready
ready_nodes: list[str] = []
if self.node_state_manager.is_node_ready(edge.head):
if self.state_manager.is_node_ready(edge.head):
ready_nodes.append(edge.head)
return ready_nodes, streaming_events
@ -151,4 +148,4 @@ class EdgeProcessor:
Args:
edge: The edge to skip
"""
self.edge_state_manager.mark_edge_skipped(edge.id)
self.state_manager.mark_edge_skipped(edge.id)

View File

@ -22,20 +22,17 @@ class SkipPropagator:
def __init__(
self,
graph: Graph,
edge_state_manager: UnifiedStateManager,
node_state_manager: UnifiedStateManager,
state_manager: UnifiedStateManager,
) -> None:
"""
Initialize the skip propagator.
Args:
graph: The workflow graph
edge_state_manager: Manager for edge states
node_state_manager: Manager for node states
state_manager: Unified state manager
"""
self.graph = graph
self.edge_state_manager = edge_state_manager
self.node_state_manager = node_state_manager
self.state_manager = state_manager
def propagate_skip_from_edge(self, edge_id: str) -> None:
"""
@ -53,7 +50,7 @@ class SkipPropagator:
incoming_edges = self.graph.get_incoming_edges(downstream_node_id)
# Analyze edge states
edge_states = self.edge_state_manager.analyze_edge_states(incoming_edges)
edge_states = self.state_manager.analyze_edge_states(incoming_edges)
# Stop if there are unknown edges (not yet processed)
if edge_states["has_unknown"]:
@ -62,7 +59,7 @@ class SkipPropagator:
# If any edge is taken, node may still execute
if edge_states["has_taken"]:
# Enqueue node
self.node_state_manager.enqueue_node(downstream_node_id)
self.state_manager.enqueue_node(downstream_node_id)
return
# All edges are skipped, propagate skip to this node
@ -77,12 +74,12 @@ class SkipPropagator:
node_id: The ID of the node to skip
"""
# Mark node as skipped
self.node_state_manager.mark_node_skipped(node_id)
self.state_manager.mark_node_skipped(node_id)
# Mark all outgoing edges as skipped and propagate
outgoing_edges = self.graph.get_outgoing_edges(node_id)
for edge in outgoing_edges:
self.edge_state_manager.mark_edge_skipped(edge.id)
self.state_manager.mark_edge_skipped(edge.id)
# Recursively propagate skip
self.propagate_skip_from_edge(edge.id)
@ -94,5 +91,5 @@ class SkipPropagator:
unselected_edges: List of edges not taken by the branch
"""
for edge in unselected_edges:
self.edge_state_manager.mark_edge_skipped(edge.id)
self.state_manager.mark_edge_skipped(edge.id)
self.propagate_skip_from_edge(edge.id)

View File

@ -26,8 +26,7 @@ class ExecutionCoordinator:
def __init__(
self,
graph_execution: GraphExecution,
node_state_manager: UnifiedStateManager,
execution_tracker: UnifiedStateManager,
state_manager: UnifiedStateManager,
event_handler: "EventHandlerRegistry",
event_collector: EventCollector,
command_processor: CommandProcessor,
@ -38,16 +37,14 @@ class ExecutionCoordinator:
Args:
graph_execution: Graph execution aggregate
node_state_manager: Manager for node states
execution_tracker: Tracker for executing nodes
state_manager: Unified state manager
event_handler: Event handler registry for processing events
event_collector: Event collector for collecting events
command_processor: Processor for commands
worker_pool: Pool of workers
"""
self.graph_execution = graph_execution
self.node_state_manager = node_state_manager
self.execution_tracker = execution_tracker
self.state_manager = state_manager
self.event_handler = event_handler
self.event_collector = event_collector
self.command_processor = command_processor
@ -59,8 +56,8 @@ class ExecutionCoordinator:
def check_scaling(self) -> None:
"""Check and perform worker scaling if needed."""
queue_depth = self.node_state_manager.ready_queue.qsize()
executing_count = self.execution_tracker.count()
queue_depth = self.state_manager.ready_queue.qsize()
executing_count = self.state_manager.get_executing_count()
self.worker_pool.check_scaling(queue_depth, executing_count)
def is_execution_complete(self) -> bool:
@ -75,7 +72,7 @@ class ExecutionCoordinator:
return True
# Complete if no work remains
return self.node_state_manager.ready_queue.empty() and self.execution_tracker.is_empty()
return self.state_manager.is_execution_complete()
def mark_complete(self) -> None:
"""Mark execution as complete."""

View File

@ -302,42 +302,3 @@ class UnifiedStateManager:
"skipped_nodes": skipped_nodes,
"unknown_nodes": unknown_nodes,
}
# ============= Backward Compatibility Methods =============
# These methods provide compatibility with existing code
@property
def execution_tracker(self) -> "UnifiedStateManager":
"""Compatibility property for ExecutionTracker access."""
return self
@property
def node_state_manager(self) -> "UnifiedStateManager":
"""Compatibility property for NodeStateManager access."""
return self
@property
def edge_state_manager(self) -> "UnifiedStateManager":
"""Compatibility property for EdgeStateManager access."""
return self
# ExecutionTracker compatibility methods
def add(self, node_id: str) -> None:
"""Compatibility method for ExecutionTracker.add()."""
self.start_execution(node_id)
def remove(self, node_id: str) -> None:
"""Compatibility method for ExecutionTracker.remove()."""
self.finish_execution(node_id)
def is_empty(self) -> bool:
"""Compatibility method for ExecutionTracker.is_empty()."""
return len(self._executing_nodes) == 0
def count(self) -> int:
"""Compatibility method for ExecutionTracker.count()."""
return self.get_executing_count()
def clear(self) -> None:
"""Compatibility method for ExecutionTracker.clear()."""
self.clear_executing()

View File

@ -330,31 +330,3 @@ class EnhancedWorkerPool:
"min_workers": self.min_workers,
"max_workers": self.max_workers,
}
# ============= Backward Compatibility =============
def scale_up(self) -> None:
"""Compatibility method for manual scale up."""
with self._lock:
if self._running and len(self.workers) < self.max_workers:
self._add_worker()
def scale_down(self, worker_ids: list[int]) -> None:
"""Compatibility method for manual scale down."""
with self._lock:
if not self._running:
return
for worker_id in worker_ids:
if len(self.workers) > self.min_workers:
self._remove_worker(worker_id)
def check_scaling(self, queue_depth: int, executing_count: int) -> None:
"""
Compatibility method for checking scaling.
Args:
queue_depth: Current queue depth (ignored, we check directly)
executing_count: Number of executing nodes (ignored)
"""
self.check_and_scale()