mirror of
https://github.com/langgenius/dify.git
synced 2026-06-24 04:51:11 +08:00
refactor: improve stream close 2 (#37106)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
ef54229d6f
commit
7852c273e4
@ -768,7 +768,6 @@ EVENT_BUS_REDIS_CHANNEL_TYPE=pubsub
|
||||
# Whether to use Redis cluster mode while use redis as event bus.
|
||||
# It's highly recommended to enable this for large deployments.
|
||||
EVENT_BUS_REDIS_USE_CLUSTERS=false
|
||||
EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS=2000
|
||||
|
||||
# Whether to Enable human input timeout check task
|
||||
ENABLE_HUMAN_INPUT_TIMEOUT_TASK=true
|
||||
|
||||
@ -2,7 +2,6 @@ from typing import Literal, Protocol, cast
|
||||
from urllib.parse import quote_plus, urlunparse
|
||||
|
||||
from pydantic import AliasChoices, Field
|
||||
from pydantic.types import NonNegativeInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
@ -71,24 +70,6 @@ class RedisPubSubConfig(BaseSettings):
|
||||
default=600,
|
||||
)
|
||||
|
||||
PUBSUB_LISTENER_JOIN_TIMEOUT_MS: NonNegativeInt = Field(
|
||||
validation_alias=AliasChoices("EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS", "PUBSUB_LISTENER_JOIN_TIMEOUT_MS"),
|
||||
description=(
|
||||
"Maximum time (milliseconds) that ``Subscription.close()`` waits for its listener thread to "
|
||||
"finish before returning. Bounds the tail latency between a terminal event being delivered to "
|
||||
"an SSE client and the response stream actually closing.\n\n"
|
||||
"The listener thread blocks on a polling read (XREAD BLOCK for streams, get_message timeout "
|
||||
"for pubsub/sharded) with a fixed 1s window, so close() naturally has to wait up to ~1s for "
|
||||
"the thread to notice the subscription was closed. Setting this lower (e.g. 100) lets close() "
|
||||
"return promptly while the daemon listener thread cleans itself up on the next poll "
|
||||
"boundary - safe because the listener holds no critical state and exits within one poll "
|
||||
"window. Setting it higher (e.g. 5000) gives the listener more grace before close() gives up "
|
||||
"and logs a warning. Default 2000ms preserves the pre-change behaviour.\n\n"
|
||||
"Also accepts ENV: EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS."
|
||||
),
|
||||
default=2000,
|
||||
)
|
||||
|
||||
def _build_default_pubsub_url(self) -> str:
|
||||
defaults = _redis_defaults(self)
|
||||
if not defaults.REDIS_HOST or not defaults.REDIS_PORT:
|
||||
|
||||
@ -25,7 +25,7 @@ from extensions.redis_names import (
|
||||
serialize_redis_name_args,
|
||||
)
|
||||
from libs.broadcast_channel.channel import BroadcastChannel as BroadcastChannelProtocol
|
||||
from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel
|
||||
from libs.broadcast_channel.redis.pubsub_channel import BroadcastChannel as RedisBroadcastChannel
|
||||
from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel
|
||||
from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel
|
||||
|
||||
@ -457,16 +457,14 @@ def init_app(app: DifyApp):
|
||||
|
||||
def get_pubsub_broadcast_channel() -> BroadcastChannelProtocol:
|
||||
assert _pubsub_redis_client is not None, "PubSub redis Client should be initialized here."
|
||||
join_timeout_ms = dify_config.PUBSUB_LISTENER_JOIN_TIMEOUT_MS
|
||||
if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "sharded":
|
||||
return ShardedRedisBroadcastChannel(_pubsub_redis_client, join_timeout_ms=join_timeout_ms)
|
||||
return ShardedRedisBroadcastChannel(_pubsub_redis_client)
|
||||
if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "streams":
|
||||
return StreamsBroadcastChannel(
|
||||
_pubsub_redis_client,
|
||||
retention_seconds=dify_config.PUBSUB_STREAMS_RETENTION_SECONDS,
|
||||
join_timeout_ms=join_timeout_ms,
|
||||
)
|
||||
return RedisBroadcastChannel(_pubsub_redis_client, join_timeout_ms=join_timeout_ms)
|
||||
return RedisBroadcastChannel(_pubsub_redis_client)
|
||||
|
||||
|
||||
def redis_fallback[T](default_return: T | None = None): # type: ignore
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from .channel import BroadcastChannel
|
||||
from .pubsub_channel import BroadcastChannel
|
||||
from .sharded_channel import ShardedRedisBroadcastChannel
|
||||
|
||||
__all__ = ["BroadcastChannel", "ShardedRedisBroadcastChannel"]
|
||||
|
||||
@ -7,6 +7,7 @@ from typing import Any, Self, override
|
||||
|
||||
from libs.broadcast_channel.channel import Subscription
|
||||
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||
from libs.broadcast_channel.signals import SIG_CLOSE
|
||||
from redis import Redis, RedisCluster
|
||||
from redis.client import PubSub
|
||||
|
||||
@ -26,8 +27,6 @@ class RedisSubscriptionBase(Subscription):
|
||||
client: Redis | RedisCluster,
|
||||
pubsub: PubSub,
|
||||
topic: str,
|
||||
*,
|
||||
join_timeout_ms: int = 2000,
|
||||
):
|
||||
# The _pubsub is None only if the subscription is closed.
|
||||
self._client = client
|
||||
@ -39,11 +38,6 @@ class RedisSubscriptionBase(Subscription):
|
||||
self._listener_thread: threading.Thread | None = None
|
||||
self._start_lock = threading.Lock()
|
||||
self._started = False
|
||||
# Max time close() will wait for the listener thread to finish before
|
||||
# returning. Bounds SSE close tail latency. The listener is a daemon
|
||||
# and exits on its own within one poll window (~1s), so a low value
|
||||
# here just means close() returns sooner without breaking anything.
|
||||
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
|
||||
|
||||
def _start_if_needed(self) -> None:
|
||||
"""Start the subscription if not already started."""
|
||||
@ -90,6 +84,11 @@ class RedisSubscriptionBase(Subscription):
|
||||
if raw_message is None:
|
||||
continue
|
||||
|
||||
# If close() sent a control event to unblock us, exit immediately
|
||||
# without processing any message — the subscription is shutting down.
|
||||
if self._closed.is_set():
|
||||
break
|
||||
|
||||
if raw_message.get("type") != self._get_message_type():
|
||||
continue
|
||||
|
||||
@ -119,6 +118,8 @@ class RedisSubscriptionBase(Subscription):
|
||||
continue
|
||||
|
||||
self._enqueue_message(payload_bytes)
|
||||
if payload_bytes == SIG_CLOSE:
|
||||
break
|
||||
|
||||
_logger.debug("%s listener thread stopped for channel %s", self._get_subscription_type().title(), self._topic)
|
||||
try:
|
||||
@ -212,13 +213,16 @@ class RedisSubscriptionBase(Subscription):
|
||||
return
|
||||
|
||||
self._closed.set()
|
||||
# Send a control event on the same Redis channel to unblock the
|
||||
self._publish_close_event()
|
||||
|
||||
# NOTE: PubSub is not thread-safe. More specifically, the `PubSub.close` method and the
|
||||
# message retrieval method should NOT be called concurrently.
|
||||
#
|
||||
# Due to the restriction above, the PubSub cleanup logic happens inside the consumer thread.
|
||||
listener = self._listener_thread
|
||||
if listener is not None:
|
||||
listener.join(timeout=self._join_timeout_ms / 1000.0)
|
||||
listener.join(timeout=2)
|
||||
self._listener_thread = None
|
||||
|
||||
# Abstract methods to be implemented by subclasses
|
||||
@ -226,6 +230,15 @@ class RedisSubscriptionBase(Subscription):
|
||||
"""Return the subscription type (e.g., 'regular' or 'sharded')."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _publish_close_event(self) -> None:
|
||||
"""Publish a control event on the Redis channel to unblock the listener.
|
||||
|
||||
This is called by close() after setting _closed. The subclass should
|
||||
publish an empty message on the same topic so that a blocking
|
||||
get_message() call in the listener thread returns promptly.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _subscribe(self) -> None:
|
||||
"""Subscribe to the Redis topic using the appropriate command."""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -1,13 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, override
|
||||
|
||||
from extensions.redis_names import serialize_redis_name
|
||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||
from libs.broadcast_channel.signals import SIG_CLOSE
|
||||
from redis import Redis, RedisCluster
|
||||
|
||||
from ._subscription import RedisSubscriptionBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BroadcastChannel:
|
||||
"""
|
||||
@ -22,16 +26,11 @@ class BroadcastChannel:
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: Redis | RedisCluster,
|
||||
*,
|
||||
join_timeout_ms: int = 2000,
|
||||
):
|
||||
self._client = redis_client
|
||||
# See `RedisSubscriptionBase._join_timeout_ms`: how long close()
|
||||
# waits for the listener thread before returning.
|
||||
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
|
||||
|
||||
def topic(self, topic: str) -> Topic:
|
||||
return Topic(self._client, topic, join_timeout_ms=self._join_timeout_ms)
|
||||
return Topic(self._client, topic)
|
||||
|
||||
|
||||
class Topic:
|
||||
@ -39,13 +38,10 @@ class Topic:
|
||||
self,
|
||||
redis_client: Redis | RedisCluster,
|
||||
topic: str,
|
||||
*,
|
||||
join_timeout_ms: int = 2000,
|
||||
):
|
||||
self._client = redis_client
|
||||
self._topic = topic
|
||||
self._redis_topic = serialize_redis_name(topic)
|
||||
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
|
||||
|
||||
def as_producer(self) -> Producer:
|
||||
return self
|
||||
@ -61,7 +57,6 @@ class Topic:
|
||||
client=self._client,
|
||||
pubsub=self._client.pubsub(),
|
||||
topic=self._redis_topic,
|
||||
join_timeout_ms=self._join_timeout_ms,
|
||||
)
|
||||
|
||||
|
||||
@ -72,6 +67,13 @@ class _RedisSubscription(RedisSubscriptionBase):
|
||||
def _get_subscription_type(self) -> str:
|
||||
return "regular"
|
||||
|
||||
@override
|
||||
def _publish_close_event(self) -> None:
|
||||
try:
|
||||
self._client.publish(self._topic, SIG_CLOSE)
|
||||
except Exception:
|
||||
logger.exception("failed to publish close event")
|
||||
|
||||
@override
|
||||
def _subscribe(self) -> None:
|
||||
assert self._pubsub is not None
|
||||
@ -1,13 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, override
|
||||
|
||||
from extensions.redis_names import serialize_redis_name
|
||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||
from libs.broadcast_channel.signals import SIG_CLOSE
|
||||
from redis import Redis, RedisCluster
|
||||
|
||||
from ._subscription import RedisSubscriptionBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ShardedRedisBroadcastChannel:
|
||||
"""
|
||||
@ -20,14 +24,11 @@ class ShardedRedisBroadcastChannel:
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: Redis | RedisCluster,
|
||||
*,
|
||||
join_timeout_ms: int = 2000,
|
||||
):
|
||||
self._client = redis_client
|
||||
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
|
||||
|
||||
def topic(self, topic: str) -> ShardedTopic:
|
||||
return ShardedTopic(self._client, topic, join_timeout_ms=self._join_timeout_ms)
|
||||
return ShardedTopic(self._client, topic)
|
||||
|
||||
|
||||
class ShardedTopic:
|
||||
@ -35,13 +36,10 @@ class ShardedTopic:
|
||||
self,
|
||||
redis_client: Redis | RedisCluster,
|
||||
topic: str,
|
||||
*,
|
||||
join_timeout_ms: int = 2000,
|
||||
):
|
||||
self._client = redis_client
|
||||
self._topic = topic
|
||||
self._redis_topic = serialize_redis_name(topic)
|
||||
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
|
||||
|
||||
def as_producer(self) -> Producer:
|
||||
return self
|
||||
@ -57,7 +55,6 @@ class ShardedTopic:
|
||||
client=self._client,
|
||||
pubsub=self._client.pubsub(),
|
||||
topic=self._redis_topic,
|
||||
join_timeout_ms=self._join_timeout_ms,
|
||||
)
|
||||
|
||||
|
||||
@ -68,6 +65,13 @@ class _RedisShardedSubscription(RedisSubscriptionBase):
|
||||
def _get_subscription_type(self) -> str:
|
||||
return "sharded"
|
||||
|
||||
@override
|
||||
def _publish_close_event(self) -> None:
|
||||
try:
|
||||
self._client.spublish(self._topic, SIG_CLOSE) # type: ignore[attr-defined,union-attr]
|
||||
except Exception:
|
||||
logger.exception("failed to publish close event")
|
||||
|
||||
@override
|
||||
def _subscribe(self) -> None:
|
||||
assert self._pubsub is not None
|
||||
|
||||
@ -9,6 +9,7 @@ from typing import Self, override
|
||||
from extensions.redis_names import serialize_redis_name
|
||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||
from libs.broadcast_channel.signals import SIG_CLOSE
|
||||
from redis import Redis, RedisCluster
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -29,20 +30,15 @@ class StreamsBroadcastChannel:
|
||||
redis_client: Redis | RedisCluster,
|
||||
*,
|
||||
retention_seconds: int = 600,
|
||||
join_timeout_ms: int = 2000,
|
||||
):
|
||||
self._client = redis_client
|
||||
self._retention_seconds = max(int(retention_seconds or 0), 0)
|
||||
# Max time close() will wait for the listener thread to finish.
|
||||
# See `_StreamsSubscription._join_timeout_ms` for the rationale.
|
||||
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
|
||||
|
||||
def topic(self, topic: str) -> StreamsTopic:
|
||||
return StreamsTopic(
|
||||
self._client,
|
||||
topic,
|
||||
retention_seconds=self._retention_seconds,
|
||||
join_timeout_ms=self._join_timeout_ms,
|
||||
)
|
||||
|
||||
|
||||
@ -53,13 +49,11 @@ class StreamsTopic:
|
||||
topic: str,
|
||||
*,
|
||||
retention_seconds: int = 600,
|
||||
join_timeout_ms: int = 2000,
|
||||
):
|
||||
self._client = redis_client
|
||||
self._topic = topic
|
||||
self._key = serialize_redis_name(f"stream:{topic}")
|
||||
self._retention_seconds = retention_seconds
|
||||
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
|
||||
self.max_length = 5000
|
||||
|
||||
def as_producer(self) -> Producer:
|
||||
@ -77,23 +71,15 @@ class StreamsTopic:
|
||||
return self
|
||||
|
||||
def subscribe(self) -> Subscription:
|
||||
return _StreamsSubscription(self._client, self._key, join_timeout_ms=self._join_timeout_ms)
|
||||
return _StreamsSubscription(self._client, self._key)
|
||||
|
||||
|
||||
class _StreamsSubscription(Subscription):
|
||||
_SENTINEL = object()
|
||||
|
||||
def __init__(self, client: Redis | RedisCluster, key: str, *, join_timeout_ms: int = 2000):
|
||||
def __init__(self, client: Redis | RedisCluster, key: str):
|
||||
self._client = client
|
||||
self._key = key
|
||||
# Max time close() will wait for the listener thread to finish before
|
||||
# returning. Bounds SSE close tail latency: the listener blocks on
|
||||
# XREAD with BLOCK=1000ms, so close() naturally waits up to ~1s for
|
||||
# the thread to notice _closed. Setting this lower lets close()
|
||||
# return promptly while the daemon listener exits on its own within
|
||||
# one BLOCK window - safe because the listener holds no critical
|
||||
# state. ``0`` means close() does not wait at all.
|
||||
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
|
||||
|
||||
self._queue: queue.Queue[object] = queue.Queue()
|
||||
|
||||
@ -106,7 +92,6 @@ class _StreamsSubscription(Subscription):
|
||||
# reading and writing the _listener / `_closed` attribute.
|
||||
self._lock = threading.Lock()
|
||||
self._closed: bool = False
|
||||
# self._closed = threading.Event()
|
||||
self._listener: threading.Thread | None = None
|
||||
|
||||
def _listen(self) -> None:
|
||||
@ -144,6 +129,8 @@ class _StreamsSubscription(Subscription):
|
||||
case bytes() | bytearray():
|
||||
data_bytes = bytes(data)
|
||||
if data_bytes is not None:
|
||||
if data_bytes == SIG_CLOSE:
|
||||
break
|
||||
self._queue.put_nowait(data_bytes)
|
||||
last_id = entry_id
|
||||
finally:
|
||||
@ -203,6 +190,13 @@ class _StreamsSubscription(Subscription):
|
||||
assert isinstance(item, (bytes, bytearray)), "Unexpected item type in stream queue"
|
||||
return bytes(item)
|
||||
|
||||
def _publish_close_event(self) -> None:
|
||||
"""Publish an empty message to the stream to unblock the listener's xread."""
|
||||
try:
|
||||
self._client.xadd(self._key, {b"data": SIG_CLOSE})
|
||||
except Exception:
|
||||
logger.exception("failed to publish close event")
|
||||
|
||||
@override
|
||||
def close(self) -> None:
|
||||
with self._lock:
|
||||
@ -212,16 +206,17 @@ class _StreamsSubscription(Subscription):
|
||||
listener = self._listener
|
||||
if listener is not None:
|
||||
self._listener = None
|
||||
# We close the listener outside of the with block to avoid holding the
|
||||
# lock for a long time.
|
||||
|
||||
if listener is not None:
|
||||
self._publish_close_event()
|
||||
|
||||
if listener is not None and listener.is_alive():
|
||||
listener.join(timeout=self._join_timeout_ms / 1000.0)
|
||||
listener.join(timeout=2)
|
||||
if listener.is_alive():
|
||||
logger.debug(
|
||||
"Streams subscription listener for key %s did not stop within %dms; "
|
||||
"Streams subscription listener for key %s did not stop after join; "
|
||||
"daemon thread will exit on its own within one poll window.",
|
||||
self._key,
|
||||
self._join_timeout_ms,
|
||||
)
|
||||
|
||||
# Context manager helpers
|
||||
|
||||
1
api/libs/broadcast_channel/signals.py
Normal file
1
api/libs/broadcast_channel/signals.py
Normal file
@ -0,0 +1 @@
|
||||
SIG_CLOSE = b"__closed__"
|
||||
@ -20,7 +20,7 @@ from testcontainers.redis import RedisContainer
|
||||
|
||||
from libs.broadcast_channel.channel import BroadcastChannel, Subscription, Topic
|
||||
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||
from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel
|
||||
from libs.broadcast_channel.redis.pubsub_channel import BroadcastChannel as RedisBroadcastChannel
|
||||
|
||||
|
||||
class TestRedisBroadcastChannelIntegration:
|
||||
|
||||
@ -2,7 +2,7 @@ import pytest
|
||||
|
||||
from configs import dify_config
|
||||
from extensions import ext_redis
|
||||
from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel
|
||||
from libs.broadcast_channel.redis.pubsub_channel import BroadcastChannel as RedisBroadcastChannel
|
||||
from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel
|
||||
|
||||
|
||||
|
||||
@ -18,13 +18,10 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from libs.broadcast_channel.exc import BroadcastChannelError, SubscriptionClosedError
|
||||
from libs.broadcast_channel.redis.channel import (
|
||||
from libs.broadcast_channel.redis.pubsub_channel import (
|
||||
BroadcastChannel as RedisBroadcastChannel,
|
||||
)
|
||||
from libs.broadcast_channel.redis.channel import (
|
||||
Topic,
|
||||
_RedisSubscription,
|
||||
)
|
||||
from libs.broadcast_channel.redis.pubsub_channel import Topic, _RedisSubscription
|
||||
from libs.broadcast_channel.redis.sharded_channel import (
|
||||
ShardedRedisBroadcastChannel,
|
||||
ShardedTopic,
|
||||
|
||||
@ -77,11 +77,28 @@ class FailExpireRedis(FakeStreamsRedis):
|
||||
|
||||
|
||||
class BlockingRedis:
|
||||
"""A Redis mock whose xread blocks until a control event is xadd-ed."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._release = threading.Event()
|
||||
self._store: dict[str, list[tuple[str, dict]]] = {}
|
||||
self._next_id: dict[str, int] = {}
|
||||
|
||||
def xadd(self, key: str, fields: dict[str, Any], *, maxlen: int | None = None) -> str:
|
||||
n = self._next_id.get(key, 0) + 1
|
||||
self._next_id[key] = n
|
||||
entry_id = f"{n}-0"
|
||||
self._store.setdefault(key, []).append((entry_id, fields))
|
||||
self._release.set() # Wake up any blocked xread
|
||||
return entry_id
|
||||
|
||||
def xread(self, streams: dict[str, Any], block: int | None = None, count: int | None = None):
|
||||
self._release.wait(timeout=block / 1000.0 if block else None)
|
||||
key = next(iter(streams))
|
||||
entries = self._store.get(key, [])
|
||||
if entries:
|
||||
self._store[key] = [] # Consume entries
|
||||
return [(key, entries)]
|
||||
return []
|
||||
|
||||
def release(self) -> None:
|
||||
@ -176,48 +193,6 @@ class TestStreamsBroadcastChannel:
|
||||
assert topic.as_producer() is topic
|
||||
assert topic.as_subscriber() is topic
|
||||
|
||||
def test_join_timeout_ms_propagates_from_channel_to_subscription(self, fake_redis: FakeStreamsRedis):
|
||||
channel = StreamsBroadcastChannel(fake_redis, retention_seconds=60, join_timeout_ms=150)
|
||||
topic = channel.topic("join-timeout-prop")
|
||||
|
||||
assert topic._join_timeout_ms == 150
|
||||
|
||||
sub = topic.subscribe()
|
||||
try:
|
||||
assert sub._join_timeout_ms == 150
|
||||
finally:
|
||||
sub.close()
|
||||
|
||||
def test_join_timeout_ms_defaults_to_2000(self, fake_redis: FakeStreamsRedis):
|
||||
channel = StreamsBroadcastChannel(fake_redis, retention_seconds=60)
|
||||
topic = channel.topic("join-timeout-default")
|
||||
|
||||
assert topic._join_timeout_ms == 2000
|
||||
|
||||
def test_small_join_timeout_makes_close_return_promptly(self, fake_redis: FakeStreamsRedis):
|
||||
"""close() should respect the configured join timeout.
|
||||
|
||||
Regression test for SSE close tail latency: when an idle listener is
|
||||
blocked on its poll cycle, close() with a small join_timeout_ms must
|
||||
not wait for the full poll window. The orphaned daemon listener
|
||||
cleans itself up later.
|
||||
"""
|
||||
channel = StreamsBroadcastChannel(fake_redis, retention_seconds=60, join_timeout_ms=50)
|
||||
topic = channel.topic("join-timeout-prompt-close")
|
||||
sub = topic.subscribe()
|
||||
|
||||
# Drive listener startup so the thread is actually blocked in xread.
|
||||
assert sub.receive(timeout=0.05) is None
|
||||
time.sleep(0.05)
|
||||
|
||||
started = time.monotonic()
|
||||
sub.close()
|
||||
elapsed = time.monotonic() - started
|
||||
|
||||
# 50ms timeout + scheduling slack; pick a ceiling well under the
|
||||
# default poll window (1000ms) to make the regression meaningful.
|
||||
assert elapsed < 0.5, f"close() took {elapsed:.3f}s; expected prompt return"
|
||||
|
||||
def test_publish_logs_warning_when_expire_fails(self, caplog: pytest.LogCaptureFixture):
|
||||
channel = StreamsBroadcastChannel(FailExpireRedis(), retention_seconds=60)
|
||||
topic = channel.topic("expire-warning")
|
||||
@ -384,40 +359,32 @@ class TestStreamsSubscription:
|
||||
|
||||
assert next(iter(subscription)) == b"event"
|
||||
|
||||
def test_close_logs_debug_when_listener_does_not_stop_in_time(
|
||||
self,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
):
|
||||
"""When a low join_timeout elapses with the listener still alive,
|
||||
close() should log at DEBUG (not WARNING) - with a deliberately small
|
||||
timeout this is expected, not anomalous; the orphaned daemon thread
|
||||
cleans itself up on the next poll boundary.
|
||||
def test_control_event_unblocks_listener_for_prompt_close(self):
|
||||
"""close() returns promptly because the control event (xadd) unblocks
|
||||
the listener from its blocking xread call.
|
||||
"""
|
||||
import logging
|
||||
|
||||
blocking_redis = BlockingRedis()
|
||||
subscription = _StreamsSubscription(blocking_redis, "stream:slow-close")
|
||||
subscription = _StreamsSubscription(blocking_redis, "stream:prompt-close")
|
||||
|
||||
# Drive listener startup so the thread is blocked in xread.
|
||||
subscription._start_if_needed()
|
||||
listener = subscription._listener
|
||||
assert listener is not None
|
||||
assert listener.is_alive()
|
||||
|
||||
original_join = listener.join
|
||||
original_is_alive = listener.is_alive
|
||||
started = time.monotonic()
|
||||
subscription.close()
|
||||
elapsed = time.monotonic() - started
|
||||
|
||||
def delayed_join(timeout: float | None = None) -> None:
|
||||
original_join(0.01)
|
||||
# The control event (xadd) wakes up xread immediately, so close()
|
||||
# should return well under 1s (the xread BLOCK timeout).
|
||||
assert elapsed < 0.5, f"close() took {elapsed:.3f}s; expected prompt return via control event"
|
||||
|
||||
listener.join = delayed_join # type: ignore[method-assign]
|
||||
listener.is_alive = lambda: True # type: ignore[method-assign]
|
||||
def test_control_event_not_sent_when_listener_not_started(self):
|
||||
"""close() should not fail when the listener was never started."""
|
||||
subscription = _StreamsSubscription(FakeStreamsRedis(), "stream:no-listener")
|
||||
subscription.close()
|
||||
|
||||
try:
|
||||
with caplog.at_level(logging.DEBUG, logger="libs.broadcast_channel.redis.streams_channel"):
|
||||
subscription.close()
|
||||
assert "did not stop within" in caplog.text
|
||||
assert "daemon thread will exit on its own" in caplog.text
|
||||
finally:
|
||||
listener.join = original_join # type: ignore[method-assign]
|
||||
listener.is_alive = original_is_alive # type: ignore[method-assign]
|
||||
blocking_redis.release()
|
||||
original_join(timeout=1)
|
||||
assert subscription._listener is None
|
||||
with pytest.raises(SubscriptionClosedError):
|
||||
subscription.receive(timeout=0.01)
|
||||
|
||||
@ -109,7 +109,7 @@ def _patch_get_channel_streams(monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
@pytest.fixture
|
||||
def _patch_get_channel_pubsub(monkeypatch: pytest.MonkeyPatch):
|
||||
from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel
|
||||
from libs.broadcast_channel.redis.pubsub_channel import BroadcastChannel as RedisBroadcastChannel
|
||||
|
||||
store: dict[str, deque[bytes]] = defaultdict(deque)
|
||||
client = _FakeRedisClient(store)
|
||||
|
||||
@ -120,7 +120,6 @@ CELERY_TASK_ANNOTATIONS=null
|
||||
EVENT_BUS_REDIS_URL=
|
||||
EVENT_BUS_REDIS_CHANNEL_TYPE=pubsub
|
||||
EVENT_BUS_REDIS_USE_CLUSTERS=false
|
||||
EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS=2000
|
||||
|
||||
# Web and app limits
|
||||
WEB_API_CORS_ALLOW_ORIGINS=*
|
||||
|
||||
Loading…
Reference in New Issue
Block a user