mirror of https://github.com/langgenius/dify.git
feat(api): Introduce Broadcast Channel (#27835)
This PR introduces a `BroadcastChannel` abstraction with broadcasting and at-most once delivery semantics, serving as the communication component between celery worker and API server. It also includes a reference implementation backed by Redis PubSub. Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
ed234e311b
commit
b9bc48d8dd
|
|
@ -0,0 +1,134 @@
|
|||
"""
|
||||
Broadcast channel for Pub/Sub messaging.
|
||||
"""
|
||||
|
||||
import types
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Iterator
|
||||
from contextlib import AbstractContextManager
|
||||
from typing import Protocol, Self
|
||||
|
||||
|
||||
class Subscription(AbstractContextManager["Subscription"], Protocol):
|
||||
"""A subscription to a topic that provides an iterator over received messages.
|
||||
The subscription can be used as a context manager and will automatically
|
||||
close when exiting the context.
|
||||
|
||||
Note: `Subscription` instances are not thread-safe. Each thread should create its own
|
||||
subscription.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __iter__(self) -> Iterator[bytes]:
|
||||
"""`__iter__` returns an iterator used to consume the message from this subscription.
|
||||
|
||||
If the caller did not enter the context, `__iter__` may lazily perform the setup before
|
||||
yielding messages; otherwise `__enter__` handles it.”
|
||||
|
||||
If the subscription is closed, then the returned iterator exits without
|
||||
raising any error.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
"""close closes the subscription, releases any resources associated with it."""
|
||||
...
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
"""`__enter__` does the setup logic of the subscription (if any), and return itself."""
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: types.TracebackType | None,
|
||||
) -> bool | None:
|
||||
self.close()
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def receive(self, timeout: float | None = 0.1) -> bytes | None:
|
||||
"""Receive the next message from the broadcast channel.
|
||||
|
||||
If `timeout` is specified, this method returns `None` if no message is
|
||||
received within the given period. If `timeout` is `None`, the call blocks
|
||||
until a message is received.
|
||||
|
||||
Calling receive with `timeout=None` is highly discouraged, as it is impossible to
|
||||
cancel a blocking subscription.
|
||||
|
||||
:param timeout: timeout for receive message, in seconds.
|
||||
|
||||
Returns:
|
||||
bytes: The received message as a byte string, or
|
||||
None: If the timeout expires before a message is received.
|
||||
|
||||
Raises:
|
||||
SubscriptionClosed: If the subscription has already been closed.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class Producer(Protocol):
|
||||
"""Producer is an interface for message publishing. It is already bound to a specific topic.
|
||||
|
||||
`Producer` implementations must be thread-safe and support concurrent use by multiple threads.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def publish(self, payload: bytes) -> None:
|
||||
"""Publish a message to the bounded topic."""
|
||||
...
|
||||
|
||||
|
||||
class Subscriber(Protocol):
|
||||
"""Subscriber is an interface for subscription creation. It is already bound to a specific topic.
|
||||
|
||||
`Subscriber` implementations must be thread-safe and support concurrent use by multiple threads.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def subscribe(self) -> Subscription:
|
||||
pass
|
||||
|
||||
|
||||
class Topic(Producer, Subscriber, Protocol):
|
||||
"""A named channel for publishing and subscribing to messages.
|
||||
|
||||
Topics provide both read and write access. For restricted access,
|
||||
use as_producer() for write-only view or as_subscriber() for read-only view.
|
||||
|
||||
`Topic` implementations must be thread-safe and support concurrent use by multiple threads.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def as_producer(self) -> Producer:
|
||||
"""as_producer creates a write-only view for this topic."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def as_subscriber(self) -> Subscriber:
|
||||
"""as_subscriber create a read-only view for this topic."""
|
||||
...
|
||||
|
||||
|
||||
class BroadcastChannel(Protocol):
|
||||
"""A broadcasting channel is a channel supporting broadcasting semantics.
|
||||
|
||||
Each channel is identified by a topic, different topics are isolated and do not affect each other.
|
||||
|
||||
There can be multiple subscriptions to a specific topic. When a publisher publishes a message to
|
||||
a specific topic, all subscription should receive the published message.
|
||||
|
||||
There are no restriction for the persistence of messages. Once a subscription is created, it
|
||||
should receive all subsequent messages published.
|
||||
|
||||
`BroadcastChannel` implementations must be thread-safe and support concurrent use by multiple threads.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def topic(self, topic: str) -> "Topic":
|
||||
"""topic returns a `Topic` instance for the given topic name."""
|
||||
...
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
class BroadcastChannelError(Exception):
|
||||
"""`BroadcastChannelError` is the base class for all exceptions related
|
||||
to `BroadcastChannel`."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SubscriptionClosedError(BroadcastChannelError):
|
||||
"""SubscriptionClosedError means that the subscription has been closed and
|
||||
methods for consuming messages should not be called."""
|
||||
|
||||
pass
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
from .channel import BroadcastChannel
|
||||
|
||||
__all__ = ["BroadcastChannel"]
|
||||
|
|
@ -0,0 +1,200 @@
|
|||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import types
|
||||
from collections.abc import Generator, Iterator
|
||||
from typing import Self
|
||||
|
||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||
from redis import Redis
|
||||
from redis.client import PubSub
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BroadcastChannel:
|
||||
"""
|
||||
Redis Pub/Sub based broadcast channel implementation.
|
||||
|
||||
Provides "at most once" delivery semantics for messages published to channels.
|
||||
Uses Redis PUBLISH/SUBSCRIBE commands for real-time message delivery.
|
||||
|
||||
The `redis_client` used to construct BroadcastChannel should have `decode_responses` set to `False`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: Redis,
|
||||
):
|
||||
self._client = redis_client
|
||||
|
||||
def topic(self, topic: str) -> "Topic":
|
||||
return Topic(self._client, topic)
|
||||
|
||||
|
||||
class Topic:
|
||||
def __init__(self, redis_client: Redis, topic: str):
|
||||
self._client = redis_client
|
||||
self._topic = topic
|
||||
|
||||
def as_producer(self) -> Producer:
|
||||
return self
|
||||
|
||||
def publish(self, payload: bytes) -> None:
|
||||
self._client.publish(self._topic, payload)
|
||||
|
||||
def as_subscriber(self) -> Subscriber:
|
||||
return self
|
||||
|
||||
def subscribe(self) -> Subscription:
|
||||
return _RedisSubscription(
|
||||
pubsub=self._client.pubsub(),
|
||||
topic=self._topic,
|
||||
)
|
||||
|
||||
|
||||
class _RedisSubscription(Subscription):
|
||||
def __init__(
|
||||
self,
|
||||
pubsub: PubSub,
|
||||
topic: str,
|
||||
):
|
||||
# The _pubsub is None only if the subscription is closed.
|
||||
self._pubsub: PubSub | None = pubsub
|
||||
self._topic = topic
|
||||
self._closed = threading.Event()
|
||||
self._queue: queue.Queue[bytes] = queue.Queue(maxsize=1024)
|
||||
self._dropped_count = 0
|
||||
self._listener_thread: threading.Thread | None = None
|
||||
self._start_lock = threading.Lock()
|
||||
self._started = False
|
||||
|
||||
def _start_if_needed(self) -> None:
|
||||
with self._start_lock:
|
||||
if self._started:
|
||||
return
|
||||
if self._closed.is_set():
|
||||
raise SubscriptionClosedError("The Redis subscription is closed")
|
||||
if self._pubsub is None:
|
||||
raise SubscriptionClosedError("The Redis subscription has been cleaned up")
|
||||
|
||||
self._pubsub.subscribe(self._topic)
|
||||
_logger.debug("Subscribed to channel %s", self._topic)
|
||||
|
||||
self._listener_thread = threading.Thread(
|
||||
target=self._listen,
|
||||
name=f"redis-broadcast-{self._topic}",
|
||||
daemon=True,
|
||||
)
|
||||
self._listener_thread.start()
|
||||
self._started = True
|
||||
|
||||
def _listen(self) -> None:
|
||||
pubsub = self._pubsub
|
||||
assert pubsub is not None, "PubSub should not be None while starting listening."
|
||||
while not self._closed.is_set():
|
||||
raw_message = pubsub.get_message(ignore_subscribe_messages=True, timeout=0.1)
|
||||
|
||||
if raw_message is None:
|
||||
continue
|
||||
|
||||
if raw_message.get("type") != "message":
|
||||
continue
|
||||
|
||||
channel_field = raw_message.get("channel")
|
||||
if isinstance(channel_field, bytes):
|
||||
channel_name = channel_field.decode("utf-8")
|
||||
elif isinstance(channel_field, str):
|
||||
channel_name = channel_field
|
||||
else:
|
||||
channel_name = str(channel_field)
|
||||
|
||||
if channel_name != self._topic:
|
||||
_logger.warning("Ignoring message from unexpected channel %s", channel_name)
|
||||
continue
|
||||
|
||||
payload_bytes: bytes | None = raw_message.get("data")
|
||||
if not isinstance(payload_bytes, bytes):
|
||||
_logger.error("Received invalid data from channel %s, type=%s", self._topic, type(payload_bytes))
|
||||
continue
|
||||
|
||||
self._enqueue_message(payload_bytes)
|
||||
|
||||
_logger.debug("Listener thread stopped for channel %s", self._topic)
|
||||
pubsub.unsubscribe(self._topic)
|
||||
pubsub.close()
|
||||
_logger.debug("PubSub closed for topic %s", self._topic)
|
||||
self._pubsub = None
|
||||
|
||||
def _enqueue_message(self, payload: bytes) -> None:
|
||||
while not self._closed.is_set():
|
||||
try:
|
||||
self._queue.put_nowait(payload)
|
||||
return
|
||||
except queue.Full:
|
||||
try:
|
||||
self._queue.get_nowait()
|
||||
self._dropped_count += 1
|
||||
_logger.debug(
|
||||
"Dropped message from Redis subscription, topic=%s, total_dropped=%d",
|
||||
self._topic,
|
||||
self._dropped_count,
|
||||
)
|
||||
except queue.Empty:
|
||||
continue
|
||||
return
|
||||
|
||||
def _message_iterator(self) -> Generator[bytes, None, None]:
|
||||
while not self._closed.is_set():
|
||||
try:
|
||||
item = self._queue.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
yield item
|
||||
|
||||
def __iter__(self) -> Iterator[bytes]:
|
||||
if self._closed.is_set():
|
||||
raise SubscriptionClosedError("The Redis subscription is closed")
|
||||
self._start_if_needed()
|
||||
return iter(self._message_iterator())
|
||||
|
||||
def receive(self, timeout: float | None = None) -> bytes | None:
|
||||
if self._closed.is_set():
|
||||
raise SubscriptionClosedError("The Redis subscription is closed")
|
||||
self._start_if_needed()
|
||||
|
||||
try:
|
||||
item = self._queue.get(timeout=timeout)
|
||||
except queue.Empty:
|
||||
return None
|
||||
|
||||
return item
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
self._start_if_needed()
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: types.TracebackType | None,
|
||||
) -> bool | None:
|
||||
self.close()
|
||||
return None
|
||||
|
||||
def close(self) -> None:
|
||||
if self._closed.is_set():
|
||||
return
|
||||
|
||||
self._closed.set()
|
||||
# NOTE: PubSub is not thread-safe. More specifically, the `PubSub.close` method and the `PubSub.get_message`
|
||||
# method should NOT be called concurrently.
|
||||
#
|
||||
# Due to the restriction above, the PubSub cleanup logic happens inside the consumer thread.
|
||||
listener = self._listener_thread
|
||||
if listener is not None:
|
||||
listener.join(timeout=1.0)
|
||||
self._listener_thread = None
|
||||
|
|
@ -0,0 +1,311 @@
|
|||
"""
|
||||
Integration tests for Redis broadcast channel implementation using TestContainers.
|
||||
|
||||
This test suite covers real Redis interactions including:
|
||||
- Multiple producer/consumer scenarios
|
||||
- Network failure scenarios
|
||||
- Performance under load
|
||||
- Real-world usage patterns
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Iterator
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import pytest
|
||||
import redis
|
||||
from testcontainers.redis import RedisContainer
|
||||
|
||||
from libs.broadcast_channel.channel import BroadcastChannel, Subscription, Topic
|
||||
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||
from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel
|
||||
|
||||
|
||||
class TestRedisBroadcastChannelIntegration:
|
||||
"""Integration tests for Redis broadcast channel with real Redis instance."""
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def redis_container(self) -> Iterator[RedisContainer]:
|
||||
"""Create a Redis container for integration testing."""
|
||||
with RedisContainer(image="redis:6-alpine") as container:
|
||||
yield container
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def redis_client(self, redis_container: RedisContainer) -> redis.Redis:
|
||||
"""Create a Redis client connected to the test container."""
|
||||
host = redis_container.get_container_host_ip()
|
||||
port = redis_container.get_exposed_port(6379)
|
||||
return redis.Redis(host=host, port=port, decode_responses=False)
|
||||
|
||||
@pytest.fixture
|
||||
def broadcast_channel(self, redis_client: redis.Redis) -> BroadcastChannel:
|
||||
"""Create a BroadcastChannel instance with real Redis client."""
|
||||
return RedisBroadcastChannel(redis_client)
|
||||
|
||||
@classmethod
|
||||
def _get_test_topic_name(cls):
|
||||
return f"test_topic_{uuid.uuid4()}"
|
||||
|
||||
# ==================== Basic Functionality Tests ===================='
|
||||
|
||||
def test_close_an_active_subscription_should_stop_iteration(self, broadcast_channel):
|
||||
topic_name = self._get_test_topic_name()
|
||||
topic = broadcast_channel.topic(topic_name)
|
||||
subscription = topic.subscribe()
|
||||
consuming_event = threading.Event()
|
||||
|
||||
def consume():
|
||||
msgs = []
|
||||
consuming_event.set()
|
||||
for msg in subscription:
|
||||
msgs.append(msg)
|
||||
return msgs
|
||||
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
producer_future = executor.submit(consume)
|
||||
consuming_event.wait()
|
||||
subscription.close()
|
||||
msgs = producer_future.result(timeout=1)
|
||||
assert msgs == []
|
||||
|
||||
def test_end_to_end_messaging(self, broadcast_channel: BroadcastChannel):
|
||||
"""Test complete end-to-end messaging flow."""
|
||||
topic_name = "test-topic"
|
||||
message = b"hello world"
|
||||
|
||||
# Create producer and subscriber
|
||||
topic = broadcast_channel.topic(topic_name)
|
||||
producer = topic.as_producer()
|
||||
subscription = topic.subscribe()
|
||||
|
||||
# Publish and receive message
|
||||
|
||||
def producer_thread():
|
||||
time.sleep(0.1) # Small delay to ensure subscriber is ready
|
||||
producer.publish(message)
|
||||
time.sleep(0.1)
|
||||
subscription.close()
|
||||
|
||||
def consumer_thread() -> list[bytes]:
|
||||
received_messages = []
|
||||
for msg in subscription:
|
||||
received_messages.append(msg)
|
||||
return received_messages
|
||||
|
||||
# Run producer and consumer
|
||||
with ThreadPoolExecutor(max_workers=2) as executor:
|
||||
producer_future = executor.submit(producer_thread)
|
||||
consumer_future = executor.submit(consumer_thread)
|
||||
|
||||
# Wait for completion
|
||||
producer_future.result(timeout=5.0)
|
||||
received_messages = consumer_future.result(timeout=5.0)
|
||||
|
||||
assert len(received_messages) == 1
|
||||
assert received_messages[0] == message
|
||||
|
||||
def test_multiple_subscribers_same_topic(self, broadcast_channel: BroadcastChannel):
|
||||
"""Test message broadcasting to multiple subscribers."""
|
||||
topic_name = "broadcast-topic"
|
||||
message = b"broadcast message"
|
||||
subscriber_count = 5
|
||||
|
||||
# Create producer and multiple subscribers
|
||||
topic = broadcast_channel.topic(topic_name)
|
||||
producer = topic.as_producer()
|
||||
subscriptions = [topic.subscribe() for _ in range(subscriber_count)]
|
||||
|
||||
def producer_thread():
|
||||
time.sleep(0.2) # Allow all subscribers to connect
|
||||
producer.publish(message)
|
||||
time.sleep(0.2)
|
||||
for sub in subscriptions:
|
||||
sub.close()
|
||||
|
||||
def consumer_thread(subscription: Subscription) -> list[bytes]:
|
||||
received_msgs = []
|
||||
while True:
|
||||
try:
|
||||
msg = subscription.receive(0.1)
|
||||
except SubscriptionClosedError:
|
||||
break
|
||||
if msg is None:
|
||||
continue
|
||||
received_msgs.append(msg)
|
||||
if len(received_msgs) >= 1:
|
||||
break
|
||||
return received_msgs
|
||||
|
||||
# Run producer and consumers
|
||||
with ThreadPoolExecutor(max_workers=subscriber_count + 1) as executor:
|
||||
producer_future = executor.submit(producer_thread)
|
||||
consumer_futures = [executor.submit(consumer_thread, subscription) for subscription in subscriptions]
|
||||
|
||||
# Wait for completion
|
||||
producer_future.result(timeout=10.0)
|
||||
msgs_by_consumers = []
|
||||
for future in as_completed(consumer_futures, timeout=10.0):
|
||||
msgs_by_consumers.append(future.result())
|
||||
|
||||
# Close all subscriptions
|
||||
for subscription in subscriptions:
|
||||
subscription.close()
|
||||
|
||||
# Verify all subscribers received the message
|
||||
for msgs in msgs_by_consumers:
|
||||
assert len(msgs) == 1
|
||||
assert msgs[0] == message
|
||||
|
||||
def test_topic_isolation(self, broadcast_channel: BroadcastChannel):
|
||||
"""Test that different topics are isolated from each other."""
|
||||
topic1_name = "topic1"
|
||||
topic2_name = "topic2"
|
||||
message1 = b"message for topic1"
|
||||
message2 = b"message for topic2"
|
||||
|
||||
# Create producers and subscribers for different topics
|
||||
topic1 = broadcast_channel.topic(topic1_name)
|
||||
topic2 = broadcast_channel.topic(topic2_name)
|
||||
|
||||
def producer_thread():
|
||||
time.sleep(0.1)
|
||||
topic1.publish(message1)
|
||||
topic2.publish(message2)
|
||||
|
||||
def consumer_by_thread(topic: Topic) -> list[bytes]:
|
||||
subscription = topic.subscribe()
|
||||
received = []
|
||||
with subscription:
|
||||
for msg in subscription:
|
||||
received.append(msg)
|
||||
if len(received) >= 1:
|
||||
break
|
||||
return received
|
||||
|
||||
# Run all threads
|
||||
with ThreadPoolExecutor(max_workers=3) as executor:
|
||||
producer_future = executor.submit(producer_thread)
|
||||
consumer1_future = executor.submit(consumer_by_thread, topic1)
|
||||
consumer2_future = executor.submit(consumer_by_thread, topic2)
|
||||
|
||||
# Wait for completion
|
||||
producer_future.result(timeout=5.0)
|
||||
received_by_topic1 = consumer1_future.result(timeout=5.0)
|
||||
received_by_topic2 = consumer2_future.result(timeout=5.0)
|
||||
|
||||
# Verify topic isolation
|
||||
assert len(received_by_topic1) == 1
|
||||
assert len(received_by_topic2) == 1
|
||||
assert received_by_topic1[0] == message1
|
||||
assert received_by_topic2[0] == message2
|
||||
|
||||
# ==================== Performance Tests ====================
|
||||
|
||||
def test_concurrent_producers(self, broadcast_channel: BroadcastChannel):
|
||||
"""Test multiple producers publishing to the same topic."""
|
||||
topic_name = "concurrent-producers-topic"
|
||||
producer_count = 5
|
||||
messages_per_producer = 5
|
||||
|
||||
topic = broadcast_channel.topic(topic_name)
|
||||
subscription = topic.subscribe()
|
||||
|
||||
expected_total = producer_count * messages_per_producer
|
||||
consumer_ready = threading.Event()
|
||||
|
||||
def producer_thread(producer_idx: int) -> set[bytes]:
|
||||
producer = topic.as_producer()
|
||||
produced = set()
|
||||
for i in range(messages_per_producer):
|
||||
message = f"producer_{producer_idx}_msg_{i}".encode()
|
||||
produced.add(message)
|
||||
producer.publish(message)
|
||||
time.sleep(0.001) # Small delay to avoid overwhelming
|
||||
return produced
|
||||
|
||||
def consumer_thread() -> set[bytes]:
|
||||
received_msgs: set[bytes] = set()
|
||||
with subscription:
|
||||
consumer_ready.set()
|
||||
while True:
|
||||
try:
|
||||
msg = subscription.receive(timeout=0.1)
|
||||
except SubscriptionClosedError:
|
||||
break
|
||||
if msg is None:
|
||||
if len(received_msgs) >= expected_total:
|
||||
break
|
||||
else:
|
||||
continue
|
||||
|
||||
received_msgs.add(msg)
|
||||
return received_msgs
|
||||
|
||||
# Run producers and consumer
|
||||
with ThreadPoolExecutor(max_workers=producer_count + 1) as executor:
|
||||
consumer_future = executor.submit(consumer_thread)
|
||||
consumer_ready.wait()
|
||||
producer_futures = [executor.submit(producer_thread, i) for i in range(producer_count)]
|
||||
|
||||
sent_msgs: set[bytes] = set()
|
||||
# Wait for completion
|
||||
for future in as_completed(producer_futures, timeout=30.0):
|
||||
sent_msgs.update(future.result())
|
||||
|
||||
subscription.close()
|
||||
consumer_received_msgs = consumer_future.result(timeout=30.0)
|
||||
|
||||
# Verify message content
|
||||
assert sent_msgs == consumer_received_msgs
|
||||
|
||||
# ==================== Resource Management Tests ====================
|
||||
|
||||
def test_subscription_cleanup(self, broadcast_channel: BroadcastChannel, redis_client: redis.Redis):
|
||||
"""Test proper cleanup of subscription resources."""
|
||||
topic_name = "cleanup-test-topic"
|
||||
|
||||
# Create multiple subscriptions
|
||||
topic = broadcast_channel.topic(topic_name)
|
||||
|
||||
def _consume(sub: Subscription):
|
||||
for i in sub:
|
||||
pass
|
||||
|
||||
subscriptions = []
|
||||
for i in range(5):
|
||||
subscription = topic.subscribe()
|
||||
subscriptions.append(subscription)
|
||||
|
||||
# Start all subscriptions
|
||||
thread = threading.Thread(target=_consume, args=(subscription,))
|
||||
thread.start()
|
||||
time.sleep(0.01)
|
||||
|
||||
# Verify subscriptions are active
|
||||
pubsub_info = redis_client.pubsub_numsub(topic_name)
|
||||
# pubsub_numsub returns list of tuples, find our topic
|
||||
topic_subscribers = 0
|
||||
for channel, count in pubsub_info:
|
||||
# the channel name returned by redis is bytes.
|
||||
if channel == topic_name.encode():
|
||||
topic_subscribers = count
|
||||
break
|
||||
assert topic_subscribers >= 5
|
||||
|
||||
# Close all subscriptions
|
||||
for subscription in subscriptions:
|
||||
subscription.close()
|
||||
|
||||
# Wait a bit for cleanup
|
||||
time.sleep(1)
|
||||
|
||||
# Verify subscriptions are cleaned up
|
||||
pubsub_info_after = redis_client.pubsub_numsub(topic_name)
|
||||
topic_subscribers_after = 0
|
||||
for channel, count in pubsub_info_after:
|
||||
if channel == topic_name.encode():
|
||||
topic_subscribers_after = count
|
||||
break
|
||||
assert topic_subscribers_after == 0
|
||||
|
|
@ -0,0 +1,514 @@
|
|||
"""
|
||||
Comprehensive unit tests for Redis broadcast channel implementation.
|
||||
|
||||
This test suite covers all aspects of the Redis broadcast channel including:
|
||||
- Basic functionality and contract compliance
|
||||
- Error handling and edge cases
|
||||
- Thread safety and concurrency
|
||||
- Resource management and cleanup
|
||||
- Performance and reliability scenarios
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from libs.broadcast_channel.exc import BroadcastChannelError, SubscriptionClosedError
|
||||
from libs.broadcast_channel.redis.channel import (
|
||||
BroadcastChannel as RedisBroadcastChannel,
|
||||
)
|
||||
from libs.broadcast_channel.redis.channel import (
|
||||
Topic,
|
||||
_RedisSubscription,
|
||||
)
|
||||
|
||||
|
||||
class TestBroadcastChannel:
|
||||
"""Test cases for the main BroadcastChannel class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis_client(self) -> MagicMock:
|
||||
"""Create a mock Redis client for testing."""
|
||||
client = MagicMock()
|
||||
client.pubsub.return_value = MagicMock()
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def broadcast_channel(self, mock_redis_client: MagicMock) -> RedisBroadcastChannel:
|
||||
"""Create a BroadcastChannel instance with mock Redis client."""
|
||||
return RedisBroadcastChannel(mock_redis_client)
|
||||
|
||||
def test_topic_creation(self, broadcast_channel: RedisBroadcastChannel, mock_redis_client: MagicMock):
|
||||
"""Test that topic() method returns a Topic instance with correct parameters."""
|
||||
topic_name = "test-topic"
|
||||
topic = broadcast_channel.topic(topic_name)
|
||||
|
||||
assert isinstance(topic, Topic)
|
||||
assert topic._client == mock_redis_client
|
||||
assert topic._topic == topic_name
|
||||
|
||||
def test_topic_isolation(self, broadcast_channel: RedisBroadcastChannel):
|
||||
"""Test that different topic names create isolated Topic instances."""
|
||||
topic1 = broadcast_channel.topic("topic1")
|
||||
topic2 = broadcast_channel.topic("topic2")
|
||||
|
||||
assert topic1 is not topic2
|
||||
assert topic1._topic == "topic1"
|
||||
assert topic2._topic == "topic2"
|
||||
|
||||
|
||||
class TestTopic:
|
||||
"""Test cases for the Topic class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis_client(self) -> MagicMock:
|
||||
"""Create a mock Redis client for testing."""
|
||||
client = MagicMock()
|
||||
client.pubsub.return_value = MagicMock()
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def topic(self, mock_redis_client: MagicMock) -> Topic:
|
||||
"""Create a Topic instance for testing."""
|
||||
return Topic(mock_redis_client, "test-topic")
|
||||
|
||||
def test_as_producer_returns_self(self, topic: Topic):
|
||||
"""Test that as_producer() returns self as Producer interface."""
|
||||
producer = topic.as_producer()
|
||||
assert producer is topic
|
||||
# Producer is a Protocol, check duck typing instead
|
||||
assert hasattr(producer, "publish")
|
||||
|
||||
def test_as_subscriber_returns_self(self, topic: Topic):
|
||||
"""Test that as_subscriber() returns self as Subscriber interface."""
|
||||
subscriber = topic.as_subscriber()
|
||||
assert subscriber is topic
|
||||
# Subscriber is a Protocol, check duck typing instead
|
||||
assert hasattr(subscriber, "subscribe")
|
||||
|
||||
def test_publish_calls_redis_publish(self, topic: Topic, mock_redis_client: MagicMock):
|
||||
"""Test that publish() calls Redis PUBLISH with correct parameters."""
|
||||
payload = b"test message"
|
||||
topic.publish(payload)
|
||||
|
||||
mock_redis_client.publish.assert_called_once_with("test-topic", payload)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class SubscriptionTestCase:
|
||||
"""Test case data for subscription tests."""
|
||||
|
||||
name: str
|
||||
buffer_size: int
|
||||
payload: bytes
|
||||
expected_messages: list[bytes]
|
||||
should_drop: bool = False
|
||||
description: str = ""
|
||||
|
||||
|
||||
class TestRedisSubscription:
|
||||
"""Test cases for the _RedisSubscription class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pubsub(self) -> MagicMock:
|
||||
"""Create a mock PubSub instance for testing."""
|
||||
pubsub = MagicMock()
|
||||
pubsub.subscribe = MagicMock()
|
||||
pubsub.unsubscribe = MagicMock()
|
||||
pubsub.close = MagicMock()
|
||||
pubsub.get_message = MagicMock()
|
||||
return pubsub
|
||||
|
||||
@pytest.fixture
|
||||
def subscription(self, mock_pubsub: MagicMock) -> Generator[_RedisSubscription, None, None]:
|
||||
"""Create a _RedisSubscription instance for testing."""
|
||||
subscription = _RedisSubscription(
|
||||
pubsub=mock_pubsub,
|
||||
topic="test-topic",
|
||||
)
|
||||
yield subscription
|
||||
subscription.close()
|
||||
|
||||
@pytest.fixture
|
||||
def started_subscription(self, subscription: _RedisSubscription) -> _RedisSubscription:
|
||||
"""Create a subscription that has been started."""
|
||||
subscription._start_if_needed()
|
||||
return subscription
|
||||
|
||||
# ==================== Lifecycle Tests ====================
|
||||
|
||||
def test_subscription_initialization(self, mock_pubsub: MagicMock):
|
||||
"""Test that subscription is properly initialized."""
|
||||
subscription = _RedisSubscription(
|
||||
pubsub=mock_pubsub,
|
||||
topic="test-topic",
|
||||
)
|
||||
|
||||
assert subscription._pubsub is mock_pubsub
|
||||
assert subscription._topic == "test-topic"
|
||||
assert not subscription._closed.is_set()
|
||||
assert subscription._dropped_count == 0
|
||||
assert subscription._listener_thread is None
|
||||
assert not subscription._started
|
||||
|
||||
def test_start_if_needed_first_call(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
|
||||
"""Test that _start_if_needed() properly starts subscription on first call."""
|
||||
subscription._start_if_needed()
|
||||
|
||||
mock_pubsub.subscribe.assert_called_once_with("test-topic")
|
||||
assert subscription._started is True
|
||||
assert subscription._listener_thread is not None
|
||||
|
||||
def test_start_if_needed_subsequent_calls(self, started_subscription: _RedisSubscription):
|
||||
"""Test that _start_if_needed() doesn't start subscription on subsequent calls."""
|
||||
original_thread = started_subscription._listener_thread
|
||||
started_subscription._start_if_needed()
|
||||
|
||||
# Should not create new thread or generator
|
||||
assert started_subscription._listener_thread is original_thread
|
||||
|
||||
def test_start_if_needed_when_closed(self, subscription: _RedisSubscription):
|
||||
"""Test that _start_if_needed() raises error when subscription is closed."""
|
||||
subscription.close()
|
||||
|
||||
with pytest.raises(SubscriptionClosedError, match="The Redis subscription is closed"):
|
||||
subscription._start_if_needed()
|
||||
|
||||
def test_start_if_needed_when_cleaned_up(self, subscription: _RedisSubscription):
|
||||
"""Test that _start_if_needed() raises error when pubsub is None."""
|
||||
subscription._pubsub = None
|
||||
|
||||
with pytest.raises(SubscriptionClosedError, match="The Redis subscription has been cleaned up"):
|
||||
subscription._start_if_needed()
|
||||
|
||||
def test_context_manager_usage(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
|
||||
"""Test that subscription works as context manager."""
|
||||
with subscription as sub:
|
||||
assert sub is subscription
|
||||
assert subscription._started is True
|
||||
mock_pubsub.subscribe.assert_called_once_with("test-topic")
|
||||
|
||||
def test_close_idempotent(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
|
||||
"""Test that close() is idempotent and can be called multiple times."""
|
||||
subscription._start_if_needed()
|
||||
|
||||
# Close multiple times
|
||||
subscription.close()
|
||||
subscription.close()
|
||||
subscription.close()
|
||||
|
||||
# Should only cleanup once
|
||||
mock_pubsub.unsubscribe.assert_called_once_with("test-topic")
|
||||
mock_pubsub.close.assert_called_once()
|
||||
assert subscription._pubsub is None
|
||||
assert subscription._closed.is_set()
|
||||
|
||||
def test_close_cleanup(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
|
||||
"""Test that close() properly cleans up all resources."""
|
||||
subscription._start_if_needed()
|
||||
thread = subscription._listener_thread
|
||||
|
||||
subscription.close()
|
||||
|
||||
# Verify cleanup
|
||||
mock_pubsub.unsubscribe.assert_called_once_with("test-topic")
|
||||
mock_pubsub.close.assert_called_once()
|
||||
assert subscription._pubsub is None
|
||||
assert subscription._listener_thread is None
|
||||
|
||||
# Wait for thread to finish (with timeout)
|
||||
if thread and thread.is_alive():
|
||||
thread.join(timeout=1.0)
|
||||
assert not thread.is_alive()
|
||||
|
||||
# ==================== Message Processing Tests ====================
|
||||
|
||||
def test_message_iterator_with_messages(self, started_subscription: _RedisSubscription):
|
||||
"""Test message iterator behavior with messages in queue."""
|
||||
test_messages = [b"msg1", b"msg2", b"msg3"]
|
||||
|
||||
# Add messages to queue
|
||||
for msg in test_messages:
|
||||
started_subscription._queue.put_nowait(msg)
|
||||
|
||||
# Iterate through messages
|
||||
iterator = iter(started_subscription)
|
||||
received_messages = []
|
||||
|
||||
for msg in iterator:
|
||||
received_messages.append(msg)
|
||||
if len(received_messages) >= len(test_messages):
|
||||
break
|
||||
|
||||
assert received_messages == test_messages
|
||||
|
||||
def test_message_iterator_when_closed(self, subscription: _RedisSubscription):
|
||||
"""Test that iterator raises error when subscription is closed."""
|
||||
subscription.close()
|
||||
|
||||
with pytest.raises(BroadcastChannelError, match="The Redis subscription is closed"):
|
||||
iter(subscription)
|
||||
|
||||
# ==================== Message Enqueue Tests ====================
|
||||
|
||||
def test_enqueue_message_success(self, started_subscription: _RedisSubscription):
|
||||
"""Test successful message enqueue."""
|
||||
payload = b"test message"
|
||||
|
||||
started_subscription._enqueue_message(payload)
|
||||
|
||||
assert started_subscription._queue.qsize() == 1
|
||||
assert started_subscription._queue.get_nowait() == payload
|
||||
|
||||
def test_enqueue_message_when_closed(self, subscription: _RedisSubscription):
|
||||
"""Test message enqueue when subscription is closed."""
|
||||
subscription.close()
|
||||
payload = b"test message"
|
||||
|
||||
# Should not raise exception, but should not enqueue
|
||||
subscription._enqueue_message(payload)
|
||||
|
||||
assert subscription._queue.empty()
|
||||
|
||||
def test_enqueue_message_with_full_queue(self, started_subscription: _RedisSubscription):
|
||||
"""Test message enqueue with full queue (dropping behavior)."""
|
||||
# Fill the queue
|
||||
for i in range(started_subscription._queue.maxsize):
|
||||
started_subscription._queue.put_nowait(f"old_msg_{i}".encode())
|
||||
|
||||
# Try to enqueue new message (should drop oldest)
|
||||
new_message = b"new_message"
|
||||
started_subscription._enqueue_message(new_message)
|
||||
|
||||
# Should have dropped one message and added new one
|
||||
assert started_subscription._dropped_count == 1
|
||||
|
||||
# New message should be in queue
|
||||
messages = []
|
||||
while not started_subscription._queue.empty():
|
||||
messages.append(started_subscription._queue.get_nowait())
|
||||
|
||||
assert new_message in messages
|
||||
|
||||
# ==================== Listener Thread Tests ====================
|
||||
|
||||
@patch("time.sleep", side_effect=lambda x: None) # Speed up test
|
||||
def test_listener_thread_normal_operation(
|
||||
self, mock_sleep, subscription: _RedisSubscription, mock_pubsub: MagicMock
|
||||
):
|
||||
"""Test listener thread normal operation."""
|
||||
# Mock message from Redis
|
||||
mock_message = {"type": "message", "channel": "test-topic", "data": b"test payload"}
|
||||
mock_pubsub.get_message.return_value = mock_message
|
||||
|
||||
# Start listener
|
||||
subscription._start_if_needed()
|
||||
|
||||
# Wait a bit for processing
|
||||
time.sleep(0.1)
|
||||
|
||||
# Verify message was processed
|
||||
assert not subscription._queue.empty()
|
||||
assert subscription._queue.get_nowait() == b"test payload"
|
||||
|
||||
def test_listener_thread_ignores_subscribe_messages(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
|
||||
"""Test that listener thread ignores subscribe/unsubscribe messages."""
|
||||
mock_message = {"type": "subscribe", "channel": "test-topic", "data": 1}
|
||||
mock_pubsub.get_message.return_value = mock_message
|
||||
|
||||
subscription._start_if_needed()
|
||||
time.sleep(0.1)
|
||||
|
||||
# Should not enqueue subscribe messages
|
||||
assert subscription._queue.empty()
|
||||
|
||||
def test_listener_thread_ignores_wrong_channel(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
|
||||
"""Test that listener thread ignores messages from wrong channels."""
|
||||
mock_message = {"type": "message", "channel": "wrong-topic", "data": b"test payload"}
|
||||
mock_pubsub.get_message.return_value = mock_message
|
||||
|
||||
subscription._start_if_needed()
|
||||
time.sleep(0.1)
|
||||
|
||||
# Should not enqueue messages from wrong channels
|
||||
assert subscription._queue.empty()
|
||||
|
||||
def test_listener_thread_handles_redis_exceptions(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
|
||||
"""Test that listener thread handles Redis exceptions gracefully."""
|
||||
mock_pubsub.get_message.side_effect = Exception("Redis error")
|
||||
|
||||
subscription._start_if_needed()
|
||||
|
||||
# Wait for thread to handle exception
|
||||
time.sleep(0.2)
|
||||
|
||||
# Thread should still be alive but not processing
|
||||
assert subscription._listener_thread is not None
|
||||
assert not subscription._listener_thread.is_alive()
|
||||
|
||||
def test_listener_thread_stops_when_closed(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
|
||||
"""Test that listener thread stops when subscription is closed."""
|
||||
subscription._start_if_needed()
|
||||
thread = subscription._listener_thread
|
||||
|
||||
# Close subscription
|
||||
subscription.close()
|
||||
|
||||
# Wait for thread to finish
|
||||
if thread is not None and thread.is_alive():
|
||||
thread.join(timeout=1.0)
|
||||
|
||||
assert thread is None or not thread.is_alive()
|
||||
|
||||
# ==================== Table-driven Tests ====================
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
SubscriptionTestCase(
|
||||
name="basic_message",
|
||||
buffer_size=5,
|
||||
payload=b"hello world",
|
||||
expected_messages=[b"hello world"],
|
||||
description="Basic message publishing and receiving",
|
||||
),
|
||||
SubscriptionTestCase(
|
||||
name="empty_message",
|
||||
buffer_size=5,
|
||||
payload=b"",
|
||||
expected_messages=[b""],
|
||||
description="Empty message handling",
|
||||
),
|
||||
SubscriptionTestCase(
|
||||
name="large_message",
|
||||
buffer_size=5,
|
||||
payload=b"x" * 10000,
|
||||
expected_messages=[b"x" * 10000],
|
||||
description="Large message handling",
|
||||
),
|
||||
SubscriptionTestCase(
|
||||
name="unicode_message",
|
||||
buffer_size=5,
|
||||
payload="你好世界".encode(),
|
||||
expected_messages=["你好世界".encode()],
|
||||
description="Unicode message handling",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_subscription_scenarios(self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock):
|
||||
"""Test various subscription scenarios using table-driven approach."""
|
||||
subscription = _RedisSubscription(
|
||||
pubsub=mock_pubsub,
|
||||
topic="test-topic",
|
||||
)
|
||||
|
||||
# Simulate receiving message
|
||||
mock_message = {"type": "message", "channel": "test-topic", "data": test_case.payload}
|
||||
mock_pubsub.get_message.return_value = mock_message
|
||||
|
||||
try:
|
||||
with subscription:
|
||||
# Wait for message processing
|
||||
time.sleep(0.1)
|
||||
|
||||
# Collect received messages
|
||||
received = []
|
||||
for msg in subscription:
|
||||
received.append(msg)
|
||||
if len(received) >= len(test_case.expected_messages):
|
||||
break
|
||||
|
||||
assert received == test_case.expected_messages, f"Failed: {test_case.description}"
|
||||
finally:
|
||||
subscription.close()
|
||||
|
||||
def test_concurrent_close_and_enqueue(self, started_subscription: _RedisSubscription):
|
||||
"""Test concurrent close and enqueue operations."""
|
||||
errors = []
|
||||
|
||||
def close_subscription():
|
||||
try:
|
||||
time.sleep(0.05) # Small delay
|
||||
started_subscription.close()
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
def enqueue_messages():
|
||||
try:
|
||||
for i in range(50):
|
||||
started_subscription._enqueue_message(f"msg_{i}".encode())
|
||||
time.sleep(0.001)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
# Start threads
|
||||
close_thread = threading.Thread(target=close_subscription)
|
||||
enqueue_thread = threading.Thread(target=enqueue_messages)
|
||||
|
||||
close_thread.start()
|
||||
enqueue_thread.start()
|
||||
|
||||
# Wait for completion
|
||||
close_thread.join(timeout=2.0)
|
||||
enqueue_thread.join(timeout=2.0)
|
||||
|
||||
# Should not have any errors (operations should be safe)
|
||||
assert len(errors) == 0
|
||||
|
||||
# ==================== Error Handling Tests ====================
|
||||
|
||||
def test_iterator_after_close(self, subscription: _RedisSubscription):
|
||||
"""Test iterator behavior after close."""
|
||||
subscription.close()
|
||||
|
||||
with pytest.raises(SubscriptionClosedError, match="The Redis subscription is closed"):
|
||||
iter(subscription)
|
||||
|
||||
def test_start_after_close(self, subscription: _RedisSubscription):
|
||||
"""Test start attempts after close."""
|
||||
subscription.close()
|
||||
|
||||
with pytest.raises(SubscriptionClosedError, match="The Redis subscription is closed"):
|
||||
subscription._start_if_needed()
|
||||
|
||||
def test_pubsub_none_operations(self, subscription: _RedisSubscription):
|
||||
"""Test operations when pubsub is None."""
|
||||
subscription._pubsub = None
|
||||
|
||||
with pytest.raises(SubscriptionClosedError, match="The Redis subscription has been cleaned up"):
|
||||
subscription._start_if_needed()
|
||||
|
||||
# Close should still work
|
||||
subscription.close() # Should not raise
|
||||
|
||||
def test_channel_name_variations(self, mock_pubsub: MagicMock):
|
||||
"""Test various channel name formats."""
|
||||
channel_names = [
|
||||
"simple",
|
||||
"with-dashes",
|
||||
"with_underscores",
|
||||
"with.numbers",
|
||||
"WITH.UPPERCASE",
|
||||
"mixed-CASE_name",
|
||||
"very.long.channel.name.with.multiple.parts",
|
||||
]
|
||||
|
||||
for channel_name in channel_names:
|
||||
subscription = _RedisSubscription(
|
||||
pubsub=mock_pubsub,
|
||||
topic=channel_name,
|
||||
)
|
||||
|
||||
subscription._start_if_needed()
|
||||
mock_pubsub.subscribe.assert_called_with(channel_name)
|
||||
subscription.close()
|
||||
|
||||
def test_received_on_closed_subscription(self, subscription: _RedisSubscription):
|
||||
subscription.close()
|
||||
|
||||
with pytest.raises(SubscriptionClosedError):
|
||||
subscription.receive()
|
||||
Loading…
Reference in New Issue