diff --git a/api/configs/middleware/cache/redis_pubsub_config.py b/api/configs/middleware/cache/redis_pubsub_config.py index a72e1dd28f..8cddc5677a 100644 --- a/api/configs/middleware/cache/redis_pubsub_config.py +++ b/api/configs/middleware/cache/redis_pubsub_config.py @@ -1,7 +1,7 @@ from typing import Literal, Protocol from urllib.parse import quote_plus, urlunparse -from pydantic import Field +from pydantic import AliasChoices, Field from pydantic_settings import BaseSettings @@ -23,41 +23,56 @@ class RedisConfigDefaultsMixin: class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin): """ - Configuration settings for Redis pub/sub streaming. + Configuration settings for event transport between API and workers. + + Supported transports: + - pubsub: Redis PUBLISH/SUBSCRIBE (at-most-once) + - sharded: Redis 7+ Sharded Pub/Sub (at-most-once, better scaling) + - streams: Redis Streams (at-least-once, supports late subscribers) """ PUBSUB_REDIS_URL: str | None = Field( - alias="PUBSUB_REDIS_URL", + validation_alias=AliasChoices("EVENT_BUS_REDIS_URL", "PUBSUB_REDIS_URL"), description=( - "Redis connection URL for pub/sub streaming events between API " - "and celery worker, defaults to url constructed from " - "`REDIS_*` configurations" + "Redis connection URL for streaming events between API and celery worker; " + "defaults to URL constructed from `REDIS_*` configurations. Also accepts ENV: EVENT_BUS_REDIS_URL." ), default=None, ) PUBSUB_REDIS_USE_CLUSTERS: bool = Field( + validation_alias=AliasChoices("EVENT_BUS_REDIS_CLUSTERS", "PUBSUB_REDIS_USE_CLUSTERS"), description=( - "Enable Redis Cluster mode for pub/sub streaming. It's highly " - "recommended to enable this for large deployments." + "Enable Redis Cluster mode for pub/sub or streams transport. Recommended for large deployments. " + "Also accepts ENV: EVENT_BUS_REDIS_CLUSTERS." ), default=False, ) - PUBSUB_REDIS_CHANNEL_TYPE: Literal["pubsub", "sharded"] = Field( + PUBSUB_REDIS_CHANNEL_TYPE: Literal["pubsub", "sharded", "streams"] = Field( + validation_alias=AliasChoices("EVENT_BUS_REDIS_CHANNEL_TYPE", "PUBSUB_REDIS_CHANNEL_TYPE"), description=( - "Pub/sub channel type for streaming events. " - "Valid options are:\n" - "\n" - " - pubsub: for normal Pub/Sub\n" - " - sharded: for sharded Pub/Sub\n" - "\n" - "It's highly recommended to use sharded Pub/Sub AND redis cluster " - "for large deployments." + "Event transport type. Options are:\n\n" + " - pubsub: normal Pub/Sub (at-most-once)\n" + " - sharded: sharded Pub/Sub (at-most-once)\n" + " - streams: Redis Streams (at-least-once, recommended to avoid subscriber races)\n\n" + "Note: Before enabling 'streams' in production, estimate your expected event volume and retention needs.\n" + "Configure Redis memory limits and stream trimming appropriately (e.g., MAXLEN and key expiry) to reduce\n" + "the risk of data loss from Redis auto-eviction under memory pressure.\n" + "Also accepts ENV: EVENT_BUS_REDIS_CHANNEL_TYPE." ), default="pubsub", ) + PUBSUB_STREAMS_RETENTION_SECONDS: int = Field( + validation_alias=AliasChoices("EVENT_BUS_STREAMS_RETENTION_SECONDS", "PUBSUB_STREAMS_RETENTION_SECONDS"), + description=( + "When using 'streams', expire each stream key this many seconds after the last event is published. " + "Also accepts ENV: EVENT_BUS_STREAMS_RETENTION_SECONDS." + ), + default=600, + ) + def _build_default_pubsub_url(self) -> str: defaults = self._redis_defaults() if not defaults.REDIS_HOST or not defaults.REDIS_PORT: diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index cadd9cb263..26262484f9 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -18,6 +18,7 @@ from dify_app import DifyApp from libs.broadcast_channel.channel import BroadcastChannel as BroadcastChannelProtocol from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel +from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel if TYPE_CHECKING: from redis.lock import Lock @@ -288,6 +289,11 @@ def get_pubsub_broadcast_channel() -> BroadcastChannelProtocol: assert _pubsub_redis_client is not None, "PubSub redis Client should be initialized here." if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "sharded": return ShardedRedisBroadcastChannel(_pubsub_redis_client) + if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "streams": + return StreamsBroadcastChannel( + _pubsub_redis_client, + retention_seconds=dify_config.PUBSUB_STREAMS_RETENTION_SECONDS, + ) return RedisBroadcastChannel(_pubsub_redis_client) diff --git a/api/libs/broadcast_channel/redis/streams_channel.py b/api/libs/broadcast_channel/redis/streams_channel.py new file mode 100644 index 0000000000..d6ec5504ca --- /dev/null +++ b/api/libs/broadcast_channel/redis/streams_channel.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +import logging +import queue +import threading +from collections.abc import Iterator +from typing import Self + +from libs.broadcast_channel.channel import Producer, Subscriber, Subscription +from libs.broadcast_channel.exc import SubscriptionClosedError +from redis import Redis, RedisCluster + +logger = logging.getLogger(__name__) + + +class StreamsBroadcastChannel: + """ + Redis Streams based broadcast channel implementation. + + Characteristics: + - At-least-once delivery for late subscribers within the stream retention window. + - Each topic is stored as a dedicated Redis Stream key. + - The stream key expires `retention_seconds` after the last event is published (to bound storage). + """ + + def __init__(self, redis_client: Redis | RedisCluster, *, retention_seconds: int = 600): + self._client = redis_client + self._retention_seconds = max(int(retention_seconds or 0), 0) + + def topic(self, topic: str) -> StreamsTopic: + return StreamsTopic(self._client, topic, retention_seconds=self._retention_seconds) + + +class StreamsTopic: + def __init__(self, redis_client: Redis | RedisCluster, topic: str, *, retention_seconds: int = 600): + self._client = redis_client + self._topic = topic + self._key = f"stream:{topic}" + self._retention_seconds = retention_seconds + self.max_length = 5000 + + def as_producer(self) -> Producer: + return self + + def publish(self, payload: bytes) -> None: + self._client.xadd(self._key, {b"data": payload}, maxlen=self.max_length) + if self._retention_seconds > 0: + try: + self._client.expire(self._key, self._retention_seconds) + except Exception as e: + logger.warning("Failed to set expire for stream key %s: %s", self._key, e, exc_info=True) + + def as_subscriber(self) -> Subscriber: + return self + + def subscribe(self) -> Subscription: + return _StreamsSubscription(self._client, self._key) + + +class _StreamsSubscription(Subscription): + _SENTINEL = object() + + def __init__(self, client: Redis | RedisCluster, key: str): + self._client = client + self._key = key + self._closed = threading.Event() + self._last_id = "0-0" + self._queue: queue.Queue[object] = queue.Queue() + self._start_lock = threading.Lock() + self._listener: threading.Thread | None = None + + def _listen(self) -> None: + try: + while not self._closed.is_set(): + streams = self._client.xread({self._key: self._last_id}, block=1000, count=100) + + if not streams: + continue + + for _key, entries in streams: + for entry_id, fields in entries: + data = None + if isinstance(fields, dict): + data = fields.get(b"data") + data_bytes: bytes | None = None + if isinstance(data, str): + data_bytes = data.encode() + elif isinstance(data, (bytes, bytearray)): + data_bytes = bytes(data) + if data_bytes is not None: + self._queue.put_nowait(data_bytes) + self._last_id = entry_id + finally: + self._queue.put_nowait(self._SENTINEL) + self._listener = None + + def _start_if_needed(self) -> None: + if self._listener is not None: + return + # Ensure only one listener thread is created under concurrent calls + with self._start_lock: + if self._listener is not None or self._closed.is_set(): + return + self._listener = threading.Thread( + target=self._listen, + name=f"redis-streams-sub-{self._key}", + daemon=True, + ) + self._listener.start() + + def __iter__(self) -> Iterator[bytes]: + # Iterator delegates to receive with timeout; stops on closure. + self._start_if_needed() + while not self._closed.is_set(): + item = self.receive(timeout=1) + if item is not None: + yield item + + def receive(self, timeout: float | None = 0.1) -> bytes | None: + if self._closed.is_set(): + raise SubscriptionClosedError("The Redis streams subscription is closed") + self._start_if_needed() + + try: + if timeout is None: + item = self._queue.get() + else: + item = self._queue.get(timeout=timeout) + except queue.Empty: + return None + + if item is self._SENTINEL or self._closed.is_set(): + raise SubscriptionClosedError("The Redis streams subscription is closed") + assert isinstance(item, (bytes, bytearray)), "Unexpected item type in stream queue" + return bytes(item) + + def close(self) -> None: + if self._closed.is_set(): + return + self._closed.set() + listener = self._listener + if listener is not None: + listener.join(timeout=2.0) + if listener.is_alive(): + logger.warning( + "Streams subscription listener for key %s did not stop within timeout; keeping reference.", + self._key, + ) + else: + self._listener = None + + # Context manager helpers + def __enter__(self) -> Self: + self._start_if_needed() + return self + + def __exit__(self, exc_type, exc_value, traceback) -> bool | None: + self.close() + return None diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 31003cb8f7..40013f2b66 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -38,6 +38,13 @@ if TYPE_CHECKING: class AppGenerateService: @staticmethod def _build_streaming_task_on_subscribe(start_task: Callable[[], None]) -> Callable[[], None]: + """ + Build a subscription callback that coordinates when the background task starts. + + - streams transport: start immediately (events are durable; late subscribers can replay). + - pubsub/sharded transport: start on first subscribe, with a short fallback timer so the task + still runs if the client never connects. + """ started = False lock = threading.Lock() @@ -54,10 +61,18 @@ class AppGenerateService: started = True return True - # XXX(QuantumGhost): dirty hacks to avoid a race between publisher and SSE subscriber. - # The Celery task may publish the first event before the API side actually subscribes, - # causing an "at most once" drop with Redis Pub/Sub. We start the task on subscribe, - # but also use a short fallback timer so the task still runs if the client never consumes. + channel_type = dify_config.PUBSUB_REDIS_CHANNEL_TYPE + if channel_type == "streams": + # With Redis Streams, we can safely start right away; consumers can read past events. + _try_start() + + # Keep return type Callable[[], None] consistent while allowing an extra (no-op) call. + def _on_subscribe_streams() -> None: + _try_start() + + return _on_subscribe_streams + + # Pub/Sub modes (at-most-once): subscribe-gated start with a tiny fallback. timer = threading.Timer(SSE_TASK_START_FALLBACK_MS / 1000.0, _try_start) timer.daemon = True timer.start() diff --git a/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py b/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py new file mode 100644 index 0000000000..248aa0b145 --- /dev/null +++ b/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py @@ -0,0 +1,145 @@ +import time + +import pytest + +from libs.broadcast_channel.redis.streams_channel import ( + StreamsBroadcastChannel, + StreamsTopic, + _StreamsSubscription, +) + + +class FakeStreamsRedis: + """Minimal in-memory Redis Streams stub for unit tests. + + - Stores entries per key as [(id, {b"data": bytes}), ...] + - xadd appends entries and returns an auto-increment id like "1-0" + - xread returns entries strictly greater than last_id + - expire is recorded but has no effect on behavior + """ + + def __init__(self) -> None: + self._store: dict[str, list[tuple[str, dict]]] = {} + self._next_id: dict[str, int] = {} + self._expire_calls: dict[str, int] = {} + + # Publisher API + def xadd(self, key: str, fields: dict, *, maxlen: int | None = None) -> str: + """Append entry to stream; accept optional maxlen for API compatibility. + + The test double ignores maxlen trimming semantics; only records the entry. + """ + n = self._next_id.get(key, 0) + 1 + self._next_id[key] = n + entry_id = f"{n}-0" + self._store.setdefault(key, []).append((entry_id, fields)) + return entry_id + + def expire(self, key: str, seconds: int) -> None: + self._expire_calls[key] = self._expire_calls.get(key, 0) + 1 + + # Consumer API + def xread(self, streams: dict, block: int | None = None, count: int | None = None): + # Expect a single key + assert len(streams) == 1 + key, last_id = next(iter(streams.items())) + entries = self._store.get(key, []) + + # Find position strictly greater than last_id + start_idx = 0 + if last_id != "0-0": + for i, (eid, _f) in enumerate(entries): + if eid == last_id: + start_idx = i + 1 + break + if start_idx >= len(entries): + # Simulate blocking wait (bounded) if requested + if block and block > 0: + time.sleep(min(0.01, block / 1000.0)) + return [] + + end_idx = len(entries) if count is None else min(len(entries), start_idx + count) + batch = entries[start_idx:end_idx] + return [(key, batch)] + + +@pytest.fixture +def fake_redis() -> FakeStreamsRedis: + return FakeStreamsRedis() + + +@pytest.fixture +def streams_channel(fake_redis: FakeStreamsRedis) -> StreamsBroadcastChannel: + return StreamsBroadcastChannel(fake_redis, retention_seconds=60) + + +class TestStreamsBroadcastChannel: + def test_topic_creation(self, streams_channel: StreamsBroadcastChannel, fake_redis: FakeStreamsRedis): + topic = streams_channel.topic("alpha") + assert isinstance(topic, StreamsTopic) + assert topic._client is fake_redis + assert topic._topic == "alpha" + assert topic._key == "stream:alpha" + + def test_publish_calls_xadd_and_expire( + self, + streams_channel: StreamsBroadcastChannel, + fake_redis: FakeStreamsRedis, + ): + topic = streams_channel.topic("beta") + payload = b"hello" + topic.publish(payload) + # One entry stored under stream key (bytes key for payload field) + assert fake_redis._store["stream:beta"][0][1] == {b"data": payload} + # Expire called after publish + assert fake_redis._expire_calls.get("stream:beta", 0) >= 1 + + +class TestStreamsSubscription: + def test_subscribe_and_receive_from_beginning(self, streams_channel: StreamsBroadcastChannel): + topic = streams_channel.topic("gamma") + # Pre-publish events before subscribing (late subscriber) + topic.publish(b"e1") + topic.publish(b"e2") + + sub = topic.subscribe() + assert isinstance(sub, _StreamsSubscription) + + received: list[bytes] = [] + with sub: + # Give listener thread a moment to xread + time.sleep(0.05) + # Drain using receive() to avoid indefinite iteration in tests + for _ in range(5): + msg = sub.receive(timeout=0.1) + if msg is None: + break + received.append(msg) + + assert received == [b"e1", b"e2"] + + def test_receive_timeout_returns_none(self, streams_channel: StreamsBroadcastChannel): + topic = streams_channel.topic("delta") + sub = topic.subscribe() + with sub: + # No messages yet + assert sub.receive(timeout=0.05) is None + + def test_close_stops_listener(self, streams_channel: StreamsBroadcastChannel): + topic = streams_channel.topic("epsilon") + sub = topic.subscribe() + with sub: + # Listener running; now close and ensure no crash + sub.close() + # After close, receive should raise SubscriptionClosedError + from libs.broadcast_channel.exc import SubscriptionClosedError + + with pytest.raises(SubscriptionClosedError): + sub.receive() + + def test_no_expire_when_zero_retention(self, fake_redis: FakeStreamsRedis): + channel = StreamsBroadcastChannel(fake_redis, retention_seconds=0) + topic = channel.topic("zeta") + topic.publish(b"payload") + # No expire recorded when retention is disabled + assert fake_redis._expire_calls.get("stream:zeta") is None diff --git a/api/tests/unit_tests/services/test_app_generate_service_streaming_integration.py b/api/tests/unit_tests/services/test_app_generate_service_streaming_integration.py new file mode 100644 index 0000000000..e66d52f66b --- /dev/null +++ b/api/tests/unit_tests/services/test_app_generate_service_streaming_integration.py @@ -0,0 +1,197 @@ +import json +import uuid +from collections import defaultdict, deque + +import pytest + +from core.app.apps.message_generator import MessageGenerator +from models.model import AppMode +from services.app_generate_service import AppGenerateService + + +# ----------------------------- +# Fakes for Redis Pub/Sub flow +# ----------------------------- +class _FakePubSub: + def __init__(self, store: dict[str, deque[bytes]]): + self._store = store + self._subs: set[str] = set() + self._closed = False + + def subscribe(self, topic: str) -> None: + self._subs.add(topic) + + def unsubscribe(self, topic: str) -> None: + self._subs.discard(topic) + + def close(self) -> None: + self._closed = True + + def get_message(self, ignore_subscribe_messages: bool = True, timeout: int | float | None = 1): + # simulate a non-blocking poll; return first available + if self._closed: + return None + for t in list(self._subs): + q = self._store.get(t) + if q and len(q) > 0: + payload = q.popleft() + return {"type": "message", "channel": t, "data": payload} + # no message + return None + + +class _FakeRedisClient: + def __init__(self, store: dict[str, deque[bytes]]): + self._store = store + + def pubsub(self): + return _FakePubSub(self._store) + + def publish(self, topic: str, payload: bytes) -> None: + self._store.setdefault(topic, deque()).append(payload) + + +# ------------------------------------ +# Fakes for Redis Streams (XADD/XREAD) +# ------------------------------------ +class _FakeStreams: + def __init__(self) -> None: + # key -> list[(id, {field: value})] + self._data: dict[str, list[tuple[str, dict]]] = defaultdict(list) + self._seq: dict[str, int] = defaultdict(int) + + def xadd(self, key: str, fields: dict, *, maxlen: int | None = None) -> str: + # maxlen is accepted for API compatibility with redis-py; ignored in this test double + self._seq[key] += 1 + eid = f"{self._seq[key]}-0" + self._data[key].append((eid, fields)) + return eid + + def expire(self, key: str, seconds: int) -> None: + # no-op for tests + return None + + def xread(self, streams: dict, block: int | None = None, count: int | None = None): + assert len(streams) == 1 + key, last_id = next(iter(streams.items())) + entries = self._data.get(key, []) + start = 0 + if last_id != "0-0": + for i, (eid, _f) in enumerate(entries): + if eid == last_id: + start = i + 1 + break + if start >= len(entries): + return [] + end = len(entries) if count is None else min(len(entries), start + count) + return [(key, entries[start:end])] + + +@pytest.fixture +def _patch_get_channel_streams(monkeypatch): + from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel + + fake = _FakeStreams() + chan = StreamsBroadcastChannel(fake, retention_seconds=60) + + def _get_channel(): + return chan + + # Patch both the source and the imported alias used by MessageGenerator + monkeypatch.setattr("extensions.ext_redis.get_pubsub_broadcast_channel", lambda: chan) + monkeypatch.setattr("core.app.apps.message_generator.get_pubsub_broadcast_channel", lambda: chan) + # Ensure AppGenerateService sees streams mode + import services.app_generate_service as ags + + monkeypatch.setattr(ags.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "streams", raising=False) + + +@pytest.fixture +def _patch_get_channel_pubsub(monkeypatch): + from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel + + store: dict[str, deque[bytes]] = defaultdict(deque) + client = _FakeRedisClient(store) + chan = RedisBroadcastChannel(client) + + def _get_channel(): + return chan + + # Patch both the source and the imported alias used by MessageGenerator + monkeypatch.setattr("extensions.ext_redis.get_pubsub_broadcast_channel", lambda: chan) + monkeypatch.setattr("core.app.apps.message_generator.get_pubsub_broadcast_channel", lambda: chan) + # Ensure AppGenerateService sees pubsub mode + import services.app_generate_service as ags + + monkeypatch.setattr(ags.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "pubsub", raising=False) + + +def _publish_events(app_mode: AppMode, run_id: str, events: list[dict]): + # Publish events to the same topic used by MessageGenerator + topic = MessageGenerator.get_response_topic(app_mode, run_id) + for ev in events: + topic.publish(json.dumps(ev).encode()) + + +@pytest.mark.usefixtures("_patch_get_channel_streams") +def test_streams_full_flow_prepublish_and_replay(): + app_mode = AppMode.WORKFLOW + run_id = str(uuid.uuid4()) + + # Build start_task that publishes two events immediately + events = [{"event": "workflow_started"}, {"event": "workflow_finished"}] + + def start_task(): + _publish_events(app_mode, run_id, events) + + on_subscribe = AppGenerateService._build_streaming_task_on_subscribe(start_task) + + # Start retrieving BEFORE subscription is established; in streams mode, we also started immediately + gen = MessageGenerator.retrieve_events(app_mode, run_id, idle_timeout=2.0, on_subscribe=on_subscribe) + + received = [] + for msg in gen: + if isinstance(msg, str): + # skip ping events + continue + received.append(msg) + if msg.get("event") == "workflow_finished": + break + + assert [m.get("event") for m in received] == ["workflow_started", "workflow_finished"] + + +@pytest.mark.usefixtures("_patch_get_channel_pubsub") +def test_pubsub_full_flow_start_on_subscribe_gated(monkeypatch): + # Speed up any potential timer if it accidentally triggers + monkeypatch.setattr("services.app_generate_service.SSE_TASK_START_FALLBACK_MS", 50) + + app_mode = AppMode.WORKFLOW + run_id = str(uuid.uuid4()) + + published_order: list[str] = [] + + def start_task(): + # When called (on subscribe), publish both events + events = [{"event": "workflow_started"}, {"event": "workflow_finished"}] + _publish_events(app_mode, run_id, events) + published_order.extend([e["event"] for e in events]) + + on_subscribe = AppGenerateService._build_streaming_task_on_subscribe(start_task) + + # Producer not started yet; only when subscribe happens + assert published_order == [] + + gen = MessageGenerator.retrieve_events(app_mode, run_id, idle_timeout=2.0, on_subscribe=on_subscribe) + + received = [] + for msg in gen: + if isinstance(msg, str): + continue + received.append(msg) + if msg.get("event") == "workflow_finished": + break + + # Verify publish happened and consumer received in order + assert published_order == ["workflow_started", "workflow_finished"] + assert [m.get("event") for m in received] == ["workflow_started", "workflow_finished"]