diff --git a/api/libs/broadcast_channel/redis/_subscription.py b/api/libs/broadcast_channel/redis/_subscription.py index 912a48d26ae..01a9e668bcc 100644 --- a/api/libs/broadcast_channel/redis/_subscription.py +++ b/api/libs/broadcast_channel/redis/_subscription.py @@ -165,14 +165,20 @@ class RedisSubscriptionBase(Subscription): except queue.Empty: continue + if self._closed.is_set(): + return + yield item @override def __iter__(self) -> Iterator[bytes]: """Return an iterator over messages from the subscription.""" if self._closed.is_set(): - raise SubscriptionClosedError(f"The Redis {self._get_subscription_type()} subscription is closed") - self._start_if_needed() + return iter(()) + try: + self._start_if_needed() + except SubscriptionClosedError: + return iter(()) return iter(self._message_iterator()) @override @@ -209,10 +215,18 @@ class RedisSubscriptionBase(Subscription): @override def close(self) -> None: """Close the subscription and clean up resources.""" - if self._closed.is_set(): - return + with self._start_lock: + if self._closed.is_set(): + return + + self._closed.set() + listener = self._listener_thread + self._listener_thread = None + started = self._started + + if started: + self._unblock_message_iterator() - self._closed.set() # Send a control event on the same Redis channel to unblock the self._publish_close_event() @@ -220,10 +234,21 @@ class RedisSubscriptionBase(Subscription): # message retrieval method should NOT be called concurrently. # # Due to the restriction above, the PubSub cleanup logic happens inside the consumer thread. - listener = self._listener_thread - if listener is not None: + if listener is not None and listener.is_alive(): listener.join(timeout=2) - self._listener_thread = None + + def _unblock_message_iterator(self) -> None: + try: + self._queue.put_nowait(SIG_CLOSE) + except queue.Full: + try: + self._queue.get_nowait() + except queue.Empty: + pass + try: + self._queue.put_nowait(SIG_CLOSE) + except queue.Full: + pass # Abstract methods to be implemented by subclasses def _get_subscription_type(self) -> str: diff --git a/api/services/enterprise/rbac_service.py b/api/services/enterprise/rbac_service.py index dd1a50157f7..be94925b94a 100644 --- a/api/services/enterprise/rbac_service.py +++ b/api/services/enterprise/rbac_service.py @@ -770,6 +770,7 @@ class RBACService: data = _inner_call( "GET", f"{_INNER_PREFIX}/role-permissions/catalog", + params={"billing_enabled": dify_config.BILLING_ENABLED}, tenant_id=tenant_id, account_id=account_id, ) @@ -1585,7 +1586,7 @@ class RBACService: account_id=member_account_id, roles=[ RBACRole( - id="", + id=role, name=role, description="", is_builtin=True, diff --git a/api/tests/unit_tests/libs/broadcast_channel/redis/test_channel_unit_tests.py b/api/tests/unit_tests/libs/broadcast_channel/redis/test_channel_unit_tests.py index 7ab54555294..b74d494134b 100644 --- a/api/tests/unit_tests/libs/broadcast_channel/redis/test_channel_unit_tests.py +++ b/api/tests/unit_tests/libs/broadcast_channel/redis/test_channel_unit_tests.py @@ -17,7 +17,7 @@ from unittest.mock import MagicMock, patch import pytest -from libs.broadcast_channel.exc import BroadcastChannelError, SubscriptionClosedError +from libs.broadcast_channel.exc import SubscriptionClosedError from libs.broadcast_channel.redis.pubsub_channel import ( BroadcastChannel as RedisBroadcastChannel, ) @@ -395,11 +395,10 @@ class TestRedisSubscription: assert received_messages == test_messages def test_message_iterator_when_closed(self, subscription: _RedisSubscription): - """Test that iterator raises error when subscription is closed.""" + """Test that iterator stops when subscription is closed.""" subscription.close() - with pytest.raises(BroadcastChannelError, match="The Redis regular subscription is closed"): - iter(subscription) + assert list(subscription) == [] # ==================== Message Enqueue Tests ==================== @@ -616,8 +615,15 @@ class TestRedisSubscription: """Test iterator behavior after close.""" subscription.close() - with pytest.raises(SubscriptionClosedError, match="The Redis regular subscription is closed"): - iter(subscription) + assert list(subscription) == [] + + def test_close_does_not_join_unstarted_listener_thread(self, subscription: _RedisSubscription): + """close() should tolerate a listener object that has not been started yet.""" + subscription._listener_thread = threading.Thread(target=lambda: None) + + subscription.close() + + assert subscription._listener_thread is None def test_start_after_close(self, subscription: _RedisSubscription): """Test start attempts after close.""" @@ -818,11 +824,10 @@ class TestRedisShardedSubscription: assert received_messages == test_messages def test_message_iterator_when_closed(self, sharded_subscription: _RedisShardedSubscription): - """Test that iterator raises error when sharded subscription is closed.""" + """Test that iterator stops when sharded subscription is closed.""" sharded_subscription.close() - with pytest.raises(SubscriptionClosedError, match="The Redis sharded subscription is closed"): - iter(sharded_subscription) + assert list(sharded_subscription) == [] # ==================== Message Enqueue Tests ==================== @@ -1093,8 +1098,7 @@ class TestRedisShardedSubscription: """Test iterator behavior after close for sharded subscription.""" sharded_subscription.close() - with pytest.raises(SubscriptionClosedError, match="The Redis sharded subscription is closed"): - iter(sharded_subscription) + assert list(sharded_subscription) == [] def test_start_after_close(self, sharded_subscription: _RedisShardedSubscription): """Test start attempts after close for sharded subscription.""" @@ -1312,12 +1316,10 @@ class TestRedisSubscriptionCommon: assert received_messages == test_messages def test_message_iterator_when_closed(self, subscription, subscription_params): - """Test that iterator raises error when subscription is closed.""" - subscription_type, _ = subscription_params + """Test that iterator stops when subscription is closed.""" subscription.close() - with pytest.raises(SubscriptionClosedError, match=f"The Redis {subscription_type} subscription is closed"): - iter(subscription) + assert list(subscription) == [] # ==================== Message Enqueue Tests ==================== @@ -1390,11 +1392,9 @@ class TestRedisSubscriptionCommon: def test_iterator_after_close(self, subscription, subscription_params): """Test iterator behavior after close.""" - subscription_type, _ = subscription_params subscription.close() - with pytest.raises(SubscriptionClosedError, match=f"The Redis {subscription_type} subscription is closed"): - iter(subscription) + assert list(subscription) == [] def test_start_after_close(self, subscription, subscription_params): """Test start attempts after close.""" diff --git a/api/tests/unit_tests/services/enterprise/test_rbac_service.py b/api/tests/unit_tests/services/enterprise/test_rbac_service.py index 35f0d3ac674..b43c01778eb 100644 --- a/api/tests/unit_tests/services/enterprise/test_rbac_service.py +++ b/api/tests/unit_tests/services/enterprise/test_rbac_service.py @@ -46,7 +46,7 @@ class TestCatalog: assert call.tenant_id == "tenant-1" assert call.account_id == "acct-1" assert call.json is None - assert call.params is None + assert call.params == {"billing_enabled": svc.dify_config.BILLING_ENABLED} assert len(out.groups) == 1 assert out.groups[0].group_key == "workspace"