refactor: consume events after pause/abort and improve API clarity (#28328)

Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
This commit is contained in:
-LAN- 2025-11-18 19:04:11 +08:00 committed by GitHub
parent 68526c09fc
commit 6efdc94661
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 261 additions and 183 deletions

View File

@ -192,7 +192,6 @@ class GraphEngine:
self._dispatcher = Dispatcher(
event_queue=self._event_queue,
event_handler=self._event_handler_registry,
event_collector=self._event_manager,
execution_coordinator=self._execution_coordinator,
event_emitter=self._event_manager,
)

View File

@ -43,7 +43,6 @@ class Dispatcher:
self,
event_queue: queue.Queue[GraphNodeEventBase],
event_handler: "EventHandler",
event_collector: EventManager,
execution_coordinator: ExecutionCoordinator,
event_emitter: EventManager | None = None,
) -> None:
@ -53,13 +52,11 @@ class Dispatcher:
Args:
event_queue: Queue of events from workers
event_handler: Event handler registry for processing events
event_collector: Event manager for collecting unhandled events
execution_coordinator: Coordinator for execution flow
event_emitter: Optional event manager to signal completion
"""
self._event_queue = event_queue
self._event_handler = event_handler
self._event_collector = event_collector
self._execution_coordinator = execution_coordinator
self._event_emitter = event_emitter
@ -86,37 +83,31 @@ class Dispatcher:
def _dispatcher_loop(self) -> None:
"""Main dispatcher loop."""
try:
self._process_commands()
while not self._stop_event.is_set():
commands_checked = False
should_check_commands = False
should_break = False
if (
self._execution_coordinator.aborted
or self._execution_coordinator.paused
or self._execution_coordinator.execution_complete
):
break
if self._execution_coordinator.is_execution_complete():
should_check_commands = True
should_break = True
else:
# Check for scaling
self._execution_coordinator.check_scaling()
self._execution_coordinator.check_scaling()
try:
event = self._event_queue.get(timeout=0.1)
self._event_handler.dispatch(event)
self._event_queue.task_done()
self._process_commands(event)
except queue.Empty:
time.sleep(0.1)
# Process events
try:
event = self._event_queue.get(timeout=0.1)
# Route to the event handler
self._event_handler.dispatch(event)
should_check_commands = self._should_check_commands(event)
self._event_queue.task_done()
except queue.Empty:
# Process commands even when no new events arrive so abort requests are not missed
should_check_commands = True
time.sleep(0.1)
if should_check_commands and not commands_checked:
self._execution_coordinator.check_commands()
commands_checked = True
if should_break:
if not commands_checked:
self._execution_coordinator.check_commands()
self._process_commands()
while True:
try:
event = self._event_queue.get(block=False)
self._event_handler.dispatch(event)
self._event_queue.task_done()
except queue.Empty:
break
except Exception as e:
@ -129,6 +120,6 @@ class Dispatcher:
if self._event_emitter:
self._event_emitter.mark_complete()
def _should_check_commands(self, event: GraphNodeEventBase) -> bool:
"""Return True if the event represents a node completion."""
return isinstance(event, self._COMMAND_TRIGGER_EVENTS)
def _process_commands(self, event: GraphNodeEventBase | None = None):
if event is None or isinstance(event, self._COMMAND_TRIGGER_EVENTS):
self._execution_coordinator.process_commands()

View File

@ -40,7 +40,7 @@ class ExecutionCoordinator:
self._command_processor = command_processor
self._worker_pool = worker_pool
def check_commands(self) -> None:
def process_commands(self) -> None:
"""Process any pending commands."""
self._command_processor.process_commands()
@ -48,24 +48,16 @@ class ExecutionCoordinator:
"""Check and perform worker scaling if needed."""
self._worker_pool.check_and_scale()
def is_execution_complete(self) -> bool:
"""
Check if execution is complete.
Returns:
True if execution is complete
"""
# Treat paused, aborted, or failed executions as terminal states
if self._graph_execution.is_paused:
return True
if self._graph_execution.aborted or self._graph_execution.has_error:
return True
@property
def execution_complete(self):
return self._state_manager.is_execution_complete()
@property
def is_paused(self) -> bool:
def aborted(self):
return self._graph_execution.aborted or self._graph_execution.has_error
@property
def paused(self) -> bool:
"""Expose whether the underlying graph execution is paused."""
return self._graph_execution.is_paused

View File

@ -0,0 +1,189 @@
"""Tests for dispatcher command checking behavior."""
from __future__ import annotations
import queue
from datetime import datetime
from unittest import mock
from core.workflow.entities.pause_reason import SchedulingPause
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph_engine.event_management.event_handlers import EventHandler
from core.workflow.graph_engine.orchestration.dispatcher import Dispatcher
from core.workflow.graph_engine.orchestration.execution_coordinator import ExecutionCoordinator
from core.workflow.graph_events import (
GraphNodeEventBase,
NodeRunPauseRequestedEvent,
NodeRunStartedEvent,
NodeRunSucceededEvent,
)
from core.workflow.node_events import NodeRunResult
def test_dispatcher_should_consume_remains_events_after_pause():
event_queue = queue.Queue()
event_queue.put(
GraphNodeEventBase(
id="test",
node_id="test",
node_type=NodeType.START,
)
)
event_handler = mock.Mock(spec=EventHandler)
execution_coordinator = mock.Mock(spec=ExecutionCoordinator)
execution_coordinator.paused.return_value = True
dispatcher = Dispatcher(
event_queue=event_queue,
event_handler=event_handler,
execution_coordinator=execution_coordinator,
)
dispatcher._dispatcher_loop()
assert event_queue.empty()
class _StubExecutionCoordinator:
"""Stub execution coordinator that tracks command checks."""
def __init__(self) -> None:
self.command_checks = 0
self.scaling_checks = 0
self.execution_complete = False
self.failed = False
self._paused = False
def process_commands(self) -> None:
self.command_checks += 1
def check_scaling(self) -> None:
self.scaling_checks += 1
@property
def paused(self) -> bool:
return self._paused
@property
def aborted(self) -> bool:
return False
def mark_complete(self) -> None:
self.execution_complete = True
def mark_failed(self, error: Exception) -> None: # pragma: no cover - defensive, not triggered in tests
self.failed = True
class _StubEventHandler:
"""Minimal event handler that marks execution complete after handling an event."""
def __init__(self, coordinator: _StubExecutionCoordinator) -> None:
self._coordinator = coordinator
self.events = []
def dispatch(self, event) -> None:
self.events.append(event)
self._coordinator.mark_complete()
def _run_dispatcher_for_event(event) -> int:
"""Run the dispatcher loop for a single event and return command check count."""
event_queue: queue.Queue = queue.Queue()
event_queue.put(event)
coordinator = _StubExecutionCoordinator()
event_handler = _StubEventHandler(coordinator)
dispatcher = Dispatcher(
event_queue=event_queue,
event_handler=event_handler,
execution_coordinator=coordinator,
)
dispatcher._dispatcher_loop()
return coordinator.command_checks
def _make_started_event() -> NodeRunStartedEvent:
return NodeRunStartedEvent(
id="start-event",
node_id="node-1",
node_type=NodeType.CODE,
node_title="Test Node",
start_at=datetime.utcnow(),
)
def _make_succeeded_event() -> NodeRunSucceededEvent:
return NodeRunSucceededEvent(
id="success-event",
node_id="node-1",
node_type=NodeType.CODE,
node_title="Test Node",
start_at=datetime.utcnow(),
node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED),
)
def test_dispatcher_checks_commands_during_idle_and_on_completion() -> None:
"""Dispatcher polls commands when idle and after completion events."""
started_checks = _run_dispatcher_for_event(_make_started_event())
succeeded_checks = _run_dispatcher_for_event(_make_succeeded_event())
assert started_checks == 2
assert succeeded_checks == 3
class _PauseStubEventHandler:
"""Minimal event handler that marks execution complete after handling an event."""
def __init__(self, coordinator: _StubExecutionCoordinator) -> None:
self._coordinator = coordinator
self.events = []
def dispatch(self, event) -> None:
self.events.append(event)
if isinstance(event, NodeRunPauseRequestedEvent):
self._coordinator.mark_complete()
def test_dispatcher_drain_event_queue():
events = [
NodeRunStartedEvent(
id="start-event",
node_id="node-1",
node_type=NodeType.CODE,
node_title="Code",
start_at=datetime.utcnow(),
),
NodeRunPauseRequestedEvent(
id="pause-event",
node_id="node-1",
node_type=NodeType.CODE,
reason=SchedulingPause(message="test pause"),
),
NodeRunSucceededEvent(
id="success-event",
node_id="node-1",
node_type=NodeType.CODE,
start_at=datetime.utcnow(),
node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED),
),
]
event_queue: queue.Queue = queue.Queue()
for e in events:
event_queue.put(e)
coordinator = _StubExecutionCoordinator()
event_handler = _PauseStubEventHandler(coordinator)
dispatcher = Dispatcher(
event_queue=event_queue,
event_handler=event_handler,
execution_coordinator=coordinator,
)
dispatcher._dispatcher_loop()
# ensure all events are drained.
assert event_queue.empty()

View File

@ -3,13 +3,17 @@
import time
from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.entities.pause_reason import SchedulingPause
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType, PauseCommand
from core.workflow.graph_events import GraphRunAbortedEvent, GraphRunPausedEvent, GraphRunStartedEvent
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from models.enums import UserFrom
def test_abort_command():
@ -26,11 +30,23 @@ def test_abort_command():
mock_graph.root_node.id = "start"
# Create mock nodes with required attributes - using shared runtime state
mock_start_node = MagicMock()
mock_start_node.state = None
mock_start_node.id = "start"
mock_start_node.graph_runtime_state = shared_runtime_state # Use shared instance
mock_graph.nodes["start"] = mock_start_node
start_node = StartNode(
id="start",
config={"id": "start"},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=shared_runtime_state,
)
start_node.init_node_data({"title": "start", "variables": []})
mock_graph.nodes["start"] = start_node
# Mock graph methods
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
@ -124,11 +140,23 @@ def test_pause_command():
mock_graph.root_node = MagicMock()
mock_graph.root_node.id = "start"
mock_start_node = MagicMock()
mock_start_node.state = None
mock_start_node.id = "start"
mock_start_node.graph_runtime_state = shared_runtime_state
mock_graph.nodes["start"] = mock_start_node
start_node = StartNode(
id="start",
config={"id": "start"},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=shared_runtime_state,
)
start_node.init_node_data({"title": "start", "variables": []})
mock_graph.nodes["start"] = start_node
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
mock_graph.get_incoming_edges = MagicMock(return_value=[])
@ -153,5 +181,5 @@ def test_pause_command():
assert pause_events[0].reason == SchedulingPause(message="User requested pause")
graph_execution = engine.graph_runtime_state.graph_execution
assert graph_execution.is_paused
assert graph_execution.paused
assert graph_execution.pause_reason == SchedulingPause(message="User requested pause")

View File

@ -1,109 +0,0 @@
"""Tests for dispatcher command checking behavior."""
from __future__ import annotations
import queue
from datetime import datetime
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph_engine.event_management.event_manager import EventManager
from core.workflow.graph_engine.orchestration.dispatcher import Dispatcher
from core.workflow.graph_events import NodeRunStartedEvent, NodeRunSucceededEvent
from core.workflow.node_events import NodeRunResult
class _StubExecutionCoordinator:
"""Stub execution coordinator that tracks command checks."""
def __init__(self) -> None:
self.command_checks = 0
self.scaling_checks = 0
self._execution_complete = False
self.mark_complete_called = False
self.failed = False
self._paused = False
def check_commands(self) -> None:
self.command_checks += 1
def check_scaling(self) -> None:
self.scaling_checks += 1
@property
def is_paused(self) -> bool:
return self._paused
def is_execution_complete(self) -> bool:
return self._execution_complete
def mark_complete(self) -> None:
self.mark_complete_called = True
def mark_failed(self, error: Exception) -> None: # pragma: no cover - defensive, not triggered in tests
self.failed = True
def set_execution_complete(self) -> None:
self._execution_complete = True
class _StubEventHandler:
"""Minimal event handler that marks execution complete after handling an event."""
def __init__(self, coordinator: _StubExecutionCoordinator) -> None:
self._coordinator = coordinator
self.events = []
def dispatch(self, event) -> None:
self.events.append(event)
self._coordinator.set_execution_complete()
def _run_dispatcher_for_event(event) -> int:
"""Run the dispatcher loop for a single event and return command check count."""
event_queue: queue.Queue = queue.Queue()
event_queue.put(event)
coordinator = _StubExecutionCoordinator()
event_handler = _StubEventHandler(coordinator)
event_manager = EventManager()
dispatcher = Dispatcher(
event_queue=event_queue,
event_handler=event_handler,
event_collector=event_manager,
execution_coordinator=coordinator,
)
dispatcher._dispatcher_loop()
return coordinator.command_checks
def _make_started_event() -> NodeRunStartedEvent:
return NodeRunStartedEvent(
id="start-event",
node_id="node-1",
node_type=NodeType.CODE,
node_title="Test Node",
start_at=datetime.utcnow(),
)
def _make_succeeded_event() -> NodeRunSucceededEvent:
return NodeRunSucceededEvent(
id="success-event",
node_id="node-1",
node_type=NodeType.CODE,
node_title="Test Node",
start_at=datetime.utcnow(),
node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED),
)
def test_dispatcher_checks_commands_during_idle_and_on_completion() -> None:
"""Dispatcher polls commands when idle and after completion events."""
started_checks = _run_dispatcher_for_event(_make_started_event())
succeeded_checks = _run_dispatcher_for_event(_make_succeeded_event())
assert started_checks == 1
assert succeeded_checks == 2

View File

@ -48,15 +48,3 @@ def test_handle_pause_noop_when_execution_running() -> None:
worker_pool.stop.assert_not_called()
state_manager.clear_executing.assert_not_called()
def test_is_execution_complete_when_paused() -> None:
"""Paused execution should be treated as complete."""
graph_execution = GraphExecution(workflow_id="workflow")
graph_execution.start()
graph_execution.pause("Awaiting input")
coordinator, state_manager, _worker_pool = _build_coordinator(graph_execution)
state_manager.is_execution_complete.return_value = False
assert coordinator.is_execution_complete()