Fix: check external commands after node completion (#26891)

Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
-LAN- 2025-10-15 16:47:43 +08:00 committed by GitHub
parent 3474c179e6
commit 1d8cca4fa2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 260 additions and 43 deletions

View File

@ -41,6 +41,7 @@ class RedisChannel:
self._redis = redis_client
self._key = channel_key
self._command_ttl = command_ttl
self._pending_key = f"{channel_key}:pending"
def fetch_commands(self) -> list[GraphEngineCommand]:
"""
@ -49,6 +50,9 @@ class RedisChannel:
Returns:
List of pending commands (drains the Redis list)
"""
if not self._has_pending_commands():
return []
commands: list[GraphEngineCommand] = []
# Use pipeline for atomic operations
@ -85,6 +89,7 @@ class RedisChannel:
with self._redis.pipeline() as pipe:
pipe.rpush(self._key, command_json)
pipe.expire(self._key, self._command_ttl)
pipe.set(self._pending_key, "1", ex=self._command_ttl)
pipe.execute()
def _deserialize_command(self, data: dict[str, Any]) -> GraphEngineCommand | None:
@ -112,3 +117,17 @@ class RedisChannel:
except (ValueError, TypeError):
return None
def _has_pending_commands(self) -> bool:
"""
Check and consume the pending marker to avoid unnecessary list reads.
Returns:
True if commands should be fetched from Redis.
"""
with self._redis.pipeline() as pipe:
pipe.get(self._pending_key)
pipe.delete(self._pending_key)
pending_value, _ = pipe.execute()
return pending_value is not None

View File

@ -8,7 +8,12 @@ import threading
import time
from typing import TYPE_CHECKING, final
from core.workflow.graph_events.base import GraphNodeEventBase
from core.workflow.graph_events import (
GraphNodeEventBase,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunSucceededEvent,
)
from ..event_management import EventManager
from .execution_coordinator import ExecutionCoordinator
@ -72,13 +77,16 @@ class Dispatcher:
if self._thread and self._thread.is_alive():
self._thread.join(timeout=10.0)
_COMMAND_TRIGGER_EVENTS = (
NodeRunSucceededEvent,
NodeRunFailedEvent,
NodeRunExceptionEvent,
)
def _dispatcher_loop(self) -> None:
"""Main dispatcher loop."""
try:
while not self._stop_event.is_set():
# Check for commands
self._execution_coordinator.check_commands()
# Check for scaling
self._execution_coordinator.check_scaling()
@ -87,6 +95,8 @@ class Dispatcher:
event = self._event_queue.get(timeout=0.1)
# Route to the event handler
self._event_handler.dispatch(event)
if self._should_check_commands(event):
self._execution_coordinator.check_commands()
self._event_queue.task_done()
except queue.Empty:
# Check if execution is complete
@ -102,3 +112,7 @@ class Dispatcher:
# Signal the event emitter that execution is complete
if self._event_emitter:
self._event_emitter.mark_complete()
def _should_check_commands(self, event: GraphNodeEventBase) -> bool:
"""Return True if the event represents a node completion."""
return isinstance(event, self._COMMAND_TRIGGER_EVENTS)

View File

@ -35,11 +35,15 @@ class TestRedisChannel:
"""Test sending a command to Redis."""
mock_redis = MagicMock()
mock_pipe = MagicMock()
mock_redis.pipeline.return_value.__enter__ = MagicMock(return_value=mock_pipe)
mock_redis.pipeline.return_value.__exit__ = MagicMock(return_value=None)
context = MagicMock()
context.__enter__.return_value = mock_pipe
context.__exit__.return_value = None
mock_redis.pipeline.return_value = context
channel = RedisChannel(mock_redis, "test:key", 3600)
pending_key = "test:key:pending"
# Create a test command
command = GraphEngineCommand(command_type=CommandType.ABORT)
@ -55,6 +59,7 @@ class TestRedisChannel:
# Verify expire was set
mock_pipe.expire.assert_called_once_with("test:key", 3600)
mock_pipe.set.assert_called_once_with(pending_key, "1", ex=3600)
# Verify execute was called
mock_pipe.execute.assert_called_once()
@ -62,33 +67,48 @@ class TestRedisChannel:
def test_fetch_commands_empty(self):
"""Test fetching commands when Redis list is empty."""
mock_redis = MagicMock()
mock_pipe = MagicMock()
mock_redis.pipeline.return_value.__enter__ = MagicMock(return_value=mock_pipe)
mock_redis.pipeline.return_value.__exit__ = MagicMock(return_value=None)
pending_pipe = MagicMock()
fetch_pipe = MagicMock()
pending_context = MagicMock()
fetch_context = MagicMock()
pending_context.__enter__.return_value = pending_pipe
pending_context.__exit__.return_value = None
fetch_context.__enter__.return_value = fetch_pipe
fetch_context.__exit__.return_value = None
mock_redis.pipeline.side_effect = [pending_context]
# Simulate empty list
mock_pipe.execute.return_value = [[], 1] # Empty list, delete successful
# No pending marker
pending_pipe.execute.return_value = [None, 0]
mock_redis.llen.return_value = 0
channel = RedisChannel(mock_redis, "test:key")
commands = channel.fetch_commands()
assert commands == []
mock_pipe.lrange.assert_called_once_with("test:key", 0, -1)
mock_pipe.delete.assert_called_once_with("test:key")
mock_redis.pipeline.assert_called_once()
fetch_pipe.lrange.assert_not_called()
fetch_pipe.delete.assert_not_called()
def test_fetch_commands_with_abort_command(self):
"""Test fetching abort commands from Redis."""
mock_redis = MagicMock()
mock_pipe = MagicMock()
mock_redis.pipeline.return_value.__enter__ = MagicMock(return_value=mock_pipe)
mock_redis.pipeline.return_value.__exit__ = MagicMock(return_value=None)
pending_pipe = MagicMock()
fetch_pipe = MagicMock()
pending_context = MagicMock()
fetch_context = MagicMock()
pending_context.__enter__.return_value = pending_pipe
pending_context.__exit__.return_value = None
fetch_context.__enter__.return_value = fetch_pipe
fetch_context.__exit__.return_value = None
mock_redis.pipeline.side_effect = [pending_context, fetch_context]
# Create abort command data
abort_command = AbortCommand()
command_json = json.dumps(abort_command.model_dump())
# Simulate Redis returning one command
mock_pipe.execute.return_value = [[command_json.encode()], 1]
pending_pipe.execute.return_value = [b"1", 1]
fetch_pipe.execute.return_value = [[command_json.encode()], 1]
channel = RedisChannel(mock_redis, "test:key")
commands = channel.fetch_commands()
@ -100,9 +120,15 @@ class TestRedisChannel:
def test_fetch_commands_multiple(self):
"""Test fetching multiple commands from Redis."""
mock_redis = MagicMock()
mock_pipe = MagicMock()
mock_redis.pipeline.return_value.__enter__ = MagicMock(return_value=mock_pipe)
mock_redis.pipeline.return_value.__exit__ = MagicMock(return_value=None)
pending_pipe = MagicMock()
fetch_pipe = MagicMock()
pending_context = MagicMock()
fetch_context = MagicMock()
pending_context.__enter__.return_value = pending_pipe
pending_context.__exit__.return_value = None
fetch_context.__enter__.return_value = fetch_pipe
fetch_context.__exit__.return_value = None
mock_redis.pipeline.side_effect = [pending_context, fetch_context]
# Create multiple commands
command1 = GraphEngineCommand(command_type=CommandType.ABORT)
@ -112,7 +138,8 @@ class TestRedisChannel:
command2_json = json.dumps(command2.model_dump())
# Simulate Redis returning multiple commands
mock_pipe.execute.return_value = [[command1_json.encode(), command2_json.encode()], 1]
pending_pipe.execute.return_value = [b"1", 1]
fetch_pipe.execute.return_value = [[command1_json.encode(), command2_json.encode()], 1]
channel = RedisChannel(mock_redis, "test:key")
commands = channel.fetch_commands()
@ -124,9 +151,15 @@ class TestRedisChannel:
def test_fetch_commands_skips_invalid_json(self):
"""Test that invalid JSON commands are skipped."""
mock_redis = MagicMock()
mock_pipe = MagicMock()
mock_redis.pipeline.return_value.__enter__ = MagicMock(return_value=mock_pipe)
mock_redis.pipeline.return_value.__exit__ = MagicMock(return_value=None)
pending_pipe = MagicMock()
fetch_pipe = MagicMock()
pending_context = MagicMock()
fetch_context = MagicMock()
pending_context.__enter__.return_value = pending_pipe
pending_context.__exit__.return_value = None
fetch_context.__enter__.return_value = fetch_pipe
fetch_context.__exit__.return_value = None
mock_redis.pipeline.side_effect = [pending_context, fetch_context]
# Mix valid and invalid JSON
valid_command = AbortCommand()
@ -134,7 +167,8 @@ class TestRedisChannel:
invalid_json = b"invalid json {"
# Simulate Redis returning mixed valid/invalid commands
mock_pipe.execute.return_value = [[invalid_json, valid_json.encode()], 1]
pending_pipe.execute.return_value = [b"1", 1]
fetch_pipe.execute.return_value = [[invalid_json, valid_json.encode()], 1]
channel = RedisChannel(mock_redis, "test:key")
commands = channel.fetch_commands()
@ -187,13 +221,20 @@ class TestRedisChannel:
def test_atomic_fetch_and_clear(self):
"""Test that fetch_commands atomically fetches and clears the list."""
mock_redis = MagicMock()
mock_pipe = MagicMock()
mock_redis.pipeline.return_value.__enter__ = MagicMock(return_value=mock_pipe)
mock_redis.pipeline.return_value.__exit__ = MagicMock(return_value=None)
pending_pipe = MagicMock()
fetch_pipe = MagicMock()
pending_context = MagicMock()
fetch_context = MagicMock()
pending_context.__enter__.return_value = pending_pipe
pending_context.__exit__.return_value = None
fetch_context.__enter__.return_value = fetch_pipe
fetch_context.__exit__.return_value = None
mock_redis.pipeline.side_effect = [pending_context, fetch_context]
command = AbortCommand()
command_json = json.dumps(command.model_dump())
mock_pipe.execute.return_value = [[command_json.encode()], 1]
pending_pipe.execute.return_value = [b"1", 1]
fetch_pipe.execute.return_value = [[command_json.encode()], 1]
channel = RedisChannel(mock_redis, "test:key")
@ -202,7 +243,29 @@ class TestRedisChannel:
assert len(commands) == 1
# Verify both lrange and delete were called in the pipeline
assert mock_pipe.lrange.call_count == 1
assert mock_pipe.delete.call_count == 1
mock_pipe.lrange.assert_called_with("test:key", 0, -1)
mock_pipe.delete.assert_called_with("test:key")
assert fetch_pipe.lrange.call_count == 1
assert fetch_pipe.delete.call_count == 1
fetch_pipe.lrange.assert_called_with("test:key", 0, -1)
fetch_pipe.delete.assert_called_with("test:key")
def test_fetch_commands_without_pending_marker_returns_empty(self):
"""Ensure we avoid unnecessary list reads when pending flag is missing."""
mock_redis = MagicMock()
pending_pipe = MagicMock()
fetch_pipe = MagicMock()
pending_context = MagicMock()
fetch_context = MagicMock()
pending_context.__enter__.return_value = pending_pipe
pending_context.__exit__.return_value = None
fetch_context.__enter__.return_value = fetch_pipe
fetch_context.__exit__.return_value = None
mock_redis.pipeline.side_effect = [pending_context, fetch_context]
# Pending flag absent
pending_pipe.execute.return_value = [None, 0]
channel = RedisChannel(mock_redis, "test:key")
commands = channel.fetch_commands()
assert commands == []
mock_redis.llen.assert_not_called()
assert mock_redis.pipeline.call_count == 1

View File

@ -0,0 +1,104 @@
"""Tests for dispatcher command checking behavior."""
from __future__ import annotations
import queue
from datetime import datetime
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph_engine.event_management.event_manager import EventManager
from core.workflow.graph_engine.orchestration.dispatcher import Dispatcher
from core.workflow.graph_events import NodeRunStartedEvent, NodeRunSucceededEvent
from core.workflow.node_events import NodeRunResult
class _StubExecutionCoordinator:
"""Stub execution coordinator that tracks command checks."""
def __init__(self) -> None:
self.command_checks = 0
self.scaling_checks = 0
self._execution_complete = False
self.mark_complete_called = False
self.failed = False
def check_commands(self) -> None:
self.command_checks += 1
def check_scaling(self) -> None:
self.scaling_checks += 1
def is_execution_complete(self) -> bool:
return self._execution_complete
def mark_complete(self) -> None:
self.mark_complete_called = True
def mark_failed(self, error: Exception) -> None: # pragma: no cover - defensive, not triggered in tests
self.failed = True
def set_execution_complete(self) -> None:
self._execution_complete = True
class _StubEventHandler:
"""Minimal event handler that marks execution complete after handling an event."""
def __init__(self, coordinator: _StubExecutionCoordinator) -> None:
self._coordinator = coordinator
self.events = []
def dispatch(self, event) -> None:
self.events.append(event)
self._coordinator.set_execution_complete()
def _run_dispatcher_for_event(event) -> int:
"""Run the dispatcher loop for a single event and return command check count."""
event_queue: queue.Queue = queue.Queue()
event_queue.put(event)
coordinator = _StubExecutionCoordinator()
event_handler = _StubEventHandler(coordinator)
event_manager = EventManager()
dispatcher = Dispatcher(
event_queue=event_queue,
event_handler=event_handler,
event_collector=event_manager,
execution_coordinator=coordinator,
)
dispatcher._dispatcher_loop()
return coordinator.command_checks
def _make_started_event() -> NodeRunStartedEvent:
return NodeRunStartedEvent(
id="start-event",
node_id="node-1",
node_type=NodeType.CODE,
node_title="Test Node",
start_at=datetime.utcnow(),
)
def _make_succeeded_event() -> NodeRunSucceededEvent:
return NodeRunSucceededEvent(
id="success-event",
node_id="node-1",
node_type=NodeType.CODE,
node_title="Test Node",
start_at=datetime.utcnow(),
node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED),
)
def test_dispatcher_checks_commands_after_node_completion() -> None:
"""Dispatcher should only check commands after node completion events."""
started_checks = _run_dispatcher_for_event(_make_started_event())
succeeded_checks = _run_dispatcher_for_event(_make_succeeded_event())
assert started_checks == 0
assert succeeded_checks == 1

View File

@ -132,15 +132,22 @@ class TestRedisStopIntegration:
"""Test RedisChannel correctly fetches and deserializes commands."""
# Setup
mock_redis = MagicMock()
mock_pipeline = MagicMock()
mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline)
mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None)
pending_pipe = MagicMock()
fetch_pipe = MagicMock()
pending_context = MagicMock()
fetch_context = MagicMock()
pending_context.__enter__.return_value = pending_pipe
pending_context.__exit__.return_value = None
fetch_context.__enter__.return_value = fetch_pipe
fetch_context.__exit__.return_value = None
mock_redis.pipeline.side_effect = [pending_context, fetch_context]
# Mock command data
abort_command_json = json.dumps({"command_type": CommandType.ABORT, "reason": "Test abort", "payload": None})
# Mock pipeline execute to return commands
mock_pipeline.execute.return_value = [
pending_pipe.execute.return_value = [b"1", 1]
fetch_pipe.execute.return_value = [
[abort_command_json.encode()], # lrange result
True, # delete result
]
@ -158,19 +165,29 @@ class TestRedisStopIntegration:
assert commands[0].reason == "Test abort"
# Verify Redis operations
mock_pipeline.lrange.assert_called_once_with(channel_key, 0, -1)
mock_pipeline.delete.assert_called_once_with(channel_key)
pending_pipe.get.assert_called_once_with(f"{channel_key}:pending")
pending_pipe.delete.assert_called_once_with(f"{channel_key}:pending")
fetch_pipe.lrange.assert_called_once_with(channel_key, 0, -1)
fetch_pipe.delete.assert_called_once_with(channel_key)
assert mock_redis.pipeline.call_count == 2
def test_redis_channel_fetch_commands_handles_invalid_json(self):
"""Test RedisChannel gracefully handles invalid JSON in commands."""
# Setup
mock_redis = MagicMock()
mock_pipeline = MagicMock()
mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline)
mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None)
pending_pipe = MagicMock()
fetch_pipe = MagicMock()
pending_context = MagicMock()
fetch_context = MagicMock()
pending_context.__enter__.return_value = pending_pipe
pending_context.__exit__.return_value = None
fetch_context.__enter__.return_value = fetch_pipe
fetch_context.__exit__.return_value = None
mock_redis.pipeline.side_effect = [pending_context, fetch_context]
# Mock invalid command data
mock_pipeline.execute.return_value = [
pending_pipe.execute.return_value = [b"1", 1]
fetch_pipe.execute.return_value = [
[b"invalid json", b'{"command_type": "invalid_type"}'], # lrange result
True, # delete result
]