mirror of https://github.com/langgenius/dify.git
refactor(graph_engine): Move setup methods into `__init__`
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
0fdb1b2bc9
commit
88622f70fb
|
|
@ -1,187 +0,0 @@
|
|||
# Graph Engine
|
||||
|
||||
Queue-based workflow execution engine for parallel graph processing.
|
||||
|
||||
## Architecture
|
||||
|
||||
The engine uses a modular architecture with specialized packages:
|
||||
|
||||
### Core Components
|
||||
|
||||
- **Domain** (`domain/`) - Core models: ExecutionContext, GraphExecution, NodeExecution
|
||||
- **Event Management** (`event_management/`) - Event handling, collection, and emission
|
||||
- **State Management** (`state_management/`) - Thread-safe state tracking for nodes and edges
|
||||
- **Error Handling** (`error_handling/`) - Strategy-based error recovery (retry, abort, fail-branch, default-value)
|
||||
- **Graph Traversal** (`graph_traversal/`) - Node readiness, edge processing, branch handling
|
||||
- **Command Processing** (`command_processing/`) - External command handling (abort, pause, resume)
|
||||
- **Worker Management** (`worker_management/`) - Dynamic worker pool with auto-scaling
|
||||
- **Orchestration** (`orchestration/`) - Main event loop and execution coordination
|
||||
|
||||
### Supporting Components
|
||||
|
||||
- **Output Registry** (`output_registry/`) - Thread-safe storage for node outputs
|
||||
- **Response Coordinator** (`response_coordinator/`) - Ordered streaming of response nodes
|
||||
- **Command Channels** (`command_channels/`) - Command transport (InMemory/Redis)
|
||||
- **Layers** (`layers/`) - Pluggable middleware for extensions
|
||||
|
||||
## Architecture Diagram
|
||||
|
||||
```mermaid
|
||||
classDiagram
|
||||
class GraphEngine {
|
||||
+run()
|
||||
+add_layer()
|
||||
}
|
||||
|
||||
class Domain {
|
||||
ExecutionContext
|
||||
GraphExecution
|
||||
NodeExecution
|
||||
}
|
||||
|
||||
class EventManagement {
|
||||
EventHandlerRegistry
|
||||
EventCollector
|
||||
EventEmitter
|
||||
}
|
||||
|
||||
class StateManagement {
|
||||
NodeStateManager
|
||||
EdgeStateManager
|
||||
ExecutionTracker
|
||||
}
|
||||
|
||||
class WorkerManagement {
|
||||
WorkerPool
|
||||
WorkerFactory
|
||||
DynamicScaler
|
||||
ActivityTracker
|
||||
}
|
||||
|
||||
class GraphTraversal {
|
||||
NodeReadinessChecker
|
||||
EdgeProcessor
|
||||
BranchHandler
|
||||
SkipPropagator
|
||||
}
|
||||
|
||||
class Orchestration {
|
||||
Dispatcher
|
||||
ExecutionCoordinator
|
||||
}
|
||||
|
||||
class ErrorHandling {
|
||||
ErrorHandler
|
||||
RetryStrategy
|
||||
AbortStrategy
|
||||
FailBranchStrategy
|
||||
}
|
||||
|
||||
class CommandProcessing {
|
||||
CommandProcessor
|
||||
AbortCommandHandler
|
||||
}
|
||||
|
||||
class CommandChannels {
|
||||
InMemoryChannel
|
||||
RedisChannel
|
||||
}
|
||||
|
||||
class OutputRegistry {
|
||||
<<Storage>>
|
||||
Scalar Values
|
||||
Streaming Data
|
||||
}
|
||||
|
||||
class ResponseCoordinator {
|
||||
Session Management
|
||||
Path Analysis
|
||||
}
|
||||
|
||||
class Layers {
|
||||
<<Plugin>>
|
||||
DebugLoggingLayer
|
||||
}
|
||||
|
||||
GraphEngine --> Orchestration : coordinates
|
||||
GraphEngine --> Layers : extends
|
||||
|
||||
Orchestration --> EventManagement : processes events
|
||||
Orchestration --> WorkerManagement : manages scaling
|
||||
Orchestration --> CommandProcessing : checks commands
|
||||
Orchestration --> StateManagement : monitors state
|
||||
|
||||
WorkerManagement --> StateManagement : consumes ready queue
|
||||
WorkerManagement --> EventManagement : produces events
|
||||
WorkerManagement --> Domain : executes nodes
|
||||
|
||||
EventManagement --> ErrorHandling : failed events
|
||||
EventManagement --> GraphTraversal : success events
|
||||
EventManagement --> ResponseCoordinator : stream events
|
||||
EventManagement --> Layers : notifies
|
||||
|
||||
GraphTraversal --> StateManagement : updates states
|
||||
GraphTraversal --> Domain : checks graph
|
||||
|
||||
CommandProcessing --> CommandChannels : fetches commands
|
||||
CommandProcessing --> Domain : modifies execution
|
||||
|
||||
ErrorHandling --> Domain : handles failures
|
||||
|
||||
StateManagement --> Domain : tracks entities
|
||||
|
||||
ResponseCoordinator --> OutputRegistry : reads outputs
|
||||
|
||||
Domain --> OutputRegistry : writes outputs
|
||||
```
|
||||
|
||||
## Package Relationships
|
||||
|
||||
### Core Dependencies
|
||||
|
||||
- **Orchestration** acts as the central coordinator, managing all subsystems
|
||||
- **Domain** provides the core business entities used by all packages
|
||||
- **EventManagement** serves as the communication backbone between components
|
||||
- **StateManagement** maintains thread-safe state for the entire system
|
||||
|
||||
### Data Flow
|
||||
|
||||
1. **Commands** flow from CommandChannels → CommandProcessing → Domain
|
||||
1. **Events** flow from Workers → EventHandlerRegistry → State updates
|
||||
1. **Node outputs** flow from Workers → OutputRegistry → ResponseCoordinator
|
||||
1. **Ready nodes** flow from GraphTraversal → StateManagement → WorkerManagement
|
||||
|
||||
### Extension Points
|
||||
|
||||
- **Layers** observe all events for monitoring, logging, and custom logic
|
||||
- **ErrorHandling** strategies can be extended for custom failure recovery
|
||||
- **CommandChannels** can be implemented for different transport mechanisms
|
||||
|
||||
## Execution Flow
|
||||
|
||||
1. **Initialization**: GraphEngine creates all subsystems with the workflow graph
|
||||
1. **Node Discovery**: Traversal components identify ready nodes
|
||||
1. **Worker Execution**: Workers pull from ready queue and execute nodes
|
||||
1. **Event Processing**: Dispatcher routes events to appropriate handlers
|
||||
1. **State Updates**: Managers track node/edge states for next steps
|
||||
1. **Completion**: Coordinator detects when all nodes are done
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
|
||||
# Create and run engine
|
||||
engine = GraphEngine(
|
||||
tenant_id="tenant_1",
|
||||
app_id="app_1",
|
||||
workflow_id="workflow_1",
|
||||
graph=graph,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
# Stream execution events
|
||||
for event in engine.run():
|
||||
handle_event(event)
|
||||
```
|
||||
|
|
@ -72,9 +72,10 @@ class GraphEngine:
|
|||
scale_up_threshold: int | None = None,
|
||||
scale_down_idle_time: float | None = None,
|
||||
) -> None:
|
||||
"""Initialize the graph engine with separated concerns."""
|
||||
"""Initialize the graph engine with all subsystems and dependencies."""
|
||||
|
||||
# Create domain models
|
||||
# === Domain Models ===
|
||||
# Execution context encapsulates workflow execution metadata
|
||||
self._execution_context = ExecutionContext(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
|
|
@ -87,62 +88,67 @@ class GraphEngine:
|
|||
max_execution_time=max_execution_time,
|
||||
)
|
||||
|
||||
# Graph execution tracks the overall execution state
|
||||
self._graph_execution = GraphExecution(workflow_id=workflow_id)
|
||||
|
||||
# Store core dependencies
|
||||
# === Core Dependencies ===
|
||||
# Graph structure and configuration
|
||||
self._graph = graph
|
||||
self._graph_config = graph_config
|
||||
self._graph_runtime_state = graph_runtime_state
|
||||
self._command_channel = command_channel
|
||||
|
||||
# Store worker management parameters
|
||||
# === Worker Management Parameters ===
|
||||
# Parameters for dynamic worker pool scaling
|
||||
self._min_workers = min_workers
|
||||
self._max_workers = max_workers
|
||||
self._scale_up_threshold = scale_up_threshold
|
||||
self._scale_down_idle_time = scale_down_idle_time
|
||||
|
||||
# Initialize queues
|
||||
# === Execution Queues ===
|
||||
# Queue for nodes ready to execute
|
||||
self._ready_queue: queue.Queue[str] = queue.Queue()
|
||||
# Queue for events generated during execution
|
||||
self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()
|
||||
|
||||
# Initialize subsystems
|
||||
self._initialize_subsystems()
|
||||
|
||||
# Layers for extensibility
|
||||
self._layers: list[Layer] = []
|
||||
|
||||
# Validate graph state consistency
|
||||
self._validate_graph_state_consistency()
|
||||
|
||||
def _initialize_subsystems(self) -> None:
|
||||
"""Initialize all subsystems with proper dependency injection."""
|
||||
|
||||
# Unified state management - single instance handles all state operations
|
||||
# === State Management ===
|
||||
# Unified state manager handles all node state transitions and queue operations
|
||||
self._state_manager = UnifiedStateManager(self._graph, self._ready_queue)
|
||||
|
||||
# Response coordination
|
||||
# === Response Coordination ===
|
||||
# Coordinates response streaming from response nodes
|
||||
self._response_coordinator = ResponseStreamCoordinator(
|
||||
variable_pool=self._graph_runtime_state.variable_pool, graph=self._graph
|
||||
)
|
||||
|
||||
# Event management
|
||||
# === Event Management ===
|
||||
# Event collector aggregates events from all subsystems
|
||||
self._event_collector = EventCollector()
|
||||
# Event emitter streams collected events to consumers
|
||||
self._event_emitter = EventEmitter(self._event_collector)
|
||||
|
||||
# Error handling
|
||||
# === Error Handling ===
|
||||
# Centralized error handler for graph execution errors
|
||||
self._error_handler = ErrorHandler(self._graph, self._graph_execution)
|
||||
|
||||
# Graph traversal
|
||||
# === Graph Traversal Components ===
|
||||
# Checks if nodes are ready to execute based on their dependencies
|
||||
self._node_readiness_checker = NodeReadinessChecker(self._graph)
|
||||
|
||||
# Processes edges to determine next nodes after execution
|
||||
self._edge_processor = EdgeProcessor(
|
||||
graph=self._graph,
|
||||
state_manager=self._state_manager,
|
||||
response_coordinator=self._response_coordinator,
|
||||
)
|
||||
|
||||
# Propagates skip status through the graph when conditions aren't met
|
||||
self._skip_propagator = SkipPropagator(
|
||||
graph=self._graph,
|
||||
state_manager=self._state_manager,
|
||||
)
|
||||
|
||||
# Handles conditional branching and route selection
|
||||
self._branch_handler = BranchHandler(
|
||||
graph=self._graph,
|
||||
edge_processor=self._edge_processor,
|
||||
|
|
@ -150,7 +156,8 @@ class GraphEngine:
|
|||
state_manager=self._state_manager,
|
||||
)
|
||||
|
||||
# Event handler registry with all dependencies
|
||||
# === Event Handler Registry ===
|
||||
# Central registry for handling all node execution events
|
||||
self._event_handler_registry = EventHandlerRegistry(
|
||||
graph=self._graph,
|
||||
graph_runtime_state=self._graph_runtime_state,
|
||||
|
|
@ -163,47 +170,22 @@ class GraphEngine:
|
|||
error_handler=self._error_handler,
|
||||
)
|
||||
|
||||
# Command processing
|
||||
# === Command Processing ===
|
||||
# Processes external commands (e.g., abort requests)
|
||||
self._command_processor = CommandProcessor(
|
||||
command_channel=self._command_channel,
|
||||
graph_execution=self._graph_execution,
|
||||
)
|
||||
self._setup_command_handlers()
|
||||
|
||||
# Worker management
|
||||
self._setup_worker_management()
|
||||
|
||||
# Orchestration
|
||||
self._execution_coordinator = ExecutionCoordinator(
|
||||
graph_execution=self._graph_execution,
|
||||
state_manager=self._state_manager,
|
||||
event_handler=self._event_handler_registry,
|
||||
event_collector=self._event_collector,
|
||||
command_processor=self._command_processor,
|
||||
worker_pool=self._worker_pool,
|
||||
)
|
||||
|
||||
self._dispatcher = Dispatcher(
|
||||
event_queue=self._event_queue,
|
||||
event_handler=self._event_handler_registry,
|
||||
event_collector=self._event_collector,
|
||||
execution_coordinator=self._execution_coordinator,
|
||||
max_execution_time=self._execution_context.max_execution_time,
|
||||
event_emitter=self._event_emitter,
|
||||
)
|
||||
|
||||
def _setup_command_handlers(self) -> None:
|
||||
"""Configure command handlers."""
|
||||
# Create handler instance that follows the protocol
|
||||
# Register abort command handler
|
||||
abort_handler = AbortCommandHandler()
|
||||
self._command_processor.register_handler(
|
||||
AbortCommand,
|
||||
abort_handler,
|
||||
)
|
||||
|
||||
def _setup_worker_management(self) -> None:
|
||||
"""Initialize worker management subsystem."""
|
||||
# Capture context for workers
|
||||
# === Worker Pool Setup ===
|
||||
# Capture Flask app context for worker threads
|
||||
flask_app: Flask | None = None
|
||||
try:
|
||||
app = current_app._get_current_object() # type: ignore
|
||||
|
|
@ -212,9 +194,10 @@ class GraphEngine:
|
|||
except RuntimeError:
|
||||
pass
|
||||
|
||||
# Capture context variables for worker threads
|
||||
context_vars = contextvars.copy_context()
|
||||
|
||||
# Create simple worker pool
|
||||
# Create worker pool for parallel node execution
|
||||
self._worker_pool = SimpleWorkerPool(
|
||||
ready_queue=self._ready_queue,
|
||||
event_queue=self._event_queue,
|
||||
|
|
@ -227,6 +210,35 @@ class GraphEngine:
|
|||
scale_down_idle_time=self._scale_down_idle_time,
|
||||
)
|
||||
|
||||
# === Orchestration ===
|
||||
# Coordinates the overall execution lifecycle
|
||||
self._execution_coordinator = ExecutionCoordinator(
|
||||
graph_execution=self._graph_execution,
|
||||
state_manager=self._state_manager,
|
||||
event_handler=self._event_handler_registry,
|
||||
event_collector=self._event_collector,
|
||||
command_processor=self._command_processor,
|
||||
worker_pool=self._worker_pool,
|
||||
)
|
||||
|
||||
# Dispatches events and manages execution flow
|
||||
self._dispatcher = Dispatcher(
|
||||
event_queue=self._event_queue,
|
||||
event_handler=self._event_handler_registry,
|
||||
event_collector=self._event_collector,
|
||||
execution_coordinator=self._execution_coordinator,
|
||||
max_execution_time=self._execution_context.max_execution_time,
|
||||
event_emitter=self._event_emitter,
|
||||
)
|
||||
|
||||
# === Extensibility ===
|
||||
# Layers allow plugins to extend engine functionality
|
||||
self._layers: list[Layer] = []
|
||||
|
||||
# === Validation ===
|
||||
# Ensure all nodes share the same GraphRuntimeState instance
|
||||
self._validate_graph_state_consistency()
|
||||
|
||||
def _validate_graph_state_consistency(self) -> None:
|
||||
"""Validate that all nodes share the same GraphRuntimeState."""
|
||||
expected_state_id = id(self._graph_runtime_state)
|
||||
|
|
@ -337,8 +349,3 @@ class GraphEngine:
|
|||
def graph_runtime_state(self) -> GraphRuntimeState:
|
||||
"""Get the graph runtime state."""
|
||||
return self._graph_runtime_state
|
||||
|
||||
@property
|
||||
def graph(self) -> Graph:
|
||||
"""Get the graph."""
|
||||
return self._graph
|
||||
|
|
|
|||
|
|
@ -100,7 +100,7 @@ def test_streaming_output_with_blocking_equals_one():
|
|||
)
|
||||
|
||||
# Check that NodeRunStreamChunkEvent contains 'query' should has same id with Start NodeRunStartedEvent
|
||||
start_node_id = engine.graph.root_node.id
|
||||
start_node_id = graph.root_node.id
|
||||
start_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_id == start_node_id]
|
||||
assert len(start_events) == 1, f"Expected 1 start event for node {start_node_id}, but got {len(start_events)}"
|
||||
start_event = start_events[0]
|
||||
|
|
@ -210,7 +210,7 @@ def test_streaming_output_with_blocking_not_equals_one():
|
|||
assert isinstance(chunk_event.chunk, str), f"Expected chunk to be string, but got {type(chunk_event.chunk)}"
|
||||
|
||||
# Check that NodeRunStreamChunkEvent contains 'query' should has same id with Start NodeRunStartedEvent
|
||||
start_node_id = engine.graph.root_node.id
|
||||
start_node_id = graph.root_node.id
|
||||
start_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_id == start_node_id]
|
||||
assert len(start_events) == 1, f"Expected 1 start event for node {start_node_id}, but got {len(start_events)}"
|
||||
start_event = start_events[0]
|
||||
|
|
|
|||
Loading…
Reference in New Issue