From b9bc48d8dd4251ee8782d4c9f5b37555f2665c01 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Mon, 10 Nov 2025 17:23:21 +0800 Subject: [PATCH] feat(api): Introduce Broadcast Channel (#27835) This PR introduces a `BroadcastChannel` abstraction with broadcasting and at-most once delivery semantics, serving as the communication component between celery worker and API server. It also includes a reference implementation backed by Redis PubSub. Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/libs/broadcast_channel/channel.py | 134 +++++ api/libs/broadcast_channel/exc.py | 12 + api/libs/broadcast_channel/redis/__init__.py | 3 + api/libs/broadcast_channel/redis/channel.py | 200 +++++++ .../broadcast_channel/redis/test_channel.py | 311 +++++++++++ .../redis/test_channel_unit_tests.py | 514 ++++++++++++++++++ 6 files changed, 1174 insertions(+) create mode 100644 api/libs/broadcast_channel/channel.py create mode 100644 api/libs/broadcast_channel/exc.py create mode 100644 api/libs/broadcast_channel/redis/__init__.py create mode 100644 api/libs/broadcast_channel/redis/channel.py create mode 100644 api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_channel.py create mode 100644 api/tests/unit_tests/libs/broadcast_channel/redis/test_channel_unit_tests.py diff --git a/api/libs/broadcast_channel/channel.py b/api/libs/broadcast_channel/channel.py new file mode 100644 index 0000000000..5bbf0c79a3 --- /dev/null +++ b/api/libs/broadcast_channel/channel.py @@ -0,0 +1,134 @@ +""" +Broadcast channel for Pub/Sub messaging. +""" + +import types +from abc import abstractmethod +from collections.abc import Iterator +from contextlib import AbstractContextManager +from typing import Protocol, Self + + +class Subscription(AbstractContextManager["Subscription"], Protocol): + """A subscription to a topic that provides an iterator over received messages. + The subscription can be used as a context manager and will automatically + close when exiting the context. + + Note: `Subscription` instances are not thread-safe. Each thread should create its own + subscription. + """ + + @abstractmethod + def __iter__(self) -> Iterator[bytes]: + """`__iter__` returns an iterator used to consume the message from this subscription. + + If the caller did not enter the context, `__iter__` may lazily perform the setup before + yielding messages; otherwise `__enter__` handles it.” + + If the subscription is closed, then the returned iterator exits without + raising any error. + """ + ... + + @abstractmethod + def close(self) -> None: + """close closes the subscription, releases any resources associated with it.""" + ... + + def __enter__(self) -> Self: + """`__enter__` does the setup logic of the subscription (if any), and return itself.""" + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: types.TracebackType | None, + ) -> bool | None: + self.close() + return None + + @abstractmethod + def receive(self, timeout: float | None = 0.1) -> bytes | None: + """Receive the next message from the broadcast channel. + + If `timeout` is specified, this method returns `None` if no message is + received within the given period. If `timeout` is `None`, the call blocks + until a message is received. + + Calling receive with `timeout=None` is highly discouraged, as it is impossible to + cancel a blocking subscription. + + :param timeout: timeout for receive message, in seconds. + + Returns: + bytes: The received message as a byte string, or + None: If the timeout expires before a message is received. + + Raises: + SubscriptionClosed: If the subscription has already been closed. + """ + ... + + +class Producer(Protocol): + """Producer is an interface for message publishing. It is already bound to a specific topic. + + `Producer` implementations must be thread-safe and support concurrent use by multiple threads. + """ + + @abstractmethod + def publish(self, payload: bytes) -> None: + """Publish a message to the bounded topic.""" + ... + + +class Subscriber(Protocol): + """Subscriber is an interface for subscription creation. It is already bound to a specific topic. + + `Subscriber` implementations must be thread-safe and support concurrent use by multiple threads. + """ + + @abstractmethod + def subscribe(self) -> Subscription: + pass + + +class Topic(Producer, Subscriber, Protocol): + """A named channel for publishing and subscribing to messages. + + Topics provide both read and write access. For restricted access, + use as_producer() for write-only view or as_subscriber() for read-only view. + + `Topic` implementations must be thread-safe and support concurrent use by multiple threads. + """ + + @abstractmethod + def as_producer(self) -> Producer: + """as_producer creates a write-only view for this topic.""" + ... + + @abstractmethod + def as_subscriber(self) -> Subscriber: + """as_subscriber create a read-only view for this topic.""" + ... + + +class BroadcastChannel(Protocol): + """A broadcasting channel is a channel supporting broadcasting semantics. + + Each channel is identified by a topic, different topics are isolated and do not affect each other. + + There can be multiple subscriptions to a specific topic. When a publisher publishes a message to + a specific topic, all subscription should receive the published message. + + There are no restriction for the persistence of messages. Once a subscription is created, it + should receive all subsequent messages published. + + `BroadcastChannel` implementations must be thread-safe and support concurrent use by multiple threads. + """ + + @abstractmethod + def topic(self, topic: str) -> "Topic": + """topic returns a `Topic` instance for the given topic name.""" + ... diff --git a/api/libs/broadcast_channel/exc.py b/api/libs/broadcast_channel/exc.py new file mode 100644 index 0000000000..ab958c94ed --- /dev/null +++ b/api/libs/broadcast_channel/exc.py @@ -0,0 +1,12 @@ +class BroadcastChannelError(Exception): + """`BroadcastChannelError` is the base class for all exceptions related + to `BroadcastChannel`.""" + + pass + + +class SubscriptionClosedError(BroadcastChannelError): + """SubscriptionClosedError means that the subscription has been closed and + methods for consuming messages should not be called.""" + + pass diff --git a/api/libs/broadcast_channel/redis/__init__.py b/api/libs/broadcast_channel/redis/__init__.py new file mode 100644 index 0000000000..138fef5c5f --- /dev/null +++ b/api/libs/broadcast_channel/redis/__init__.py @@ -0,0 +1,3 @@ +from .channel import BroadcastChannel + +__all__ = ["BroadcastChannel"] diff --git a/api/libs/broadcast_channel/redis/channel.py b/api/libs/broadcast_channel/redis/channel.py new file mode 100644 index 0000000000..e6b32345be --- /dev/null +++ b/api/libs/broadcast_channel/redis/channel.py @@ -0,0 +1,200 @@ +import logging +import queue +import threading +import types +from collections.abc import Generator, Iterator +from typing import Self + +from libs.broadcast_channel.channel import Producer, Subscriber, Subscription +from libs.broadcast_channel.exc import SubscriptionClosedError +from redis import Redis +from redis.client import PubSub + +_logger = logging.getLogger(__name__) + + +class BroadcastChannel: + """ + Redis Pub/Sub based broadcast channel implementation. + + Provides "at most once" delivery semantics for messages published to channels. + Uses Redis PUBLISH/SUBSCRIBE commands for real-time message delivery. + + The `redis_client` used to construct BroadcastChannel should have `decode_responses` set to `False`. + """ + + def __init__( + self, + redis_client: Redis, + ): + self._client = redis_client + + def topic(self, topic: str) -> "Topic": + return Topic(self._client, topic) + + +class Topic: + def __init__(self, redis_client: Redis, topic: str): + self._client = redis_client + self._topic = topic + + def as_producer(self) -> Producer: + return self + + def publish(self, payload: bytes) -> None: + self._client.publish(self._topic, payload) + + def as_subscriber(self) -> Subscriber: + return self + + def subscribe(self) -> Subscription: + return _RedisSubscription( + pubsub=self._client.pubsub(), + topic=self._topic, + ) + + +class _RedisSubscription(Subscription): + def __init__( + self, + pubsub: PubSub, + topic: str, + ): + # The _pubsub is None only if the subscription is closed. + self._pubsub: PubSub | None = pubsub + self._topic = topic + self._closed = threading.Event() + self._queue: queue.Queue[bytes] = queue.Queue(maxsize=1024) + self._dropped_count = 0 + self._listener_thread: threading.Thread | None = None + self._start_lock = threading.Lock() + self._started = False + + def _start_if_needed(self) -> None: + with self._start_lock: + if self._started: + return + if self._closed.is_set(): + raise SubscriptionClosedError("The Redis subscription is closed") + if self._pubsub is None: + raise SubscriptionClosedError("The Redis subscription has been cleaned up") + + self._pubsub.subscribe(self._topic) + _logger.debug("Subscribed to channel %s", self._topic) + + self._listener_thread = threading.Thread( + target=self._listen, + name=f"redis-broadcast-{self._topic}", + daemon=True, + ) + self._listener_thread.start() + self._started = True + + def _listen(self) -> None: + pubsub = self._pubsub + assert pubsub is not None, "PubSub should not be None while starting listening." + while not self._closed.is_set(): + raw_message = pubsub.get_message(ignore_subscribe_messages=True, timeout=0.1) + + if raw_message is None: + continue + + if raw_message.get("type") != "message": + continue + + channel_field = raw_message.get("channel") + if isinstance(channel_field, bytes): + channel_name = channel_field.decode("utf-8") + elif isinstance(channel_field, str): + channel_name = channel_field + else: + channel_name = str(channel_field) + + if channel_name != self._topic: + _logger.warning("Ignoring message from unexpected channel %s", channel_name) + continue + + payload_bytes: bytes | None = raw_message.get("data") + if not isinstance(payload_bytes, bytes): + _logger.error("Received invalid data from channel %s, type=%s", self._topic, type(payload_bytes)) + continue + + self._enqueue_message(payload_bytes) + + _logger.debug("Listener thread stopped for channel %s", self._topic) + pubsub.unsubscribe(self._topic) + pubsub.close() + _logger.debug("PubSub closed for topic %s", self._topic) + self._pubsub = None + + def _enqueue_message(self, payload: bytes) -> None: + while not self._closed.is_set(): + try: + self._queue.put_nowait(payload) + return + except queue.Full: + try: + self._queue.get_nowait() + self._dropped_count += 1 + _logger.debug( + "Dropped message from Redis subscription, topic=%s, total_dropped=%d", + self._topic, + self._dropped_count, + ) + except queue.Empty: + continue + return + + def _message_iterator(self) -> Generator[bytes, None, None]: + while not self._closed.is_set(): + try: + item = self._queue.get(timeout=0.1) + except queue.Empty: + continue + + yield item + + def __iter__(self) -> Iterator[bytes]: + if self._closed.is_set(): + raise SubscriptionClosedError("The Redis subscription is closed") + self._start_if_needed() + return iter(self._message_iterator()) + + def receive(self, timeout: float | None = None) -> bytes | None: + if self._closed.is_set(): + raise SubscriptionClosedError("The Redis subscription is closed") + self._start_if_needed() + + try: + item = self._queue.get(timeout=timeout) + except queue.Empty: + return None + + return item + + def __enter__(self) -> Self: + self._start_if_needed() + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: types.TracebackType | None, + ) -> bool | None: + self.close() + return None + + def close(self) -> None: + if self._closed.is_set(): + return + + self._closed.set() + # NOTE: PubSub is not thread-safe. More specifically, the `PubSub.close` method and the `PubSub.get_message` + # method should NOT be called concurrently. + # + # Due to the restriction above, the PubSub cleanup logic happens inside the consumer thread. + listener = self._listener_thread + if listener is not None: + listener.join(timeout=1.0) + self._listener_thread = None diff --git a/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_channel.py b/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_channel.py new file mode 100644 index 0000000000..c2e17328d6 --- /dev/null +++ b/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_channel.py @@ -0,0 +1,311 @@ +""" +Integration tests for Redis broadcast channel implementation using TestContainers. + +This test suite covers real Redis interactions including: +- Multiple producer/consumer scenarios +- Network failure scenarios +- Performance under load +- Real-world usage patterns +""" + +import threading +import time +import uuid +from collections.abc import Iterator +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +import redis +from testcontainers.redis import RedisContainer + +from libs.broadcast_channel.channel import BroadcastChannel, Subscription, Topic +from libs.broadcast_channel.exc import SubscriptionClosedError +from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel + + +class TestRedisBroadcastChannelIntegration: + """Integration tests for Redis broadcast channel with real Redis instance.""" + + @pytest.fixture(scope="class") + def redis_container(self) -> Iterator[RedisContainer]: + """Create a Redis container for integration testing.""" + with RedisContainer(image="redis:6-alpine") as container: + yield container + + @pytest.fixture(scope="class") + def redis_client(self, redis_container: RedisContainer) -> redis.Redis: + """Create a Redis client connected to the test container.""" + host = redis_container.get_container_host_ip() + port = redis_container.get_exposed_port(6379) + return redis.Redis(host=host, port=port, decode_responses=False) + + @pytest.fixture + def broadcast_channel(self, redis_client: redis.Redis) -> BroadcastChannel: + """Create a BroadcastChannel instance with real Redis client.""" + return RedisBroadcastChannel(redis_client) + + @classmethod + def _get_test_topic_name(cls): + return f"test_topic_{uuid.uuid4()}" + + # ==================== Basic Functionality Tests ====================' + + def test_close_an_active_subscription_should_stop_iteration(self, broadcast_channel): + topic_name = self._get_test_topic_name() + topic = broadcast_channel.topic(topic_name) + subscription = topic.subscribe() + consuming_event = threading.Event() + + def consume(): + msgs = [] + consuming_event.set() + for msg in subscription: + msgs.append(msg) + return msgs + + with ThreadPoolExecutor(max_workers=1) as executor: + producer_future = executor.submit(consume) + consuming_event.wait() + subscription.close() + msgs = producer_future.result(timeout=1) + assert msgs == [] + + def test_end_to_end_messaging(self, broadcast_channel: BroadcastChannel): + """Test complete end-to-end messaging flow.""" + topic_name = "test-topic" + message = b"hello world" + + # Create producer and subscriber + topic = broadcast_channel.topic(topic_name) + producer = topic.as_producer() + subscription = topic.subscribe() + + # Publish and receive message + + def producer_thread(): + time.sleep(0.1) # Small delay to ensure subscriber is ready + producer.publish(message) + time.sleep(0.1) + subscription.close() + + def consumer_thread() -> list[bytes]: + received_messages = [] + for msg in subscription: + received_messages.append(msg) + return received_messages + + # Run producer and consumer + with ThreadPoolExecutor(max_workers=2) as executor: + producer_future = executor.submit(producer_thread) + consumer_future = executor.submit(consumer_thread) + + # Wait for completion + producer_future.result(timeout=5.0) + received_messages = consumer_future.result(timeout=5.0) + + assert len(received_messages) == 1 + assert received_messages[0] == message + + def test_multiple_subscribers_same_topic(self, broadcast_channel: BroadcastChannel): + """Test message broadcasting to multiple subscribers.""" + topic_name = "broadcast-topic" + message = b"broadcast message" + subscriber_count = 5 + + # Create producer and multiple subscribers + topic = broadcast_channel.topic(topic_name) + producer = topic.as_producer() + subscriptions = [topic.subscribe() for _ in range(subscriber_count)] + + def producer_thread(): + time.sleep(0.2) # Allow all subscribers to connect + producer.publish(message) + time.sleep(0.2) + for sub in subscriptions: + sub.close() + + def consumer_thread(subscription: Subscription) -> list[bytes]: + received_msgs = [] + while True: + try: + msg = subscription.receive(0.1) + except SubscriptionClosedError: + break + if msg is None: + continue + received_msgs.append(msg) + if len(received_msgs) >= 1: + break + return received_msgs + + # Run producer and consumers + with ThreadPoolExecutor(max_workers=subscriber_count + 1) as executor: + producer_future = executor.submit(producer_thread) + consumer_futures = [executor.submit(consumer_thread, subscription) for subscription in subscriptions] + + # Wait for completion + producer_future.result(timeout=10.0) + msgs_by_consumers = [] + for future in as_completed(consumer_futures, timeout=10.0): + msgs_by_consumers.append(future.result()) + + # Close all subscriptions + for subscription in subscriptions: + subscription.close() + + # Verify all subscribers received the message + for msgs in msgs_by_consumers: + assert len(msgs) == 1 + assert msgs[0] == message + + def test_topic_isolation(self, broadcast_channel: BroadcastChannel): + """Test that different topics are isolated from each other.""" + topic1_name = "topic1" + topic2_name = "topic2" + message1 = b"message for topic1" + message2 = b"message for topic2" + + # Create producers and subscribers for different topics + topic1 = broadcast_channel.topic(topic1_name) + topic2 = broadcast_channel.topic(topic2_name) + + def producer_thread(): + time.sleep(0.1) + topic1.publish(message1) + topic2.publish(message2) + + def consumer_by_thread(topic: Topic) -> list[bytes]: + subscription = topic.subscribe() + received = [] + with subscription: + for msg in subscription: + received.append(msg) + if len(received) >= 1: + break + return received + + # Run all threads + with ThreadPoolExecutor(max_workers=3) as executor: + producer_future = executor.submit(producer_thread) + consumer1_future = executor.submit(consumer_by_thread, topic1) + consumer2_future = executor.submit(consumer_by_thread, topic2) + + # Wait for completion + producer_future.result(timeout=5.0) + received_by_topic1 = consumer1_future.result(timeout=5.0) + received_by_topic2 = consumer2_future.result(timeout=5.0) + + # Verify topic isolation + assert len(received_by_topic1) == 1 + assert len(received_by_topic2) == 1 + assert received_by_topic1[0] == message1 + assert received_by_topic2[0] == message2 + + # ==================== Performance Tests ==================== + + def test_concurrent_producers(self, broadcast_channel: BroadcastChannel): + """Test multiple producers publishing to the same topic.""" + topic_name = "concurrent-producers-topic" + producer_count = 5 + messages_per_producer = 5 + + topic = broadcast_channel.topic(topic_name) + subscription = topic.subscribe() + + expected_total = producer_count * messages_per_producer + consumer_ready = threading.Event() + + def producer_thread(producer_idx: int) -> set[bytes]: + producer = topic.as_producer() + produced = set() + for i in range(messages_per_producer): + message = f"producer_{producer_idx}_msg_{i}".encode() + produced.add(message) + producer.publish(message) + time.sleep(0.001) # Small delay to avoid overwhelming + return produced + + def consumer_thread() -> set[bytes]: + received_msgs: set[bytes] = set() + with subscription: + consumer_ready.set() + while True: + try: + msg = subscription.receive(timeout=0.1) + except SubscriptionClosedError: + break + if msg is None: + if len(received_msgs) >= expected_total: + break + else: + continue + + received_msgs.add(msg) + return received_msgs + + # Run producers and consumer + with ThreadPoolExecutor(max_workers=producer_count + 1) as executor: + consumer_future = executor.submit(consumer_thread) + consumer_ready.wait() + producer_futures = [executor.submit(producer_thread, i) for i in range(producer_count)] + + sent_msgs: set[bytes] = set() + # Wait for completion + for future in as_completed(producer_futures, timeout=30.0): + sent_msgs.update(future.result()) + + subscription.close() + consumer_received_msgs = consumer_future.result(timeout=30.0) + + # Verify message content + assert sent_msgs == consumer_received_msgs + + # ==================== Resource Management Tests ==================== + + def test_subscription_cleanup(self, broadcast_channel: BroadcastChannel, redis_client: redis.Redis): + """Test proper cleanup of subscription resources.""" + topic_name = "cleanup-test-topic" + + # Create multiple subscriptions + topic = broadcast_channel.topic(topic_name) + + def _consume(sub: Subscription): + for i in sub: + pass + + subscriptions = [] + for i in range(5): + subscription = topic.subscribe() + subscriptions.append(subscription) + + # Start all subscriptions + thread = threading.Thread(target=_consume, args=(subscription,)) + thread.start() + time.sleep(0.01) + + # Verify subscriptions are active + pubsub_info = redis_client.pubsub_numsub(topic_name) + # pubsub_numsub returns list of tuples, find our topic + topic_subscribers = 0 + for channel, count in pubsub_info: + # the channel name returned by redis is bytes. + if channel == topic_name.encode(): + topic_subscribers = count + break + assert topic_subscribers >= 5 + + # Close all subscriptions + for subscription in subscriptions: + subscription.close() + + # Wait a bit for cleanup + time.sleep(1) + + # Verify subscriptions are cleaned up + pubsub_info_after = redis_client.pubsub_numsub(topic_name) + topic_subscribers_after = 0 + for channel, count in pubsub_info_after: + if channel == topic_name.encode(): + topic_subscribers_after = count + break + assert topic_subscribers_after == 0 diff --git a/api/tests/unit_tests/libs/broadcast_channel/redis/test_channel_unit_tests.py b/api/tests/unit_tests/libs/broadcast_channel/redis/test_channel_unit_tests.py new file mode 100644 index 0000000000..dffad4142c --- /dev/null +++ b/api/tests/unit_tests/libs/broadcast_channel/redis/test_channel_unit_tests.py @@ -0,0 +1,514 @@ +""" +Comprehensive unit tests for Redis broadcast channel implementation. + +This test suite covers all aspects of the Redis broadcast channel including: +- Basic functionality and contract compliance +- Error handling and edge cases +- Thread safety and concurrency +- Resource management and cleanup +- Performance and reliability scenarios +""" + +import dataclasses +import threading +import time +from collections.abc import Generator +from unittest.mock import MagicMock, patch + +import pytest + +from libs.broadcast_channel.exc import BroadcastChannelError, SubscriptionClosedError +from libs.broadcast_channel.redis.channel import ( + BroadcastChannel as RedisBroadcastChannel, +) +from libs.broadcast_channel.redis.channel import ( + Topic, + _RedisSubscription, +) + + +class TestBroadcastChannel: + """Test cases for the main BroadcastChannel class.""" + + @pytest.fixture + def mock_redis_client(self) -> MagicMock: + """Create a mock Redis client for testing.""" + client = MagicMock() + client.pubsub.return_value = MagicMock() + return client + + @pytest.fixture + def broadcast_channel(self, mock_redis_client: MagicMock) -> RedisBroadcastChannel: + """Create a BroadcastChannel instance with mock Redis client.""" + return RedisBroadcastChannel(mock_redis_client) + + def test_topic_creation(self, broadcast_channel: RedisBroadcastChannel, mock_redis_client: MagicMock): + """Test that topic() method returns a Topic instance with correct parameters.""" + topic_name = "test-topic" + topic = broadcast_channel.topic(topic_name) + + assert isinstance(topic, Topic) + assert topic._client == mock_redis_client + assert topic._topic == topic_name + + def test_topic_isolation(self, broadcast_channel: RedisBroadcastChannel): + """Test that different topic names create isolated Topic instances.""" + topic1 = broadcast_channel.topic("topic1") + topic2 = broadcast_channel.topic("topic2") + + assert topic1 is not topic2 + assert topic1._topic == "topic1" + assert topic2._topic == "topic2" + + +class TestTopic: + """Test cases for the Topic class.""" + + @pytest.fixture + def mock_redis_client(self) -> MagicMock: + """Create a mock Redis client for testing.""" + client = MagicMock() + client.pubsub.return_value = MagicMock() + return client + + @pytest.fixture + def topic(self, mock_redis_client: MagicMock) -> Topic: + """Create a Topic instance for testing.""" + return Topic(mock_redis_client, "test-topic") + + def test_as_producer_returns_self(self, topic: Topic): + """Test that as_producer() returns self as Producer interface.""" + producer = topic.as_producer() + assert producer is topic + # Producer is a Protocol, check duck typing instead + assert hasattr(producer, "publish") + + def test_as_subscriber_returns_self(self, topic: Topic): + """Test that as_subscriber() returns self as Subscriber interface.""" + subscriber = topic.as_subscriber() + assert subscriber is topic + # Subscriber is a Protocol, check duck typing instead + assert hasattr(subscriber, "subscribe") + + def test_publish_calls_redis_publish(self, topic: Topic, mock_redis_client: MagicMock): + """Test that publish() calls Redis PUBLISH with correct parameters.""" + payload = b"test message" + topic.publish(payload) + + mock_redis_client.publish.assert_called_once_with("test-topic", payload) + + +@dataclasses.dataclass(frozen=True) +class SubscriptionTestCase: + """Test case data for subscription tests.""" + + name: str + buffer_size: int + payload: bytes + expected_messages: list[bytes] + should_drop: bool = False + description: str = "" + + +class TestRedisSubscription: + """Test cases for the _RedisSubscription class.""" + + @pytest.fixture + def mock_pubsub(self) -> MagicMock: + """Create a mock PubSub instance for testing.""" + pubsub = MagicMock() + pubsub.subscribe = MagicMock() + pubsub.unsubscribe = MagicMock() + pubsub.close = MagicMock() + pubsub.get_message = MagicMock() + return pubsub + + @pytest.fixture + def subscription(self, mock_pubsub: MagicMock) -> Generator[_RedisSubscription, None, None]: + """Create a _RedisSubscription instance for testing.""" + subscription = _RedisSubscription( + pubsub=mock_pubsub, + topic="test-topic", + ) + yield subscription + subscription.close() + + @pytest.fixture + def started_subscription(self, subscription: _RedisSubscription) -> _RedisSubscription: + """Create a subscription that has been started.""" + subscription._start_if_needed() + return subscription + + # ==================== Lifecycle Tests ==================== + + def test_subscription_initialization(self, mock_pubsub: MagicMock): + """Test that subscription is properly initialized.""" + subscription = _RedisSubscription( + pubsub=mock_pubsub, + topic="test-topic", + ) + + assert subscription._pubsub is mock_pubsub + assert subscription._topic == "test-topic" + assert not subscription._closed.is_set() + assert subscription._dropped_count == 0 + assert subscription._listener_thread is None + assert not subscription._started + + def test_start_if_needed_first_call(self, subscription: _RedisSubscription, mock_pubsub: MagicMock): + """Test that _start_if_needed() properly starts subscription on first call.""" + subscription._start_if_needed() + + mock_pubsub.subscribe.assert_called_once_with("test-topic") + assert subscription._started is True + assert subscription._listener_thread is not None + + def test_start_if_needed_subsequent_calls(self, started_subscription: _RedisSubscription): + """Test that _start_if_needed() doesn't start subscription on subsequent calls.""" + original_thread = started_subscription._listener_thread + started_subscription._start_if_needed() + + # Should not create new thread or generator + assert started_subscription._listener_thread is original_thread + + def test_start_if_needed_when_closed(self, subscription: _RedisSubscription): + """Test that _start_if_needed() raises error when subscription is closed.""" + subscription.close() + + with pytest.raises(SubscriptionClosedError, match="The Redis subscription is closed"): + subscription._start_if_needed() + + def test_start_if_needed_when_cleaned_up(self, subscription: _RedisSubscription): + """Test that _start_if_needed() raises error when pubsub is None.""" + subscription._pubsub = None + + with pytest.raises(SubscriptionClosedError, match="The Redis subscription has been cleaned up"): + subscription._start_if_needed() + + def test_context_manager_usage(self, subscription: _RedisSubscription, mock_pubsub: MagicMock): + """Test that subscription works as context manager.""" + with subscription as sub: + assert sub is subscription + assert subscription._started is True + mock_pubsub.subscribe.assert_called_once_with("test-topic") + + def test_close_idempotent(self, subscription: _RedisSubscription, mock_pubsub: MagicMock): + """Test that close() is idempotent and can be called multiple times.""" + subscription._start_if_needed() + + # Close multiple times + subscription.close() + subscription.close() + subscription.close() + + # Should only cleanup once + mock_pubsub.unsubscribe.assert_called_once_with("test-topic") + mock_pubsub.close.assert_called_once() + assert subscription._pubsub is None + assert subscription._closed.is_set() + + def test_close_cleanup(self, subscription: _RedisSubscription, mock_pubsub: MagicMock): + """Test that close() properly cleans up all resources.""" + subscription._start_if_needed() + thread = subscription._listener_thread + + subscription.close() + + # Verify cleanup + mock_pubsub.unsubscribe.assert_called_once_with("test-topic") + mock_pubsub.close.assert_called_once() + assert subscription._pubsub is None + assert subscription._listener_thread is None + + # Wait for thread to finish (with timeout) + if thread and thread.is_alive(): + thread.join(timeout=1.0) + assert not thread.is_alive() + + # ==================== Message Processing Tests ==================== + + def test_message_iterator_with_messages(self, started_subscription: _RedisSubscription): + """Test message iterator behavior with messages in queue.""" + test_messages = [b"msg1", b"msg2", b"msg3"] + + # Add messages to queue + for msg in test_messages: + started_subscription._queue.put_nowait(msg) + + # Iterate through messages + iterator = iter(started_subscription) + received_messages = [] + + for msg in iterator: + received_messages.append(msg) + if len(received_messages) >= len(test_messages): + break + + assert received_messages == test_messages + + def test_message_iterator_when_closed(self, subscription: _RedisSubscription): + """Test that iterator raises error when subscription is closed.""" + subscription.close() + + with pytest.raises(BroadcastChannelError, match="The Redis subscription is closed"): + iter(subscription) + + # ==================== Message Enqueue Tests ==================== + + def test_enqueue_message_success(self, started_subscription: _RedisSubscription): + """Test successful message enqueue.""" + payload = b"test message" + + started_subscription._enqueue_message(payload) + + assert started_subscription._queue.qsize() == 1 + assert started_subscription._queue.get_nowait() == payload + + def test_enqueue_message_when_closed(self, subscription: _RedisSubscription): + """Test message enqueue when subscription is closed.""" + subscription.close() + payload = b"test message" + + # Should not raise exception, but should not enqueue + subscription._enqueue_message(payload) + + assert subscription._queue.empty() + + def test_enqueue_message_with_full_queue(self, started_subscription: _RedisSubscription): + """Test message enqueue with full queue (dropping behavior).""" + # Fill the queue + for i in range(started_subscription._queue.maxsize): + started_subscription._queue.put_nowait(f"old_msg_{i}".encode()) + + # Try to enqueue new message (should drop oldest) + new_message = b"new_message" + started_subscription._enqueue_message(new_message) + + # Should have dropped one message and added new one + assert started_subscription._dropped_count == 1 + + # New message should be in queue + messages = [] + while not started_subscription._queue.empty(): + messages.append(started_subscription._queue.get_nowait()) + + assert new_message in messages + + # ==================== Listener Thread Tests ==================== + + @patch("time.sleep", side_effect=lambda x: None) # Speed up test + def test_listener_thread_normal_operation( + self, mock_sleep, subscription: _RedisSubscription, mock_pubsub: MagicMock + ): + """Test listener thread normal operation.""" + # Mock message from Redis + mock_message = {"type": "message", "channel": "test-topic", "data": b"test payload"} + mock_pubsub.get_message.return_value = mock_message + + # Start listener + subscription._start_if_needed() + + # Wait a bit for processing + time.sleep(0.1) + + # Verify message was processed + assert not subscription._queue.empty() + assert subscription._queue.get_nowait() == b"test payload" + + def test_listener_thread_ignores_subscribe_messages(self, subscription: _RedisSubscription, mock_pubsub: MagicMock): + """Test that listener thread ignores subscribe/unsubscribe messages.""" + mock_message = {"type": "subscribe", "channel": "test-topic", "data": 1} + mock_pubsub.get_message.return_value = mock_message + + subscription._start_if_needed() + time.sleep(0.1) + + # Should not enqueue subscribe messages + assert subscription._queue.empty() + + def test_listener_thread_ignores_wrong_channel(self, subscription: _RedisSubscription, mock_pubsub: MagicMock): + """Test that listener thread ignores messages from wrong channels.""" + mock_message = {"type": "message", "channel": "wrong-topic", "data": b"test payload"} + mock_pubsub.get_message.return_value = mock_message + + subscription._start_if_needed() + time.sleep(0.1) + + # Should not enqueue messages from wrong channels + assert subscription._queue.empty() + + def test_listener_thread_handles_redis_exceptions(self, subscription: _RedisSubscription, mock_pubsub: MagicMock): + """Test that listener thread handles Redis exceptions gracefully.""" + mock_pubsub.get_message.side_effect = Exception("Redis error") + + subscription._start_if_needed() + + # Wait for thread to handle exception + time.sleep(0.2) + + # Thread should still be alive but not processing + assert subscription._listener_thread is not None + assert not subscription._listener_thread.is_alive() + + def test_listener_thread_stops_when_closed(self, subscription: _RedisSubscription, mock_pubsub: MagicMock): + """Test that listener thread stops when subscription is closed.""" + subscription._start_if_needed() + thread = subscription._listener_thread + + # Close subscription + subscription.close() + + # Wait for thread to finish + if thread is not None and thread.is_alive(): + thread.join(timeout=1.0) + + assert thread is None or not thread.is_alive() + + # ==================== Table-driven Tests ==================== + + @pytest.mark.parametrize( + "test_case", + [ + SubscriptionTestCase( + name="basic_message", + buffer_size=5, + payload=b"hello world", + expected_messages=[b"hello world"], + description="Basic message publishing and receiving", + ), + SubscriptionTestCase( + name="empty_message", + buffer_size=5, + payload=b"", + expected_messages=[b""], + description="Empty message handling", + ), + SubscriptionTestCase( + name="large_message", + buffer_size=5, + payload=b"x" * 10000, + expected_messages=[b"x" * 10000], + description="Large message handling", + ), + SubscriptionTestCase( + name="unicode_message", + buffer_size=5, + payload="你好世界".encode(), + expected_messages=["你好世界".encode()], + description="Unicode message handling", + ), + ], + ) + def test_subscription_scenarios(self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock): + """Test various subscription scenarios using table-driven approach.""" + subscription = _RedisSubscription( + pubsub=mock_pubsub, + topic="test-topic", + ) + + # Simulate receiving message + mock_message = {"type": "message", "channel": "test-topic", "data": test_case.payload} + mock_pubsub.get_message.return_value = mock_message + + try: + with subscription: + # Wait for message processing + time.sleep(0.1) + + # Collect received messages + received = [] + for msg in subscription: + received.append(msg) + if len(received) >= len(test_case.expected_messages): + break + + assert received == test_case.expected_messages, f"Failed: {test_case.description}" + finally: + subscription.close() + + def test_concurrent_close_and_enqueue(self, started_subscription: _RedisSubscription): + """Test concurrent close and enqueue operations.""" + errors = [] + + def close_subscription(): + try: + time.sleep(0.05) # Small delay + started_subscription.close() + except Exception as e: + errors.append(e) + + def enqueue_messages(): + try: + for i in range(50): + started_subscription._enqueue_message(f"msg_{i}".encode()) + time.sleep(0.001) + except Exception as e: + errors.append(e) + + # Start threads + close_thread = threading.Thread(target=close_subscription) + enqueue_thread = threading.Thread(target=enqueue_messages) + + close_thread.start() + enqueue_thread.start() + + # Wait for completion + close_thread.join(timeout=2.0) + enqueue_thread.join(timeout=2.0) + + # Should not have any errors (operations should be safe) + assert len(errors) == 0 + + # ==================== Error Handling Tests ==================== + + def test_iterator_after_close(self, subscription: _RedisSubscription): + """Test iterator behavior after close.""" + subscription.close() + + with pytest.raises(SubscriptionClosedError, match="The Redis subscription is closed"): + iter(subscription) + + def test_start_after_close(self, subscription: _RedisSubscription): + """Test start attempts after close.""" + subscription.close() + + with pytest.raises(SubscriptionClosedError, match="The Redis subscription is closed"): + subscription._start_if_needed() + + def test_pubsub_none_operations(self, subscription: _RedisSubscription): + """Test operations when pubsub is None.""" + subscription._pubsub = None + + with pytest.raises(SubscriptionClosedError, match="The Redis subscription has been cleaned up"): + subscription._start_if_needed() + + # Close should still work + subscription.close() # Should not raise + + def test_channel_name_variations(self, mock_pubsub: MagicMock): + """Test various channel name formats.""" + channel_names = [ + "simple", + "with-dashes", + "with_underscores", + "with.numbers", + "WITH.UPPERCASE", + "mixed-CASE_name", + "very.long.channel.name.with.multiple.parts", + ] + + for channel_name in channel_names: + subscription = _RedisSubscription( + pubsub=mock_pubsub, + topic=channel_name, + ) + + subscription._start_if_needed() + mock_pubsub.subscribe.assert_called_with(channel_name) + subscription.close() + + def test_received_on_closed_subscription(self, subscription: _RedisSubscription): + subscription.close() + + with pytest.raises(SubscriptionClosedError): + subscription.receive()