refactor(graph_engine): Correct private attributes and private methods naming

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2025-09-01 04:37:23 +08:00
parent a5cb9d2b73
commit 0fdb1b2bc9
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF
17 changed files with 287 additions and 272 deletions

View File

@ -39,8 +39,8 @@ class CommandProcessor:
command_channel: Channel for receiving commands
graph_execution: Graph execution aggregate
"""
self.command_channel = command_channel
self.graph_execution = graph_execution
self._command_channel = command_channel
self._graph_execution = graph_execution
self._handlers: dict[type[GraphEngineCommand], CommandHandler] = {}
def register_handler(self, command_type: type[GraphEngineCommand], handler: CommandHandler) -> None:
@ -56,7 +56,7 @@ class CommandProcessor:
def process_commands(self) -> None:
"""Check for and process any pending commands."""
try:
commands = self.command_channel.fetch_commands()
commands = self._command_channel.fetch_commands()
for command in commands:
self._handle_command(command)
except Exception as e:
@ -72,7 +72,7 @@ class CommandProcessor:
handler = self._handlers.get(type(command))
if handler:
try:
handler.handle(command, self.graph_execution)
handler.handle(command, self._graph_execution)
except Exception:
logger.exception("Error handling command %s", command.__class__.__name__)
else:

View File

@ -32,6 +32,8 @@ class AbortStrategy:
Returns:
None - signals abortion
"""
_ = graph
_ = retry_count
logger.error("Node %s failed with ABORT strategy: %s", event.node_id, event.error)
# Return None to signal that execution should stop

View File

@ -31,6 +31,7 @@ class DefaultValueStrategy:
Returns:
NodeRunExceptionEvent with default values
"""
_ = retry_count
node = graph.nodes[event.node_id]
outputs = {

View File

@ -31,6 +31,8 @@ class FailBranchStrategy:
Returns:
NodeRunExceptionEvent to continue via fail branch
"""
_ = graph
_ = retry_count
outputs = {
"error_message": event.node_run_result.error,
"error_type": event.node_run_result.error_type,

View File

@ -23,7 +23,7 @@ class ReadWriteLock:
def acquire_read(self) -> None:
"""Acquire a read lock."""
self._read_ready.acquire()
_ = self._read_ready.acquire()
try:
self._readers += 1
finally:
@ -31,7 +31,7 @@ class ReadWriteLock:
def release_read(self) -> None:
"""Release a read lock."""
self._read_ready.acquire()
_ = self._read_ready.acquire()
try:
self._readers -= 1
if self._readers == 0:
@ -41,9 +41,9 @@ class ReadWriteLock:
def acquire_write(self) -> None:
"""Acquire a write lock."""
self._read_ready.acquire()
_ = self._read_ready.acquire()
while self._readers > 0:
self._read_ready.wait()
_ = self._read_ready.wait()
def release_write(self) -> None:
"""Release a write lock."""

View File

@ -28,7 +28,7 @@ class EventEmitter:
Args:
event_collector: The collector to emit events from
"""
self.event_collector = event_collector
self._event_collector = event_collector
self._execution_complete = threading.Event()
def mark_complete(self) -> None:
@ -44,9 +44,9 @@ class EventEmitter:
"""
yielded_count = 0
while not self._execution_complete.is_set() or yielded_count < self.event_collector.event_count():
while not self._execution_complete.is_set() or yielded_count < self._event_collector.event_count():
# Get new events since last yield
new_events = self.event_collector.get_new_events(yielded_count)
new_events = self._event_collector.get_new_events(yielded_count)
# Yield any new events
for event in new_events:

View File

@ -75,7 +75,7 @@ class GraphEngine:
"""Initialize the graph engine with separated concerns."""
# Create domain models
self.execution_context = ExecutionContext(
self._execution_context = ExecutionContext(
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
@ -87,13 +87,13 @@ class GraphEngine:
max_execution_time=max_execution_time,
)
self.graph_execution = GraphExecution(workflow_id=workflow_id)
self._graph_execution = GraphExecution(workflow_id=workflow_id)
# Store core dependencies
self.graph = graph
self.graph_config = graph_config
self.graph_runtime_state = graph_runtime_state
self.command_channel = command_channel
self._graph = graph
self._graph_config = graph_config
self._graph_runtime_state = graph_runtime_state
self._command_channel = command_channel
# Store worker management parameters
self._min_workers = min_workers
@ -102,8 +102,8 @@ class GraphEngine:
self._scale_down_idle_time = scale_down_idle_time
# Initialize queues
self.ready_queue: queue.Queue[str] = queue.Queue()
self.event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()
self._ready_queue: queue.Queue[str] = queue.Queue()
self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()
# Initialize subsystems
self._initialize_subsystems()
@ -118,55 +118,55 @@ class GraphEngine:
"""Initialize all subsystems with proper dependency injection."""
# Unified state management - single instance handles all state operations
self.state_manager = UnifiedStateManager(self.graph, self.ready_queue)
self._state_manager = UnifiedStateManager(self._graph, self._ready_queue)
# Response coordination
self.response_coordinator = ResponseStreamCoordinator(
variable_pool=self.graph_runtime_state.variable_pool, graph=self.graph
self._response_coordinator = ResponseStreamCoordinator(
variable_pool=self._graph_runtime_state.variable_pool, graph=self._graph
)
# Event management
self.event_collector = EventCollector()
self.event_emitter = EventEmitter(self.event_collector)
self._event_collector = EventCollector()
self._event_emitter = EventEmitter(self._event_collector)
# Error handling
self.error_handler = ErrorHandler(self.graph, self.graph_execution)
self._error_handler = ErrorHandler(self._graph, self._graph_execution)
# Graph traversal
self.node_readiness_checker = NodeReadinessChecker(self.graph)
self.edge_processor = EdgeProcessor(
graph=self.graph,
state_manager=self.state_manager,
response_coordinator=self.response_coordinator,
self._node_readiness_checker = NodeReadinessChecker(self._graph)
self._edge_processor = EdgeProcessor(
graph=self._graph,
state_manager=self._state_manager,
response_coordinator=self._response_coordinator,
)
self.skip_propagator = SkipPropagator(
graph=self.graph,
state_manager=self.state_manager,
self._skip_propagator = SkipPropagator(
graph=self._graph,
state_manager=self._state_manager,
)
self.branch_handler = BranchHandler(
graph=self.graph,
edge_processor=self.edge_processor,
skip_propagator=self.skip_propagator,
state_manager=self.state_manager,
self._branch_handler = BranchHandler(
graph=self._graph,
edge_processor=self._edge_processor,
skip_propagator=self._skip_propagator,
state_manager=self._state_manager,
)
# Event handler registry with all dependencies
self.event_handler_registry = EventHandlerRegistry(
graph=self.graph,
graph_runtime_state=self.graph_runtime_state,
graph_execution=self.graph_execution,
response_coordinator=self.response_coordinator,
event_collector=self.event_collector,
branch_handler=self.branch_handler,
edge_processor=self.edge_processor,
state_manager=self.state_manager,
error_handler=self.error_handler,
self._event_handler_registry = EventHandlerRegistry(
graph=self._graph,
graph_runtime_state=self._graph_runtime_state,
graph_execution=self._graph_execution,
response_coordinator=self._response_coordinator,
event_collector=self._event_collector,
branch_handler=self._branch_handler,
edge_processor=self._edge_processor,
state_manager=self._state_manager,
error_handler=self._error_handler,
)
# Command processing
self.command_processor = CommandProcessor(
command_channel=self.command_channel,
graph_execution=self.graph_execution,
self._command_processor = CommandProcessor(
command_channel=self._command_channel,
graph_execution=self._graph_execution,
)
self._setup_command_handlers()
@ -174,29 +174,29 @@ class GraphEngine:
self._setup_worker_management()
# Orchestration
self.execution_coordinator = ExecutionCoordinator(
graph_execution=self.graph_execution,
state_manager=self.state_manager,
event_handler=self.event_handler_registry,
event_collector=self.event_collector,
command_processor=self.command_processor,
self._execution_coordinator = ExecutionCoordinator(
graph_execution=self._graph_execution,
state_manager=self._state_manager,
event_handler=self._event_handler_registry,
event_collector=self._event_collector,
command_processor=self._command_processor,
worker_pool=self._worker_pool,
)
self.dispatcher = Dispatcher(
event_queue=self.event_queue,
event_handler=self.event_handler_registry,
event_collector=self.event_collector,
execution_coordinator=self.execution_coordinator,
max_execution_time=self.execution_context.max_execution_time,
event_emitter=self.event_emitter,
self._dispatcher = Dispatcher(
event_queue=self._event_queue,
event_handler=self._event_handler_registry,
event_collector=self._event_collector,
execution_coordinator=self._execution_coordinator,
max_execution_time=self._execution_context.max_execution_time,
event_emitter=self._event_emitter,
)
def _setup_command_handlers(self) -> None:
"""Configure command handlers."""
# Create handler instance that follows the protocol
abort_handler = AbortCommandHandler()
self.command_processor.register_handler(
self._command_processor.register_handler(
AbortCommand,
abort_handler,
)
@ -216,9 +216,9 @@ class GraphEngine:
# Create simple worker pool
self._worker_pool = SimpleWorkerPool(
ready_queue=self.ready_queue,
event_queue=self.event_queue,
graph=self.graph,
ready_queue=self._ready_queue,
event_queue=self._event_queue,
graph=self._graph,
flask_app=flask_app,
context_vars=context_vars,
min_workers=self._min_workers,
@ -229,8 +229,8 @@ class GraphEngine:
def _validate_graph_state_consistency(self) -> None:
"""Validate that all nodes share the same GraphRuntimeState."""
expected_state_id = id(self.graph_runtime_state)
for node in self.graph.nodes.values():
expected_state_id = id(self._graph_runtime_state)
for node in self._graph.nodes.values():
if id(node.graph_runtime_state) != expected_state_id:
raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance")
@ -251,7 +251,7 @@ class GraphEngine:
self._initialize_layers()
# Start execution
self.graph_execution.start()
self._graph_execution.start()
start_event = GraphRunStartedEvent()
yield start_event
@ -259,23 +259,23 @@ class GraphEngine:
self._start_execution()
# Yield events as they occur
yield from self.event_emitter.emit_events()
yield from self._event_emitter.emit_events()
# Handle completion
if self.graph_execution.aborted:
if self._graph_execution.aborted:
abort_reason = "Workflow execution aborted by user command"
if self.graph_execution.error:
abort_reason = str(self.graph_execution.error)
if self._graph_execution.error:
abort_reason = str(self._graph_execution.error)
yield GraphRunAbortedEvent(
reason=abort_reason,
outputs=self.graph_runtime_state.outputs,
outputs=self._graph_runtime_state.outputs,
)
elif self.graph_execution.has_error:
if self.graph_execution.error:
raise self.graph_execution.error
elif self._graph_execution.has_error:
if self._graph_execution.error:
raise self._graph_execution.error
else:
yield GraphRunSucceededEvent(
outputs=self.graph_runtime_state.outputs,
outputs=self._graph_runtime_state.outputs,
)
except Exception as e:
@ -287,10 +287,10 @@ class GraphEngine:
def _initialize_layers(self) -> None:
"""Initialize layers with context."""
self.event_collector.set_layers(self._layers)
self._event_collector.set_layers(self._layers)
for layer in self._layers:
try:
layer.initialize(self.graph_runtime_state, self.command_channel)
layer.initialize(self._graph_runtime_state, self._command_channel)
except Exception as e:
logger.warning("Failed to initialize layer %s: %s", layer.__class__.__name__, e)
@ -305,21 +305,21 @@ class GraphEngine:
self._worker_pool.start()
# Register response nodes
for node in self.graph.nodes.values():
for node in self._graph.nodes.values():
if node.execution_type == NodeExecutionType.RESPONSE:
self.response_coordinator.register(node.id)
self._response_coordinator.register(node.id)
# Enqueue root node
root_node = self.graph.root_node
self.state_manager.enqueue_node(root_node.id)
self.state_manager.start_execution(root_node.id)
root_node = self._graph.root_node
self._state_manager.enqueue_node(root_node.id)
self._state_manager.start_execution(root_node.id)
# Start dispatcher
self.dispatcher.start()
self._dispatcher.start()
def _stop_execution(self) -> None:
"""Stop execution subsystems."""
self.dispatcher.stop()
self._dispatcher.stop()
self._worker_pool.stop()
# Don't mark complete here as the dispatcher already does it
@ -328,6 +328,17 @@ class GraphEngine:
for layer in self._layers:
try:
layer.on_graph_end(self.graph_execution.error)
layer.on_graph_end(self._graph_execution.error)
except Exception as e:
logger.warning("Layer %s failed on_graph_end: %s", layer.__class__.__name__, e)
# Public property accessors for attributes that need external access
@property
def graph_runtime_state(self) -> GraphRuntimeState:
"""Get the graph runtime state."""
return self._graph_runtime_state
@property
def graph(self) -> Graph:
"""Get the graph."""
return self._graph

View File

@ -38,10 +38,10 @@ class BranchHandler:
skip_propagator: Propagator for skip states
state_manager: Unified state manager
"""
self.graph = graph
self.edge_processor = edge_processor
self.skip_propagator = skip_propagator
self.state_manager = state_manager
self._graph = graph
self._edge_processor = edge_processor
self._skip_propagator = skip_propagator
self._state_manager = state_manager
def handle_branch_completion(
self, node_id: str, selected_handle: str | None
@ -63,13 +63,13 @@ class BranchHandler:
raise ValueError(f"Branch node {node_id} completed without selecting a branch")
# Categorize edges into selected and unselected
_, unselected_edges = self.state_manager.categorize_branch_edges(node_id, selected_handle)
_, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle)
# Skip all unselected paths
self.skip_propagator.skip_branch_paths(unselected_edges)
self._skip_propagator.skip_branch_paths(unselected_edges)
# Process selected edges and get ready nodes and streaming events
return self.edge_processor.process_node_success(node_id, selected_handle)
return self._edge_processor.process_node_success(node_id, selected_handle)
def validate_branch_selection(self, node_id: str, selected_handle: str) -> bool:
"""
@ -82,6 +82,6 @@ class BranchHandler:
Returns:
True if the selection is valid
"""
outgoing_edges = self.graph.get_outgoing_edges(node_id)
outgoing_edges = self._graph.get_outgoing_edges(node_id)
valid_handles = {edge.source_handle for edge in outgoing_edges}
return selected_handle in valid_handles

View File

@ -36,9 +36,9 @@ class EdgeProcessor:
state_manager: Unified state manager
response_coordinator: Response stream coordinator
"""
self.graph = graph
self.state_manager = state_manager
self.response_coordinator = response_coordinator
self._graph = graph
self._state_manager = state_manager
self._response_coordinator = response_coordinator
def process_node_success(
self, node_id: str, selected_handle: str | None = None
@ -53,7 +53,7 @@ class EdgeProcessor:
Returns:
Tuple of (list of downstream node IDs that are now ready, list of streaming events)
"""
node = self.graph.nodes[node_id]
node = self._graph.nodes[node_id]
if node.execution_type == NodeExecutionType.BRANCH:
return self._process_branch_node_edges(node_id, selected_handle)
@ -72,7 +72,7 @@ class EdgeProcessor:
"""
ready_nodes: list[str] = []
all_streaming_events: list[NodeRunStreamChunkEvent] = []
outgoing_edges = self.graph.get_outgoing_edges(node_id)
outgoing_edges = self._graph.get_outgoing_edges(node_id)
for edge in outgoing_edges:
nodes, events = self._process_taken_edge(edge)
@ -104,7 +104,7 @@ class EdgeProcessor:
all_streaming_events: list[NodeRunStreamChunkEvent] = []
# Categorize edges
selected_edges, unselected_edges = self.state_manager.categorize_branch_edges(node_id, selected_handle)
selected_edges, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle)
# Process unselected edges first (mark as skipped)
for edge in unselected_edges:
@ -129,14 +129,14 @@ class EdgeProcessor:
Tuple of (list containing downstream node ID if it's ready, list of streaming events)
"""
# Mark edge as taken
self.state_manager.mark_edge_taken(edge.id)
self._state_manager.mark_edge_taken(edge.id)
# Notify response coordinator and get streaming events
streaming_events = self.response_coordinator.on_edge_taken(edge.id)
streaming_events = self._response_coordinator.on_edge_taken(edge.id)
# Check if downstream node is ready
ready_nodes: list[str] = []
if self.state_manager.is_node_ready(edge.head):
if self._state_manager.is_node_ready(edge.head):
ready_nodes.append(edge.head)
return ready_nodes, streaming_events
@ -148,4 +148,4 @@ class EdgeProcessor:
Args:
edge: The edge to skip
"""
self.state_manager.mark_edge_skipped(edge.id)
self._state_manager.mark_edge_skipped(edge.id)

View File

@ -24,7 +24,7 @@ class NodeReadinessChecker:
Args:
graph: The workflow graph
"""
self.graph = graph
self._graph = graph
def is_node_ready(self, node_id: str) -> bool:
"""
@ -40,7 +40,7 @@ class NodeReadinessChecker:
Returns:
True if the node is ready for execution
"""
incoming_edges = self.graph.get_incoming_edges(node_id)
incoming_edges = self._graph.get_incoming_edges(node_id)
# No dependencies means always ready
if not incoming_edges:
@ -75,7 +75,7 @@ class NodeReadinessChecker:
List of node IDs that are now ready
"""
ready_nodes: list[str] = []
outgoing_edges = self.graph.get_outgoing_edges(from_node_id)
outgoing_edges = self._graph.get_outgoing_edges(from_node_id)
for edge in outgoing_edges:
if edge.state == NodeState.TAKEN:

View File

@ -31,8 +31,8 @@ class SkipPropagator:
graph: The workflow graph
state_manager: Unified state manager
"""
self.graph = graph
self.state_manager = state_manager
self._graph = graph
self._state_manager = state_manager
def propagate_skip_from_edge(self, edge_id: str) -> None:
"""
@ -46,11 +46,11 @@ class SkipPropagator:
Args:
edge_id: The ID of the skipped edge to start from
"""
downstream_node_id = self.graph.edges[edge_id].head
incoming_edges = self.graph.get_incoming_edges(downstream_node_id)
downstream_node_id = self._graph.edges[edge_id].head
incoming_edges = self._graph.get_incoming_edges(downstream_node_id)
# Analyze edge states
edge_states = self.state_manager.analyze_edge_states(incoming_edges)
edge_states = self._state_manager.analyze_edge_states(incoming_edges)
# Stop if there are unknown edges (not yet processed)
if edge_states["has_unknown"]:
@ -59,7 +59,7 @@ class SkipPropagator:
# If any edge is taken, node may still execute
if edge_states["has_taken"]:
# Enqueue node
self.state_manager.enqueue_node(downstream_node_id)
self._state_manager.enqueue_node(downstream_node_id)
return
# All edges are skipped, propagate skip to this node
@ -74,12 +74,12 @@ class SkipPropagator:
node_id: The ID of the node to skip
"""
# Mark node as skipped
self.state_manager.mark_node_skipped(node_id)
self._state_manager.mark_node_skipped(node_id)
# Mark all outgoing edges as skipped and propagate
outgoing_edges = self.graph.get_outgoing_edges(node_id)
outgoing_edges = self._graph.get_outgoing_edges(node_id)
for edge in outgoing_edges:
self.state_manager.mark_edge_skipped(edge.id)
self._state_manager.mark_edge_skipped(edge.id)
# Recursively propagate skip
self.propagate_skip_from_edge(edge.id)
@ -91,5 +91,5 @@ class SkipPropagator:
unselected_edges: List of edges not taken by the branch
"""
for edge in unselected_edges:
self.state_manager.mark_edge_skipped(edge.id)
self._state_manager.mark_edge_skipped(edge.id)
self.propagate_skip_from_edge(edge.id)

View File

@ -48,12 +48,12 @@ class Dispatcher:
max_execution_time: Maximum execution time in seconds
event_emitter: Optional event emitter to signal completion
"""
self.event_queue = event_queue
self.event_handler = event_handler
self.event_collector = event_collector
self.execution_coordinator = execution_coordinator
self.max_execution_time = max_execution_time
self.event_emitter = event_emitter
self._event_queue = event_queue
self._event_handler = event_handler
self._event_collector = event_collector
self._execution_coordinator = execution_coordinator
self._max_execution_time = max_execution_time
self._event_emitter = event_emitter
self._thread: threading.Thread | None = None
self._stop_event = threading.Event()
@ -80,28 +80,28 @@ class Dispatcher:
try:
while not self._stop_event.is_set():
# Check for commands
self.execution_coordinator.check_commands()
self._execution_coordinator.check_commands()
# Check for scaling
self.execution_coordinator.check_scaling()
self._execution_coordinator.check_scaling()
# Process events
try:
event = self.event_queue.get(timeout=0.1)
event = self._event_queue.get(timeout=0.1)
# Route to the event handler
self.event_handler.handle_event(event)
self.event_queue.task_done()
self._event_handler.handle_event(event)
self._event_queue.task_done()
except queue.Empty:
# Check if execution is complete
if self.execution_coordinator.is_execution_complete():
if self._execution_coordinator.is_execution_complete():
break
except Exception as e:
logger.exception("Dispatcher error")
self.execution_coordinator.mark_failed(e)
self._execution_coordinator.mark_failed(e)
finally:
self.execution_coordinator.mark_complete()
self._execution_coordinator.mark_complete()
# Signal the event emitter that execution is complete
if self.event_emitter:
self.event_emitter.mark_complete()
if self._event_emitter:
self._event_emitter.mark_complete()

View File

@ -43,20 +43,20 @@ class ExecutionCoordinator:
command_processor: Processor for commands
worker_pool: Pool of workers
"""
self.graph_execution = graph_execution
self.state_manager = state_manager
self.event_handler = event_handler
self.event_collector = event_collector
self.command_processor = command_processor
self.worker_pool = worker_pool
self._graph_execution = graph_execution
self._state_manager = state_manager
self._event_handler = event_handler
self._event_collector = event_collector
self._command_processor = command_processor
self._worker_pool = worker_pool
def check_commands(self) -> None:
"""Process any pending commands."""
self.command_processor.process_commands()
self._command_processor.process_commands()
def check_scaling(self) -> None:
"""Check and perform worker scaling if needed."""
self.worker_pool.check_and_scale()
self._worker_pool.check_and_scale()
def is_execution_complete(self) -> bool:
"""
@ -66,16 +66,16 @@ class ExecutionCoordinator:
True if execution is complete
"""
# Check if aborted or failed
if self.graph_execution.aborted or self.graph_execution.has_error:
if self._graph_execution.aborted or self._graph_execution.has_error:
return True
# Complete if no work remains
return self.state_manager.is_execution_complete()
return self._state_manager.is_execution_complete()
def mark_complete(self) -> None:
"""Mark execution as complete."""
if not self.graph_execution.completed:
self.graph_execution.complete()
if not self._graph_execution.completed:
self._graph_execution.complete()
def mark_failed(self, error: Exception) -> None:
"""
@ -84,4 +84,4 @@ class ExecutionCoordinator:
Args:
error: The error that caused failure
"""
self.graph_execution.fail(error)
self._graph_execution.fail(error)

View File

@ -44,11 +44,11 @@ class ResponseStreamCoordinator:
variable_pool: VariablePool instance for accessing node variables
graph: Graph instance for looking up node information
"""
self.variable_pool = variable_pool
self.graph = graph
self.active_session: ResponseSession | None = None
self.waiting_sessions: deque[ResponseSession] = deque()
self.lock = RLock()
self._variable_pool = variable_pool
self._graph = graph
self._active_session: ResponseSession | None = None
self._waiting_sessions: deque[ResponseSession] = deque()
self._lock = RLock()
# Internal stream management (replacing OutputRegistry)
self._stream_buffers: dict[tuple[str, ...], list[NodeRunStreamChunkEvent]] = {}
@ -68,7 +68,7 @@ class ResponseStreamCoordinator:
self._response_sessions: dict[NodeID, ResponseSession] = {} # node_id -> session
def register(self, response_node_id: NodeID) -> None:
with self.lock:
with self._lock:
self._response_nodes.add(response_node_id)
# Build and save paths map for this response node
@ -76,7 +76,7 @@ class ResponseStreamCoordinator:
self._paths_maps[response_node_id] = paths_map
# Create and store response session for this node
response_node = self.graph.nodes[response_node_id]
response_node = self._graph.nodes[response_node_id]
session = ResponseSession.from_node(response_node)
self._response_sessions[response_node_id] = session
@ -87,7 +87,7 @@ class ResponseStreamCoordinator:
node_id: The ID of the node
execution_id: The execution ID from NodeRunStartedEvent
"""
with self.lock:
with self._lock:
self._node_execution_ids[node_id] = execution_id
def _get_or_create_execution_id(self, node_id: NodeID) -> str:
@ -99,7 +99,7 @@ class ResponseStreamCoordinator:
Returns:
The execution ID for the node
"""
with self.lock:
with self._lock:
if node_id not in self._node_execution_ids:
self._node_execution_ids[node_id] = str(uuid4())
return self._node_execution_ids[node_id]
@ -116,14 +116,14 @@ class ResponseStreamCoordinator:
List of Path objects, where each path contains branch edge IDs
"""
# Get root node ID
root_node_id = self.graph.root_node.id
root_node_id = self._graph.root_node.id
# If root is the response node, return empty path
if root_node_id == response_node_id:
return [Path()]
# Extract variable selectors from the response node's template
response_node = self.graph.nodes[response_node_id]
response_node = self._graph.nodes[response_node_id]
response_session = ResponseSession.from_node(response_node)
template = response_session.template
@ -149,7 +149,7 @@ class ResponseStreamCoordinator:
visited.add(current_node_id)
# Explore outgoing edges
outgoing_edges = self.graph.get_outgoing_edges(current_node_id)
outgoing_edges = self._graph.get_outgoing_edges(current_node_id)
for edge in outgoing_edges:
edge_id = edge.id
next_node_id = edge.head
@ -168,8 +168,8 @@ class ResponseStreamCoordinator:
for path in all_complete_paths:
blocking_edges: list[str] = []
for edge_id in path:
edge = self.graph.edges[edge_id]
source_node = self.graph.nodes[edge.tail]
edge = self._graph.edges[edge_id]
source_node = self._graph.nodes[edge.tail]
# Check if node is a branch/container (original behavior)
if source_node.execution_type in {
@ -199,7 +199,7 @@ class ResponseStreamCoordinator:
"""
events: list[NodeRunStreamChunkEvent] = []
with self.lock:
with self._lock:
# Check each response node in order
for response_node_id in self._response_nodes:
if response_node_id not in self._paths_maps:
@ -245,21 +245,21 @@ class ResponseStreamCoordinator:
# Remove from map to ensure it won't be activated again
del self._response_sessions[node_id]
if self.active_session is None:
self.active_session = session
if self._active_session is None:
self._active_session = session
# Try to flush immediately
events.extend(self.try_flush())
else:
# Queue the session if another is active
self.waiting_sessions.append(session)
self._waiting_sessions.append(session)
return events
def intercept_event(
self, event: NodeRunStreamChunkEvent | NodeRunSucceededEvent
) -> Sequence[NodeRunStreamChunkEvent]:
with self.lock:
with self._lock:
if isinstance(event, NodeRunStreamChunkEvent):
self._append_stream_chunk(event.selector, event)
if event.is_final:
@ -269,9 +269,8 @@ class ResponseStreamCoordinator:
# Skip cause we share the same variable pool.
#
# for variable_name, variable_value in event.node_run_result.outputs.items():
# self.variable_pool.add((event.node_id, variable_name), variable_value)
# self._variable_pool.add((event.node_id, variable_name), variable_value)
return self.try_flush()
return []
def _create_stream_chunk_event(
self,
@ -287,9 +286,9 @@ class ResponseStreamCoordinator:
active response node's information since these are not actual node IDs.
"""
# Check if this is a special selector that doesn't correspond to a node
if selector and selector[0] not in self.graph.nodes and self.active_session:
if selector and selector[0] not in self._graph.nodes and self._active_session:
# Use the active response node for special selectors
response_node = self.graph.nodes[self.active_session.node_id]
response_node = self._graph.nodes[self._active_session.node_id]
return NodeRunStreamChunkEvent(
id=execution_id,
node_id=response_node.id,
@ -300,7 +299,7 @@ class ResponseStreamCoordinator:
)
# Standard case: selector refers to an actual node
node = self.graph.nodes[node_id]
node = self._graph.nodes[node_id]
return NodeRunStreamChunkEvent(
id=execution_id,
node_id=node.id,
@ -323,9 +322,9 @@ class ResponseStreamCoordinator:
# Determine which node to attribute the output to
# For special selectors (sys, env, conversation), use the active response node
# For regular selectors, use the source node
if self.active_session and source_selector_prefix not in self.graph.nodes:
if self._active_session and source_selector_prefix not in self._graph.nodes:
# Special selector - use active response node
output_node_id = self.active_session.node_id
output_node_id = self._active_session.node_id
else:
# Regular node selector
output_node_id = source_selector_prefix
@ -336,8 +335,8 @@ class ResponseStreamCoordinator:
if event := self._pop_stream_chunk(segment.selector):
# For special selectors, we need to update the event to use
# the active response node's information
if self.active_session and source_selector_prefix not in self.graph.nodes:
response_node = self.graph.nodes[self.active_session.node_id]
if self._active_session and source_selector_prefix not in self._graph.nodes:
response_node = self._graph.nodes[self._active_session.node_id]
# Create a new event with the response node's information
# but keep the original selector
updated_event = NodeRunStreamChunkEvent(
@ -359,10 +358,10 @@ class ResponseStreamCoordinator:
if stream_closed:
is_complete = True
elif value := self.variable_pool.get(segment.selector):
elif value := self._variable_pool.get(segment.selector):
# Process scalar value
is_last_segment = bool(
self.active_session and self.active_session.index == len(self.active_session.template.segments) - 1
self._active_session and self._active_session.index == len(self._active_session.template.segments) - 1
)
events.append(
self._create_stream_chunk_event(
@ -379,13 +378,13 @@ class ResponseStreamCoordinator:
def _process_text_segment(self, segment: TextSegment) -> Sequence[NodeRunStreamChunkEvent]:
"""Process a text segment. Returns (events, is_complete)."""
assert self.active_session is not None
current_response_node = self.graph.nodes[self.active_session.node_id]
assert self._active_session is not None
current_response_node = self._graph.nodes[self._active_session.node_id]
# Use get_or_create_execution_id to ensure we have a consistent ID
execution_id = self._get_or_create_execution_id(current_response_node.id)
is_last_segment = self.active_session.index == len(self.active_session.template.segments) - 1
is_last_segment = self._active_session.index == len(self._active_session.template.segments) - 1
event = self._create_stream_chunk_event(
node_id=current_response_node.id,
execution_id=execution_id,
@ -396,29 +395,29 @@ class ResponseStreamCoordinator:
return [event]
def try_flush(self) -> list[NodeRunStreamChunkEvent]:
with self.lock:
if not self.active_session:
with self._lock:
if not self._active_session:
return []
template = self.active_session.template
response_node_id = self.active_session.node_id
template = self._active_session.template
response_node_id = self._active_session.node_id
events: list[NodeRunStreamChunkEvent] = []
# Process segments sequentially from current index
while self.active_session.index < len(template.segments):
segment = template.segments[self.active_session.index]
while self._active_session.index < len(template.segments):
segment = template.segments[self._active_session.index]
if isinstance(segment, VariableSegment):
# Check if the source node for this variable is skipped
# Only check for actual nodes, not special selectors (sys, env, conversation)
source_selector_prefix = segment.selector[0] if segment.selector else ""
if source_selector_prefix in self.graph.nodes:
source_node = self.graph.nodes[source_selector_prefix]
if source_selector_prefix in self._graph.nodes:
source_node = self._graph.nodes[source_selector_prefix]
if source_node.state == NodeState.SKIPPED:
# Skip this variable segment if the source node is skipped
self.active_session.index += 1
self._active_session.index += 1
continue
segment_events, is_complete = self._process_variable_segment(segment)
@ -426,7 +425,7 @@ class ResponseStreamCoordinator:
# Only advance index if this variable segment is complete
if is_complete:
self.active_session.index += 1
self._active_session.index += 1
else:
# Wait for more data
break
@ -434,9 +433,9 @@ class ResponseStreamCoordinator:
else:
segment_events = self._process_text_segment(segment)
events.extend(segment_events)
self.active_session.index += 1
self._active_session.index += 1
if self.active_session.is_complete():
if self._active_session.is_complete():
# End current session and get events from starting next session
next_session_events = self.end_session(response_node_id)
events.extend(next_session_events)
@ -454,16 +453,16 @@ class ResponseStreamCoordinator:
Returns:
List of events from starting the next session
"""
with self.lock:
with self._lock:
events: list[NodeRunStreamChunkEvent] = []
if self.active_session and self.active_session.node_id == node_id:
self.active_session = None
if self._active_session and self._active_session.node_id == node_id:
self._active_session = None
# Try to start next waiting session
if self.waiting_sessions:
next_session = self.waiting_sessions.popleft()
self.active_session = next_session
if self._waiting_sessions:
next_session = self._waiting_sessions.popleft()
self._active_session = next_session
# Immediately try to flush any available segments
events = self.try_flush()

View File

@ -46,8 +46,8 @@ class UnifiedStateManager:
graph: The workflow graph
ready_queue: Queue for nodes ready to execute
"""
self.graph = graph
self.ready_queue = ready_queue
self._graph = graph
self._ready_queue = ready_queue
self._lock = threading.RLock()
# Execution tracking state
@ -66,8 +66,8 @@ class UnifiedStateManager:
node_id: The ID of the node to enqueue
"""
with self._lock:
self.graph.nodes[node_id].state = NodeState.TAKEN
self.ready_queue.put(node_id)
self._graph.nodes[node_id].state = NodeState.TAKEN
self._ready_queue.put(node_id)
def mark_node_skipped(self, node_id: str) -> None:
"""
@ -77,7 +77,7 @@ class UnifiedStateManager:
node_id: The ID of the node to skip
"""
with self._lock:
self.graph.nodes[node_id].state = NodeState.SKIPPED
self._graph.nodes[node_id].state = NodeState.SKIPPED
def is_node_ready(self, node_id: str) -> bool:
"""
@ -94,7 +94,7 @@ class UnifiedStateManager:
"""
with self._lock:
# Get all incoming edges to this node
incoming_edges = self.graph.get_incoming_edges(node_id)
incoming_edges = self._graph.get_incoming_edges(node_id)
# If no incoming edges, node is always ready
if not incoming_edges:
@ -118,7 +118,7 @@ class UnifiedStateManager:
The current node state
"""
with self._lock:
return self.graph.nodes[node_id].state
return self._graph.nodes[node_id].state
# ============= Edge State Operations =============
@ -130,7 +130,7 @@ class UnifiedStateManager:
edge_id: The ID of the edge to mark
"""
with self._lock:
self.graph.edges[edge_id].state = NodeState.TAKEN
self._graph.edges[edge_id].state = NodeState.TAKEN
def mark_edge_skipped(self, edge_id: str) -> None:
"""
@ -140,7 +140,7 @@ class UnifiedStateManager:
edge_id: The ID of the edge to mark
"""
with self._lock:
self.graph.edges[edge_id].state = NodeState.SKIPPED
self._graph.edges[edge_id].state = NodeState.SKIPPED
def analyze_edge_states(self, edges: list[Edge]) -> EdgeStateAnalysis:
"""
@ -172,7 +172,7 @@ class UnifiedStateManager:
The current edge state
"""
with self._lock:
return self.graph.edges[edge_id].state
return self._graph.edges[edge_id].state
def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[Sequence[Edge], Sequence[Edge]]:
"""
@ -186,7 +186,7 @@ class UnifiedStateManager:
A tuple of (selected_edges, unselected_edges)
"""
with self._lock:
outgoing_edges = self.graph.get_outgoing_edges(node_id)
outgoing_edges = self._graph.get_outgoing_edges(node_id)
selected_edges: list[Edge] = []
unselected_edges: list[Edge] = []
@ -272,7 +272,7 @@ class UnifiedStateManager:
True if execution is complete
"""
with self._lock:
return self.ready_queue.empty() and len(self._executing_nodes) == 0
return self._ready_queue.empty() and len(self._executing_nodes) == 0
def get_queue_depth(self) -> int:
"""
@ -281,7 +281,7 @@ class UnifiedStateManager:
Returns:
Number of nodes in the ready queue
"""
return self.ready_queue.qsize()
return self._ready_queue.qsize()
def get_execution_stats(self) -> dict[str, int]:
"""
@ -291,12 +291,12 @@ class UnifiedStateManager:
Dictionary with execution statistics
"""
with self._lock:
taken_nodes = sum(1 for node in self.graph.nodes.values() if node.state == NodeState.TAKEN)
skipped_nodes = sum(1 for node in self.graph.nodes.values() if node.state == NodeState.SKIPPED)
unknown_nodes = sum(1 for node in self.graph.nodes.values() if node.state == NodeState.UNKNOWN)
taken_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.TAKEN)
skipped_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.SKIPPED)
unknown_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.UNKNOWN)
return {
"queue_depth": self.ready_queue.qsize(),
"queue_depth": self._ready_queue.qsize(),
"executing": len(self._executing_nodes),
"taken_nodes": taken_nodes,
"skipped_nodes": skipped_nodes,

View File

@ -59,16 +59,16 @@ class Worker(threading.Thread):
on_active_callback: Optional callback when worker becomes active
"""
super().__init__(name=f"GraphWorker-{worker_id}", daemon=True)
self.ready_queue = ready_queue
self.event_queue = event_queue
self.graph = graph
self.worker_id = worker_id
self.flask_app = flask_app
self.context_vars = context_vars
self._ready_queue = ready_queue
self._event_queue = event_queue
self._graph = graph
self._worker_id = worker_id
self._flask_app = flask_app
self._context_vars = context_vars
self._stop_event = threading.Event()
self.on_idle_callback = on_idle_callback
self.on_active_callback = on_active_callback
self.last_task_time = time.time()
self._on_idle_callback = on_idle_callback
self._on_active_callback = on_active_callback
self._last_task_time = time.time()
def stop(self) -> None:
"""Signal the worker to stop processing."""
@ -85,22 +85,22 @@ class Worker(threading.Thread):
while not self._stop_event.is_set():
# Try to get a node ID from the ready queue (with timeout)
try:
node_id = self.ready_queue.get(timeout=0.1)
node_id = self._ready_queue.get(timeout=0.1)
except queue.Empty:
# Notify that worker is idle
if self.on_idle_callback:
self.on_idle_callback(self.worker_id)
if self._on_idle_callback:
self._on_idle_callback(self._worker_id)
continue
# Notify that worker is active
if self.on_active_callback:
self.on_active_callback(self.worker_id)
if self._on_active_callback:
self._on_active_callback(self._worker_id)
self.last_task_time = time.time()
node = self.graph.nodes[node_id]
self._last_task_time = time.time()
node = self._graph.nodes[node_id]
try:
self._execute_node(node)
self.ready_queue.task_done()
self._ready_queue.task_done()
except Exception as e:
error_event = NodeRunFailedEvent(
id=str(uuid4()),
@ -110,7 +110,7 @@ class Worker(threading.Thread):
error=str(e),
start_at=datetime.now(),
)
self.event_queue.put(error_event)
self._event_queue.put(error_event)
def _execute_node(self, node: Node) -> None:
"""
@ -120,19 +120,19 @@ class Worker(threading.Thread):
node: The node instance to execute
"""
# Execute the node with preserved context if Flask app is provided
if self.flask_app and self.context_vars:
if self._flask_app and self._context_vars:
with preserve_flask_contexts(
flask_app=self.flask_app,
context_vars=self.context_vars,
flask_app=self._flask_app,
context_vars=self._context_vars,
):
# Execute the node
node_events = node.run()
for event in node_events:
# Forward event to dispatcher immediately for streaming
self.event_queue.put(event)
self._event_queue.put(event)
else:
# Execute without context preservation
node_events = node.run()
for event in node_events:
# Forward event to dispatcher immediately for streaming
self.event_queue.put(event)
self._event_queue.put(event)

View File

@ -56,20 +56,20 @@ class SimpleWorkerPool:
scale_up_threshold: Queue depth to trigger scale up
scale_down_idle_time: Seconds before scaling down idle workers
"""
self.ready_queue = ready_queue
self.event_queue = event_queue
self.graph = graph
self.flask_app = flask_app
self.context_vars = context_vars
self._ready_queue = ready_queue
self._event_queue = event_queue
self._graph = graph
self._flask_app = flask_app
self._context_vars = context_vars
# Scaling parameters with defaults
self.min_workers = min_workers or dify_config.GRAPH_ENGINE_MIN_WORKERS
self.max_workers = max_workers or dify_config.GRAPH_ENGINE_MAX_WORKERS
self.scale_up_threshold = scale_up_threshold or dify_config.GRAPH_ENGINE_SCALE_UP_THRESHOLD
self.scale_down_idle_time = scale_down_idle_time or dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME
self._min_workers = min_workers or dify_config.GRAPH_ENGINE_MIN_WORKERS
self._max_workers = max_workers or dify_config.GRAPH_ENGINE_MAX_WORKERS
self._scale_up_threshold = scale_up_threshold or dify_config.GRAPH_ENGINE_SCALE_UP_THRESHOLD
self._scale_down_idle_time = scale_down_idle_time or dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME
# Worker management
self.workers: list[Worker] = []
self._workers: list[Worker] = []
self._worker_counter = 0
self._lock = threading.RLock()
self._running = False
@ -89,13 +89,13 @@ class SimpleWorkerPool:
# Calculate initial worker count
if initial_count is None:
node_count = len(self.graph.nodes)
node_count = len(self._graph.nodes)
if node_count < 10:
initial_count = self.min_workers
initial_count = self._min_workers
elif node_count < 50:
initial_count = min(self.min_workers + 1, self.max_workers)
initial_count = min(self._min_workers + 1, self._max_workers)
else:
initial_count = min(self.min_workers + 2, self.max_workers)
initial_count = min(self._min_workers + 2, self._max_workers)
# Create initial workers
for _ in range(initial_count):
@ -107,15 +107,15 @@ class SimpleWorkerPool:
self._running = False
# Stop all workers
for worker in self.workers:
for worker in self._workers:
worker.stop()
# Wait for workers to finish
for worker in self.workers:
for worker in self._workers:
if worker.is_alive():
worker.join(timeout=10.0)
self.workers.clear()
self._workers.clear()
def _create_worker(self) -> None:
"""Create and start a new worker."""
@ -123,16 +123,16 @@ class SimpleWorkerPool:
self._worker_counter += 1
worker = Worker(
ready_queue=self.ready_queue,
event_queue=self.event_queue,
graph=self.graph,
ready_queue=self._ready_queue,
event_queue=self._event_queue,
graph=self._graph,
worker_id=worker_id,
flask_app=self.flask_app,
context_vars=self.context_vars,
flask_app=self._flask_app,
context_vars=self._context_vars,
)
worker.start()
self.workers.append(worker)
self._workers.append(worker)
def check_and_scale(self) -> None:
"""Check and perform scaling if needed."""
@ -140,17 +140,17 @@ class SimpleWorkerPool:
if not self._running:
return
current_count = len(self.workers)
queue_depth = self.ready_queue.qsize()
current_count = len(self._workers)
queue_depth = self._ready_queue.qsize()
# Simple scaling logic
if queue_depth > self.scale_up_threshold and current_count < self.max_workers:
if queue_depth > self._scale_up_threshold and current_count < self._max_workers:
self._create_worker()
def get_worker_count(self) -> int:
"""Get current number of workers."""
with self._lock:
return len(self.workers)
return len(self._workers)
def get_status(self) -> dict[str, int]:
"""
@ -161,8 +161,8 @@ class SimpleWorkerPool:
"""
with self._lock:
return {
"total_workers": len(self.workers),
"queue_depth": self.ready_queue.qsize(),
"min_workers": self.min_workers,
"max_workers": self.max_workers,
"total_workers": len(self._workers),
"queue_depth": self._ready_queue.qsize(),
"min_workers": self._min_workers,
"max_workers": self._max_workers,
}