fix(api): excessive high CPU usage caused by RedisClientWrapper (#32212)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
QuantumGhost 2026-02-11 09:49:29 +08:00 committed by GitHub
parent 3119c99979
commit 704ee40caa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 52 additions and 42 deletions

View File

@ -34,7 +34,7 @@ def stream_topic_events(
on_subscribe()
while True:
try:
msg = sub.receive(timeout=0.1)
msg = sub.receive(timeout=1)
except SubscriptionClosedError:
return
if msg is None:

View File

@ -119,7 +119,7 @@ class RedisClientWrapper:
redis_client: RedisClientWrapper = RedisClientWrapper()
pubsub_redis_client: RedisClientWrapper = RedisClientWrapper()
_pubsub_redis_client: redis.Redis | RedisCluster | None = None
def _get_ssl_configuration() -> tuple[type[Union[Connection, SSLConnection]], dict[str, Any]]:
@ -232,7 +232,7 @@ def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis
return client
def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> Union[redis.Redis, RedisCluster]:
def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> redis.Redis | RedisCluster:
if use_clusters:
return RedisCluster.from_url(pubsub_url)
return redis.Redis.from_url(pubsub_url)
@ -256,23 +256,19 @@ def init_app(app: DifyApp):
redis_client.initialize(client)
app.extensions["redis"] = redis_client
pubsub_client = client
global _pubsub_redis_client
_pubsub_redis_client = client
if dify_config.normalized_pubsub_redis_url:
pubsub_client = _create_pubsub_client(
_pubsub_redis_client = _create_pubsub_client(
dify_config.normalized_pubsub_redis_url, dify_config.PUBSUB_REDIS_USE_CLUSTERS
)
pubsub_redis_client.initialize(pubsub_client)
def get_pubsub_redis_client() -> RedisClientWrapper:
return pubsub_redis_client
def get_pubsub_broadcast_channel() -> BroadcastChannelProtocol:
redis_conn = get_pubsub_redis_client()
assert _pubsub_redis_client is not None, "PubSub redis Client should be initialized here."
if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "sharded":
return ShardedRedisBroadcastChannel(redis_conn) # pyright: ignore[reportArgumentType]
return RedisBroadcastChannel(redis_conn) # pyright: ignore[reportArgumentType]
return ShardedRedisBroadcastChannel(_pubsub_redis_client)
return RedisBroadcastChannel(_pubsub_redis_client)
P = ParamSpec("P")

View File

@ -152,7 +152,7 @@ class RedisSubscriptionBase(Subscription):
"""Iterator for consuming messages from the subscription."""
while not self._closed.is_set():
try:
item = self._queue.get(timeout=0.1)
item = self._queue.get(timeout=1)
except queue.Empty:
continue

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
from redis import Redis
from redis import Redis, RedisCluster
from ._subscription import RedisSubscriptionBase
@ -18,7 +18,7 @@ class BroadcastChannel:
def __init__(
self,
redis_client: Redis,
redis_client: Redis | RedisCluster,
):
self._client = redis_client
@ -27,7 +27,7 @@ class BroadcastChannel:
class Topic:
def __init__(self, redis_client: Redis, topic: str):
def __init__(self, redis_client: Redis | RedisCluster, topic: str):
self._client = redis_client
self._topic = topic

View File

@ -70,8 +70,9 @@ class _RedisShardedSubscription(RedisSubscriptionBase):
# Since we have already filtered at the caller's site, we can safely set
# `ignore_subscribe_messages=False`.
if isinstance(self._client, RedisCluster):
# NOTE(QuantumGhost): due to an issue in upstream code, calling `get_sharded_message`
# would use busy-looping to wait for incoming message, consuming excessive CPU quota.
# NOTE(QuantumGhost): due to an issue in upstream code, calling `get_sharded_message` without
# specifying the `target_node` argument would use busy-looping to wait
# for incoming message, consuming excessive CPU quota.
#
# Here we specify the `target_node` to mitigate this problem.
node = self._client.get_node_from_key(self._topic)
@ -80,8 +81,10 @@ class _RedisShardedSubscription(RedisSubscriptionBase):
timeout=1,
target_node=node,
)
else:
elif isinstance(self._client, Redis):
return self._pubsub.get_sharded_message(ignore_subscribe_messages=False, timeout=1) # type: ignore[attr-defined]
else:
raise AssertionError("client should be either Redis or RedisCluster.")
def _get_message_type(self) -> str:
return "smessage"

View File

@ -22,7 +22,7 @@ from libs.exception import BaseHTTPException
from models.human_input import RecipientType
from models.model import App, AppMode
from repositories.factory import DifyAPIRepositoryFactory
from tasks.app_generate.workflow_execute_task import WORKFLOW_BASED_APP_EXECUTION_QUEUE, resume_app_execution
from tasks.app_generate.workflow_execute_task import resume_app_execution
class Form:
@ -230,7 +230,6 @@ class HumanInputService:
try:
resume_app_execution.apply_async(
kwargs={"payload": payload},
queue=WORKFLOW_BASED_APP_EXECUTION_QUEUE,
)
except Exception: # pragma: no cover
logger.exception("Failed to enqueue resume task for workflow run %s", workflow_run_id)

View File

@ -129,15 +129,15 @@ def build_workflow_event_stream(
return
try:
event = buffer_state.queue.get(timeout=0.1)
event = buffer_state.queue.get(timeout=1)
except queue.Empty:
current_time = time.time()
if current_time - last_msg_time > idle_timeout:
logger.debug(
"No workflow events received for %s seconds, keeping stream open",
"Idle timeout of %s seconds reached, closing workflow event stream.",
idle_timeout,
)
last_msg_time = current_time
return
if current_time - last_ping_time >= ping_interval:
yield StreamEvent.PING.value
last_ping_time = current_time
@ -405,7 +405,7 @@ def _start_buffering(subscription) -> BufferState:
dropped_count = 0
try:
while not buffer_state.stop_event.is_set():
msg = subscription.receive(timeout=0.1)
msg = subscription.receive(timeout=1)
if msg is None:
continue
event = _parse_event_message(msg)

View File

@ -51,7 +51,7 @@ def _patch_redis_clients_on_loaded_modules():
continue
if hasattr(module, "redis_client"):
module.redis_client = redis_mock
if hasattr(module, "pubsub_redis_client"):
if hasattr(module, "_pubsub_redis_client"):
module.pubsub_redis_client = redis_mock
@ -72,7 +72,7 @@ def _patch_redis_clients():
with (
patch.object(ext_redis, "redis_client", redis_mock),
patch.object(ext_redis, "pubsub_redis_client", redis_mock),
patch.object(ext_redis, "_pubsub_redis_client", redis_mock),
):
_patch_redis_clients_on_loaded_modules()
yield

View File

@ -198,6 +198,15 @@ class SubscriptionTestCase:
description: str = ""
class FakeRedisClient:
"""Minimal fake Redis client for unit tests."""
def __init__(self) -> None:
self.publish = MagicMock()
self.spublish = MagicMock()
self.pubsub = MagicMock(return_value=MagicMock())
class TestRedisSubscription:
"""Test cases for the _RedisSubscription class."""
@ -619,10 +628,13 @@ class TestRedisSubscription:
class TestRedisShardedSubscription:
"""Test cases for the _RedisShardedSubscription class."""
@pytest.fixture(autouse=True)
def patch_sharded_redis_type(self, monkeypatch):
monkeypatch.setattr("libs.broadcast_channel.redis.sharded_channel.Redis", FakeRedisClient)
@pytest.fixture
def mock_redis_client(self) -> MagicMock:
client = MagicMock()
return client
def mock_redis_client(self) -> FakeRedisClient:
return FakeRedisClient()
@pytest.fixture
def mock_pubsub(self) -> MagicMock:
@ -636,7 +648,7 @@ class TestRedisShardedSubscription:
@pytest.fixture
def sharded_subscription(
self, mock_pubsub: MagicMock, mock_redis_client: MagicMock
self, mock_pubsub: MagicMock, mock_redis_client: FakeRedisClient
) -> Generator[_RedisShardedSubscription, None, None]:
"""Create a _RedisShardedSubscription instance for testing."""
subscription = _RedisShardedSubscription(
@ -657,7 +669,7 @@ class TestRedisShardedSubscription:
# ==================== Lifecycle Tests ====================
def test_sharded_subscription_initialization(self, mock_pubsub: MagicMock, mock_redis_client: MagicMock):
def test_sharded_subscription_initialization(self, mock_pubsub: MagicMock, mock_redis_client: FakeRedisClient):
"""Test that sharded subscription is properly initialized."""
subscription = _RedisShardedSubscription(
client=mock_redis_client,
@ -970,7 +982,7 @@ class TestRedisShardedSubscription:
],
)
def test_sharded_subscription_scenarios(
self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock, mock_redis_client: MagicMock
self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock, mock_redis_client: FakeRedisClient
):
"""Test various sharded subscription scenarios using table-driven approach."""
subscription = _RedisShardedSubscription(
@ -1058,7 +1070,7 @@ class TestRedisShardedSubscription:
# Close should still work
sharded_subscription.close() # Should not raise
def test_channel_name_variations(self, mock_pubsub: MagicMock, mock_redis_client: MagicMock):
def test_channel_name_variations(self, mock_pubsub: MagicMock, mock_redis_client: FakeRedisClient):
"""Test various sharded channel name formats."""
channel_names = [
"simple",
@ -1120,10 +1132,13 @@ class TestRedisSubscriptionCommon:
"""Parameterized fixture providing subscription type and class."""
return request.param
@pytest.fixture(autouse=True)
def patch_sharded_redis_type(self, monkeypatch):
monkeypatch.setattr("libs.broadcast_channel.redis.sharded_channel.Redis", FakeRedisClient)
@pytest.fixture
def mock_redis_client(self) -> MagicMock:
client = MagicMock()
return client
def mock_redis_client(self) -> FakeRedisClient:
return FakeRedisClient()
@pytest.fixture
def mock_pubsub(self) -> MagicMock:
@ -1140,7 +1155,7 @@ class TestRedisSubscriptionCommon:
return pubsub
@pytest.fixture
def subscription(self, subscription_params, mock_pubsub: MagicMock, mock_redis_client: MagicMock):
def subscription(self, subscription_params, mock_pubsub: MagicMock, mock_redis_client: FakeRedisClient):
"""Create a subscription instance based on parameterized type."""
subscription_type, subscription_class = subscription_params
topic_name = f"test-{subscription_type}-topic"

View File

@ -17,7 +17,6 @@ from core.workflow.nodes.human_input.entities import (
from core.workflow.nodes.human_input.enums import FormInputType, HumanInputFormKind, HumanInputFormStatus
from models.human_input import RecipientType
from services.human_input_service import Form, FormExpiredError, HumanInputService, InvalidFormDataError
from tasks.app_generate.workflow_execute_task import WORKFLOW_BASED_APP_EXECUTION_QUEUE
@pytest.fixture
@ -88,7 +87,6 @@ def test_enqueue_resume_dispatches_task_for_workflow(mocker, mock_session_factor
resume_task.apply_async.assert_called_once()
call_kwargs = resume_task.apply_async.call_args.kwargs
assert call_kwargs["queue"] == WORKFLOW_BASED_APP_EXECUTION_QUEUE
assert call_kwargs["kwargs"]["payload"]["workflow_run_id"] == "workflow-run-id"
@ -130,7 +128,6 @@ def test_enqueue_resume_dispatches_task_for_advanced_chat(mocker, mock_session_f
resume_task.apply_async.assert_called_once()
call_kwargs = resume_task.apply_async.call_args.kwargs
assert call_kwargs["queue"] == WORKFLOW_BASED_APP_EXECUTION_QUEUE
assert call_kwargs["kwargs"]["payload"]["workflow_run_id"] == "workflow-run-id"