mirror of https://github.com/langgenius/dify.git
refactor(graph_engine): Correct private attributes and private methods naming
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
a5cb9d2b73
commit
0fdb1b2bc9
|
|
@ -39,8 +39,8 @@ class CommandProcessor:
|
|||
command_channel: Channel for receiving commands
|
||||
graph_execution: Graph execution aggregate
|
||||
"""
|
||||
self.command_channel = command_channel
|
||||
self.graph_execution = graph_execution
|
||||
self._command_channel = command_channel
|
||||
self._graph_execution = graph_execution
|
||||
self._handlers: dict[type[GraphEngineCommand], CommandHandler] = {}
|
||||
|
||||
def register_handler(self, command_type: type[GraphEngineCommand], handler: CommandHandler) -> None:
|
||||
|
|
@ -56,7 +56,7 @@ class CommandProcessor:
|
|||
def process_commands(self) -> None:
|
||||
"""Check for and process any pending commands."""
|
||||
try:
|
||||
commands = self.command_channel.fetch_commands()
|
||||
commands = self._command_channel.fetch_commands()
|
||||
for command in commands:
|
||||
self._handle_command(command)
|
||||
except Exception as e:
|
||||
|
|
@ -72,7 +72,7 @@ class CommandProcessor:
|
|||
handler = self._handlers.get(type(command))
|
||||
if handler:
|
||||
try:
|
||||
handler.handle(command, self.graph_execution)
|
||||
handler.handle(command, self._graph_execution)
|
||||
except Exception:
|
||||
logger.exception("Error handling command %s", command.__class__.__name__)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -32,6 +32,8 @@ class AbortStrategy:
|
|||
Returns:
|
||||
None - signals abortion
|
||||
"""
|
||||
_ = graph
|
||||
_ = retry_count
|
||||
logger.error("Node %s failed with ABORT strategy: %s", event.node_id, event.error)
|
||||
|
||||
# Return None to signal that execution should stop
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ class DefaultValueStrategy:
|
|||
Returns:
|
||||
NodeRunExceptionEvent with default values
|
||||
"""
|
||||
_ = retry_count
|
||||
node = graph.nodes[event.node_id]
|
||||
|
||||
outputs = {
|
||||
|
|
|
|||
|
|
@ -31,6 +31,8 @@ class FailBranchStrategy:
|
|||
Returns:
|
||||
NodeRunExceptionEvent to continue via fail branch
|
||||
"""
|
||||
_ = graph
|
||||
_ = retry_count
|
||||
outputs = {
|
||||
"error_message": event.node_run_result.error,
|
||||
"error_type": event.node_run_result.error_type,
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ class ReadWriteLock:
|
|||
|
||||
def acquire_read(self) -> None:
|
||||
"""Acquire a read lock."""
|
||||
self._read_ready.acquire()
|
||||
_ = self._read_ready.acquire()
|
||||
try:
|
||||
self._readers += 1
|
||||
finally:
|
||||
|
|
@ -31,7 +31,7 @@ class ReadWriteLock:
|
|||
|
||||
def release_read(self) -> None:
|
||||
"""Release a read lock."""
|
||||
self._read_ready.acquire()
|
||||
_ = self._read_ready.acquire()
|
||||
try:
|
||||
self._readers -= 1
|
||||
if self._readers == 0:
|
||||
|
|
@ -41,9 +41,9 @@ class ReadWriteLock:
|
|||
|
||||
def acquire_write(self) -> None:
|
||||
"""Acquire a write lock."""
|
||||
self._read_ready.acquire()
|
||||
_ = self._read_ready.acquire()
|
||||
while self._readers > 0:
|
||||
self._read_ready.wait()
|
||||
_ = self._read_ready.wait()
|
||||
|
||||
def release_write(self) -> None:
|
||||
"""Release a write lock."""
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ class EventEmitter:
|
|||
Args:
|
||||
event_collector: The collector to emit events from
|
||||
"""
|
||||
self.event_collector = event_collector
|
||||
self._event_collector = event_collector
|
||||
self._execution_complete = threading.Event()
|
||||
|
||||
def mark_complete(self) -> None:
|
||||
|
|
@ -44,9 +44,9 @@ class EventEmitter:
|
|||
"""
|
||||
yielded_count = 0
|
||||
|
||||
while not self._execution_complete.is_set() or yielded_count < self.event_collector.event_count():
|
||||
while not self._execution_complete.is_set() or yielded_count < self._event_collector.event_count():
|
||||
# Get new events since last yield
|
||||
new_events = self.event_collector.get_new_events(yielded_count)
|
||||
new_events = self._event_collector.get_new_events(yielded_count)
|
||||
|
||||
# Yield any new events
|
||||
for event in new_events:
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ class GraphEngine:
|
|||
"""Initialize the graph engine with separated concerns."""
|
||||
|
||||
# Create domain models
|
||||
self.execution_context = ExecutionContext(
|
||||
self._execution_context = ExecutionContext(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_id=workflow_id,
|
||||
|
|
@ -87,13 +87,13 @@ class GraphEngine:
|
|||
max_execution_time=max_execution_time,
|
||||
)
|
||||
|
||||
self.graph_execution = GraphExecution(workflow_id=workflow_id)
|
||||
self._graph_execution = GraphExecution(workflow_id=workflow_id)
|
||||
|
||||
# Store core dependencies
|
||||
self.graph = graph
|
||||
self.graph_config = graph_config
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self.command_channel = command_channel
|
||||
self._graph = graph
|
||||
self._graph_config = graph_config
|
||||
self._graph_runtime_state = graph_runtime_state
|
||||
self._command_channel = command_channel
|
||||
|
||||
# Store worker management parameters
|
||||
self._min_workers = min_workers
|
||||
|
|
@ -102,8 +102,8 @@ class GraphEngine:
|
|||
self._scale_down_idle_time = scale_down_idle_time
|
||||
|
||||
# Initialize queues
|
||||
self.ready_queue: queue.Queue[str] = queue.Queue()
|
||||
self.event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()
|
||||
self._ready_queue: queue.Queue[str] = queue.Queue()
|
||||
self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()
|
||||
|
||||
# Initialize subsystems
|
||||
self._initialize_subsystems()
|
||||
|
|
@ -118,55 +118,55 @@ class GraphEngine:
|
|||
"""Initialize all subsystems with proper dependency injection."""
|
||||
|
||||
# Unified state management - single instance handles all state operations
|
||||
self.state_manager = UnifiedStateManager(self.graph, self.ready_queue)
|
||||
self._state_manager = UnifiedStateManager(self._graph, self._ready_queue)
|
||||
|
||||
# Response coordination
|
||||
self.response_coordinator = ResponseStreamCoordinator(
|
||||
variable_pool=self.graph_runtime_state.variable_pool, graph=self.graph
|
||||
self._response_coordinator = ResponseStreamCoordinator(
|
||||
variable_pool=self._graph_runtime_state.variable_pool, graph=self._graph
|
||||
)
|
||||
|
||||
# Event management
|
||||
self.event_collector = EventCollector()
|
||||
self.event_emitter = EventEmitter(self.event_collector)
|
||||
self._event_collector = EventCollector()
|
||||
self._event_emitter = EventEmitter(self._event_collector)
|
||||
|
||||
# Error handling
|
||||
self.error_handler = ErrorHandler(self.graph, self.graph_execution)
|
||||
self._error_handler = ErrorHandler(self._graph, self._graph_execution)
|
||||
|
||||
# Graph traversal
|
||||
self.node_readiness_checker = NodeReadinessChecker(self.graph)
|
||||
self.edge_processor = EdgeProcessor(
|
||||
graph=self.graph,
|
||||
state_manager=self.state_manager,
|
||||
response_coordinator=self.response_coordinator,
|
||||
self._node_readiness_checker = NodeReadinessChecker(self._graph)
|
||||
self._edge_processor = EdgeProcessor(
|
||||
graph=self._graph,
|
||||
state_manager=self._state_manager,
|
||||
response_coordinator=self._response_coordinator,
|
||||
)
|
||||
self.skip_propagator = SkipPropagator(
|
||||
graph=self.graph,
|
||||
state_manager=self.state_manager,
|
||||
self._skip_propagator = SkipPropagator(
|
||||
graph=self._graph,
|
||||
state_manager=self._state_manager,
|
||||
)
|
||||
self.branch_handler = BranchHandler(
|
||||
graph=self.graph,
|
||||
edge_processor=self.edge_processor,
|
||||
skip_propagator=self.skip_propagator,
|
||||
state_manager=self.state_manager,
|
||||
self._branch_handler = BranchHandler(
|
||||
graph=self._graph,
|
||||
edge_processor=self._edge_processor,
|
||||
skip_propagator=self._skip_propagator,
|
||||
state_manager=self._state_manager,
|
||||
)
|
||||
|
||||
# Event handler registry with all dependencies
|
||||
self.event_handler_registry = EventHandlerRegistry(
|
||||
graph=self.graph,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
graph_execution=self.graph_execution,
|
||||
response_coordinator=self.response_coordinator,
|
||||
event_collector=self.event_collector,
|
||||
branch_handler=self.branch_handler,
|
||||
edge_processor=self.edge_processor,
|
||||
state_manager=self.state_manager,
|
||||
error_handler=self.error_handler,
|
||||
self._event_handler_registry = EventHandlerRegistry(
|
||||
graph=self._graph,
|
||||
graph_runtime_state=self._graph_runtime_state,
|
||||
graph_execution=self._graph_execution,
|
||||
response_coordinator=self._response_coordinator,
|
||||
event_collector=self._event_collector,
|
||||
branch_handler=self._branch_handler,
|
||||
edge_processor=self._edge_processor,
|
||||
state_manager=self._state_manager,
|
||||
error_handler=self._error_handler,
|
||||
)
|
||||
|
||||
# Command processing
|
||||
self.command_processor = CommandProcessor(
|
||||
command_channel=self.command_channel,
|
||||
graph_execution=self.graph_execution,
|
||||
self._command_processor = CommandProcessor(
|
||||
command_channel=self._command_channel,
|
||||
graph_execution=self._graph_execution,
|
||||
)
|
||||
self._setup_command_handlers()
|
||||
|
||||
|
|
@ -174,29 +174,29 @@ class GraphEngine:
|
|||
self._setup_worker_management()
|
||||
|
||||
# Orchestration
|
||||
self.execution_coordinator = ExecutionCoordinator(
|
||||
graph_execution=self.graph_execution,
|
||||
state_manager=self.state_manager,
|
||||
event_handler=self.event_handler_registry,
|
||||
event_collector=self.event_collector,
|
||||
command_processor=self.command_processor,
|
||||
self._execution_coordinator = ExecutionCoordinator(
|
||||
graph_execution=self._graph_execution,
|
||||
state_manager=self._state_manager,
|
||||
event_handler=self._event_handler_registry,
|
||||
event_collector=self._event_collector,
|
||||
command_processor=self._command_processor,
|
||||
worker_pool=self._worker_pool,
|
||||
)
|
||||
|
||||
self.dispatcher = Dispatcher(
|
||||
event_queue=self.event_queue,
|
||||
event_handler=self.event_handler_registry,
|
||||
event_collector=self.event_collector,
|
||||
execution_coordinator=self.execution_coordinator,
|
||||
max_execution_time=self.execution_context.max_execution_time,
|
||||
event_emitter=self.event_emitter,
|
||||
self._dispatcher = Dispatcher(
|
||||
event_queue=self._event_queue,
|
||||
event_handler=self._event_handler_registry,
|
||||
event_collector=self._event_collector,
|
||||
execution_coordinator=self._execution_coordinator,
|
||||
max_execution_time=self._execution_context.max_execution_time,
|
||||
event_emitter=self._event_emitter,
|
||||
)
|
||||
|
||||
def _setup_command_handlers(self) -> None:
|
||||
"""Configure command handlers."""
|
||||
# Create handler instance that follows the protocol
|
||||
abort_handler = AbortCommandHandler()
|
||||
self.command_processor.register_handler(
|
||||
self._command_processor.register_handler(
|
||||
AbortCommand,
|
||||
abort_handler,
|
||||
)
|
||||
|
|
@ -216,9 +216,9 @@ class GraphEngine:
|
|||
|
||||
# Create simple worker pool
|
||||
self._worker_pool = SimpleWorkerPool(
|
||||
ready_queue=self.ready_queue,
|
||||
event_queue=self.event_queue,
|
||||
graph=self.graph,
|
||||
ready_queue=self._ready_queue,
|
||||
event_queue=self._event_queue,
|
||||
graph=self._graph,
|
||||
flask_app=flask_app,
|
||||
context_vars=context_vars,
|
||||
min_workers=self._min_workers,
|
||||
|
|
@ -229,8 +229,8 @@ class GraphEngine:
|
|||
|
||||
def _validate_graph_state_consistency(self) -> None:
|
||||
"""Validate that all nodes share the same GraphRuntimeState."""
|
||||
expected_state_id = id(self.graph_runtime_state)
|
||||
for node in self.graph.nodes.values():
|
||||
expected_state_id = id(self._graph_runtime_state)
|
||||
for node in self._graph.nodes.values():
|
||||
if id(node.graph_runtime_state) != expected_state_id:
|
||||
raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance")
|
||||
|
||||
|
|
@ -251,7 +251,7 @@ class GraphEngine:
|
|||
self._initialize_layers()
|
||||
|
||||
# Start execution
|
||||
self.graph_execution.start()
|
||||
self._graph_execution.start()
|
||||
start_event = GraphRunStartedEvent()
|
||||
yield start_event
|
||||
|
||||
|
|
@ -259,23 +259,23 @@ class GraphEngine:
|
|||
self._start_execution()
|
||||
|
||||
# Yield events as they occur
|
||||
yield from self.event_emitter.emit_events()
|
||||
yield from self._event_emitter.emit_events()
|
||||
|
||||
# Handle completion
|
||||
if self.graph_execution.aborted:
|
||||
if self._graph_execution.aborted:
|
||||
abort_reason = "Workflow execution aborted by user command"
|
||||
if self.graph_execution.error:
|
||||
abort_reason = str(self.graph_execution.error)
|
||||
if self._graph_execution.error:
|
||||
abort_reason = str(self._graph_execution.error)
|
||||
yield GraphRunAbortedEvent(
|
||||
reason=abort_reason,
|
||||
outputs=self.graph_runtime_state.outputs,
|
||||
outputs=self._graph_runtime_state.outputs,
|
||||
)
|
||||
elif self.graph_execution.has_error:
|
||||
if self.graph_execution.error:
|
||||
raise self.graph_execution.error
|
||||
elif self._graph_execution.has_error:
|
||||
if self._graph_execution.error:
|
||||
raise self._graph_execution.error
|
||||
else:
|
||||
yield GraphRunSucceededEvent(
|
||||
outputs=self.graph_runtime_state.outputs,
|
||||
outputs=self._graph_runtime_state.outputs,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -287,10 +287,10 @@ class GraphEngine:
|
|||
|
||||
def _initialize_layers(self) -> None:
|
||||
"""Initialize layers with context."""
|
||||
self.event_collector.set_layers(self._layers)
|
||||
self._event_collector.set_layers(self._layers)
|
||||
for layer in self._layers:
|
||||
try:
|
||||
layer.initialize(self.graph_runtime_state, self.command_channel)
|
||||
layer.initialize(self._graph_runtime_state, self._command_channel)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to initialize layer %s: %s", layer.__class__.__name__, e)
|
||||
|
||||
|
|
@ -305,21 +305,21 @@ class GraphEngine:
|
|||
self._worker_pool.start()
|
||||
|
||||
# Register response nodes
|
||||
for node in self.graph.nodes.values():
|
||||
for node in self._graph.nodes.values():
|
||||
if node.execution_type == NodeExecutionType.RESPONSE:
|
||||
self.response_coordinator.register(node.id)
|
||||
self._response_coordinator.register(node.id)
|
||||
|
||||
# Enqueue root node
|
||||
root_node = self.graph.root_node
|
||||
self.state_manager.enqueue_node(root_node.id)
|
||||
self.state_manager.start_execution(root_node.id)
|
||||
root_node = self._graph.root_node
|
||||
self._state_manager.enqueue_node(root_node.id)
|
||||
self._state_manager.start_execution(root_node.id)
|
||||
|
||||
# Start dispatcher
|
||||
self.dispatcher.start()
|
||||
self._dispatcher.start()
|
||||
|
||||
def _stop_execution(self) -> None:
|
||||
"""Stop execution subsystems."""
|
||||
self.dispatcher.stop()
|
||||
self._dispatcher.stop()
|
||||
self._worker_pool.stop()
|
||||
# Don't mark complete here as the dispatcher already does it
|
||||
|
||||
|
|
@ -328,6 +328,17 @@ class GraphEngine:
|
|||
|
||||
for layer in self._layers:
|
||||
try:
|
||||
layer.on_graph_end(self.graph_execution.error)
|
||||
layer.on_graph_end(self._graph_execution.error)
|
||||
except Exception as e:
|
||||
logger.warning("Layer %s failed on_graph_end: %s", layer.__class__.__name__, e)
|
||||
|
||||
# Public property accessors for attributes that need external access
|
||||
@property
|
||||
def graph_runtime_state(self) -> GraphRuntimeState:
|
||||
"""Get the graph runtime state."""
|
||||
return self._graph_runtime_state
|
||||
|
||||
@property
|
||||
def graph(self) -> Graph:
|
||||
"""Get the graph."""
|
||||
return self._graph
|
||||
|
|
|
|||
|
|
@ -38,10 +38,10 @@ class BranchHandler:
|
|||
skip_propagator: Propagator for skip states
|
||||
state_manager: Unified state manager
|
||||
"""
|
||||
self.graph = graph
|
||||
self.edge_processor = edge_processor
|
||||
self.skip_propagator = skip_propagator
|
||||
self.state_manager = state_manager
|
||||
self._graph = graph
|
||||
self._edge_processor = edge_processor
|
||||
self._skip_propagator = skip_propagator
|
||||
self._state_manager = state_manager
|
||||
|
||||
def handle_branch_completion(
|
||||
self, node_id: str, selected_handle: str | None
|
||||
|
|
@ -63,13 +63,13 @@ class BranchHandler:
|
|||
raise ValueError(f"Branch node {node_id} completed without selecting a branch")
|
||||
|
||||
# Categorize edges into selected and unselected
|
||||
_, unselected_edges = self.state_manager.categorize_branch_edges(node_id, selected_handle)
|
||||
_, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle)
|
||||
|
||||
# Skip all unselected paths
|
||||
self.skip_propagator.skip_branch_paths(unselected_edges)
|
||||
self._skip_propagator.skip_branch_paths(unselected_edges)
|
||||
|
||||
# Process selected edges and get ready nodes and streaming events
|
||||
return self.edge_processor.process_node_success(node_id, selected_handle)
|
||||
return self._edge_processor.process_node_success(node_id, selected_handle)
|
||||
|
||||
def validate_branch_selection(self, node_id: str, selected_handle: str) -> bool:
|
||||
"""
|
||||
|
|
@ -82,6 +82,6 @@ class BranchHandler:
|
|||
Returns:
|
||||
True if the selection is valid
|
||||
"""
|
||||
outgoing_edges = self.graph.get_outgoing_edges(node_id)
|
||||
outgoing_edges = self._graph.get_outgoing_edges(node_id)
|
||||
valid_handles = {edge.source_handle for edge in outgoing_edges}
|
||||
return selected_handle in valid_handles
|
||||
|
|
|
|||
|
|
@ -36,9 +36,9 @@ class EdgeProcessor:
|
|||
state_manager: Unified state manager
|
||||
response_coordinator: Response stream coordinator
|
||||
"""
|
||||
self.graph = graph
|
||||
self.state_manager = state_manager
|
||||
self.response_coordinator = response_coordinator
|
||||
self._graph = graph
|
||||
self._state_manager = state_manager
|
||||
self._response_coordinator = response_coordinator
|
||||
|
||||
def process_node_success(
|
||||
self, node_id: str, selected_handle: str | None = None
|
||||
|
|
@ -53,7 +53,7 @@ class EdgeProcessor:
|
|||
Returns:
|
||||
Tuple of (list of downstream node IDs that are now ready, list of streaming events)
|
||||
"""
|
||||
node = self.graph.nodes[node_id]
|
||||
node = self._graph.nodes[node_id]
|
||||
|
||||
if node.execution_type == NodeExecutionType.BRANCH:
|
||||
return self._process_branch_node_edges(node_id, selected_handle)
|
||||
|
|
@ -72,7 +72,7 @@ class EdgeProcessor:
|
|||
"""
|
||||
ready_nodes: list[str] = []
|
||||
all_streaming_events: list[NodeRunStreamChunkEvent] = []
|
||||
outgoing_edges = self.graph.get_outgoing_edges(node_id)
|
||||
outgoing_edges = self._graph.get_outgoing_edges(node_id)
|
||||
|
||||
for edge in outgoing_edges:
|
||||
nodes, events = self._process_taken_edge(edge)
|
||||
|
|
@ -104,7 +104,7 @@ class EdgeProcessor:
|
|||
all_streaming_events: list[NodeRunStreamChunkEvent] = []
|
||||
|
||||
# Categorize edges
|
||||
selected_edges, unselected_edges = self.state_manager.categorize_branch_edges(node_id, selected_handle)
|
||||
selected_edges, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle)
|
||||
|
||||
# Process unselected edges first (mark as skipped)
|
||||
for edge in unselected_edges:
|
||||
|
|
@ -129,14 +129,14 @@ class EdgeProcessor:
|
|||
Tuple of (list containing downstream node ID if it's ready, list of streaming events)
|
||||
"""
|
||||
# Mark edge as taken
|
||||
self.state_manager.mark_edge_taken(edge.id)
|
||||
self._state_manager.mark_edge_taken(edge.id)
|
||||
|
||||
# Notify response coordinator and get streaming events
|
||||
streaming_events = self.response_coordinator.on_edge_taken(edge.id)
|
||||
streaming_events = self._response_coordinator.on_edge_taken(edge.id)
|
||||
|
||||
# Check if downstream node is ready
|
||||
ready_nodes: list[str] = []
|
||||
if self.state_manager.is_node_ready(edge.head):
|
||||
if self._state_manager.is_node_ready(edge.head):
|
||||
ready_nodes.append(edge.head)
|
||||
|
||||
return ready_nodes, streaming_events
|
||||
|
|
@ -148,4 +148,4 @@ class EdgeProcessor:
|
|||
Args:
|
||||
edge: The edge to skip
|
||||
"""
|
||||
self.state_manager.mark_edge_skipped(edge.id)
|
||||
self._state_manager.mark_edge_skipped(edge.id)
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ class NodeReadinessChecker:
|
|||
Args:
|
||||
graph: The workflow graph
|
||||
"""
|
||||
self.graph = graph
|
||||
self._graph = graph
|
||||
|
||||
def is_node_ready(self, node_id: str) -> bool:
|
||||
"""
|
||||
|
|
@ -40,7 +40,7 @@ class NodeReadinessChecker:
|
|||
Returns:
|
||||
True if the node is ready for execution
|
||||
"""
|
||||
incoming_edges = self.graph.get_incoming_edges(node_id)
|
||||
incoming_edges = self._graph.get_incoming_edges(node_id)
|
||||
|
||||
# No dependencies means always ready
|
||||
if not incoming_edges:
|
||||
|
|
@ -75,7 +75,7 @@ class NodeReadinessChecker:
|
|||
List of node IDs that are now ready
|
||||
"""
|
||||
ready_nodes: list[str] = []
|
||||
outgoing_edges = self.graph.get_outgoing_edges(from_node_id)
|
||||
outgoing_edges = self._graph.get_outgoing_edges(from_node_id)
|
||||
|
||||
for edge in outgoing_edges:
|
||||
if edge.state == NodeState.TAKEN:
|
||||
|
|
|
|||
|
|
@ -31,8 +31,8 @@ class SkipPropagator:
|
|||
graph: The workflow graph
|
||||
state_manager: Unified state manager
|
||||
"""
|
||||
self.graph = graph
|
||||
self.state_manager = state_manager
|
||||
self._graph = graph
|
||||
self._state_manager = state_manager
|
||||
|
||||
def propagate_skip_from_edge(self, edge_id: str) -> None:
|
||||
"""
|
||||
|
|
@ -46,11 +46,11 @@ class SkipPropagator:
|
|||
Args:
|
||||
edge_id: The ID of the skipped edge to start from
|
||||
"""
|
||||
downstream_node_id = self.graph.edges[edge_id].head
|
||||
incoming_edges = self.graph.get_incoming_edges(downstream_node_id)
|
||||
downstream_node_id = self._graph.edges[edge_id].head
|
||||
incoming_edges = self._graph.get_incoming_edges(downstream_node_id)
|
||||
|
||||
# Analyze edge states
|
||||
edge_states = self.state_manager.analyze_edge_states(incoming_edges)
|
||||
edge_states = self._state_manager.analyze_edge_states(incoming_edges)
|
||||
|
||||
# Stop if there are unknown edges (not yet processed)
|
||||
if edge_states["has_unknown"]:
|
||||
|
|
@ -59,7 +59,7 @@ class SkipPropagator:
|
|||
# If any edge is taken, node may still execute
|
||||
if edge_states["has_taken"]:
|
||||
# Enqueue node
|
||||
self.state_manager.enqueue_node(downstream_node_id)
|
||||
self._state_manager.enqueue_node(downstream_node_id)
|
||||
return
|
||||
|
||||
# All edges are skipped, propagate skip to this node
|
||||
|
|
@ -74,12 +74,12 @@ class SkipPropagator:
|
|||
node_id: The ID of the node to skip
|
||||
"""
|
||||
# Mark node as skipped
|
||||
self.state_manager.mark_node_skipped(node_id)
|
||||
self._state_manager.mark_node_skipped(node_id)
|
||||
|
||||
# Mark all outgoing edges as skipped and propagate
|
||||
outgoing_edges = self.graph.get_outgoing_edges(node_id)
|
||||
outgoing_edges = self._graph.get_outgoing_edges(node_id)
|
||||
for edge in outgoing_edges:
|
||||
self.state_manager.mark_edge_skipped(edge.id)
|
||||
self._state_manager.mark_edge_skipped(edge.id)
|
||||
# Recursively propagate skip
|
||||
self.propagate_skip_from_edge(edge.id)
|
||||
|
||||
|
|
@ -91,5 +91,5 @@ class SkipPropagator:
|
|||
unselected_edges: List of edges not taken by the branch
|
||||
"""
|
||||
for edge in unselected_edges:
|
||||
self.state_manager.mark_edge_skipped(edge.id)
|
||||
self._state_manager.mark_edge_skipped(edge.id)
|
||||
self.propagate_skip_from_edge(edge.id)
|
||||
|
|
|
|||
|
|
@ -48,12 +48,12 @@ class Dispatcher:
|
|||
max_execution_time: Maximum execution time in seconds
|
||||
event_emitter: Optional event emitter to signal completion
|
||||
"""
|
||||
self.event_queue = event_queue
|
||||
self.event_handler = event_handler
|
||||
self.event_collector = event_collector
|
||||
self.execution_coordinator = execution_coordinator
|
||||
self.max_execution_time = max_execution_time
|
||||
self.event_emitter = event_emitter
|
||||
self._event_queue = event_queue
|
||||
self._event_handler = event_handler
|
||||
self._event_collector = event_collector
|
||||
self._execution_coordinator = execution_coordinator
|
||||
self._max_execution_time = max_execution_time
|
||||
self._event_emitter = event_emitter
|
||||
|
||||
self._thread: threading.Thread | None = None
|
||||
self._stop_event = threading.Event()
|
||||
|
|
@ -80,28 +80,28 @@ class Dispatcher:
|
|||
try:
|
||||
while not self._stop_event.is_set():
|
||||
# Check for commands
|
||||
self.execution_coordinator.check_commands()
|
||||
self._execution_coordinator.check_commands()
|
||||
|
||||
# Check for scaling
|
||||
self.execution_coordinator.check_scaling()
|
||||
self._execution_coordinator.check_scaling()
|
||||
|
||||
# Process events
|
||||
try:
|
||||
event = self.event_queue.get(timeout=0.1)
|
||||
event = self._event_queue.get(timeout=0.1)
|
||||
# Route to the event handler
|
||||
self.event_handler.handle_event(event)
|
||||
self.event_queue.task_done()
|
||||
self._event_handler.handle_event(event)
|
||||
self._event_queue.task_done()
|
||||
except queue.Empty:
|
||||
# Check if execution is complete
|
||||
if self.execution_coordinator.is_execution_complete():
|
||||
if self._execution_coordinator.is_execution_complete():
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Dispatcher error")
|
||||
self.execution_coordinator.mark_failed(e)
|
||||
self._execution_coordinator.mark_failed(e)
|
||||
|
||||
finally:
|
||||
self.execution_coordinator.mark_complete()
|
||||
self._execution_coordinator.mark_complete()
|
||||
# Signal the event emitter that execution is complete
|
||||
if self.event_emitter:
|
||||
self.event_emitter.mark_complete()
|
||||
if self._event_emitter:
|
||||
self._event_emitter.mark_complete()
|
||||
|
|
|
|||
|
|
@ -43,20 +43,20 @@ class ExecutionCoordinator:
|
|||
command_processor: Processor for commands
|
||||
worker_pool: Pool of workers
|
||||
"""
|
||||
self.graph_execution = graph_execution
|
||||
self.state_manager = state_manager
|
||||
self.event_handler = event_handler
|
||||
self.event_collector = event_collector
|
||||
self.command_processor = command_processor
|
||||
self.worker_pool = worker_pool
|
||||
self._graph_execution = graph_execution
|
||||
self._state_manager = state_manager
|
||||
self._event_handler = event_handler
|
||||
self._event_collector = event_collector
|
||||
self._command_processor = command_processor
|
||||
self._worker_pool = worker_pool
|
||||
|
||||
def check_commands(self) -> None:
|
||||
"""Process any pending commands."""
|
||||
self.command_processor.process_commands()
|
||||
self._command_processor.process_commands()
|
||||
|
||||
def check_scaling(self) -> None:
|
||||
"""Check and perform worker scaling if needed."""
|
||||
self.worker_pool.check_and_scale()
|
||||
self._worker_pool.check_and_scale()
|
||||
|
||||
def is_execution_complete(self) -> bool:
|
||||
"""
|
||||
|
|
@ -66,16 +66,16 @@ class ExecutionCoordinator:
|
|||
True if execution is complete
|
||||
"""
|
||||
# Check if aborted or failed
|
||||
if self.graph_execution.aborted or self.graph_execution.has_error:
|
||||
if self._graph_execution.aborted or self._graph_execution.has_error:
|
||||
return True
|
||||
|
||||
# Complete if no work remains
|
||||
return self.state_manager.is_execution_complete()
|
||||
return self._state_manager.is_execution_complete()
|
||||
|
||||
def mark_complete(self) -> None:
|
||||
"""Mark execution as complete."""
|
||||
if not self.graph_execution.completed:
|
||||
self.graph_execution.complete()
|
||||
if not self._graph_execution.completed:
|
||||
self._graph_execution.complete()
|
||||
|
||||
def mark_failed(self, error: Exception) -> None:
|
||||
"""
|
||||
|
|
@ -84,4 +84,4 @@ class ExecutionCoordinator:
|
|||
Args:
|
||||
error: The error that caused failure
|
||||
"""
|
||||
self.graph_execution.fail(error)
|
||||
self._graph_execution.fail(error)
|
||||
|
|
|
|||
|
|
@ -44,11 +44,11 @@ class ResponseStreamCoordinator:
|
|||
variable_pool: VariablePool instance for accessing node variables
|
||||
graph: Graph instance for looking up node information
|
||||
"""
|
||||
self.variable_pool = variable_pool
|
||||
self.graph = graph
|
||||
self.active_session: ResponseSession | None = None
|
||||
self.waiting_sessions: deque[ResponseSession] = deque()
|
||||
self.lock = RLock()
|
||||
self._variable_pool = variable_pool
|
||||
self._graph = graph
|
||||
self._active_session: ResponseSession | None = None
|
||||
self._waiting_sessions: deque[ResponseSession] = deque()
|
||||
self._lock = RLock()
|
||||
|
||||
# Internal stream management (replacing OutputRegistry)
|
||||
self._stream_buffers: dict[tuple[str, ...], list[NodeRunStreamChunkEvent]] = {}
|
||||
|
|
@ -68,7 +68,7 @@ class ResponseStreamCoordinator:
|
|||
self._response_sessions: dict[NodeID, ResponseSession] = {} # node_id -> session
|
||||
|
||||
def register(self, response_node_id: NodeID) -> None:
|
||||
with self.lock:
|
||||
with self._lock:
|
||||
self._response_nodes.add(response_node_id)
|
||||
|
||||
# Build and save paths map for this response node
|
||||
|
|
@ -76,7 +76,7 @@ class ResponseStreamCoordinator:
|
|||
self._paths_maps[response_node_id] = paths_map
|
||||
|
||||
# Create and store response session for this node
|
||||
response_node = self.graph.nodes[response_node_id]
|
||||
response_node = self._graph.nodes[response_node_id]
|
||||
session = ResponseSession.from_node(response_node)
|
||||
self._response_sessions[response_node_id] = session
|
||||
|
||||
|
|
@ -87,7 +87,7 @@ class ResponseStreamCoordinator:
|
|||
node_id: The ID of the node
|
||||
execution_id: The execution ID from NodeRunStartedEvent
|
||||
"""
|
||||
with self.lock:
|
||||
with self._lock:
|
||||
self._node_execution_ids[node_id] = execution_id
|
||||
|
||||
def _get_or_create_execution_id(self, node_id: NodeID) -> str:
|
||||
|
|
@ -99,7 +99,7 @@ class ResponseStreamCoordinator:
|
|||
Returns:
|
||||
The execution ID for the node
|
||||
"""
|
||||
with self.lock:
|
||||
with self._lock:
|
||||
if node_id not in self._node_execution_ids:
|
||||
self._node_execution_ids[node_id] = str(uuid4())
|
||||
return self._node_execution_ids[node_id]
|
||||
|
|
@ -116,14 +116,14 @@ class ResponseStreamCoordinator:
|
|||
List of Path objects, where each path contains branch edge IDs
|
||||
"""
|
||||
# Get root node ID
|
||||
root_node_id = self.graph.root_node.id
|
||||
root_node_id = self._graph.root_node.id
|
||||
|
||||
# If root is the response node, return empty path
|
||||
if root_node_id == response_node_id:
|
||||
return [Path()]
|
||||
|
||||
# Extract variable selectors from the response node's template
|
||||
response_node = self.graph.nodes[response_node_id]
|
||||
response_node = self._graph.nodes[response_node_id]
|
||||
response_session = ResponseSession.from_node(response_node)
|
||||
template = response_session.template
|
||||
|
||||
|
|
@ -149,7 +149,7 @@ class ResponseStreamCoordinator:
|
|||
visited.add(current_node_id)
|
||||
|
||||
# Explore outgoing edges
|
||||
outgoing_edges = self.graph.get_outgoing_edges(current_node_id)
|
||||
outgoing_edges = self._graph.get_outgoing_edges(current_node_id)
|
||||
for edge in outgoing_edges:
|
||||
edge_id = edge.id
|
||||
next_node_id = edge.head
|
||||
|
|
@ -168,8 +168,8 @@ class ResponseStreamCoordinator:
|
|||
for path in all_complete_paths:
|
||||
blocking_edges: list[str] = []
|
||||
for edge_id in path:
|
||||
edge = self.graph.edges[edge_id]
|
||||
source_node = self.graph.nodes[edge.tail]
|
||||
edge = self._graph.edges[edge_id]
|
||||
source_node = self._graph.nodes[edge.tail]
|
||||
|
||||
# Check if node is a branch/container (original behavior)
|
||||
if source_node.execution_type in {
|
||||
|
|
@ -199,7 +199,7 @@ class ResponseStreamCoordinator:
|
|||
"""
|
||||
events: list[NodeRunStreamChunkEvent] = []
|
||||
|
||||
with self.lock:
|
||||
with self._lock:
|
||||
# Check each response node in order
|
||||
for response_node_id in self._response_nodes:
|
||||
if response_node_id not in self._paths_maps:
|
||||
|
|
@ -245,21 +245,21 @@ class ResponseStreamCoordinator:
|
|||
# Remove from map to ensure it won't be activated again
|
||||
del self._response_sessions[node_id]
|
||||
|
||||
if self.active_session is None:
|
||||
self.active_session = session
|
||||
if self._active_session is None:
|
||||
self._active_session = session
|
||||
|
||||
# Try to flush immediately
|
||||
events.extend(self.try_flush())
|
||||
else:
|
||||
# Queue the session if another is active
|
||||
self.waiting_sessions.append(session)
|
||||
self._waiting_sessions.append(session)
|
||||
|
||||
return events
|
||||
|
||||
def intercept_event(
|
||||
self, event: NodeRunStreamChunkEvent | NodeRunSucceededEvent
|
||||
) -> Sequence[NodeRunStreamChunkEvent]:
|
||||
with self.lock:
|
||||
with self._lock:
|
||||
if isinstance(event, NodeRunStreamChunkEvent):
|
||||
self._append_stream_chunk(event.selector, event)
|
||||
if event.is_final:
|
||||
|
|
@ -269,9 +269,8 @@ class ResponseStreamCoordinator:
|
|||
# Skip cause we share the same variable pool.
|
||||
#
|
||||
# for variable_name, variable_value in event.node_run_result.outputs.items():
|
||||
# self.variable_pool.add((event.node_id, variable_name), variable_value)
|
||||
# self._variable_pool.add((event.node_id, variable_name), variable_value)
|
||||
return self.try_flush()
|
||||
return []
|
||||
|
||||
def _create_stream_chunk_event(
|
||||
self,
|
||||
|
|
@ -287,9 +286,9 @@ class ResponseStreamCoordinator:
|
|||
active response node's information since these are not actual node IDs.
|
||||
"""
|
||||
# Check if this is a special selector that doesn't correspond to a node
|
||||
if selector and selector[0] not in self.graph.nodes and self.active_session:
|
||||
if selector and selector[0] not in self._graph.nodes and self._active_session:
|
||||
# Use the active response node for special selectors
|
||||
response_node = self.graph.nodes[self.active_session.node_id]
|
||||
response_node = self._graph.nodes[self._active_session.node_id]
|
||||
return NodeRunStreamChunkEvent(
|
||||
id=execution_id,
|
||||
node_id=response_node.id,
|
||||
|
|
@ -300,7 +299,7 @@ class ResponseStreamCoordinator:
|
|||
)
|
||||
|
||||
# Standard case: selector refers to an actual node
|
||||
node = self.graph.nodes[node_id]
|
||||
node = self._graph.nodes[node_id]
|
||||
return NodeRunStreamChunkEvent(
|
||||
id=execution_id,
|
||||
node_id=node.id,
|
||||
|
|
@ -323,9 +322,9 @@ class ResponseStreamCoordinator:
|
|||
# Determine which node to attribute the output to
|
||||
# For special selectors (sys, env, conversation), use the active response node
|
||||
# For regular selectors, use the source node
|
||||
if self.active_session and source_selector_prefix not in self.graph.nodes:
|
||||
if self._active_session and source_selector_prefix not in self._graph.nodes:
|
||||
# Special selector - use active response node
|
||||
output_node_id = self.active_session.node_id
|
||||
output_node_id = self._active_session.node_id
|
||||
else:
|
||||
# Regular node selector
|
||||
output_node_id = source_selector_prefix
|
||||
|
|
@ -336,8 +335,8 @@ class ResponseStreamCoordinator:
|
|||
if event := self._pop_stream_chunk(segment.selector):
|
||||
# For special selectors, we need to update the event to use
|
||||
# the active response node's information
|
||||
if self.active_session and source_selector_prefix not in self.graph.nodes:
|
||||
response_node = self.graph.nodes[self.active_session.node_id]
|
||||
if self._active_session and source_selector_prefix not in self._graph.nodes:
|
||||
response_node = self._graph.nodes[self._active_session.node_id]
|
||||
# Create a new event with the response node's information
|
||||
# but keep the original selector
|
||||
updated_event = NodeRunStreamChunkEvent(
|
||||
|
|
@ -359,10 +358,10 @@ class ResponseStreamCoordinator:
|
|||
if stream_closed:
|
||||
is_complete = True
|
||||
|
||||
elif value := self.variable_pool.get(segment.selector):
|
||||
elif value := self._variable_pool.get(segment.selector):
|
||||
# Process scalar value
|
||||
is_last_segment = bool(
|
||||
self.active_session and self.active_session.index == len(self.active_session.template.segments) - 1
|
||||
self._active_session and self._active_session.index == len(self._active_session.template.segments) - 1
|
||||
)
|
||||
events.append(
|
||||
self._create_stream_chunk_event(
|
||||
|
|
@ -379,13 +378,13 @@ class ResponseStreamCoordinator:
|
|||
|
||||
def _process_text_segment(self, segment: TextSegment) -> Sequence[NodeRunStreamChunkEvent]:
|
||||
"""Process a text segment. Returns (events, is_complete)."""
|
||||
assert self.active_session is not None
|
||||
current_response_node = self.graph.nodes[self.active_session.node_id]
|
||||
assert self._active_session is not None
|
||||
current_response_node = self._graph.nodes[self._active_session.node_id]
|
||||
|
||||
# Use get_or_create_execution_id to ensure we have a consistent ID
|
||||
execution_id = self._get_or_create_execution_id(current_response_node.id)
|
||||
|
||||
is_last_segment = self.active_session.index == len(self.active_session.template.segments) - 1
|
||||
is_last_segment = self._active_session.index == len(self._active_session.template.segments) - 1
|
||||
event = self._create_stream_chunk_event(
|
||||
node_id=current_response_node.id,
|
||||
execution_id=execution_id,
|
||||
|
|
@ -396,29 +395,29 @@ class ResponseStreamCoordinator:
|
|||
return [event]
|
||||
|
||||
def try_flush(self) -> list[NodeRunStreamChunkEvent]:
|
||||
with self.lock:
|
||||
if not self.active_session:
|
||||
with self._lock:
|
||||
if not self._active_session:
|
||||
return []
|
||||
|
||||
template = self.active_session.template
|
||||
response_node_id = self.active_session.node_id
|
||||
template = self._active_session.template
|
||||
response_node_id = self._active_session.node_id
|
||||
|
||||
events: list[NodeRunStreamChunkEvent] = []
|
||||
|
||||
# Process segments sequentially from current index
|
||||
while self.active_session.index < len(template.segments):
|
||||
segment = template.segments[self.active_session.index]
|
||||
while self._active_session.index < len(template.segments):
|
||||
segment = template.segments[self._active_session.index]
|
||||
|
||||
if isinstance(segment, VariableSegment):
|
||||
# Check if the source node for this variable is skipped
|
||||
# Only check for actual nodes, not special selectors (sys, env, conversation)
|
||||
source_selector_prefix = segment.selector[0] if segment.selector else ""
|
||||
if source_selector_prefix in self.graph.nodes:
|
||||
source_node = self.graph.nodes[source_selector_prefix]
|
||||
if source_selector_prefix in self._graph.nodes:
|
||||
source_node = self._graph.nodes[source_selector_prefix]
|
||||
|
||||
if source_node.state == NodeState.SKIPPED:
|
||||
# Skip this variable segment if the source node is skipped
|
||||
self.active_session.index += 1
|
||||
self._active_session.index += 1
|
||||
continue
|
||||
|
||||
segment_events, is_complete = self._process_variable_segment(segment)
|
||||
|
|
@ -426,7 +425,7 @@ class ResponseStreamCoordinator:
|
|||
|
||||
# Only advance index if this variable segment is complete
|
||||
if is_complete:
|
||||
self.active_session.index += 1
|
||||
self._active_session.index += 1
|
||||
else:
|
||||
# Wait for more data
|
||||
break
|
||||
|
|
@ -434,9 +433,9 @@ class ResponseStreamCoordinator:
|
|||
else:
|
||||
segment_events = self._process_text_segment(segment)
|
||||
events.extend(segment_events)
|
||||
self.active_session.index += 1
|
||||
self._active_session.index += 1
|
||||
|
||||
if self.active_session.is_complete():
|
||||
if self._active_session.is_complete():
|
||||
# End current session and get events from starting next session
|
||||
next_session_events = self.end_session(response_node_id)
|
||||
events.extend(next_session_events)
|
||||
|
|
@ -454,16 +453,16 @@ class ResponseStreamCoordinator:
|
|||
Returns:
|
||||
List of events from starting the next session
|
||||
"""
|
||||
with self.lock:
|
||||
with self._lock:
|
||||
events: list[NodeRunStreamChunkEvent] = []
|
||||
|
||||
if self.active_session and self.active_session.node_id == node_id:
|
||||
self.active_session = None
|
||||
if self._active_session and self._active_session.node_id == node_id:
|
||||
self._active_session = None
|
||||
|
||||
# Try to start next waiting session
|
||||
if self.waiting_sessions:
|
||||
next_session = self.waiting_sessions.popleft()
|
||||
self.active_session = next_session
|
||||
if self._waiting_sessions:
|
||||
next_session = self._waiting_sessions.popleft()
|
||||
self._active_session = next_session
|
||||
|
||||
# Immediately try to flush any available segments
|
||||
events = self.try_flush()
|
||||
|
|
|
|||
|
|
@ -46,8 +46,8 @@ class UnifiedStateManager:
|
|||
graph: The workflow graph
|
||||
ready_queue: Queue for nodes ready to execute
|
||||
"""
|
||||
self.graph = graph
|
||||
self.ready_queue = ready_queue
|
||||
self._graph = graph
|
||||
self._ready_queue = ready_queue
|
||||
self._lock = threading.RLock()
|
||||
|
||||
# Execution tracking state
|
||||
|
|
@ -66,8 +66,8 @@ class UnifiedStateManager:
|
|||
node_id: The ID of the node to enqueue
|
||||
"""
|
||||
with self._lock:
|
||||
self.graph.nodes[node_id].state = NodeState.TAKEN
|
||||
self.ready_queue.put(node_id)
|
||||
self._graph.nodes[node_id].state = NodeState.TAKEN
|
||||
self._ready_queue.put(node_id)
|
||||
|
||||
def mark_node_skipped(self, node_id: str) -> None:
|
||||
"""
|
||||
|
|
@ -77,7 +77,7 @@ class UnifiedStateManager:
|
|||
node_id: The ID of the node to skip
|
||||
"""
|
||||
with self._lock:
|
||||
self.graph.nodes[node_id].state = NodeState.SKIPPED
|
||||
self._graph.nodes[node_id].state = NodeState.SKIPPED
|
||||
|
||||
def is_node_ready(self, node_id: str) -> bool:
|
||||
"""
|
||||
|
|
@ -94,7 +94,7 @@ class UnifiedStateManager:
|
|||
"""
|
||||
with self._lock:
|
||||
# Get all incoming edges to this node
|
||||
incoming_edges = self.graph.get_incoming_edges(node_id)
|
||||
incoming_edges = self._graph.get_incoming_edges(node_id)
|
||||
|
||||
# If no incoming edges, node is always ready
|
||||
if not incoming_edges:
|
||||
|
|
@ -118,7 +118,7 @@ class UnifiedStateManager:
|
|||
The current node state
|
||||
"""
|
||||
with self._lock:
|
||||
return self.graph.nodes[node_id].state
|
||||
return self._graph.nodes[node_id].state
|
||||
|
||||
# ============= Edge State Operations =============
|
||||
|
||||
|
|
@ -130,7 +130,7 @@ class UnifiedStateManager:
|
|||
edge_id: The ID of the edge to mark
|
||||
"""
|
||||
with self._lock:
|
||||
self.graph.edges[edge_id].state = NodeState.TAKEN
|
||||
self._graph.edges[edge_id].state = NodeState.TAKEN
|
||||
|
||||
def mark_edge_skipped(self, edge_id: str) -> None:
|
||||
"""
|
||||
|
|
@ -140,7 +140,7 @@ class UnifiedStateManager:
|
|||
edge_id: The ID of the edge to mark
|
||||
"""
|
||||
with self._lock:
|
||||
self.graph.edges[edge_id].state = NodeState.SKIPPED
|
||||
self._graph.edges[edge_id].state = NodeState.SKIPPED
|
||||
|
||||
def analyze_edge_states(self, edges: list[Edge]) -> EdgeStateAnalysis:
|
||||
"""
|
||||
|
|
@ -172,7 +172,7 @@ class UnifiedStateManager:
|
|||
The current edge state
|
||||
"""
|
||||
with self._lock:
|
||||
return self.graph.edges[edge_id].state
|
||||
return self._graph.edges[edge_id].state
|
||||
|
||||
def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[Sequence[Edge], Sequence[Edge]]:
|
||||
"""
|
||||
|
|
@ -186,7 +186,7 @@ class UnifiedStateManager:
|
|||
A tuple of (selected_edges, unselected_edges)
|
||||
"""
|
||||
with self._lock:
|
||||
outgoing_edges = self.graph.get_outgoing_edges(node_id)
|
||||
outgoing_edges = self._graph.get_outgoing_edges(node_id)
|
||||
selected_edges: list[Edge] = []
|
||||
unselected_edges: list[Edge] = []
|
||||
|
||||
|
|
@ -272,7 +272,7 @@ class UnifiedStateManager:
|
|||
True if execution is complete
|
||||
"""
|
||||
with self._lock:
|
||||
return self.ready_queue.empty() and len(self._executing_nodes) == 0
|
||||
return self._ready_queue.empty() and len(self._executing_nodes) == 0
|
||||
|
||||
def get_queue_depth(self) -> int:
|
||||
"""
|
||||
|
|
@ -281,7 +281,7 @@ class UnifiedStateManager:
|
|||
Returns:
|
||||
Number of nodes in the ready queue
|
||||
"""
|
||||
return self.ready_queue.qsize()
|
||||
return self._ready_queue.qsize()
|
||||
|
||||
def get_execution_stats(self) -> dict[str, int]:
|
||||
"""
|
||||
|
|
@ -291,12 +291,12 @@ class UnifiedStateManager:
|
|||
Dictionary with execution statistics
|
||||
"""
|
||||
with self._lock:
|
||||
taken_nodes = sum(1 for node in self.graph.nodes.values() if node.state == NodeState.TAKEN)
|
||||
skipped_nodes = sum(1 for node in self.graph.nodes.values() if node.state == NodeState.SKIPPED)
|
||||
unknown_nodes = sum(1 for node in self.graph.nodes.values() if node.state == NodeState.UNKNOWN)
|
||||
taken_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.TAKEN)
|
||||
skipped_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.SKIPPED)
|
||||
unknown_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.UNKNOWN)
|
||||
|
||||
return {
|
||||
"queue_depth": self.ready_queue.qsize(),
|
||||
"queue_depth": self._ready_queue.qsize(),
|
||||
"executing": len(self._executing_nodes),
|
||||
"taken_nodes": taken_nodes,
|
||||
"skipped_nodes": skipped_nodes,
|
||||
|
|
|
|||
|
|
@ -59,16 +59,16 @@ class Worker(threading.Thread):
|
|||
on_active_callback: Optional callback when worker becomes active
|
||||
"""
|
||||
super().__init__(name=f"GraphWorker-{worker_id}", daemon=True)
|
||||
self.ready_queue = ready_queue
|
||||
self.event_queue = event_queue
|
||||
self.graph = graph
|
||||
self.worker_id = worker_id
|
||||
self.flask_app = flask_app
|
||||
self.context_vars = context_vars
|
||||
self._ready_queue = ready_queue
|
||||
self._event_queue = event_queue
|
||||
self._graph = graph
|
||||
self._worker_id = worker_id
|
||||
self._flask_app = flask_app
|
||||
self._context_vars = context_vars
|
||||
self._stop_event = threading.Event()
|
||||
self.on_idle_callback = on_idle_callback
|
||||
self.on_active_callback = on_active_callback
|
||||
self.last_task_time = time.time()
|
||||
self._on_idle_callback = on_idle_callback
|
||||
self._on_active_callback = on_active_callback
|
||||
self._last_task_time = time.time()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Signal the worker to stop processing."""
|
||||
|
|
@ -85,22 +85,22 @@ class Worker(threading.Thread):
|
|||
while not self._stop_event.is_set():
|
||||
# Try to get a node ID from the ready queue (with timeout)
|
||||
try:
|
||||
node_id = self.ready_queue.get(timeout=0.1)
|
||||
node_id = self._ready_queue.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
# Notify that worker is idle
|
||||
if self.on_idle_callback:
|
||||
self.on_idle_callback(self.worker_id)
|
||||
if self._on_idle_callback:
|
||||
self._on_idle_callback(self._worker_id)
|
||||
continue
|
||||
|
||||
# Notify that worker is active
|
||||
if self.on_active_callback:
|
||||
self.on_active_callback(self.worker_id)
|
||||
if self._on_active_callback:
|
||||
self._on_active_callback(self._worker_id)
|
||||
|
||||
self.last_task_time = time.time()
|
||||
node = self.graph.nodes[node_id]
|
||||
self._last_task_time = time.time()
|
||||
node = self._graph.nodes[node_id]
|
||||
try:
|
||||
self._execute_node(node)
|
||||
self.ready_queue.task_done()
|
||||
self._ready_queue.task_done()
|
||||
except Exception as e:
|
||||
error_event = NodeRunFailedEvent(
|
||||
id=str(uuid4()),
|
||||
|
|
@ -110,7 +110,7 @@ class Worker(threading.Thread):
|
|||
error=str(e),
|
||||
start_at=datetime.now(),
|
||||
)
|
||||
self.event_queue.put(error_event)
|
||||
self._event_queue.put(error_event)
|
||||
|
||||
def _execute_node(self, node: Node) -> None:
|
||||
"""
|
||||
|
|
@ -120,19 +120,19 @@ class Worker(threading.Thread):
|
|||
node: The node instance to execute
|
||||
"""
|
||||
# Execute the node with preserved context if Flask app is provided
|
||||
if self.flask_app and self.context_vars:
|
||||
if self._flask_app and self._context_vars:
|
||||
with preserve_flask_contexts(
|
||||
flask_app=self.flask_app,
|
||||
context_vars=self.context_vars,
|
||||
flask_app=self._flask_app,
|
||||
context_vars=self._context_vars,
|
||||
):
|
||||
# Execute the node
|
||||
node_events = node.run()
|
||||
for event in node_events:
|
||||
# Forward event to dispatcher immediately for streaming
|
||||
self.event_queue.put(event)
|
||||
self._event_queue.put(event)
|
||||
else:
|
||||
# Execute without context preservation
|
||||
node_events = node.run()
|
||||
for event in node_events:
|
||||
# Forward event to dispatcher immediately for streaming
|
||||
self.event_queue.put(event)
|
||||
self._event_queue.put(event)
|
||||
|
|
|
|||
|
|
@ -56,20 +56,20 @@ class SimpleWorkerPool:
|
|||
scale_up_threshold: Queue depth to trigger scale up
|
||||
scale_down_idle_time: Seconds before scaling down idle workers
|
||||
"""
|
||||
self.ready_queue = ready_queue
|
||||
self.event_queue = event_queue
|
||||
self.graph = graph
|
||||
self.flask_app = flask_app
|
||||
self.context_vars = context_vars
|
||||
self._ready_queue = ready_queue
|
||||
self._event_queue = event_queue
|
||||
self._graph = graph
|
||||
self._flask_app = flask_app
|
||||
self._context_vars = context_vars
|
||||
|
||||
# Scaling parameters with defaults
|
||||
self.min_workers = min_workers or dify_config.GRAPH_ENGINE_MIN_WORKERS
|
||||
self.max_workers = max_workers or dify_config.GRAPH_ENGINE_MAX_WORKERS
|
||||
self.scale_up_threshold = scale_up_threshold or dify_config.GRAPH_ENGINE_SCALE_UP_THRESHOLD
|
||||
self.scale_down_idle_time = scale_down_idle_time or dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME
|
||||
self._min_workers = min_workers or dify_config.GRAPH_ENGINE_MIN_WORKERS
|
||||
self._max_workers = max_workers or dify_config.GRAPH_ENGINE_MAX_WORKERS
|
||||
self._scale_up_threshold = scale_up_threshold or dify_config.GRAPH_ENGINE_SCALE_UP_THRESHOLD
|
||||
self._scale_down_idle_time = scale_down_idle_time or dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME
|
||||
|
||||
# Worker management
|
||||
self.workers: list[Worker] = []
|
||||
self._workers: list[Worker] = []
|
||||
self._worker_counter = 0
|
||||
self._lock = threading.RLock()
|
||||
self._running = False
|
||||
|
|
@ -89,13 +89,13 @@ class SimpleWorkerPool:
|
|||
|
||||
# Calculate initial worker count
|
||||
if initial_count is None:
|
||||
node_count = len(self.graph.nodes)
|
||||
node_count = len(self._graph.nodes)
|
||||
if node_count < 10:
|
||||
initial_count = self.min_workers
|
||||
initial_count = self._min_workers
|
||||
elif node_count < 50:
|
||||
initial_count = min(self.min_workers + 1, self.max_workers)
|
||||
initial_count = min(self._min_workers + 1, self._max_workers)
|
||||
else:
|
||||
initial_count = min(self.min_workers + 2, self.max_workers)
|
||||
initial_count = min(self._min_workers + 2, self._max_workers)
|
||||
|
||||
# Create initial workers
|
||||
for _ in range(initial_count):
|
||||
|
|
@ -107,15 +107,15 @@ class SimpleWorkerPool:
|
|||
self._running = False
|
||||
|
||||
# Stop all workers
|
||||
for worker in self.workers:
|
||||
for worker in self._workers:
|
||||
worker.stop()
|
||||
|
||||
# Wait for workers to finish
|
||||
for worker in self.workers:
|
||||
for worker in self._workers:
|
||||
if worker.is_alive():
|
||||
worker.join(timeout=10.0)
|
||||
|
||||
self.workers.clear()
|
||||
self._workers.clear()
|
||||
|
||||
def _create_worker(self) -> None:
|
||||
"""Create and start a new worker."""
|
||||
|
|
@ -123,16 +123,16 @@ class SimpleWorkerPool:
|
|||
self._worker_counter += 1
|
||||
|
||||
worker = Worker(
|
||||
ready_queue=self.ready_queue,
|
||||
event_queue=self.event_queue,
|
||||
graph=self.graph,
|
||||
ready_queue=self._ready_queue,
|
||||
event_queue=self._event_queue,
|
||||
graph=self._graph,
|
||||
worker_id=worker_id,
|
||||
flask_app=self.flask_app,
|
||||
context_vars=self.context_vars,
|
||||
flask_app=self._flask_app,
|
||||
context_vars=self._context_vars,
|
||||
)
|
||||
|
||||
worker.start()
|
||||
self.workers.append(worker)
|
||||
self._workers.append(worker)
|
||||
|
||||
def check_and_scale(self) -> None:
|
||||
"""Check and perform scaling if needed."""
|
||||
|
|
@ -140,17 +140,17 @@ class SimpleWorkerPool:
|
|||
if not self._running:
|
||||
return
|
||||
|
||||
current_count = len(self.workers)
|
||||
queue_depth = self.ready_queue.qsize()
|
||||
current_count = len(self._workers)
|
||||
queue_depth = self._ready_queue.qsize()
|
||||
|
||||
# Simple scaling logic
|
||||
if queue_depth > self.scale_up_threshold and current_count < self.max_workers:
|
||||
if queue_depth > self._scale_up_threshold and current_count < self._max_workers:
|
||||
self._create_worker()
|
||||
|
||||
def get_worker_count(self) -> int:
|
||||
"""Get current number of workers."""
|
||||
with self._lock:
|
||||
return len(self.workers)
|
||||
return len(self._workers)
|
||||
|
||||
def get_status(self) -> dict[str, int]:
|
||||
"""
|
||||
|
|
@ -161,8 +161,8 @@ class SimpleWorkerPool:
|
|||
"""
|
||||
with self._lock:
|
||||
return {
|
||||
"total_workers": len(self.workers),
|
||||
"queue_depth": self.ready_queue.qsize(),
|
||||
"min_workers": self.min_workers,
|
||||
"max_workers": self.max_workers,
|
||||
"total_workers": len(self._workers),
|
||||
"queue_depth": self._ready_queue.qsize(),
|
||||
"min_workers": self._min_workers,
|
||||
"max_workers": self._max_workers,
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue