fix(api): fix concurrency issues in StreamsBroadcastChannel (#34061)

This commit is contained in:
QuantumGhost 2026-03-25 15:47:31 +08:00 committed by GitHub
parent b4af0d0f9a
commit 1789988be7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 75 additions and 40 deletions

View File

@ -63,24 +63,45 @@ class _StreamsSubscription(Subscription):
def __init__(self, client: Redis | RedisCluster, key: str):
self._client = client
self._key = key
self._closed = threading.Event()
# Setting initial last id to `$` to signal redis that we only want new messages.
#
# ref: https://redis.io/docs/latest/commands/xread/#the-special--id
self._last_id = "$"
self._queue: queue.Queue[object] = queue.Queue()
self._start_lock = threading.Lock()
# The `_lock` lock is used to
#
# 1. protect the _listener attribute
# 2. prevent repeated releases of underlying resoueces. (The _closed flag.)
#
# INVARIANT: the implementation must hold the lock while
# reading and writing the _listener / `_closed` attribute.
self._lock = threading.Lock()
self._closed: bool = False
# self._closed = threading.Event()
self._listener: threading.Thread | None = None
def _listen(self) -> None:
try:
while not self._closed.is_set():
streams = self._client.xread({self._key: self._last_id}, block=1000, count=100)
"""The `_listen` method handles the message retrieval loop. It requires a dedicated thread
and is not intended for direct invocation.
The thread is started by `_start_if_needed`.
"""
# since this method runs in a dedicated thread, acquiring `_lock` inside this method won't cause
# deadlock.
# Setting initial last id to `$` to signal redis that we only want new messages.
#
# ref: https://redis.io/docs/latest/commands/xread/#the-special--id
last_id = "$"
try:
while True:
with self._lock:
if self._closed:
break
streams = self._client.xread({self._key: last_id}, block=1000, count=100)
if not streams:
continue
for _key, entries in streams:
for _, entries in streams:
for entry_id, fields in entries:
data = None
if isinstance(fields, dict):
@ -92,37 +113,48 @@ class _StreamsSubscription(Subscription):
data_bytes = bytes(data)
if data_bytes is not None:
self._queue.put_nowait(data_bytes)
self._last_id = entry_id
last_id = entry_id
finally:
self._queue.put_nowait(self._SENTINEL)
self._listener = None
with self._lock:
self._listener = None
self._closed = True
def _start_if_needed(self) -> None:
"""This method must be called with `_lock` held."""
if self._listener is not None:
return
# Ensure only one listener thread is created under concurrent calls
with self._start_lock:
if self._listener is not None or self._closed.is_set():
return
self._listener = threading.Thread(
target=self._listen,
name=f"redis-streams-sub-{self._key}",
daemon=True,
)
self._listener.start()
if self._listener is not None or self._closed:
return
self._listener = threading.Thread(
target=self._listen,
name=f"redis-streams-sub-{self._key}",
daemon=True,
)
self._listener.start()
def __iter__(self) -> Iterator[bytes]:
# Iterator delegates to receive with timeout; stops on closure.
self._start_if_needed()
while not self._closed.is_set():
item = self.receive(timeout=1)
with self._lock:
self._start_if_needed()
while True:
with self._lock:
if self._closed:
return
try:
item = self.receive(timeout=1)
except SubscriptionClosedError:
return
if item is not None:
yield item
def receive(self, timeout: float | None = 0.1) -> bytes | None:
if self._closed.is_set():
raise SubscriptionClosedError("The Redis streams subscription is closed")
self._start_if_needed()
with self._lock:
if self._closed:
raise SubscriptionClosedError("The Redis streams subscription is closed")
self._start_if_needed()
try:
if timeout is None:
@ -132,29 +164,33 @@ class _StreamsSubscription(Subscription):
except queue.Empty:
return None
if item is self._SENTINEL or self._closed.is_set():
if item is self._SENTINEL:
raise SubscriptionClosedError("The Redis streams subscription is closed")
assert isinstance(item, (bytes, bytearray)), "Unexpected item type in stream queue"
return bytes(item)
def close(self) -> None:
if self._closed.is_set():
return
self._closed.set()
listener = self._listener
if listener is not None:
with self._lock:
if self._closed:
return
self._closed = True
listener = self._listener
if listener is not None:
self._listener = None
# We close the listener outside of the with block to avoid holding the
# lock for a long time.
if listener is not None and listener.is_alive():
listener.join(timeout=2.0)
if listener.is_alive():
logger.warning(
"Streams subscription listener for key %s did not stop within timeout; keeping reference.",
self._key,
)
else:
self._listener = None
# Context manager helpers
def __enter__(self) -> Self:
self._start_if_needed()
with self._lock:
self._start_if_needed()
return self
def __exit__(self, exc_type, exc_value, traceback) -> bool | None:

View File

@ -230,7 +230,7 @@ class TestStreamsSubscription:
if self._calls == 1:
key = next(iter(streams))
return [(key, [("1-0", self._fields)])]
subscription._closed.set()
subscription._closed = True
return []
subscription = _StreamsSubscription(OneShotRedis(case.fields), "stream:payload-shape")
@ -244,7 +244,6 @@ class TestStreamsSubscription:
received.append(bytes(item))
assert received == case.expected_messages
assert subscription._last_id == "1-0"
def test_iterator_yields_messages_until_subscription_is_closed(self, streams_channel: StreamsBroadcastChannel):
topic = streams_channel.topic("iter")
@ -301,7 +300,7 @@ class TestStreamsSubscription:
def test_start_if_needed_returns_immediately_for_closed_subscription(self):
subscription = _StreamsSubscription(FakeStreamsRedis(), "stream:already-closed")
subscription._closed.set()
subscription._closed = True
subscription._start_if_needed()
@ -316,7 +315,7 @@ class TestStreamsSubscription:
def fake_receive(timeout: float | None = 0.1) -> bytes | None:
value = next(items)
if value is not None:
subscription._closed.set()
subscription._closed = True
return value
subscription.receive = fake_receive # type: ignore[method-assign]