mirror of https://github.com/langgenius/dify.git
refactor(graph_engine): Remove backward compatibility code
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
e2f4c9ba8d
commit
202fdfcb81
|
|
@ -56,8 +56,7 @@ class EventHandlerRegistry:
|
|||
event_collector: "EventCollector",
|
||||
branch_handler: "BranchHandler",
|
||||
edge_processor: "EdgeProcessor",
|
||||
node_state_manager: "UnifiedStateManager",
|
||||
execution_tracker: "UnifiedStateManager",
|
||||
state_manager: "UnifiedStateManager",
|
||||
error_handler: "ErrorHandler",
|
||||
) -> None:
|
||||
"""
|
||||
|
|
@ -71,8 +70,7 @@ class EventHandlerRegistry:
|
|||
event_collector: Event collector for collecting events
|
||||
branch_handler: Branch handler for branch node processing
|
||||
edge_processor: Edge processor for edge traversal
|
||||
node_state_manager: Node state manager
|
||||
execution_tracker: Execution tracker
|
||||
state_manager: Unified state manager
|
||||
error_handler: Error handler
|
||||
"""
|
||||
self._graph = graph
|
||||
|
|
@ -82,8 +80,7 @@ class EventHandlerRegistry:
|
|||
self._event_collector = event_collector
|
||||
self._branch_handler = branch_handler
|
||||
self._edge_processor = edge_processor
|
||||
self._node_state_manager = node_state_manager
|
||||
self._execution_tracker = execution_tracker
|
||||
self._state_manager = state_manager
|
||||
self._error_handler = error_handler
|
||||
|
||||
def handle_event(self, event: GraphNodeEventBase) -> None:
|
||||
|
|
@ -199,11 +196,11 @@ class EventHandlerRegistry:
|
|||
|
||||
# Enqueue ready nodes
|
||||
for node_id in ready_nodes:
|
||||
self._node_state_manager.enqueue_node(node_id)
|
||||
self._execution_tracker.add(node_id)
|
||||
self._state_manager.enqueue_node(node_id)
|
||||
self._state_manager.start_execution(node_id)
|
||||
|
||||
# Update execution tracking
|
||||
self._execution_tracker.remove(event.node_id)
|
||||
self._state_manager.finish_execution(event.node_id)
|
||||
|
||||
# Handle response node outputs
|
||||
if node.execution_type == NodeExecutionType.RESPONSE:
|
||||
|
|
@ -232,7 +229,7 @@ class EventHandlerRegistry:
|
|||
# Abort execution
|
||||
self._graph_execution.fail(RuntimeError(event.error))
|
||||
self._event_collector.collect(event)
|
||||
self._execution_tracker.remove(event.node_id)
|
||||
self._state_manager.finish_execution(event.node_id)
|
||||
|
||||
def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -137,20 +137,18 @@ class GraphEngine:
|
|||
self.node_readiness_checker = NodeReadinessChecker(self.graph)
|
||||
self.edge_processor = EdgeProcessor(
|
||||
graph=self.graph,
|
||||
edge_state_manager=self.state_manager,
|
||||
node_state_manager=self.state_manager,
|
||||
state_manager=self.state_manager,
|
||||
response_coordinator=self.response_coordinator,
|
||||
)
|
||||
self.skip_propagator = SkipPropagator(
|
||||
graph=self.graph,
|
||||
edge_state_manager=self.state_manager,
|
||||
node_state_manager=self.state_manager,
|
||||
state_manager=self.state_manager,
|
||||
)
|
||||
self.branch_handler = BranchHandler(
|
||||
graph=self.graph,
|
||||
edge_processor=self.edge_processor,
|
||||
skip_propagator=self.skip_propagator,
|
||||
edge_state_manager=self.state_manager,
|
||||
state_manager=self.state_manager,
|
||||
)
|
||||
|
||||
# Event handler registry with all dependencies
|
||||
|
|
@ -162,8 +160,7 @@ class GraphEngine:
|
|||
event_collector=self.event_collector,
|
||||
branch_handler=self.branch_handler,
|
||||
edge_processor=self.edge_processor,
|
||||
node_state_manager=self.state_manager,
|
||||
execution_tracker=self.state_manager,
|
||||
state_manager=self.state_manager,
|
||||
error_handler=self.error_handler,
|
||||
)
|
||||
|
||||
|
|
@ -180,8 +177,7 @@ class GraphEngine:
|
|||
# Orchestration
|
||||
self.execution_coordinator = ExecutionCoordinator(
|
||||
graph_execution=self.graph_execution,
|
||||
node_state_manager=self.state_manager,
|
||||
execution_tracker=self.state_manager,
|
||||
state_manager=self.state_manager,
|
||||
event_handler=self.event_handler_registry,
|
||||
event_collector=self.event_collector,
|
||||
command_processor=self.command_processor,
|
||||
|
|
@ -334,7 +330,7 @@ class GraphEngine:
|
|||
# Enqueue root node
|
||||
root_node = self.graph.root_node
|
||||
self.state_manager.enqueue_node(root_node.id)
|
||||
self.state_manager.add(root_node.id)
|
||||
self.state_manager.start_execution(root_node.id)
|
||||
|
||||
# Start dispatcher
|
||||
self.dispatcher.start()
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ class BranchHandler:
|
|||
graph: Graph,
|
||||
edge_processor: EdgeProcessor,
|
||||
skip_propagator: SkipPropagator,
|
||||
edge_state_manager: UnifiedStateManager,
|
||||
state_manager: UnifiedStateManager,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the branch handler.
|
||||
|
|
@ -36,12 +36,12 @@ class BranchHandler:
|
|||
graph: The workflow graph
|
||||
edge_processor: Processor for edges
|
||||
skip_propagator: Propagator for skip states
|
||||
edge_state_manager: Manager for edge states
|
||||
state_manager: Unified state manager
|
||||
"""
|
||||
self.graph = graph
|
||||
self.edge_processor = edge_processor
|
||||
self.skip_propagator = skip_propagator
|
||||
self.edge_state_manager = edge_state_manager
|
||||
self.state_manager = state_manager
|
||||
|
||||
def handle_branch_completion(
|
||||
self, node_id: str, selected_handle: str | None
|
||||
|
|
@ -63,7 +63,7 @@ class BranchHandler:
|
|||
raise ValueError(f"Branch node {node_id} completed without selecting a branch")
|
||||
|
||||
# Categorize edges into selected and unselected
|
||||
_, unselected_edges = self.edge_state_manager.categorize_branch_edges(node_id, selected_handle)
|
||||
_, unselected_edges = self.state_manager.categorize_branch_edges(node_id, selected_handle)
|
||||
|
||||
# Skip all unselected paths
|
||||
self.skip_propagator.skip_branch_paths(unselected_edges)
|
||||
|
|
|
|||
|
|
@ -25,8 +25,7 @@ class EdgeProcessor:
|
|||
def __init__(
|
||||
self,
|
||||
graph: Graph,
|
||||
edge_state_manager: UnifiedStateManager,
|
||||
node_state_manager: UnifiedStateManager,
|
||||
state_manager: UnifiedStateManager,
|
||||
response_coordinator: ResponseStreamCoordinator,
|
||||
) -> None:
|
||||
"""
|
||||
|
|
@ -34,13 +33,11 @@ class EdgeProcessor:
|
|||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
edge_state_manager: Manager for edge states
|
||||
node_state_manager: Manager for node states
|
||||
state_manager: Unified state manager
|
||||
response_coordinator: Response stream coordinator
|
||||
"""
|
||||
self.graph = graph
|
||||
self.edge_state_manager = edge_state_manager
|
||||
self.node_state_manager = node_state_manager
|
||||
self.state_manager = state_manager
|
||||
self.response_coordinator = response_coordinator
|
||||
|
||||
def process_node_success(
|
||||
|
|
@ -107,7 +104,7 @@ class EdgeProcessor:
|
|||
all_streaming_events: list[NodeRunStreamChunkEvent] = []
|
||||
|
||||
# Categorize edges
|
||||
selected_edges, unselected_edges = self.edge_state_manager.categorize_branch_edges(node_id, selected_handle)
|
||||
selected_edges, unselected_edges = self.state_manager.categorize_branch_edges(node_id, selected_handle)
|
||||
|
||||
# Process unselected edges first (mark as skipped)
|
||||
for edge in unselected_edges:
|
||||
|
|
@ -132,14 +129,14 @@ class EdgeProcessor:
|
|||
Tuple of (list containing downstream node ID if it's ready, list of streaming events)
|
||||
"""
|
||||
# Mark edge as taken
|
||||
self.edge_state_manager.mark_edge_taken(edge.id)
|
||||
self.state_manager.mark_edge_taken(edge.id)
|
||||
|
||||
# Notify response coordinator and get streaming events
|
||||
streaming_events = self.response_coordinator.on_edge_taken(edge.id)
|
||||
|
||||
# Check if downstream node is ready
|
||||
ready_nodes: list[str] = []
|
||||
if self.node_state_manager.is_node_ready(edge.head):
|
||||
if self.state_manager.is_node_ready(edge.head):
|
||||
ready_nodes.append(edge.head)
|
||||
|
||||
return ready_nodes, streaming_events
|
||||
|
|
@ -151,4 +148,4 @@ class EdgeProcessor:
|
|||
Args:
|
||||
edge: The edge to skip
|
||||
"""
|
||||
self.edge_state_manager.mark_edge_skipped(edge.id)
|
||||
self.state_manager.mark_edge_skipped(edge.id)
|
||||
|
|
|
|||
|
|
@ -22,20 +22,17 @@ class SkipPropagator:
|
|||
def __init__(
|
||||
self,
|
||||
graph: Graph,
|
||||
edge_state_manager: UnifiedStateManager,
|
||||
node_state_manager: UnifiedStateManager,
|
||||
state_manager: UnifiedStateManager,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the skip propagator.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
edge_state_manager: Manager for edge states
|
||||
node_state_manager: Manager for node states
|
||||
state_manager: Unified state manager
|
||||
"""
|
||||
self.graph = graph
|
||||
self.edge_state_manager = edge_state_manager
|
||||
self.node_state_manager = node_state_manager
|
||||
self.state_manager = state_manager
|
||||
|
||||
def propagate_skip_from_edge(self, edge_id: str) -> None:
|
||||
"""
|
||||
|
|
@ -53,7 +50,7 @@ class SkipPropagator:
|
|||
incoming_edges = self.graph.get_incoming_edges(downstream_node_id)
|
||||
|
||||
# Analyze edge states
|
||||
edge_states = self.edge_state_manager.analyze_edge_states(incoming_edges)
|
||||
edge_states = self.state_manager.analyze_edge_states(incoming_edges)
|
||||
|
||||
# Stop if there are unknown edges (not yet processed)
|
||||
if edge_states["has_unknown"]:
|
||||
|
|
@ -62,7 +59,7 @@ class SkipPropagator:
|
|||
# If any edge is taken, node may still execute
|
||||
if edge_states["has_taken"]:
|
||||
# Enqueue node
|
||||
self.node_state_manager.enqueue_node(downstream_node_id)
|
||||
self.state_manager.enqueue_node(downstream_node_id)
|
||||
return
|
||||
|
||||
# All edges are skipped, propagate skip to this node
|
||||
|
|
@ -77,12 +74,12 @@ class SkipPropagator:
|
|||
node_id: The ID of the node to skip
|
||||
"""
|
||||
# Mark node as skipped
|
||||
self.node_state_manager.mark_node_skipped(node_id)
|
||||
self.state_manager.mark_node_skipped(node_id)
|
||||
|
||||
# Mark all outgoing edges as skipped and propagate
|
||||
outgoing_edges = self.graph.get_outgoing_edges(node_id)
|
||||
for edge in outgoing_edges:
|
||||
self.edge_state_manager.mark_edge_skipped(edge.id)
|
||||
self.state_manager.mark_edge_skipped(edge.id)
|
||||
# Recursively propagate skip
|
||||
self.propagate_skip_from_edge(edge.id)
|
||||
|
||||
|
|
@ -94,5 +91,5 @@ class SkipPropagator:
|
|||
unselected_edges: List of edges not taken by the branch
|
||||
"""
|
||||
for edge in unselected_edges:
|
||||
self.edge_state_manager.mark_edge_skipped(edge.id)
|
||||
self.state_manager.mark_edge_skipped(edge.id)
|
||||
self.propagate_skip_from_edge(edge.id)
|
||||
|
|
|
|||
|
|
@ -26,8 +26,7 @@ class ExecutionCoordinator:
|
|||
def __init__(
|
||||
self,
|
||||
graph_execution: GraphExecution,
|
||||
node_state_manager: UnifiedStateManager,
|
||||
execution_tracker: UnifiedStateManager,
|
||||
state_manager: UnifiedStateManager,
|
||||
event_handler: "EventHandlerRegistry",
|
||||
event_collector: EventCollector,
|
||||
command_processor: CommandProcessor,
|
||||
|
|
@ -38,16 +37,14 @@ class ExecutionCoordinator:
|
|||
|
||||
Args:
|
||||
graph_execution: Graph execution aggregate
|
||||
node_state_manager: Manager for node states
|
||||
execution_tracker: Tracker for executing nodes
|
||||
state_manager: Unified state manager
|
||||
event_handler: Event handler registry for processing events
|
||||
event_collector: Event collector for collecting events
|
||||
command_processor: Processor for commands
|
||||
worker_pool: Pool of workers
|
||||
"""
|
||||
self.graph_execution = graph_execution
|
||||
self.node_state_manager = node_state_manager
|
||||
self.execution_tracker = execution_tracker
|
||||
self.state_manager = state_manager
|
||||
self.event_handler = event_handler
|
||||
self.event_collector = event_collector
|
||||
self.command_processor = command_processor
|
||||
|
|
@ -59,8 +56,8 @@ class ExecutionCoordinator:
|
|||
|
||||
def check_scaling(self) -> None:
|
||||
"""Check and perform worker scaling if needed."""
|
||||
queue_depth = self.node_state_manager.ready_queue.qsize()
|
||||
executing_count = self.execution_tracker.count()
|
||||
queue_depth = self.state_manager.ready_queue.qsize()
|
||||
executing_count = self.state_manager.get_executing_count()
|
||||
self.worker_pool.check_scaling(queue_depth, executing_count)
|
||||
|
||||
def is_execution_complete(self) -> bool:
|
||||
|
|
@ -75,7 +72,7 @@ class ExecutionCoordinator:
|
|||
return True
|
||||
|
||||
# Complete if no work remains
|
||||
return self.node_state_manager.ready_queue.empty() and self.execution_tracker.is_empty()
|
||||
return self.state_manager.is_execution_complete()
|
||||
|
||||
def mark_complete(self) -> None:
|
||||
"""Mark execution as complete."""
|
||||
|
|
|
|||
|
|
@ -302,42 +302,3 @@ class UnifiedStateManager:
|
|||
"skipped_nodes": skipped_nodes,
|
||||
"unknown_nodes": unknown_nodes,
|
||||
}
|
||||
|
||||
# ============= Backward Compatibility Methods =============
|
||||
# These methods provide compatibility with existing code
|
||||
|
||||
@property
|
||||
def execution_tracker(self) -> "UnifiedStateManager":
|
||||
"""Compatibility property for ExecutionTracker access."""
|
||||
return self
|
||||
|
||||
@property
|
||||
def node_state_manager(self) -> "UnifiedStateManager":
|
||||
"""Compatibility property for NodeStateManager access."""
|
||||
return self
|
||||
|
||||
@property
|
||||
def edge_state_manager(self) -> "UnifiedStateManager":
|
||||
"""Compatibility property for EdgeStateManager access."""
|
||||
return self
|
||||
|
||||
# ExecutionTracker compatibility methods
|
||||
def add(self, node_id: str) -> None:
|
||||
"""Compatibility method for ExecutionTracker.add()."""
|
||||
self.start_execution(node_id)
|
||||
|
||||
def remove(self, node_id: str) -> None:
|
||||
"""Compatibility method for ExecutionTracker.remove()."""
|
||||
self.finish_execution(node_id)
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""Compatibility method for ExecutionTracker.is_empty()."""
|
||||
return len(self._executing_nodes) == 0
|
||||
|
||||
def count(self) -> int:
|
||||
"""Compatibility method for ExecutionTracker.count()."""
|
||||
return self.get_executing_count()
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Compatibility method for ExecutionTracker.clear()."""
|
||||
self.clear_executing()
|
||||
|
|
|
|||
|
|
@ -330,31 +330,3 @@ class EnhancedWorkerPool:
|
|||
"min_workers": self.min_workers,
|
||||
"max_workers": self.max_workers,
|
||||
}
|
||||
|
||||
# ============= Backward Compatibility =============
|
||||
|
||||
def scale_up(self) -> None:
|
||||
"""Compatibility method for manual scale up."""
|
||||
with self._lock:
|
||||
if self._running and len(self.workers) < self.max_workers:
|
||||
self._add_worker()
|
||||
|
||||
def scale_down(self, worker_ids: list[int]) -> None:
|
||||
"""Compatibility method for manual scale down."""
|
||||
with self._lock:
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
for worker_id in worker_ids:
|
||||
if len(self.workers) > self.min_workers:
|
||||
self._remove_worker(worker_id)
|
||||
|
||||
def check_scaling(self, queue_depth: int, executing_count: int) -> None:
|
||||
"""
|
||||
Compatibility method for checking scaling.
|
||||
|
||||
Args:
|
||||
queue_depth: Current queue depth (ignored, we check directly)
|
||||
executing_count: Number of executing nodes (ignored)
|
||||
"""
|
||||
self.check_and_scale()
|
||||
|
|
|
|||
Loading…
Reference in New Issue