diff --git a/api/libs/broadcast_channel/redis/streams_channel.py b/api/libs/broadcast_channel/redis/streams_channel.py index aaeaf76f7b..983f785027 100644 --- a/api/libs/broadcast_channel/redis/streams_channel.py +++ b/api/libs/broadcast_channel/redis/streams_channel.py @@ -63,24 +63,45 @@ class _StreamsSubscription(Subscription): def __init__(self, client: Redis | RedisCluster, key: str): self._client = client self._key = key - self._closed = threading.Event() - # Setting initial last id to `$` to signal redis that we only want new messages. - # - # ref: https://redis.io/docs/latest/commands/xread/#the-special--id - self._last_id = "$" + self._queue: queue.Queue[object] = queue.Queue() - self._start_lock = threading.Lock() + + # The `_lock` lock is used to + # + # 1. protect the _listener attribute + # 2. prevent repeated releases of underlying resoueces. (The _closed flag.) + # + # INVARIANT: the implementation must hold the lock while + # reading and writing the _listener / `_closed` attribute. + self._lock = threading.Lock() + self._closed: bool = False + # self._closed = threading.Event() self._listener: threading.Thread | None = None def _listen(self) -> None: - try: - while not self._closed.is_set(): - streams = self._client.xread({self._key: self._last_id}, block=1000, count=100) + """The `_listen` method handles the message retrieval loop. It requires a dedicated thread + and is not intended for direct invocation. + The thread is started by `_start_if_needed`. + """ + + # since this method runs in a dedicated thread, acquiring `_lock` inside this method won't cause + # deadlock. + + # Setting initial last id to `$` to signal redis that we only want new messages. + # + # ref: https://redis.io/docs/latest/commands/xread/#the-special--id + last_id = "$" + try: + while True: + with self._lock: + if self._closed: + break + streams = self._client.xread({self._key: last_id}, block=1000, count=100) if not streams: continue - for _key, entries in streams: + for _, entries in streams: for entry_id, fields in entries: data = None if isinstance(fields, dict): @@ -92,37 +113,48 @@ class _StreamsSubscription(Subscription): data_bytes = bytes(data) if data_bytes is not None: self._queue.put_nowait(data_bytes) - self._last_id = entry_id + last_id = entry_id finally: self._queue.put_nowait(self._SENTINEL) - self._listener = None + with self._lock: + self._listener = None + self._closed = True def _start_if_needed(self) -> None: + """This method must be called with `_lock` held.""" if self._listener is not None: return # Ensure only one listener thread is created under concurrent calls - with self._start_lock: - if self._listener is not None or self._closed.is_set(): - return - self._listener = threading.Thread( - target=self._listen, - name=f"redis-streams-sub-{self._key}", - daemon=True, - ) - self._listener.start() + if self._listener is not None or self._closed: + return + self._listener = threading.Thread( + target=self._listen, + name=f"redis-streams-sub-{self._key}", + daemon=True, + ) + self._listener.start() def __iter__(self) -> Iterator[bytes]: # Iterator delegates to receive with timeout; stops on closure. - self._start_if_needed() - while not self._closed.is_set(): - item = self.receive(timeout=1) + with self._lock: + self._start_if_needed() + + while True: + with self._lock: + if self._closed: + return + try: + item = self.receive(timeout=1) + except SubscriptionClosedError: + return if item is not None: yield item def receive(self, timeout: float | None = 0.1) -> bytes | None: - if self._closed.is_set(): - raise SubscriptionClosedError("The Redis streams subscription is closed") - self._start_if_needed() + with self._lock: + if self._closed: + raise SubscriptionClosedError("The Redis streams subscription is closed") + self._start_if_needed() try: if timeout is None: @@ -132,29 +164,33 @@ class _StreamsSubscription(Subscription): except queue.Empty: return None - if item is self._SENTINEL or self._closed.is_set(): + if item is self._SENTINEL: raise SubscriptionClosedError("The Redis streams subscription is closed") assert isinstance(item, (bytes, bytearray)), "Unexpected item type in stream queue" return bytes(item) def close(self) -> None: - if self._closed.is_set(): - return - self._closed.set() - listener = self._listener - if listener is not None: + with self._lock: + if self._closed: + return + self._closed = True + listener = self._listener + if listener is not None: + self._listener = None + # We close the listener outside of the with block to avoid holding the + # lock for a long time. + if listener is not None and listener.is_alive(): listener.join(timeout=2.0) if listener.is_alive(): logger.warning( "Streams subscription listener for key %s did not stop within timeout; keeping reference.", self._key, ) - else: - self._listener = None # Context manager helpers def __enter__(self) -> Self: - self._start_if_needed() + with self._lock: + self._start_if_needed() return self def __exit__(self, exc_type, exc_value, traceback) -> bool | None: diff --git a/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py b/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py index bf548f69cf..0886b70ee5 100644 --- a/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py +++ b/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py @@ -230,7 +230,7 @@ class TestStreamsSubscription: if self._calls == 1: key = next(iter(streams)) return [(key, [("1-0", self._fields)])] - subscription._closed.set() + subscription._closed = True return [] subscription = _StreamsSubscription(OneShotRedis(case.fields), "stream:payload-shape") @@ -244,7 +244,6 @@ class TestStreamsSubscription: received.append(bytes(item)) assert received == case.expected_messages - assert subscription._last_id == "1-0" def test_iterator_yields_messages_until_subscription_is_closed(self, streams_channel: StreamsBroadcastChannel): topic = streams_channel.topic("iter") @@ -301,7 +300,7 @@ class TestStreamsSubscription: def test_start_if_needed_returns_immediately_for_closed_subscription(self): subscription = _StreamsSubscription(FakeStreamsRedis(), "stream:already-closed") - subscription._closed.set() + subscription._closed = True subscription._start_if_needed() @@ -316,7 +315,7 @@ class TestStreamsSubscription: def fake_receive(timeout: float | None = 0.1) -> bytes | None: value = next(items) if value is not None: - subscription._closed.set() + subscription._closed = True return value subscription.receive = fake_receive # type: ignore[method-assign]