mirror of https://github.com/langgenius/dify.git
fix(api): excessive high CPU usage caused by RedisClientWrapper (#32212)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
3119c99979
commit
704ee40caa
|
|
@ -34,7 +34,7 @@ def stream_topic_events(
|
|||
on_subscribe()
|
||||
while True:
|
||||
try:
|
||||
msg = sub.receive(timeout=0.1)
|
||||
msg = sub.receive(timeout=1)
|
||||
except SubscriptionClosedError:
|
||||
return
|
||||
if msg is None:
|
||||
|
|
|
|||
|
|
@ -119,7 +119,7 @@ class RedisClientWrapper:
|
|||
|
||||
|
||||
redis_client: RedisClientWrapper = RedisClientWrapper()
|
||||
pubsub_redis_client: RedisClientWrapper = RedisClientWrapper()
|
||||
_pubsub_redis_client: redis.Redis | RedisCluster | None = None
|
||||
|
||||
|
||||
def _get_ssl_configuration() -> tuple[type[Union[Connection, SSLConnection]], dict[str, Any]]:
|
||||
|
|
@ -232,7 +232,7 @@ def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis
|
|||
return client
|
||||
|
||||
|
||||
def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> Union[redis.Redis, RedisCluster]:
|
||||
def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> redis.Redis | RedisCluster:
|
||||
if use_clusters:
|
||||
return RedisCluster.from_url(pubsub_url)
|
||||
return redis.Redis.from_url(pubsub_url)
|
||||
|
|
@ -256,23 +256,19 @@ def init_app(app: DifyApp):
|
|||
redis_client.initialize(client)
|
||||
app.extensions["redis"] = redis_client
|
||||
|
||||
pubsub_client = client
|
||||
global _pubsub_redis_client
|
||||
_pubsub_redis_client = client
|
||||
if dify_config.normalized_pubsub_redis_url:
|
||||
pubsub_client = _create_pubsub_client(
|
||||
_pubsub_redis_client = _create_pubsub_client(
|
||||
dify_config.normalized_pubsub_redis_url, dify_config.PUBSUB_REDIS_USE_CLUSTERS
|
||||
)
|
||||
pubsub_redis_client.initialize(pubsub_client)
|
||||
|
||||
|
||||
def get_pubsub_redis_client() -> RedisClientWrapper:
|
||||
return pubsub_redis_client
|
||||
|
||||
|
||||
def get_pubsub_broadcast_channel() -> BroadcastChannelProtocol:
|
||||
redis_conn = get_pubsub_redis_client()
|
||||
assert _pubsub_redis_client is not None, "PubSub redis Client should be initialized here."
|
||||
if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "sharded":
|
||||
return ShardedRedisBroadcastChannel(redis_conn) # pyright: ignore[reportArgumentType]
|
||||
return RedisBroadcastChannel(redis_conn) # pyright: ignore[reportArgumentType]
|
||||
return ShardedRedisBroadcastChannel(_pubsub_redis_client)
|
||||
return RedisBroadcastChannel(_pubsub_redis_client)
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
|
|
|||
|
|
@ -152,7 +152,7 @@ class RedisSubscriptionBase(Subscription):
|
|||
"""Iterator for consuming messages from the subscription."""
|
||||
while not self._closed.is_set():
|
||||
try:
|
||||
item = self._queue.get(timeout=0.1)
|
||||
item = self._queue.get(timeout=1)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||
from redis import Redis
|
||||
from redis import Redis, RedisCluster
|
||||
|
||||
from ._subscription import RedisSubscriptionBase
|
||||
|
||||
|
|
@ -18,7 +18,7 @@ class BroadcastChannel:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: Redis,
|
||||
redis_client: Redis | RedisCluster,
|
||||
):
|
||||
self._client = redis_client
|
||||
|
||||
|
|
@ -27,7 +27,7 @@ class BroadcastChannel:
|
|||
|
||||
|
||||
class Topic:
|
||||
def __init__(self, redis_client: Redis, topic: str):
|
||||
def __init__(self, redis_client: Redis | RedisCluster, topic: str):
|
||||
self._client = redis_client
|
||||
self._topic = topic
|
||||
|
||||
|
|
|
|||
|
|
@ -70,8 +70,9 @@ class _RedisShardedSubscription(RedisSubscriptionBase):
|
|||
# Since we have already filtered at the caller's site, we can safely set
|
||||
# `ignore_subscribe_messages=False`.
|
||||
if isinstance(self._client, RedisCluster):
|
||||
# NOTE(QuantumGhost): due to an issue in upstream code, calling `get_sharded_message`
|
||||
# would use busy-looping to wait for incoming message, consuming excessive CPU quota.
|
||||
# NOTE(QuantumGhost): due to an issue in upstream code, calling `get_sharded_message` without
|
||||
# specifying the `target_node` argument would use busy-looping to wait
|
||||
# for incoming message, consuming excessive CPU quota.
|
||||
#
|
||||
# Here we specify the `target_node` to mitigate this problem.
|
||||
node = self._client.get_node_from_key(self._topic)
|
||||
|
|
@ -80,8 +81,10 @@ class _RedisShardedSubscription(RedisSubscriptionBase):
|
|||
timeout=1,
|
||||
target_node=node,
|
||||
)
|
||||
else:
|
||||
elif isinstance(self._client, Redis):
|
||||
return self._pubsub.get_sharded_message(ignore_subscribe_messages=False, timeout=1) # type: ignore[attr-defined]
|
||||
else:
|
||||
raise AssertionError("client should be either Redis or RedisCluster.")
|
||||
|
||||
def _get_message_type(self) -> str:
|
||||
return "smessage"
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from libs.exception import BaseHTTPException
|
|||
from models.human_input import RecipientType
|
||||
from models.model import App, AppMode
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from tasks.app_generate.workflow_execute_task import WORKFLOW_BASED_APP_EXECUTION_QUEUE, resume_app_execution
|
||||
from tasks.app_generate.workflow_execute_task import resume_app_execution
|
||||
|
||||
|
||||
class Form:
|
||||
|
|
@ -230,7 +230,6 @@ class HumanInputService:
|
|||
try:
|
||||
resume_app_execution.apply_async(
|
||||
kwargs={"payload": payload},
|
||||
queue=WORKFLOW_BASED_APP_EXECUTION_QUEUE,
|
||||
)
|
||||
except Exception: # pragma: no cover
|
||||
logger.exception("Failed to enqueue resume task for workflow run %s", workflow_run_id)
|
||||
|
|
|
|||
|
|
@ -129,15 +129,15 @@ def build_workflow_event_stream(
|
|||
return
|
||||
|
||||
try:
|
||||
event = buffer_state.queue.get(timeout=0.1)
|
||||
event = buffer_state.queue.get(timeout=1)
|
||||
except queue.Empty:
|
||||
current_time = time.time()
|
||||
if current_time - last_msg_time > idle_timeout:
|
||||
logger.debug(
|
||||
"No workflow events received for %s seconds, keeping stream open",
|
||||
"Idle timeout of %s seconds reached, closing workflow event stream.",
|
||||
idle_timeout,
|
||||
)
|
||||
last_msg_time = current_time
|
||||
return
|
||||
if current_time - last_ping_time >= ping_interval:
|
||||
yield StreamEvent.PING.value
|
||||
last_ping_time = current_time
|
||||
|
|
@ -405,7 +405,7 @@ def _start_buffering(subscription) -> BufferState:
|
|||
dropped_count = 0
|
||||
try:
|
||||
while not buffer_state.stop_event.is_set():
|
||||
msg = subscription.receive(timeout=0.1)
|
||||
msg = subscription.receive(timeout=1)
|
||||
if msg is None:
|
||||
continue
|
||||
event = _parse_event_message(msg)
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ def _patch_redis_clients_on_loaded_modules():
|
|||
continue
|
||||
if hasattr(module, "redis_client"):
|
||||
module.redis_client = redis_mock
|
||||
if hasattr(module, "pubsub_redis_client"):
|
||||
if hasattr(module, "_pubsub_redis_client"):
|
||||
module.pubsub_redis_client = redis_mock
|
||||
|
||||
|
||||
|
|
@ -72,7 +72,7 @@ def _patch_redis_clients():
|
|||
|
||||
with (
|
||||
patch.object(ext_redis, "redis_client", redis_mock),
|
||||
patch.object(ext_redis, "pubsub_redis_client", redis_mock),
|
||||
patch.object(ext_redis, "_pubsub_redis_client", redis_mock),
|
||||
):
|
||||
_patch_redis_clients_on_loaded_modules()
|
||||
yield
|
||||
|
|
|
|||
|
|
@ -198,6 +198,15 @@ class SubscriptionTestCase:
|
|||
description: str = ""
|
||||
|
||||
|
||||
class FakeRedisClient:
|
||||
"""Minimal fake Redis client for unit tests."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.publish = MagicMock()
|
||||
self.spublish = MagicMock()
|
||||
self.pubsub = MagicMock(return_value=MagicMock())
|
||||
|
||||
|
||||
class TestRedisSubscription:
|
||||
"""Test cases for the _RedisSubscription class."""
|
||||
|
||||
|
|
@ -619,10 +628,13 @@ class TestRedisSubscription:
|
|||
class TestRedisShardedSubscription:
|
||||
"""Test cases for the _RedisShardedSubscription class."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_sharded_redis_type(self, monkeypatch):
|
||||
monkeypatch.setattr("libs.broadcast_channel.redis.sharded_channel.Redis", FakeRedisClient)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis_client(self) -> MagicMock:
|
||||
client = MagicMock()
|
||||
return client
|
||||
def mock_redis_client(self) -> FakeRedisClient:
|
||||
return FakeRedisClient()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pubsub(self) -> MagicMock:
|
||||
|
|
@ -636,7 +648,7 @@ class TestRedisShardedSubscription:
|
|||
|
||||
@pytest.fixture
|
||||
def sharded_subscription(
|
||||
self, mock_pubsub: MagicMock, mock_redis_client: MagicMock
|
||||
self, mock_pubsub: MagicMock, mock_redis_client: FakeRedisClient
|
||||
) -> Generator[_RedisShardedSubscription, None, None]:
|
||||
"""Create a _RedisShardedSubscription instance for testing."""
|
||||
subscription = _RedisShardedSubscription(
|
||||
|
|
@ -657,7 +669,7 @@ class TestRedisShardedSubscription:
|
|||
|
||||
# ==================== Lifecycle Tests ====================
|
||||
|
||||
def test_sharded_subscription_initialization(self, mock_pubsub: MagicMock, mock_redis_client: MagicMock):
|
||||
def test_sharded_subscription_initialization(self, mock_pubsub: MagicMock, mock_redis_client: FakeRedisClient):
|
||||
"""Test that sharded subscription is properly initialized."""
|
||||
subscription = _RedisShardedSubscription(
|
||||
client=mock_redis_client,
|
||||
|
|
@ -970,7 +982,7 @@ class TestRedisShardedSubscription:
|
|||
],
|
||||
)
|
||||
def test_sharded_subscription_scenarios(
|
||||
self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock, mock_redis_client: MagicMock
|
||||
self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock, mock_redis_client: FakeRedisClient
|
||||
):
|
||||
"""Test various sharded subscription scenarios using table-driven approach."""
|
||||
subscription = _RedisShardedSubscription(
|
||||
|
|
@ -1058,7 +1070,7 @@ class TestRedisShardedSubscription:
|
|||
# Close should still work
|
||||
sharded_subscription.close() # Should not raise
|
||||
|
||||
def test_channel_name_variations(self, mock_pubsub: MagicMock, mock_redis_client: MagicMock):
|
||||
def test_channel_name_variations(self, mock_pubsub: MagicMock, mock_redis_client: FakeRedisClient):
|
||||
"""Test various sharded channel name formats."""
|
||||
channel_names = [
|
||||
"simple",
|
||||
|
|
@ -1120,10 +1132,13 @@ class TestRedisSubscriptionCommon:
|
|||
"""Parameterized fixture providing subscription type and class."""
|
||||
return request.param
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_sharded_redis_type(self, monkeypatch):
|
||||
monkeypatch.setattr("libs.broadcast_channel.redis.sharded_channel.Redis", FakeRedisClient)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis_client(self) -> MagicMock:
|
||||
client = MagicMock()
|
||||
return client
|
||||
def mock_redis_client(self) -> FakeRedisClient:
|
||||
return FakeRedisClient()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pubsub(self) -> MagicMock:
|
||||
|
|
@ -1140,7 +1155,7 @@ class TestRedisSubscriptionCommon:
|
|||
return pubsub
|
||||
|
||||
@pytest.fixture
|
||||
def subscription(self, subscription_params, mock_pubsub: MagicMock, mock_redis_client: MagicMock):
|
||||
def subscription(self, subscription_params, mock_pubsub: MagicMock, mock_redis_client: FakeRedisClient):
|
||||
"""Create a subscription instance based on parameterized type."""
|
||||
subscription_type, subscription_class = subscription_params
|
||||
topic_name = f"test-{subscription_type}-topic"
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ from core.workflow.nodes.human_input.entities import (
|
|||
from core.workflow.nodes.human_input.enums import FormInputType, HumanInputFormKind, HumanInputFormStatus
|
||||
from models.human_input import RecipientType
|
||||
from services.human_input_service import Form, FormExpiredError, HumanInputService, InvalidFormDataError
|
||||
from tasks.app_generate.workflow_execute_task import WORKFLOW_BASED_APP_EXECUTION_QUEUE
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -88,7 +87,6 @@ def test_enqueue_resume_dispatches_task_for_workflow(mocker, mock_session_factor
|
|||
|
||||
resume_task.apply_async.assert_called_once()
|
||||
call_kwargs = resume_task.apply_async.call_args.kwargs
|
||||
assert call_kwargs["queue"] == WORKFLOW_BASED_APP_EXECUTION_QUEUE
|
||||
assert call_kwargs["kwargs"]["payload"]["workflow_run_id"] == "workflow-run-id"
|
||||
|
||||
|
||||
|
|
@ -130,7 +128,6 @@ def test_enqueue_resume_dispatches_task_for_advanced_chat(mocker, mock_session_f
|
|||
|
||||
resume_task.apply_async.assert_called_once()
|
||||
call_kwargs = resume_task.apply_async.call_args.kwargs
|
||||
assert call_kwargs["queue"] == WORKFLOW_BASED_APP_EXECUTION_QUEUE
|
||||
assert call_kwargs["kwargs"]["payload"]["workflow_run_id"] == "workflow-run-id"
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue