mirror of
https://github.com/langgenius/dify.git
synced 2026-03-20 01:40:38 +08:00
revert(graph-engine): rollback stop-event unification (#32789)
This commit is contained in:
parent
b462a96fa0
commit
ffe77fecdf
@ -9,7 +9,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import queue
|
import queue
|
||||||
import threading
|
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import TYPE_CHECKING, cast, final
|
from typing import TYPE_CHECKING, cast, final
|
||||||
|
|
||||||
@ -77,13 +76,10 @@ class GraphEngine:
|
|||||||
config: GraphEngineConfig = _DEFAULT_CONFIG,
|
config: GraphEngineConfig = _DEFAULT_CONFIG,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the graph engine with all subsystems and dependencies."""
|
"""Initialize the graph engine with all subsystems and dependencies."""
|
||||||
# stop event
|
|
||||||
self._stop_event = threading.Event()
|
|
||||||
|
|
||||||
# Bind runtime state to current workflow context
|
# Bind runtime state to current workflow context
|
||||||
self._graph = graph
|
self._graph = graph
|
||||||
self._graph_runtime_state = graph_runtime_state
|
self._graph_runtime_state = graph_runtime_state
|
||||||
self._graph_runtime_state.stop_event = self._stop_event
|
|
||||||
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
|
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
|
||||||
self._command_channel = command_channel
|
self._command_channel = command_channel
|
||||||
self._config = config
|
self._config = config
|
||||||
@ -163,7 +159,6 @@ class GraphEngine:
|
|||||||
layers=self._layers,
|
layers=self._layers,
|
||||||
execution_context=execution_context,
|
execution_context=execution_context,
|
||||||
config=self._config,
|
config=self._config,
|
||||||
stop_event=self._stop_event,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# === Orchestration ===
|
# === Orchestration ===
|
||||||
@ -194,7 +189,6 @@ class GraphEngine:
|
|||||||
event_handler=self._event_handler_registry,
|
event_handler=self._event_handler_registry,
|
||||||
execution_coordinator=self._execution_coordinator,
|
execution_coordinator=self._execution_coordinator,
|
||||||
event_emitter=self._event_manager,
|
event_emitter=self._event_manager,
|
||||||
stop_event=self._stop_event,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# === Validation ===
|
# === Validation ===
|
||||||
@ -314,7 +308,6 @@ class GraphEngine:
|
|||||||
|
|
||||||
def _start_execution(self, *, resume: bool = False) -> None:
|
def _start_execution(self, *, resume: bool = False) -> None:
|
||||||
"""Start execution subsystems."""
|
"""Start execution subsystems."""
|
||||||
self._stop_event.clear()
|
|
||||||
paused_nodes: list[str] = []
|
paused_nodes: list[str] = []
|
||||||
deferred_nodes: list[str] = []
|
deferred_nodes: list[str] = []
|
||||||
if resume:
|
if resume:
|
||||||
@ -348,7 +341,6 @@ class GraphEngine:
|
|||||||
|
|
||||||
def _stop_execution(self) -> None:
|
def _stop_execution(self) -> None:
|
||||||
"""Stop execution subsystems."""
|
"""Stop execution subsystems."""
|
||||||
self._stop_event.set()
|
|
||||||
self._dispatcher.stop()
|
self._dispatcher.stop()
|
||||||
self._worker_pool.stop()
|
self._worker_pool.stop()
|
||||||
# Don't mark complete here as the dispatcher already does it
|
# Don't mark complete here as the dispatcher already does it
|
||||||
|
|||||||
@ -44,7 +44,6 @@ class Dispatcher:
|
|||||||
event_queue: queue.Queue[GraphNodeEventBase],
|
event_queue: queue.Queue[GraphNodeEventBase],
|
||||||
event_handler: "EventHandler",
|
event_handler: "EventHandler",
|
||||||
execution_coordinator: ExecutionCoordinator,
|
execution_coordinator: ExecutionCoordinator,
|
||||||
stop_event: threading.Event,
|
|
||||||
event_emitter: EventManager | None = None,
|
event_emitter: EventManager | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@ -62,7 +61,7 @@ class Dispatcher:
|
|||||||
self._event_emitter = event_emitter
|
self._event_emitter = event_emitter
|
||||||
|
|
||||||
self._thread: threading.Thread | None = None
|
self._thread: threading.Thread | None = None
|
||||||
self._stop_event = stop_event
|
self._stop_event = threading.Event()
|
||||||
self._start_time: float | None = None
|
self._start_time: float | None = None
|
||||||
|
|
||||||
def start(self) -> None:
|
def start(self) -> None:
|
||||||
@ -70,12 +69,14 @@ class Dispatcher:
|
|||||||
if self._thread and self._thread.is_alive():
|
if self._thread and self._thread.is_alive():
|
||||||
return
|
return
|
||||||
|
|
||||||
|
self._stop_event.clear()
|
||||||
self._start_time = time.time()
|
self._start_time = time.time()
|
||||||
self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True)
|
self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True)
|
||||||
self._thread.start()
|
self._thread.start()
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self) -> None:
|
||||||
"""Stop the dispatcher thread."""
|
"""Stop the dispatcher thread."""
|
||||||
|
self._stop_event.set()
|
||||||
if self._thread and self._thread.is_alive():
|
if self._thread and self._thread.is_alive():
|
||||||
self._thread.join(timeout=2.0)
|
self._thread.join(timeout=2.0)
|
||||||
|
|
||||||
|
|||||||
@ -42,7 +42,6 @@ class Worker(threading.Thread):
|
|||||||
event_queue: queue.Queue[GraphNodeEventBase],
|
event_queue: queue.Queue[GraphNodeEventBase],
|
||||||
graph: Graph,
|
graph: Graph,
|
||||||
layers: Sequence[GraphEngineLayer],
|
layers: Sequence[GraphEngineLayer],
|
||||||
stop_event: threading.Event,
|
|
||||||
worker_id: int = 0,
|
worker_id: int = 0,
|
||||||
execution_context: IExecutionContext | None = None,
|
execution_context: IExecutionContext | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -63,16 +62,13 @@ class Worker(threading.Thread):
|
|||||||
self._graph = graph
|
self._graph = graph
|
||||||
self._worker_id = worker_id
|
self._worker_id = worker_id
|
||||||
self._execution_context = execution_context
|
self._execution_context = execution_context
|
||||||
self._stop_event = stop_event
|
self._stop_event = threading.Event()
|
||||||
self._layers = layers if layers is not None else []
|
self._layers = layers if layers is not None else []
|
||||||
self._last_task_time = time.time()
|
self._last_task_time = time.time()
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self) -> None:
|
||||||
"""Worker is controlled via shared stop_event from GraphEngine.
|
"""Signal the worker to stop processing."""
|
||||||
|
self._stop_event.set()
|
||||||
This method is a no-op retained for backward compatibility.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_idle(self) -> bool:
|
def is_idle(self) -> bool:
|
||||||
|
|||||||
@ -37,7 +37,6 @@ class WorkerPool:
|
|||||||
event_queue: queue.Queue[GraphNodeEventBase],
|
event_queue: queue.Queue[GraphNodeEventBase],
|
||||||
graph: Graph,
|
graph: Graph,
|
||||||
layers: list[GraphEngineLayer],
|
layers: list[GraphEngineLayer],
|
||||||
stop_event: threading.Event,
|
|
||||||
config: GraphEngineConfig,
|
config: GraphEngineConfig,
|
||||||
execution_context: IExecutionContext | None = None,
|
execution_context: IExecutionContext | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -64,7 +63,6 @@ class WorkerPool:
|
|||||||
self._worker_counter = 0
|
self._worker_counter = 0
|
||||||
self._lock = threading.RLock()
|
self._lock = threading.RLock()
|
||||||
self._running = False
|
self._running = False
|
||||||
self._stop_event = stop_event
|
|
||||||
|
|
||||||
# No longer tracking worker states with callbacks to avoid lock contention
|
# No longer tracking worker states with callbacks to avoid lock contention
|
||||||
|
|
||||||
@ -135,7 +133,6 @@ class WorkerPool:
|
|||||||
layers=self._layers,
|
layers=self._layers,
|
||||||
worker_id=worker_id,
|
worker_id=worker_id,
|
||||||
execution_context=self._execution_context,
|
execution_context=self._execution_context,
|
||||||
stop_event=self._stop_event,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
worker.start()
|
worker.start()
|
||||||
|
|||||||
@ -302,10 +302,6 @@ class Node(Generic[NodeDataT]):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def _should_stop(self) -> bool:
|
|
||||||
"""Check if execution should be stopped."""
|
|
||||||
return self.graph_runtime_state.stop_event.is_set()
|
|
||||||
|
|
||||||
def run(self) -> Generator[GraphNodeEventBase, None, None]:
|
def run(self) -> Generator[GraphNodeEventBase, None, None]:
|
||||||
execution_id = self.ensure_execution_id()
|
execution_id = self.ensure_execution_id()
|
||||||
self._start_at = naive_utc_now()
|
self._start_at = naive_utc_now()
|
||||||
@ -374,21 +370,6 @@ class Node(Generic[NodeDataT]):
|
|||||||
yield event
|
yield event
|
||||||
else:
|
else:
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
if self._should_stop():
|
|
||||||
error_message = "Execution cancelled"
|
|
||||||
yield NodeRunFailedEvent(
|
|
||||||
id=self.execution_id,
|
|
||||||
node_id=self._node_id,
|
|
||||||
node_type=self.node_type,
|
|
||||||
start_at=self._start_at,
|
|
||||||
node_run_result=NodeRunResult(
|
|
||||||
status=WorkflowNodeExecutionStatus.FAILED,
|
|
||||||
error=error_message,
|
|
||||||
),
|
|
||||||
error=error_message,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Node %s failed to run", self._node_id)
|
logger.exception("Node %s failed to run", self._node_id)
|
||||||
result = NodeRunResult(
|
result = NodeRunResult(
|
||||||
|
|||||||
@ -2,7 +2,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import json
|
import json
|
||||||
import threading
|
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -219,8 +218,6 @@ class GraphRuntimeState:
|
|||||||
self._pending_graph_node_states: dict[str, NodeState] | None = None
|
self._pending_graph_node_states: dict[str, NodeState] | None = None
|
||||||
self._pending_graph_edge_states: dict[str, NodeState] | None = None
|
self._pending_graph_edge_states: dict[str, NodeState] | None = None
|
||||||
|
|
||||||
self.stop_event: threading.Event = threading.Event()
|
|
||||||
|
|
||||||
if graph is not None:
|
if graph is not None:
|
||||||
self.attach_graph(graph)
|
self.attach_graph(graph)
|
||||||
|
|
||||||
|
|||||||
@ -3,7 +3,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import queue
|
import queue
|
||||||
import threading
|
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
from core.workflow.entities.pause_reason import SchedulingPause
|
from core.workflow.entities.pause_reason import SchedulingPause
|
||||||
@ -37,7 +36,6 @@ def test_dispatcher_should_consume_remains_events_after_pause():
|
|||||||
event_queue=event_queue,
|
event_queue=event_queue,
|
||||||
event_handler=event_handler,
|
event_handler=event_handler,
|
||||||
execution_coordinator=execution_coordinator,
|
execution_coordinator=execution_coordinator,
|
||||||
stop_event=threading.Event(),
|
|
||||||
)
|
)
|
||||||
dispatcher._dispatcher_loop()
|
dispatcher._dispatcher_loop()
|
||||||
assert event_queue.empty()
|
assert event_queue.empty()
|
||||||
@ -98,7 +96,6 @@ def _run_dispatcher_for_event(event) -> int:
|
|||||||
event_queue=event_queue,
|
event_queue=event_queue,
|
||||||
event_handler=event_handler,
|
event_handler=event_handler,
|
||||||
execution_coordinator=coordinator,
|
execution_coordinator=coordinator,
|
||||||
stop_event=threading.Event(),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
dispatcher._dispatcher_loop()
|
dispatcher._dispatcher_loop()
|
||||||
@ -184,7 +181,6 @@ def test_dispatcher_drain_event_queue():
|
|||||||
event_queue=event_queue,
|
event_queue=event_queue,
|
||||||
event_handler=event_handler,
|
event_handler=event_handler,
|
||||||
execution_coordinator=coordinator,
|
execution_coordinator=coordinator,
|
||||||
stop_event=threading.Event(),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
dispatcher._dispatcher_loop()
|
dispatcher._dispatcher_loop()
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
import queue
|
import queue
|
||||||
import threading
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||||
@ -65,7 +64,6 @@ def test_dispatcher_drains_events_when_paused() -> None:
|
|||||||
event_handler=handler,
|
event_handler=handler,
|
||||||
execution_coordinator=coordinator,
|
execution_coordinator=coordinator,
|
||||||
event_emitter=None,
|
event_emitter=None,
|
||||||
stop_event=threading.Event(),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
dispatcher._dispatcher_loop()
|
dispatcher._dispatcher_loop()
|
||||||
|
|||||||
@ -1,550 +0,0 @@
|
|||||||
"""
|
|
||||||
Unit tests for stop_event functionality in GraphEngine.
|
|
||||||
|
|
||||||
Tests the unified stop_event management by GraphEngine and its propagation
|
|
||||||
to WorkerPool, Worker, Dispatcher, and Nodes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
from unittest.mock import MagicMock, Mock, patch
|
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
|
||||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
|
||||||
from core.workflow.graph import Graph
|
|
||||||
from core.workflow.graph_engine import GraphEngine, GraphEngineConfig
|
|
||||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
|
||||||
from core.workflow.graph_events import (
|
|
||||||
GraphRunStartedEvent,
|
|
||||||
GraphRunSucceededEvent,
|
|
||||||
NodeRunStartedEvent,
|
|
||||||
)
|
|
||||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
|
||||||
from core.workflow.nodes.start.start_node import StartNode
|
|
||||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
|
||||||
from models.enums import UserFrom
|
|
||||||
|
|
||||||
|
|
||||||
class TestStopEventPropagation:
|
|
||||||
"""Test suite for stop_event propagation through GraphEngine components."""
|
|
||||||
|
|
||||||
def test_graph_engine_creates_stop_event(self):
|
|
||||||
"""Test that GraphEngine creates a stop_event on initialization."""
|
|
||||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
|
||||||
mock_graph = MagicMock(spec=Graph)
|
|
||||||
mock_graph.nodes = {}
|
|
||||||
mock_graph.edges = {}
|
|
||||||
mock_graph.root_node = MagicMock()
|
|
||||||
|
|
||||||
engine = GraphEngine(
|
|
||||||
workflow_id="test_workflow",
|
|
||||||
graph=mock_graph,
|
|
||||||
graph_runtime_state=runtime_state,
|
|
||||||
command_channel=InMemoryChannel(),
|
|
||||||
config=GraphEngineConfig(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify stop_event was created
|
|
||||||
assert engine._stop_event is not None
|
|
||||||
assert isinstance(engine._stop_event, threading.Event)
|
|
||||||
|
|
||||||
# Verify it was set in graph_runtime_state
|
|
||||||
assert runtime_state.stop_event is not None
|
|
||||||
assert runtime_state.stop_event is engine._stop_event
|
|
||||||
|
|
||||||
def test_stop_event_cleared_on_start(self):
|
|
||||||
"""Test that stop_event is cleared when execution starts."""
|
|
||||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
|
||||||
mock_graph = MagicMock(spec=Graph)
|
|
||||||
mock_graph.nodes = {}
|
|
||||||
mock_graph.edges = {}
|
|
||||||
mock_graph.root_node = MagicMock()
|
|
||||||
mock_graph.root_node.id = "start" # Set proper id
|
|
||||||
|
|
||||||
start_node = StartNode(
|
|
||||||
id="start",
|
|
||||||
config={"id": "start", "data": {"title": "start", "variables": []}},
|
|
||||||
graph_init_params=GraphInitParams(
|
|
||||||
tenant_id="test_tenant",
|
|
||||||
app_id="test_app",
|
|
||||||
workflow_id="test_workflow",
|
|
||||||
graph_config={},
|
|
||||||
user_id="test_user",
|
|
||||||
user_from=UserFrom.ACCOUNT,
|
|
||||||
invoke_from=InvokeFrom.DEBUGGER,
|
|
||||||
call_depth=0,
|
|
||||||
),
|
|
||||||
graph_runtime_state=runtime_state,
|
|
||||||
)
|
|
||||||
mock_graph.nodes["start"] = start_node
|
|
||||||
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
|
|
||||||
mock_graph.get_incoming_edges = MagicMock(return_value=[])
|
|
||||||
|
|
||||||
engine = GraphEngine(
|
|
||||||
workflow_id="test_workflow",
|
|
||||||
graph=mock_graph,
|
|
||||||
graph_runtime_state=runtime_state,
|
|
||||||
command_channel=InMemoryChannel(),
|
|
||||||
config=GraphEngineConfig(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set the stop_event before running
|
|
||||||
engine._stop_event.set()
|
|
||||||
assert engine._stop_event.is_set()
|
|
||||||
|
|
||||||
# Run the engine (should clear the stop_event)
|
|
||||||
events = list(engine.run())
|
|
||||||
|
|
||||||
# After running, stop_event should be set again (by _stop_execution)
|
|
||||||
# But during start it was cleared
|
|
||||||
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
|
|
||||||
assert any(isinstance(e, GraphRunSucceededEvent) for e in events)
|
|
||||||
|
|
||||||
def test_stop_event_set_on_stop(self):
|
|
||||||
"""Test that stop_event is set when execution stops."""
|
|
||||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
|
||||||
mock_graph = MagicMock(spec=Graph)
|
|
||||||
mock_graph.nodes = {}
|
|
||||||
mock_graph.edges = {}
|
|
||||||
mock_graph.root_node = MagicMock()
|
|
||||||
mock_graph.root_node.id = "start" # Set proper id
|
|
||||||
|
|
||||||
start_node = StartNode(
|
|
||||||
id="start",
|
|
||||||
config={"id": "start", "data": {"title": "start", "variables": []}},
|
|
||||||
graph_init_params=GraphInitParams(
|
|
||||||
tenant_id="test_tenant",
|
|
||||||
app_id="test_app",
|
|
||||||
workflow_id="test_workflow",
|
|
||||||
graph_config={},
|
|
||||||
user_id="test_user",
|
|
||||||
user_from=UserFrom.ACCOUNT,
|
|
||||||
invoke_from=InvokeFrom.DEBUGGER,
|
|
||||||
call_depth=0,
|
|
||||||
),
|
|
||||||
graph_runtime_state=runtime_state,
|
|
||||||
)
|
|
||||||
mock_graph.nodes["start"] = start_node
|
|
||||||
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
|
|
||||||
mock_graph.get_incoming_edges = MagicMock(return_value=[])
|
|
||||||
|
|
||||||
engine = GraphEngine(
|
|
||||||
workflow_id="test_workflow",
|
|
||||||
graph=mock_graph,
|
|
||||||
graph_runtime_state=runtime_state,
|
|
||||||
command_channel=InMemoryChannel(),
|
|
||||||
config=GraphEngineConfig(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initially not set
|
|
||||||
assert not engine._stop_event.is_set()
|
|
||||||
|
|
||||||
# Run the engine
|
|
||||||
list(engine.run())
|
|
||||||
|
|
||||||
# After execution completes, stop_event should be set
|
|
||||||
assert engine._stop_event.is_set()
|
|
||||||
|
|
||||||
def test_stop_event_passed_to_worker_pool(self):
|
|
||||||
"""Test that stop_event is passed to WorkerPool."""
|
|
||||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
|
||||||
mock_graph = MagicMock(spec=Graph)
|
|
||||||
mock_graph.nodes = {}
|
|
||||||
mock_graph.edges = {}
|
|
||||||
mock_graph.root_node = MagicMock()
|
|
||||||
|
|
||||||
engine = GraphEngine(
|
|
||||||
workflow_id="test_workflow",
|
|
||||||
graph=mock_graph,
|
|
||||||
graph_runtime_state=runtime_state,
|
|
||||||
command_channel=InMemoryChannel(),
|
|
||||||
config=GraphEngineConfig(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify WorkerPool has the stop_event
|
|
||||||
assert engine._worker_pool._stop_event is not None
|
|
||||||
assert engine._worker_pool._stop_event is engine._stop_event
|
|
||||||
|
|
||||||
def test_stop_event_passed_to_dispatcher(self):
|
|
||||||
"""Test that stop_event is passed to Dispatcher."""
|
|
||||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
|
||||||
mock_graph = MagicMock(spec=Graph)
|
|
||||||
mock_graph.nodes = {}
|
|
||||||
mock_graph.edges = {}
|
|
||||||
mock_graph.root_node = MagicMock()
|
|
||||||
|
|
||||||
engine = GraphEngine(
|
|
||||||
workflow_id="test_workflow",
|
|
||||||
graph=mock_graph,
|
|
||||||
graph_runtime_state=runtime_state,
|
|
||||||
command_channel=InMemoryChannel(),
|
|
||||||
config=GraphEngineConfig(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify Dispatcher has the stop_event
|
|
||||||
assert engine._dispatcher._stop_event is not None
|
|
||||||
assert engine._dispatcher._stop_event is engine._stop_event
|
|
||||||
|
|
||||||
|
|
||||||
class TestNodeStopCheck:
|
|
||||||
"""Test suite for Node._should_stop() functionality."""
|
|
||||||
|
|
||||||
def test_node_should_stop_checks_runtime_state(self):
|
|
||||||
"""Test that Node._should_stop() checks GraphRuntimeState.stop_event."""
|
|
||||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
|
||||||
|
|
||||||
answer_node = AnswerNode(
|
|
||||||
id="answer",
|
|
||||||
config={"id": "answer", "data": {"title": "answer", "answer": "{{#start.result#}}"}},
|
|
||||||
graph_init_params=GraphInitParams(
|
|
||||||
tenant_id="test_tenant",
|
|
||||||
app_id="test_app",
|
|
||||||
workflow_id="test_workflow",
|
|
||||||
graph_config={},
|
|
||||||
user_id="test_user",
|
|
||||||
user_from=UserFrom.ACCOUNT,
|
|
||||||
invoke_from=InvokeFrom.DEBUGGER,
|
|
||||||
call_depth=0,
|
|
||||||
),
|
|
||||||
graph_runtime_state=runtime_state,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initially stop_event is not set
|
|
||||||
assert not answer_node._should_stop()
|
|
||||||
|
|
||||||
# Set the stop_event
|
|
||||||
runtime_state.stop_event.set()
|
|
||||||
|
|
||||||
# Now _should_stop should return True
|
|
||||||
assert answer_node._should_stop()
|
|
||||||
|
|
||||||
def test_node_run_checks_stop_event_between_yields(self):
|
|
||||||
"""Test that Node.run() checks stop_event between yielding events."""
|
|
||||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
|
||||||
|
|
||||||
# Create a simple node
|
|
||||||
answer_node = AnswerNode(
|
|
||||||
id="answer",
|
|
||||||
config={"id": "answer", "data": {"title": "answer", "answer": "hello"}},
|
|
||||||
graph_init_params=GraphInitParams(
|
|
||||||
tenant_id="test_tenant",
|
|
||||||
app_id="test_app",
|
|
||||||
workflow_id="test_workflow",
|
|
||||||
graph_config={},
|
|
||||||
user_id="test_user",
|
|
||||||
user_from=UserFrom.ACCOUNT,
|
|
||||||
invoke_from=InvokeFrom.DEBUGGER,
|
|
||||||
call_depth=0,
|
|
||||||
),
|
|
||||||
graph_runtime_state=runtime_state,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set stop_event BEFORE running the node
|
|
||||||
runtime_state.stop_event.set()
|
|
||||||
|
|
||||||
# Run the node - should yield start event then detect stop
|
|
||||||
# The node should check stop_event before processing
|
|
||||||
assert answer_node._should_stop(), "stop_event should be set"
|
|
||||||
|
|
||||||
# Run and collect events
|
|
||||||
events = list(answer_node.run())
|
|
||||||
|
|
||||||
# Since stop_event is set at the start, we should get:
|
|
||||||
# 1. NodeRunStartedEvent (always yielded first)
|
|
||||||
# 2. Either NodeRunFailedEvent (if detected early) or NodeRunSucceededEvent (if too fast)
|
|
||||||
assert len(events) >= 2
|
|
||||||
assert isinstance(events[0], NodeRunStartedEvent)
|
|
||||||
|
|
||||||
# Note: AnswerNode is very simple and might complete before stop check
|
|
||||||
# The important thing is that _should_stop() returns True when stop_event is set
|
|
||||||
assert answer_node._should_stop()
|
|
||||||
|
|
||||||
|
|
||||||
class TestStopEventIntegration:
|
|
||||||
"""Integration tests for stop_event in workflow execution."""
|
|
||||||
|
|
||||||
def test_simple_workflow_respects_stop_event(self):
|
|
||||||
"""Test that a simple workflow respects stop_event."""
|
|
||||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
|
||||||
|
|
||||||
mock_graph = MagicMock(spec=Graph)
|
|
||||||
mock_graph.nodes = {}
|
|
||||||
mock_graph.edges = {}
|
|
||||||
mock_graph.root_node = MagicMock()
|
|
||||||
mock_graph.root_node.id = "start"
|
|
||||||
|
|
||||||
# Create start and answer nodes
|
|
||||||
start_node = StartNode(
|
|
||||||
id="start",
|
|
||||||
config={"id": "start", "data": {"title": "start", "variables": []}},
|
|
||||||
graph_init_params=GraphInitParams(
|
|
||||||
tenant_id="test_tenant",
|
|
||||||
app_id="test_app",
|
|
||||||
workflow_id="test_workflow",
|
|
||||||
graph_config={},
|
|
||||||
user_id="test_user",
|
|
||||||
user_from=UserFrom.ACCOUNT,
|
|
||||||
invoke_from=InvokeFrom.DEBUGGER,
|
|
||||||
call_depth=0,
|
|
||||||
),
|
|
||||||
graph_runtime_state=runtime_state,
|
|
||||||
)
|
|
||||||
|
|
||||||
answer_node = AnswerNode(
|
|
||||||
id="answer",
|
|
||||||
config={"id": "answer", "data": {"title": "answer", "answer": "hello"}},
|
|
||||||
graph_init_params=GraphInitParams(
|
|
||||||
tenant_id="test_tenant",
|
|
||||||
app_id="test_app",
|
|
||||||
workflow_id="test_workflow",
|
|
||||||
graph_config={},
|
|
||||||
user_id="test_user",
|
|
||||||
user_from=UserFrom.ACCOUNT,
|
|
||||||
invoke_from=InvokeFrom.DEBUGGER,
|
|
||||||
call_depth=0,
|
|
||||||
),
|
|
||||||
graph_runtime_state=runtime_state,
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_graph.nodes["start"] = start_node
|
|
||||||
mock_graph.nodes["answer"] = answer_node
|
|
||||||
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
|
|
||||||
mock_graph.get_incoming_edges = MagicMock(return_value=[])
|
|
||||||
|
|
||||||
engine = GraphEngine(
|
|
||||||
workflow_id="test_workflow",
|
|
||||||
graph=mock_graph,
|
|
||||||
graph_runtime_state=runtime_state,
|
|
||||||
command_channel=InMemoryChannel(),
|
|
||||||
config=GraphEngineConfig(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set stop_event before running
|
|
||||||
runtime_state.stop_event.set()
|
|
||||||
|
|
||||||
# Run the engine
|
|
||||||
events = list(engine.run())
|
|
||||||
|
|
||||||
# Should get started event but not succeeded (due to stop)
|
|
||||||
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
|
|
||||||
# The workflow should still complete (start node runs quickly)
|
|
||||||
# but answer node might be cancelled depending on timing
|
|
||||||
|
|
||||||
def test_stop_event_with_concurrent_nodes(self):
|
|
||||||
"""Test stop_event behavior with multiple concurrent nodes."""
|
|
||||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
|
||||||
|
|
||||||
mock_graph = MagicMock(spec=Graph)
|
|
||||||
mock_graph.nodes = {}
|
|
||||||
mock_graph.edges = {}
|
|
||||||
mock_graph.root_node = MagicMock()
|
|
||||||
|
|
||||||
# Create multiple nodes
|
|
||||||
for i in range(3):
|
|
||||||
answer_node = AnswerNode(
|
|
||||||
id=f"answer_{i}",
|
|
||||||
config={"id": f"answer_{i}", "data": {"title": f"answer_{i}", "answer": f"test{i}"}},
|
|
||||||
graph_init_params=GraphInitParams(
|
|
||||||
tenant_id="test_tenant",
|
|
||||||
app_id="test_app",
|
|
||||||
workflow_id="test_workflow",
|
|
||||||
graph_config={},
|
|
||||||
user_id="test_user",
|
|
||||||
user_from=UserFrom.ACCOUNT,
|
|
||||||
invoke_from=InvokeFrom.DEBUGGER,
|
|
||||||
call_depth=0,
|
|
||||||
),
|
|
||||||
graph_runtime_state=runtime_state,
|
|
||||||
)
|
|
||||||
mock_graph.nodes[f"answer_{i}"] = answer_node
|
|
||||||
|
|
||||||
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
|
|
||||||
mock_graph.get_incoming_edges = MagicMock(return_value=[])
|
|
||||||
|
|
||||||
engine = GraphEngine(
|
|
||||||
workflow_id="test_workflow",
|
|
||||||
graph=mock_graph,
|
|
||||||
graph_runtime_state=runtime_state,
|
|
||||||
command_channel=InMemoryChannel(),
|
|
||||||
config=GraphEngineConfig(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# All nodes should share the same stop_event
|
|
||||||
for node in mock_graph.nodes.values():
|
|
||||||
assert node.graph_runtime_state.stop_event is runtime_state.stop_event
|
|
||||||
assert node.graph_runtime_state.stop_event is engine._stop_event
|
|
||||||
|
|
||||||
|
|
||||||
class TestStopEventTimeoutBehavior:
|
|
||||||
"""Test stop_event behavior with join timeouts."""
|
|
||||||
|
|
||||||
@patch("core.workflow.graph_engine.orchestration.dispatcher.threading.Thread", autospec=True)
|
|
||||||
def test_dispatcher_uses_shorter_timeout(self, mock_thread_cls: MagicMock):
|
|
||||||
"""Test that Dispatcher uses 2s timeout instead of 10s."""
|
|
||||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
|
||||||
mock_graph = MagicMock(spec=Graph)
|
|
||||||
mock_graph.nodes = {}
|
|
||||||
mock_graph.edges = {}
|
|
||||||
mock_graph.root_node = MagicMock()
|
|
||||||
|
|
||||||
engine = GraphEngine(
|
|
||||||
workflow_id="test_workflow",
|
|
||||||
graph=mock_graph,
|
|
||||||
graph_runtime_state=runtime_state,
|
|
||||||
command_channel=InMemoryChannel(),
|
|
||||||
config=GraphEngineConfig(),
|
|
||||||
)
|
|
||||||
|
|
||||||
dispatcher = engine._dispatcher
|
|
||||||
dispatcher.start() # This will create and start the mocked thread
|
|
||||||
|
|
||||||
mock_thread_instance = mock_thread_cls.return_value
|
|
||||||
mock_thread_instance.is_alive.return_value = True
|
|
||||||
|
|
||||||
dispatcher.stop()
|
|
||||||
|
|
||||||
mock_thread_instance.join.assert_called_once_with(timeout=2.0)
|
|
||||||
|
|
||||||
@patch("core.workflow.graph_engine.worker_management.worker_pool.Worker", autospec=True)
|
|
||||||
def test_worker_pool_uses_shorter_timeout(self, mock_worker_cls: MagicMock):
|
|
||||||
"""Test that WorkerPool uses 2s timeout instead of 10s."""
|
|
||||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
|
||||||
mock_graph = MagicMock(spec=Graph)
|
|
||||||
mock_graph.nodes = {}
|
|
||||||
mock_graph.edges = {}
|
|
||||||
mock_graph.root_node = MagicMock()
|
|
||||||
|
|
||||||
engine = GraphEngine(
|
|
||||||
workflow_id="test_workflow",
|
|
||||||
graph=mock_graph,
|
|
||||||
graph_runtime_state=runtime_state,
|
|
||||||
command_channel=InMemoryChannel(),
|
|
||||||
config=GraphEngineConfig(),
|
|
||||||
)
|
|
||||||
|
|
||||||
worker_pool = engine._worker_pool
|
|
||||||
worker_pool.start(initial_count=1) # Start with one worker
|
|
||||||
|
|
||||||
mock_worker_instance = mock_worker_cls.return_value
|
|
||||||
mock_worker_instance.is_alive.return_value = True
|
|
||||||
|
|
||||||
worker_pool.stop()
|
|
||||||
|
|
||||||
mock_worker_instance.join.assert_called_once_with(timeout=2.0)
|
|
||||||
|
|
||||||
|
|
||||||
class TestStopEventResumeBehavior:
|
|
||||||
"""Test stop_event behavior during workflow resume."""
|
|
||||||
|
|
||||||
def test_stop_event_cleared_on_resume(self):
|
|
||||||
"""Test that stop_event is cleared when resuming a paused workflow."""
|
|
||||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
|
||||||
mock_graph = MagicMock(spec=Graph)
|
|
||||||
mock_graph.nodes = {}
|
|
||||||
mock_graph.edges = {}
|
|
||||||
mock_graph.root_node = MagicMock()
|
|
||||||
mock_graph.root_node.id = "start" # Set proper id
|
|
||||||
|
|
||||||
start_node = StartNode(
|
|
||||||
id="start",
|
|
||||||
config={"id": "start", "data": {"title": "start", "variables": []}},
|
|
||||||
graph_init_params=GraphInitParams(
|
|
||||||
tenant_id="test_tenant",
|
|
||||||
app_id="test_app",
|
|
||||||
workflow_id="test_workflow",
|
|
||||||
graph_config={},
|
|
||||||
user_id="test_user",
|
|
||||||
user_from=UserFrom.ACCOUNT,
|
|
||||||
invoke_from=InvokeFrom.DEBUGGER,
|
|
||||||
call_depth=0,
|
|
||||||
),
|
|
||||||
graph_runtime_state=runtime_state,
|
|
||||||
)
|
|
||||||
mock_graph.nodes["start"] = start_node
|
|
||||||
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
|
|
||||||
mock_graph.get_incoming_edges = MagicMock(return_value=[])
|
|
||||||
|
|
||||||
engine = GraphEngine(
|
|
||||||
workflow_id="test_workflow",
|
|
||||||
graph=mock_graph,
|
|
||||||
graph_runtime_state=runtime_state,
|
|
||||||
command_channel=InMemoryChannel(),
|
|
||||||
config=GraphEngineConfig(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Simulate a previous execution that set stop_event
|
|
||||||
engine._stop_event.set()
|
|
||||||
assert engine._stop_event.is_set()
|
|
||||||
|
|
||||||
# Run the engine (should clear stop_event in _start_execution)
|
|
||||||
events = list(engine.run())
|
|
||||||
|
|
||||||
# Execution should complete successfully
|
|
||||||
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
|
|
||||||
assert any(isinstance(e, GraphRunSucceededEvent) for e in events)
|
|
||||||
|
|
||||||
|
|
||||||
class TestWorkerStopBehavior:
|
|
||||||
"""Test Worker behavior with shared stop_event."""
|
|
||||||
|
|
||||||
def test_worker_uses_shared_stop_event(self):
|
|
||||||
"""Test that Worker uses shared stop_event from GraphEngine."""
|
|
||||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
|
||||||
mock_graph = MagicMock(spec=Graph)
|
|
||||||
mock_graph.nodes = {}
|
|
||||||
mock_graph.edges = {}
|
|
||||||
mock_graph.root_node = MagicMock()
|
|
||||||
|
|
||||||
engine = GraphEngine(
|
|
||||||
workflow_id="test_workflow",
|
|
||||||
graph=mock_graph,
|
|
||||||
graph_runtime_state=runtime_state,
|
|
||||||
command_channel=InMemoryChannel(),
|
|
||||||
config=GraphEngineConfig(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get the worker pool and check workers
|
|
||||||
worker_pool = engine._worker_pool
|
|
||||||
|
|
||||||
# Start the worker pool to create workers
|
|
||||||
worker_pool.start()
|
|
||||||
|
|
||||||
# Check that at least one worker was created
|
|
||||||
assert len(worker_pool._workers) > 0
|
|
||||||
|
|
||||||
# Verify workers use the shared stop_event
|
|
||||||
for worker in worker_pool._workers:
|
|
||||||
assert worker._stop_event is engine._stop_event
|
|
||||||
|
|
||||||
# Clean up
|
|
||||||
worker_pool.stop()
|
|
||||||
|
|
||||||
def test_worker_stop_is_noop(self):
|
|
||||||
"""Test that Worker.stop() is now a no-op."""
|
|
||||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
|
||||||
|
|
||||||
# Create a mock worker
|
|
||||||
from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue
|
|
||||||
from core.workflow.graph_engine.worker import Worker
|
|
||||||
|
|
||||||
ready_queue = InMemoryReadyQueue()
|
|
||||||
event_queue = MagicMock()
|
|
||||||
|
|
||||||
# Create a proper mock graph with real dict
|
|
||||||
mock_graph = Mock(spec=Graph)
|
|
||||||
mock_graph.nodes = {} # Use real dict
|
|
||||||
|
|
||||||
stop_event = threading.Event()
|
|
||||||
|
|
||||||
worker = Worker(
|
|
||||||
ready_queue=ready_queue,
|
|
||||||
event_queue=event_queue,
|
|
||||||
graph=mock_graph,
|
|
||||||
layers=[],
|
|
||||||
stop_event=stop_event,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calling stop() should do nothing (no-op)
|
|
||||||
# and should NOT set the stop_event
|
|
||||||
worker.stop()
|
|
||||||
assert not stop_event.is_set()
|
|
||||||
Loading…
Reference in New Issue
Block a user