diff --git a/api/core/app/apps/streaming_utils.py b/api/core/app/apps/streaming_utils.py index 57d4b537a4..af3441aca3 100644 --- a/api/core/app/apps/streaming_utils.py +++ b/api/core/app/apps/streaming_utils.py @@ -34,7 +34,7 @@ def stream_topic_events( on_subscribe() while True: try: - msg = sub.receive(timeout=0.1) + msg = sub.receive(timeout=1) except SubscriptionClosedError: return if msg is None: diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index 0797a3cb98..3ca3598002 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -119,7 +119,7 @@ class RedisClientWrapper: redis_client: RedisClientWrapper = RedisClientWrapper() -pubsub_redis_client: RedisClientWrapper = RedisClientWrapper() +_pubsub_redis_client: redis.Redis | RedisCluster | None = None def _get_ssl_configuration() -> tuple[type[Union[Connection, SSLConnection]], dict[str, Any]]: @@ -232,7 +232,7 @@ def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis return client -def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> Union[redis.Redis, RedisCluster]: +def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> redis.Redis | RedisCluster: if use_clusters: return RedisCluster.from_url(pubsub_url) return redis.Redis.from_url(pubsub_url) @@ -256,23 +256,19 @@ def init_app(app: DifyApp): redis_client.initialize(client) app.extensions["redis"] = redis_client - pubsub_client = client + global _pubsub_redis_client + _pubsub_redis_client = client if dify_config.normalized_pubsub_redis_url: - pubsub_client = _create_pubsub_client( + _pubsub_redis_client = _create_pubsub_client( dify_config.normalized_pubsub_redis_url, dify_config.PUBSUB_REDIS_USE_CLUSTERS ) - pubsub_redis_client.initialize(pubsub_client) - - -def get_pubsub_redis_client() -> RedisClientWrapper: - return pubsub_redis_client def get_pubsub_broadcast_channel() -> BroadcastChannelProtocol: - redis_conn = get_pubsub_redis_client() + assert _pubsub_redis_client is not None, "PubSub redis Client should be initialized here." if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "sharded": - return ShardedRedisBroadcastChannel(redis_conn) # pyright: ignore[reportArgumentType] - return RedisBroadcastChannel(redis_conn) # pyright: ignore[reportArgumentType] + return ShardedRedisBroadcastChannel(_pubsub_redis_client) + return RedisBroadcastChannel(_pubsub_redis_client) P = ParamSpec("P") diff --git a/api/libs/broadcast_channel/redis/_subscription.py b/api/libs/broadcast_channel/redis/_subscription.py index df81775660..40027bc424 100644 --- a/api/libs/broadcast_channel/redis/_subscription.py +++ b/api/libs/broadcast_channel/redis/_subscription.py @@ -152,7 +152,7 @@ class RedisSubscriptionBase(Subscription): """Iterator for consuming messages from the subscription.""" while not self._closed.is_set(): try: - item = self._queue.get(timeout=0.1) + item = self._queue.get(timeout=1) except queue.Empty: continue diff --git a/api/libs/broadcast_channel/redis/channel.py b/api/libs/broadcast_channel/redis/channel.py index 35a227769c..bd6d58c53f 100644 --- a/api/libs/broadcast_channel/redis/channel.py +++ b/api/libs/broadcast_channel/redis/channel.py @@ -1,7 +1,7 @@ from __future__ import annotations from libs.broadcast_channel.channel import Producer, Subscriber, Subscription -from redis import Redis +from redis import Redis, RedisCluster from ._subscription import RedisSubscriptionBase @@ -18,7 +18,7 @@ class BroadcastChannel: def __init__( self, - redis_client: Redis, + redis_client: Redis | RedisCluster, ): self._client = redis_client @@ -27,7 +27,7 @@ class BroadcastChannel: class Topic: - def __init__(self, redis_client: Redis, topic: str): + def __init__(self, redis_client: Redis | RedisCluster, topic: str): self._client = redis_client self._topic = topic diff --git a/api/libs/broadcast_channel/redis/sharded_channel.py b/api/libs/broadcast_channel/redis/sharded_channel.py index 290c077d11..20c43b8bbb 100644 --- a/api/libs/broadcast_channel/redis/sharded_channel.py +++ b/api/libs/broadcast_channel/redis/sharded_channel.py @@ -70,8 +70,9 @@ class _RedisShardedSubscription(RedisSubscriptionBase): # Since we have already filtered at the caller's site, we can safely set # `ignore_subscribe_messages=False`. if isinstance(self._client, RedisCluster): - # NOTE(QuantumGhost): due to an issue in upstream code, calling `get_sharded_message` - # would use busy-looping to wait for incoming message, consuming excessive CPU quota. + # NOTE(QuantumGhost): due to an issue in upstream code, calling `get_sharded_message` without + # specifying the `target_node` argument would use busy-looping to wait + # for incoming message, consuming excessive CPU quota. # # Here we specify the `target_node` to mitigate this problem. node = self._client.get_node_from_key(self._topic) @@ -80,8 +81,10 @@ class _RedisShardedSubscription(RedisSubscriptionBase): timeout=1, target_node=node, ) - else: + elif isinstance(self._client, Redis): return self._pubsub.get_sharded_message(ignore_subscribe_messages=False, timeout=1) # type: ignore[attr-defined] + else: + raise AssertionError("client should be either Redis or RedisCluster.") def _get_message_type(self) -> str: return "smessage" diff --git a/api/services/human_input_service.py b/api/services/human_input_service.py index 76b6e6e0e6..87816643f6 100644 --- a/api/services/human_input_service.py +++ b/api/services/human_input_service.py @@ -22,7 +22,7 @@ from libs.exception import BaseHTTPException from models.human_input import RecipientType from models.model import App, AppMode from repositories.factory import DifyAPIRepositoryFactory -from tasks.app_generate.workflow_execute_task import WORKFLOW_BASED_APP_EXECUTION_QUEUE, resume_app_execution +from tasks.app_generate.workflow_execute_task import resume_app_execution class Form: @@ -230,7 +230,6 @@ class HumanInputService: try: resume_app_execution.apply_async( kwargs={"payload": payload}, - queue=WORKFLOW_BASED_APP_EXECUTION_QUEUE, ) except Exception: # pragma: no cover logger.exception("Failed to enqueue resume task for workflow run %s", workflow_run_id) diff --git a/api/services/workflow_event_snapshot_service.py b/api/services/workflow_event_snapshot_service.py index 74211e1340..09037a92ce 100644 --- a/api/services/workflow_event_snapshot_service.py +++ b/api/services/workflow_event_snapshot_service.py @@ -129,15 +129,15 @@ def build_workflow_event_stream( return try: - event = buffer_state.queue.get(timeout=0.1) + event = buffer_state.queue.get(timeout=1) except queue.Empty: current_time = time.time() if current_time - last_msg_time > idle_timeout: logger.debug( - "No workflow events received for %s seconds, keeping stream open", + "Idle timeout of %s seconds reached, closing workflow event stream.", idle_timeout, ) - last_msg_time = current_time + return if current_time - last_ping_time >= ping_interval: yield StreamEvent.PING.value last_ping_time = current_time @@ -405,7 +405,7 @@ def _start_buffering(subscription) -> BufferState: dropped_count = 0 try: while not buffer_state.stop_event.is_set(): - msg = subscription.receive(timeout=0.1) + msg = subscription.receive(timeout=1) if msg is None: continue event = _parse_event_message(msg) diff --git a/api/tests/unit_tests/conftest.py b/api/tests/unit_tests/conftest.py index da957d3a81..e443f48f3b 100644 --- a/api/tests/unit_tests/conftest.py +++ b/api/tests/unit_tests/conftest.py @@ -51,7 +51,7 @@ def _patch_redis_clients_on_loaded_modules(): continue if hasattr(module, "redis_client"): module.redis_client = redis_mock - if hasattr(module, "pubsub_redis_client"): + if hasattr(module, "_pubsub_redis_client"): module.pubsub_redis_client = redis_mock @@ -72,7 +72,7 @@ def _patch_redis_clients(): with ( patch.object(ext_redis, "redis_client", redis_mock), - patch.object(ext_redis, "pubsub_redis_client", redis_mock), + patch.object(ext_redis, "_pubsub_redis_client", redis_mock), ): _patch_redis_clients_on_loaded_modules() yield diff --git a/api/tests/unit_tests/libs/broadcast_channel/redis/test_channel_unit_tests.py b/api/tests/unit_tests/libs/broadcast_channel/redis/test_channel_unit_tests.py index f206c411fd..f84df42bfd 100644 --- a/api/tests/unit_tests/libs/broadcast_channel/redis/test_channel_unit_tests.py +++ b/api/tests/unit_tests/libs/broadcast_channel/redis/test_channel_unit_tests.py @@ -198,6 +198,15 @@ class SubscriptionTestCase: description: str = "" +class FakeRedisClient: + """Minimal fake Redis client for unit tests.""" + + def __init__(self) -> None: + self.publish = MagicMock() + self.spublish = MagicMock() + self.pubsub = MagicMock(return_value=MagicMock()) + + class TestRedisSubscription: """Test cases for the _RedisSubscription class.""" @@ -619,10 +628,13 @@ class TestRedisSubscription: class TestRedisShardedSubscription: """Test cases for the _RedisShardedSubscription class.""" + @pytest.fixture(autouse=True) + def patch_sharded_redis_type(self, monkeypatch): + monkeypatch.setattr("libs.broadcast_channel.redis.sharded_channel.Redis", FakeRedisClient) + @pytest.fixture - def mock_redis_client(self) -> MagicMock: - client = MagicMock() - return client + def mock_redis_client(self) -> FakeRedisClient: + return FakeRedisClient() @pytest.fixture def mock_pubsub(self) -> MagicMock: @@ -636,7 +648,7 @@ class TestRedisShardedSubscription: @pytest.fixture def sharded_subscription( - self, mock_pubsub: MagicMock, mock_redis_client: MagicMock + self, mock_pubsub: MagicMock, mock_redis_client: FakeRedisClient ) -> Generator[_RedisShardedSubscription, None, None]: """Create a _RedisShardedSubscription instance for testing.""" subscription = _RedisShardedSubscription( @@ -657,7 +669,7 @@ class TestRedisShardedSubscription: # ==================== Lifecycle Tests ==================== - def test_sharded_subscription_initialization(self, mock_pubsub: MagicMock, mock_redis_client: MagicMock): + def test_sharded_subscription_initialization(self, mock_pubsub: MagicMock, mock_redis_client: FakeRedisClient): """Test that sharded subscription is properly initialized.""" subscription = _RedisShardedSubscription( client=mock_redis_client, @@ -970,7 +982,7 @@ class TestRedisShardedSubscription: ], ) def test_sharded_subscription_scenarios( - self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock, mock_redis_client: MagicMock + self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock, mock_redis_client: FakeRedisClient ): """Test various sharded subscription scenarios using table-driven approach.""" subscription = _RedisShardedSubscription( @@ -1058,7 +1070,7 @@ class TestRedisShardedSubscription: # Close should still work sharded_subscription.close() # Should not raise - def test_channel_name_variations(self, mock_pubsub: MagicMock, mock_redis_client: MagicMock): + def test_channel_name_variations(self, mock_pubsub: MagicMock, mock_redis_client: FakeRedisClient): """Test various sharded channel name formats.""" channel_names = [ "simple", @@ -1120,10 +1132,13 @@ class TestRedisSubscriptionCommon: """Parameterized fixture providing subscription type and class.""" return request.param + @pytest.fixture(autouse=True) + def patch_sharded_redis_type(self, monkeypatch): + monkeypatch.setattr("libs.broadcast_channel.redis.sharded_channel.Redis", FakeRedisClient) + @pytest.fixture - def mock_redis_client(self) -> MagicMock: - client = MagicMock() - return client + def mock_redis_client(self) -> FakeRedisClient: + return FakeRedisClient() @pytest.fixture def mock_pubsub(self) -> MagicMock: @@ -1140,7 +1155,7 @@ class TestRedisSubscriptionCommon: return pubsub @pytest.fixture - def subscription(self, subscription_params, mock_pubsub: MagicMock, mock_redis_client: MagicMock): + def subscription(self, subscription_params, mock_pubsub: MagicMock, mock_redis_client: FakeRedisClient): """Create a subscription instance based on parameterized type.""" subscription_type, subscription_class = subscription_params topic_name = f"test-{subscription_type}-topic" diff --git a/api/tests/unit_tests/services/test_human_input_service.py b/api/tests/unit_tests/services/test_human_input_service.py index d2cf74daf3..5800d029ca 100644 --- a/api/tests/unit_tests/services/test_human_input_service.py +++ b/api/tests/unit_tests/services/test_human_input_service.py @@ -17,7 +17,6 @@ from core.workflow.nodes.human_input.entities import ( from core.workflow.nodes.human_input.enums import FormInputType, HumanInputFormKind, HumanInputFormStatus from models.human_input import RecipientType from services.human_input_service import Form, FormExpiredError, HumanInputService, InvalidFormDataError -from tasks.app_generate.workflow_execute_task import WORKFLOW_BASED_APP_EXECUTION_QUEUE @pytest.fixture @@ -88,7 +87,6 @@ def test_enqueue_resume_dispatches_task_for_workflow(mocker, mock_session_factor resume_task.apply_async.assert_called_once() call_kwargs = resume_task.apply_async.call_args.kwargs - assert call_kwargs["queue"] == WORKFLOW_BASED_APP_EXECUTION_QUEUE assert call_kwargs["kwargs"]["payload"]["workflow_run_id"] == "workflow-run-id" @@ -130,7 +128,6 @@ def test_enqueue_resume_dispatches_task_for_advanced_chat(mocker, mock_session_f resume_task.apply_async.assert_called_once() call_kwargs = resume_task.apply_async.call_args.kwargs - assert call_kwargs["queue"] == WORKFLOW_BASED_APP_EXECUTION_QUEUE assert call_kwargs["kwargs"]["payload"]["workflow_run_id"] == "workflow-run-id"