diff --git a/api/.importlinter b/api/.importlinter index cd674dbf95..f615a2ea5f 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -56,8 +56,6 @@ ignore_imports = core.workflow.nodes.llm.llm_utils -> extensions.ext_database core.workflow.nodes.llm.node -> extensions.ext_database core.workflow.nodes.tool.tool_node -> extensions.ext_database - core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis - core.workflow.graph_engine.manager -> extensions.ext_redis # TODO(QuantumGhost): use DI to avoid depending on global DB. core.workflow.nodes.human_input.human_input_node -> extensions.ext_database @@ -105,7 +103,6 @@ forbidden_modules = core.variables ignore_imports = core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory - core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis core.workflow.workflow_entry -> core.app.workflow.layers.observability core.workflow.nodes.agent.agent_node -> core.model_manager core.workflow.nodes.agent.agent_node -> core.provider_manager @@ -242,7 +239,6 @@ ignore_imports = core.workflow.variable_loader -> core.variables core.workflow.variable_loader -> core.variables.consts core.workflow.workflow_type_encoder -> core.variables - core.workflow.graph_engine.manager -> extensions.ext_redis core.workflow.nodes.agent.agent_node -> extensions.ext_database core.workflow.nodes.datasource.datasource_node -> extensions.ext_database core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index b05d28b686..a66e9543ff 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -33,6 +33,7 @@ from core.workflow.enums import NodeType from core.workflow.file.models import File from core.workflow.graph_engine.manager import GraphEngineManager from extensions.ext_database import db +from extensions.ext_redis import redis_client from factories import file_factory, variable_factory from fields.member_fields import simple_account_fields from fields.workflow_fields import workflow_fields, workflow_pagination_fields @@ -740,7 +741,7 @@ class WorkflowTaskStopApi(Resource): AppQueueManager.set_stop_flag_no_user_check(task_id) # New graph engine command channel mechanism - GraphEngineManager.send_stop_command(task_id) + GraphEngineManager(redis_client).send_stop_command(task_id) return {"result": "success"} diff --git a/api/controllers/console/explore/trial.py b/api/controllers/console/explore/trial.py index 4ae12cecf5..f6f731df36 100644 --- a/api/controllers/console/explore/trial.py +++ b/api/controllers/console/explore/trial.py @@ -44,6 +44,7 @@ from core.errors.error import ( from core.model_runtime.errors.invoke import InvokeError from core.workflow.graph_engine.manager import GraphEngineManager from extensions.ext_database import db +from extensions.ext_redis import redis_client from fields.app_fields import ( app_detail_fields_with_site, deleted_tool_fields, @@ -225,7 +226,7 @@ class TrialAppWorkflowTaskStopApi(TrialAppResource): AppQueueManager.set_stop_flag_no_user_check(task_id) # New graph engine command channel mechanism - GraphEngineManager.send_stop_command(task_id) + GraphEngineManager(redis_client).send_stop_command(task_id) return {"result": "success"} diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index d679d0722d..b841bda323 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -23,6 +23,7 @@ from core.errors.error import ( ) from core.model_runtime.errors.invoke import InvokeError from core.workflow.graph_engine.manager import GraphEngineManager +from extensions.ext_redis import redis_client from libs import helper from libs.login import current_account_with_tenant from models.model import AppMode, InstalledApp @@ -100,6 +101,6 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource): AppQueueManager.set_stop_flag_no_user_check(task_id) # New graph engine command channel mechanism - GraphEngineManager.send_stop_command(task_id) + GraphEngineManager(redis_client).send_stop_command(task_id) return {"result": "success"} diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index 6088b142c2..2ce8f05f75 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -31,6 +31,7 @@ from core.model_runtime.errors.invoke import InvokeError from core.workflow.enums import WorkflowExecutionStatus from core.workflow.graph_engine.manager import GraphEngineManager from extensions.ext_database import db +from extensions.ext_redis import redis_client from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model from libs import helper from libs.helper import OptionalTimestampField, TimestampField @@ -280,7 +281,7 @@ class WorkflowTaskStopApi(Resource): AppQueueManager.set_stop_flag_no_user_check(task_id) # New graph engine command channel mechanism - GraphEngineManager.send_stop_command(task_id) + GraphEngineManager(redis_client).send_stop_command(task_id) return {"result": "success"} diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py index 95d8c6d5a5..a309ef3dad 100644 --- a/api/controllers/web/workflow.py +++ b/api/controllers/web/workflow.py @@ -24,6 +24,7 @@ from core.errors.error import ( ) from core.model_runtime.errors.invoke import InvokeError from core.workflow.graph_engine.manager import GraphEngineManager +from extensions.ext_redis import redis_client from libs import helper from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService @@ -121,6 +122,6 @@ class WorkflowTaskStopApi(WebApiResource): AppQueueManager.set_stop_flag_no_user_check(task_id) # New graph engine command channel mechanism - GraphEngineManager.send_stop_command(task_id) + GraphEngineManager(redis_client).send_stop_command(task_id) return {"result": "success"} diff --git a/api/core/workflow/graph_engine/command_channels/redis_channel.py b/api/core/workflow/graph_engine/command_channels/redis_channel.py index 0fccd4a0fd..77cf884c67 100644 --- a/api/core/workflow/graph_engine/command_channels/redis_channel.py +++ b/api/core/workflow/graph_engine/command_channels/redis_channel.py @@ -7,12 +7,28 @@ Each instance uses a unique key for its command queue. """ import json -from typing import TYPE_CHECKING, Any, final +from contextlib import AbstractContextManager +from typing import Any, Protocol, final from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand, UpdateVariablesCommand -if TYPE_CHECKING: - from extensions.ext_redis import RedisClientWrapper + +class RedisPipelineProtocol(Protocol): + """Minimal Redis pipeline contract used by the command channel.""" + + def lrange(self, name: str, start: int, end: int) -> Any: ... + def delete(self, *names: str) -> Any: ... + def execute(self) -> list[Any]: ... + def rpush(self, name: str, *values: str) -> Any: ... + def expire(self, name: str, time: int) -> Any: ... + def set(self, name: str, value: str, ex: int | None = None) -> Any: ... + def get(self, name: str) -> Any: ... + + +class RedisClientProtocol(Protocol): + """Redis client contract required by the command channel.""" + + def pipeline(self) -> AbstractContextManager[RedisPipelineProtocol]: ... @final @@ -26,7 +42,7 @@ class RedisChannel: def __init__( self, - redis_client: "RedisClientWrapper", + redis_client: RedisClientProtocol, channel_key: str, command_ttl: int = 3600, ) -> None: diff --git a/api/core/workflow/graph_engine/manager.py b/api/core/workflow/graph_engine/manager.py index d2cfa755d9..36f1612af0 100644 --- a/api/core/workflow/graph_engine/manager.py +++ b/api/core/workflow/graph_engine/manager.py @@ -3,13 +3,14 @@ GraphEngine Manager for sending control commands via Redis channel. This module provides a simplified interface for controlling workflow executions using the new Redis command channel, without requiring user permission checks. +Callers must provide a Redis client dependency from outside the workflow package. """ import logging from collections.abc import Sequence from typing import final -from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel +from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel, RedisClientProtocol from core.workflow.graph_engine.entities.commands import ( AbortCommand, GraphEngineCommand, @@ -17,7 +18,6 @@ from core.workflow.graph_engine.entities.commands import ( UpdateVariablesCommand, VariableUpdate, ) -from extensions.ext_redis import redis_client logger = logging.getLogger(__name__) @@ -31,8 +31,12 @@ class GraphEngineManager: by sending commands through Redis channels, without user validation. """ - @staticmethod - def send_stop_command(task_id: str, reason: str | None = None) -> None: + _redis_client: RedisClientProtocol + + def __init__(self, redis_client: RedisClientProtocol) -> None: + self._redis_client = redis_client + + def send_stop_command(self, task_id: str, reason: str | None = None) -> None: """ Send a stop command to a running workflow. @@ -41,34 +45,31 @@ class GraphEngineManager: reason: Optional reason for stopping (defaults to "User requested stop") """ abort_command = AbortCommand(reason=reason or "User requested stop") - GraphEngineManager._send_command(task_id, abort_command) + self._send_command(task_id, abort_command) - @staticmethod - def send_pause_command(task_id: str, reason: str | None = None) -> None: + def send_pause_command(self, task_id: str, reason: str | None = None) -> None: """Send a pause command to a running workflow.""" pause_command = PauseCommand(reason=reason or "User requested pause") - GraphEngineManager._send_command(task_id, pause_command) + self._send_command(task_id, pause_command) - @staticmethod - def send_update_variables_command(task_id: str, updates: Sequence[VariableUpdate]) -> None: + def send_update_variables_command(self, task_id: str, updates: Sequence[VariableUpdate]) -> None: """Send a command to update variables in a running workflow.""" if not updates: return update_command = UpdateVariablesCommand(updates=updates) - GraphEngineManager._send_command(task_id, update_command) + self._send_command(task_id, update_command) - @staticmethod - def _send_command(task_id: str, command: GraphEngineCommand) -> None: + def _send_command(self, task_id: str, command: GraphEngineCommand) -> None: """Send a command to the workflow-specific Redis channel.""" if not task_id: return channel_key = f"workflow:{task_id}:commands" - channel = RedisChannel(redis_client, channel_key) + channel = RedisChannel(self._redis_client, channel_key) try: channel.send_command(command) diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index 3ca3598002..658e6a0738 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -111,6 +111,7 @@ class RedisClientWrapper: def zcard(self, name: str | bytes) -> Any: ... def getdel(self, name: str | bytes) -> Any: ... def pubsub(self) -> PubSub: ... + def pipeline(self, transaction: bool = True, shard_hint: str | None = None) -> Any: ... def __getattr__(self, item: str) -> Any: if self._client is None: diff --git a/api/services/app_task_service.py b/api/services/app_task_service.py index 01874b3f9f..5ae1fba2e8 100644 --- a/api/services/app_task_service.py +++ b/api/services/app_task_service.py @@ -8,6 +8,7 @@ new GraphEngine command channel mechanism. from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.graph_engine.manager import GraphEngineManager +from extensions.ext_redis import redis_client from models.model import AppMode @@ -42,4 +43,4 @@ class AppTaskService: # New mechanism: Send stop command via GraphEngine for workflow-based apps # This ensures proper workflow status recording in the persistence layer if app_mode in (AppMode.ADVANCED_CHAT, AppMode.WORKFLOW): - GraphEngineManager.send_stop_command(task_id) + GraphEngineManager(redis_client).send_stop_command(task_id) diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py index 314393f059..0eb3854c84 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py @@ -596,7 +596,8 @@ class TestWorkflowTaskStopApiPost: assert result == {"result": "success"} mock_queue_mgr.set_stop_flag_no_user_check.assert_called_once_with("task-1") - mock_graph_mgr.send_stop_command.assert_called_once_with("task-1") + mock_graph_mgr.assert_called_once() + mock_graph_mgr.return_value.send_stop_command.assert_called_once_with("task-1") def test_stop_workflow_task_wrong_app_mode(self, app): """Test NotWorkflowAppError when app mode is not workflow.""" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py b/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py index f1a495d20a..0920940e51 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py @@ -32,25 +32,26 @@ class TestRedisStopIntegration: mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline) mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None) - with patch("core.workflow.graph_engine.manager.redis_client", mock_redis): - # Execute - GraphEngineManager.send_stop_command(task_id, reason="Test stop") + manager = GraphEngineManager(mock_redis) - # Verify - mock_redis.pipeline.assert_called_once() + # Execute + manager.send_stop_command(task_id, reason="Test stop") - # Check that rpush was called with correct arguments - calls = mock_pipeline.rpush.call_args_list - assert len(calls) == 1 + # Verify + mock_redis.pipeline.assert_called_once() - # Verify the channel key - assert calls[0][0][0] == expected_channel_key + # Check that rpush was called with correct arguments + calls = mock_pipeline.rpush.call_args_list + assert len(calls) == 1 - # Verify the command data - command_json = calls[0][0][1] - command_data = json.loads(command_json) - assert command_data["command_type"] == CommandType.ABORT - assert command_data["reason"] == "Test stop" + # Verify the channel key + assert calls[0][0][0] == expected_channel_key + + # Verify the command data + command_json = calls[0][0][1] + command_data = json.loads(command_json) + assert command_data["command_type"] == CommandType.ABORT + assert command_data["reason"] == "Test stop" def test_graph_engine_manager_sends_pause_command(self): """Test that GraphEngineManager correctly sends pause command through Redis.""" @@ -62,18 +63,18 @@ class TestRedisStopIntegration: mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline) mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None) - with patch("core.workflow.graph_engine.manager.redis_client", mock_redis): - GraphEngineManager.send_pause_command(task_id, reason="Awaiting resources") + manager = GraphEngineManager(mock_redis) + manager.send_pause_command(task_id, reason="Awaiting resources") - mock_redis.pipeline.assert_called_once() - calls = mock_pipeline.rpush.call_args_list - assert len(calls) == 1 - assert calls[0][0][0] == expected_channel_key + mock_redis.pipeline.assert_called_once() + calls = mock_pipeline.rpush.call_args_list + assert len(calls) == 1 + assert calls[0][0][0] == expected_channel_key - command_json = calls[0][0][1] - command_data = json.loads(command_json) - assert command_data["command_type"] == CommandType.PAUSE.value - assert command_data["reason"] == "Awaiting resources" + command_json = calls[0][0][1] + command_data = json.loads(command_json) + assert command_data["command_type"] == CommandType.PAUSE.value + assert command_data["reason"] == "Awaiting resources" def test_graph_engine_manager_handles_redis_failure_gracefully(self): """Test that GraphEngineManager handles Redis failures without raising exceptions.""" @@ -82,13 +83,13 @@ class TestRedisStopIntegration: # Mock redis client to raise exception mock_redis = MagicMock() mock_redis.pipeline.side_effect = redis.ConnectionError("Redis connection failed") + manager = GraphEngineManager(mock_redis) - with patch("core.workflow.graph_engine.manager.redis_client", mock_redis): - # Should not raise exception - try: - GraphEngineManager.send_stop_command(task_id) - except Exception as e: - pytest.fail(f"GraphEngineManager.send_stop_command raised {e} unexpectedly") + # Should not raise exception + try: + manager.send_stop_command(task_id) + except Exception as e: + pytest.fail(f"GraphEngineManager.send_stop_command raised {e} unexpectedly") def test_app_queue_manager_no_user_check(self): """Test that AppQueueManager.set_stop_flag_no_user_check works without user validation.""" @@ -251,13 +252,10 @@ class TestRedisStopIntegration: mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline) mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None) - with ( - patch("core.app.apps.base_app_queue_manager.redis_client", mock_redis), - patch("core.workflow.graph_engine.manager.redis_client", mock_redis), - ): + with patch("core.app.apps.base_app_queue_manager.redis_client", mock_redis): # Execute both stop mechanisms AppQueueManager.set_stop_flag_no_user_check(task_id) - GraphEngineManager.send_stop_command(task_id) + GraphEngineManager(mock_redis).send_stop_command(task_id) # Verify legacy stop flag was set expected_stop_flag_key = f"generate_task_stopped:{task_id}" diff --git a/api/tests/unit_tests/services/test_app_task_service.py b/api/tests/unit_tests/services/test_app_task_service.py index e00486f77c..33ca4cb853 100644 --- a/api/tests/unit_tests/services/test_app_task_service.py +++ b/api/tests/unit_tests/services/test_app_task_service.py @@ -44,9 +44,10 @@ class TestAppTaskService: # Assert mock_app_queue_manager.set_stop_flag.assert_called_once_with(task_id, invoke_from, user_id) if should_call_graph_engine: - mock_graph_engine_manager.send_stop_command.assert_called_once_with(task_id) + mock_graph_engine_manager.assert_called_once() + mock_graph_engine_manager.return_value.send_stop_command.assert_called_once_with(task_id) else: - mock_graph_engine_manager.send_stop_command.assert_not_called() + mock_graph_engine_manager.assert_not_called() @pytest.mark.parametrize( "invoke_from", @@ -76,7 +77,8 @@ class TestAppTaskService: # Assert mock_app_queue_manager.set_stop_flag.assert_called_once_with(task_id, invoke_from, user_id) - mock_graph_engine_manager.send_stop_command.assert_called_once_with(task_id) + mock_graph_engine_manager.assert_called_once() + mock_graph_engine_manager.return_value.send_stop_command.assert_called_once_with(task_id) @patch("services.app_task_service.GraphEngineManager") @patch("services.app_task_service.AppQueueManager") @@ -96,7 +98,7 @@ class TestAppTaskService: app_mode = AppMode.ADVANCED_CHAT # Simulate GraphEngine failure - mock_graph_engine_manager.send_stop_command.side_effect = Exception("GraphEngine error") + mock_graph_engine_manager.return_value.send_stop_command.side_effect = Exception("GraphEngine error") # Act & Assert - should raise the exception since it's not caught with pytest.raises(Exception, match="GraphEngine error"):