mirror of https://github.com/langgenius/dify.git
refactor(graph_engine): Merge branch_handler into edge_processor
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
88622f70fb
commit
bb5d52539c
|
|
@ -97,10 +97,12 @@ modules =
|
|||
|
||||
[importlinter:contract:graph-traversal-components]
|
||||
name = Graph Traversal Components
|
||||
type = independence
|
||||
modules =
|
||||
core.workflow.graph_engine.graph_traversal.node_readiness
|
||||
core.workflow.graph_engine.graph_traversal.skip_propagator
|
||||
type = layers
|
||||
layers =
|
||||
edge_processor
|
||||
skip_propagator
|
||||
containers =
|
||||
core.workflow.graph_engine.graph_traversal
|
||||
|
||||
[importlinter:contract:command-channels]
|
||||
name = Command Channels Independence
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ from ..response_coordinator import ResponseStreamCoordinator
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from ..error_handling import ErrorHandler
|
||||
from ..graph_traversal import BranchHandler, EdgeProcessor
|
||||
from ..graph_traversal import EdgeProcessor
|
||||
from ..state_management import UnifiedStateManager
|
||||
from .event_collector import EventCollector
|
||||
|
||||
|
|
@ -54,7 +54,6 @@ class EventHandlerRegistry:
|
|||
graph_execution: GraphExecution,
|
||||
response_coordinator: ResponseStreamCoordinator,
|
||||
event_collector: "EventCollector",
|
||||
branch_handler: "BranchHandler",
|
||||
edge_processor: "EdgeProcessor",
|
||||
state_manager: "UnifiedStateManager",
|
||||
error_handler: "ErrorHandler",
|
||||
|
|
@ -68,7 +67,6 @@ class EventHandlerRegistry:
|
|||
graph_execution: Graph execution aggregate
|
||||
response_coordinator: Response stream coordinator
|
||||
event_collector: Event collector for collecting events
|
||||
branch_handler: Branch handler for branch node processing
|
||||
edge_processor: Edge processor for edge traversal
|
||||
state_manager: Unified state manager
|
||||
error_handler: Error handler
|
||||
|
|
@ -78,7 +76,6 @@ class EventHandlerRegistry:
|
|||
self._graph_execution = graph_execution
|
||||
self._response_coordinator = response_coordinator
|
||||
self._event_collector = event_collector
|
||||
self._branch_handler = branch_handler
|
||||
self._edge_processor = edge_processor
|
||||
self._state_manager = state_manager
|
||||
self._error_handler = error_handler
|
||||
|
|
@ -184,7 +181,7 @@ class EventHandlerRegistry:
|
|||
# Process edges and get ready nodes
|
||||
node = self._graph.nodes[event.node_id]
|
||||
if node.execution_type == NodeExecutionType.BRANCH:
|
||||
ready_nodes, edge_streaming_events = self._branch_handler.handle_branch_completion(
|
||||
ready_nodes, edge_streaming_events = self._edge_processor.handle_branch_completion(
|
||||
event.node_id, event.node_run_result.edge_source_handle
|
||||
)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ from .domain import ExecutionContext, GraphExecution
|
|||
from .entities.commands import AbortCommand
|
||||
from .error_handling import ErrorHandler
|
||||
from .event_management import EventCollector, EventEmitter, EventHandlerRegistry
|
||||
from .graph_traversal import BranchHandler, EdgeProcessor, NodeReadinessChecker, SkipPropagator
|
||||
from .graph_traversal import EdgeProcessor, SkipPropagator
|
||||
from .layers.base import Layer
|
||||
from .orchestration import Dispatcher, ExecutionCoordinator
|
||||
from .protocols.command_channel import CommandChannel
|
||||
|
|
@ -132,28 +132,19 @@ class GraphEngine:
|
|||
self._error_handler = ErrorHandler(self._graph, self._graph_execution)
|
||||
|
||||
# === Graph Traversal Components ===
|
||||
# Checks if nodes are ready to execute based on their dependencies
|
||||
self._node_readiness_checker = NodeReadinessChecker(self._graph)
|
||||
|
||||
# Processes edges to determine next nodes after execution
|
||||
self._edge_processor = EdgeProcessor(
|
||||
graph=self._graph,
|
||||
state_manager=self._state_manager,
|
||||
response_coordinator=self._response_coordinator,
|
||||
)
|
||||
|
||||
# Propagates skip status through the graph when conditions aren't met
|
||||
self._skip_propagator = SkipPropagator(
|
||||
graph=self._graph,
|
||||
state_manager=self._state_manager,
|
||||
)
|
||||
|
||||
# Handles conditional branching and route selection
|
||||
self._branch_handler = BranchHandler(
|
||||
# Processes edges to determine next nodes after execution
|
||||
# Also handles conditional branching and route selection
|
||||
self._edge_processor = EdgeProcessor(
|
||||
graph=self._graph,
|
||||
edge_processor=self._edge_processor,
|
||||
skip_propagator=self._skip_propagator,
|
||||
state_manager=self._state_manager,
|
||||
response_coordinator=self._response_coordinator,
|
||||
skip_propagator=self._skip_propagator,
|
||||
)
|
||||
|
||||
# === Event Handler Registry ===
|
||||
|
|
@ -164,7 +155,6 @@ class GraphEngine:
|
|||
graph_execution=self._graph_execution,
|
||||
response_coordinator=self._response_coordinator,
|
||||
event_collector=self._event_collector,
|
||||
branch_handler=self._branch_handler,
|
||||
edge_processor=self._edge_processor,
|
||||
state_manager=self._state_manager,
|
||||
error_handler=self._error_handler,
|
||||
|
|
|
|||
|
|
@ -5,14 +5,10 @@ This package handles graph navigation, edge processing,
|
|||
and skip propagation logic.
|
||||
"""
|
||||
|
||||
from .branch_handler import BranchHandler
|
||||
from .edge_processor import EdgeProcessor
|
||||
from .node_readiness import NodeReadinessChecker
|
||||
from .skip_propagator import SkipPropagator
|
||||
|
||||
__all__ = [
|
||||
"BranchHandler",
|
||||
"EdgeProcessor",
|
||||
"NodeReadinessChecker",
|
||||
"SkipPropagator",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,87 +0,0 @@
|
|||
"""
|
||||
Branch node handling for graph traversal.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events.node import NodeRunStreamChunkEvent
|
||||
|
||||
from ..state_management import UnifiedStateManager
|
||||
from .edge_processor import EdgeProcessor
|
||||
from .skip_propagator import SkipPropagator
|
||||
|
||||
|
||||
@final
|
||||
class BranchHandler:
|
||||
"""
|
||||
Handles branch node logic during graph traversal.
|
||||
|
||||
Branch nodes select one of multiple paths based on conditions,
|
||||
requiring special handling for edge selection and skip propagation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: Graph,
|
||||
edge_processor: EdgeProcessor,
|
||||
skip_propagator: SkipPropagator,
|
||||
state_manager: UnifiedStateManager,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the branch handler.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
edge_processor: Processor for edges
|
||||
skip_propagator: Propagator for skip states
|
||||
state_manager: Unified state manager
|
||||
"""
|
||||
self._graph = graph
|
||||
self._edge_processor = edge_processor
|
||||
self._skip_propagator = skip_propagator
|
||||
self._state_manager = state_manager
|
||||
|
||||
def handle_branch_completion(
|
||||
self, node_id: str, selected_handle: str | None
|
||||
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Handle completion of a branch node.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the branch node
|
||||
selected_handle: The handle of the selected branch
|
||||
|
||||
Returns:
|
||||
Tuple of (list of downstream nodes ready for execution, list of streaming events)
|
||||
|
||||
Raises:
|
||||
ValueError: If no branch was selected
|
||||
"""
|
||||
if not selected_handle:
|
||||
raise ValueError(f"Branch node {node_id} completed without selecting a branch")
|
||||
|
||||
# Categorize edges into selected and unselected
|
||||
_, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle)
|
||||
|
||||
# Skip all unselected paths
|
||||
self._skip_propagator.skip_branch_paths(unselected_edges)
|
||||
|
||||
# Process selected edges and get ready nodes and streaming events
|
||||
return self._edge_processor.process_node_success(node_id, selected_handle)
|
||||
|
||||
def validate_branch_selection(self, node_id: str, selected_handle: str) -> bool:
|
||||
"""
|
||||
Validate that a branch selection is valid.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the branch node
|
||||
selected_handle: The handle to validate
|
||||
|
||||
Returns:
|
||||
True if the selection is valid
|
||||
"""
|
||||
outgoing_edges = self._graph.get_outgoing_edges(node_id)
|
||||
valid_handles = {edge.source_handle for edge in outgoing_edges}
|
||||
return selected_handle in valid_handles
|
||||
|
|
@ -3,7 +3,7 @@ Edge processing logic for graph traversal.
|
|||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import final
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from core.workflow.enums import NodeExecutionType
|
||||
from core.workflow.graph import Edge, Graph
|
||||
|
|
@ -12,6 +12,9 @@ from core.workflow.graph_events import NodeRunStreamChunkEvent
|
|||
from ..response_coordinator import ResponseStreamCoordinator
|
||||
from ..state_management import UnifiedStateManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .skip_propagator import SkipPropagator
|
||||
|
||||
|
||||
@final
|
||||
class EdgeProcessor:
|
||||
|
|
@ -19,7 +22,8 @@ class EdgeProcessor:
|
|||
Processes edges during graph execution.
|
||||
|
||||
This handles marking edges as taken or skipped, notifying
|
||||
the response coordinator, and triggering downstream node execution.
|
||||
the response coordinator, triggering downstream node execution,
|
||||
and managing branch node logic.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -27,6 +31,7 @@ class EdgeProcessor:
|
|||
graph: Graph,
|
||||
state_manager: UnifiedStateManager,
|
||||
response_coordinator: ResponseStreamCoordinator,
|
||||
skip_propagator: "SkipPropagator",
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the edge processor.
|
||||
|
|
@ -35,10 +40,12 @@ class EdgeProcessor:
|
|||
graph: The workflow graph
|
||||
state_manager: Unified state manager
|
||||
response_coordinator: Response stream coordinator
|
||||
skip_propagator: Propagator for skip states
|
||||
"""
|
||||
self._graph = graph
|
||||
self._state_manager = state_manager
|
||||
self._response_coordinator = response_coordinator
|
||||
self._skip_propagator = skip_propagator
|
||||
|
||||
def process_node_success(
|
||||
self, node_id: str, selected_handle: str | None = None
|
||||
|
|
@ -149,3 +156,46 @@ class EdgeProcessor:
|
|||
edge: The edge to skip
|
||||
"""
|
||||
self._state_manager.mark_edge_skipped(edge.id)
|
||||
|
||||
def handle_branch_completion(
|
||||
self, node_id: str, selected_handle: str | None
|
||||
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Handle completion of a branch node.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the branch node
|
||||
selected_handle: The handle of the selected branch
|
||||
|
||||
Returns:
|
||||
Tuple of (list of downstream nodes ready for execution, list of streaming events)
|
||||
|
||||
Raises:
|
||||
ValueError: If no branch was selected
|
||||
"""
|
||||
if not selected_handle:
|
||||
raise ValueError(f"Branch node {node_id} completed without selecting a branch")
|
||||
|
||||
# Categorize edges into selected and unselected
|
||||
_, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle)
|
||||
|
||||
# Skip all unselected paths
|
||||
self._skip_propagator.skip_branch_paths(unselected_edges)
|
||||
|
||||
# Process selected edges and get ready nodes and streaming events
|
||||
return self.process_node_success(node_id, selected_handle)
|
||||
|
||||
def validate_branch_selection(self, node_id: str, selected_handle: str) -> bool:
|
||||
"""
|
||||
Validate that a branch selection is valid.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the branch node
|
||||
selected_handle: The handle to validate
|
||||
|
||||
Returns:
|
||||
True if the selection is valid
|
||||
"""
|
||||
outgoing_edges = self._graph.get_outgoing_edges(node_id)
|
||||
valid_handles = {edge.source_handle for edge in outgoing_edges}
|
||||
return selected_handle in valid_handles
|
||||
|
|
|
|||
|
|
@ -1,86 +0,0 @@
|
|||
"""
|
||||
Node readiness checking for execution.
|
||||
"""
|
||||
|
||||
from typing import final
|
||||
|
||||
from core.workflow.enums import NodeState
|
||||
from core.workflow.graph import Graph
|
||||
|
||||
|
||||
@final
|
||||
class NodeReadinessChecker:
|
||||
"""
|
||||
Checks if nodes are ready for execution based on their dependencies.
|
||||
|
||||
A node is ready when its dependencies (incoming edges) have been
|
||||
satisfied according to the graph's execution rules.
|
||||
"""
|
||||
|
||||
def __init__(self, graph: Graph) -> None:
|
||||
"""
|
||||
Initialize the readiness checker.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
"""
|
||||
self._graph = graph
|
||||
|
||||
def is_node_ready(self, node_id: str) -> bool:
|
||||
"""
|
||||
Check if a node is ready to be executed.
|
||||
|
||||
A node is ready when:
|
||||
- It has no incoming edges (root or isolated node), OR
|
||||
- At least one incoming edge is TAKEN and none are UNKNOWN
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to check
|
||||
|
||||
Returns:
|
||||
True if the node is ready for execution
|
||||
"""
|
||||
incoming_edges = self._graph.get_incoming_edges(node_id)
|
||||
|
||||
# No dependencies means always ready
|
||||
if not incoming_edges:
|
||||
return True
|
||||
|
||||
# Check edge states
|
||||
has_unknown = False
|
||||
has_taken = False
|
||||
|
||||
for edge in incoming_edges:
|
||||
if edge.state == NodeState.UNKNOWN:
|
||||
has_unknown = True
|
||||
break
|
||||
elif edge.state == NodeState.TAKEN:
|
||||
has_taken = True
|
||||
|
||||
# Not ready if any dependency is still unknown
|
||||
if has_unknown:
|
||||
return False
|
||||
|
||||
# Ready if at least one path is taken
|
||||
return has_taken
|
||||
|
||||
def get_ready_downstream_nodes(self, from_node_id: str) -> list[str]:
|
||||
"""
|
||||
Get all downstream nodes that are ready after a node completes.
|
||||
|
||||
Args:
|
||||
from_node_id: The ID of the completed node
|
||||
|
||||
Returns:
|
||||
List of node IDs that are now ready
|
||||
"""
|
||||
ready_nodes: list[str] = []
|
||||
outgoing_edges = self._graph.get_outgoing_edges(from_node_id)
|
||||
|
||||
for edge in outgoing_edges:
|
||||
if edge.state == NodeState.TAKEN:
|
||||
downstream_node_id = edge.head
|
||||
if self.is_node_ready(downstream_node_id):
|
||||
ready_nodes.append(downstream_node_id)
|
||||
|
||||
return ready_nodes
|
||||
|
|
@ -1,81 +0,0 @@
|
|||
# Worker Management
|
||||
|
||||
Dynamic worker pool for node execution.
|
||||
|
||||
## Components
|
||||
|
||||
### WorkerPool
|
||||
|
||||
Manages worker thread lifecycle.
|
||||
|
||||
- `start/stop/wait()` - Control workers
|
||||
- `scale_up/down()` - Adjust pool size
|
||||
- `get_worker_count()` - Current count
|
||||
|
||||
### WorkerFactory
|
||||
|
||||
Creates workers with Flask context.
|
||||
|
||||
- `create_worker()` - Build with dependencies
|
||||
- Preserves request context
|
||||
|
||||
### DynamicScaler
|
||||
|
||||
Determines scaling decisions.
|
||||
|
||||
- `min/max_workers` - Pool bounds
|
||||
- `scale_up_threshold` - Queue trigger
|
||||
- `should_scale_up/down()` - Check conditions
|
||||
|
||||
### ActivityTracker
|
||||
|
||||
Tracks worker activity.
|
||||
|
||||
- `track_activity(worker_id)` - Record activity
|
||||
- `get_idle_workers(threshold)` - Find idle
|
||||
- `get_active_count()` - Active count
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
scaler = DynamicScaler(
|
||||
min_workers=2,
|
||||
max_workers=10,
|
||||
scale_up_threshold=5
|
||||
)
|
||||
|
||||
pool = WorkerPool(
|
||||
ready_queue=ready_queue,
|
||||
worker_factory=factory,
|
||||
dynamic_scaler=scaler
|
||||
)
|
||||
|
||||
pool.start()
|
||||
|
||||
# Scale based on load
|
||||
if scaler.should_scale_up(queue_size, active):
|
||||
pool.scale_up()
|
||||
|
||||
pool.stop()
|
||||
```
|
||||
|
||||
## Scaling Strategy
|
||||
|
||||
**Scale Up**: Queue size > threshold AND workers < max
|
||||
**Scale Down**: Idle workers exist AND workers > min
|
||||
|
||||
## Parameters
|
||||
|
||||
- `min_workers` - Minimum pool size
|
||||
- `max_workers` - Maximum pool size
|
||||
- `scale_up_threshold` - Queue trigger
|
||||
- `scale_down_threshold` - Idle seconds
|
||||
|
||||
## Flask Context
|
||||
|
||||
WorkerFactory preserves request context across threads:
|
||||
|
||||
```python
|
||||
context_vars = {"request_id": request.id}
|
||||
# Workers receive same context
|
||||
```
|
||||
Loading…
Reference in New Issue