feat: support redis xstream (#32586)

This commit is contained in:
wangxiaolei 2026-03-04 13:18:55 +08:00 committed by GitHub
parent e14b09d4db
commit 2f4c740d46
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 558 additions and 21 deletions

View File

@ -1,7 +1,7 @@
from typing import Literal, Protocol
from urllib.parse import quote_plus, urlunparse
from pydantic import Field
from pydantic import AliasChoices, Field
from pydantic_settings import BaseSettings
@ -23,41 +23,56 @@ class RedisConfigDefaultsMixin:
class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
"""
Configuration settings for Redis pub/sub streaming.
Configuration settings for event transport between API and workers.
Supported transports:
- pubsub: Redis PUBLISH/SUBSCRIBE (at-most-once)
- sharded: Redis 7+ Sharded Pub/Sub (at-most-once, better scaling)
- streams: Redis Streams (at-least-once, supports late subscribers)
"""
PUBSUB_REDIS_URL: str | None = Field(
alias="PUBSUB_REDIS_URL",
validation_alias=AliasChoices("EVENT_BUS_REDIS_URL", "PUBSUB_REDIS_URL"),
description=(
"Redis connection URL for pub/sub streaming events between API "
"and celery worker, defaults to url constructed from "
"`REDIS_*` configurations"
"Redis connection URL for streaming events between API and celery worker; "
"defaults to URL constructed from `REDIS_*` configurations. Also accepts ENV: EVENT_BUS_REDIS_URL."
),
default=None,
)
PUBSUB_REDIS_USE_CLUSTERS: bool = Field(
validation_alias=AliasChoices("EVENT_BUS_REDIS_CLUSTERS", "PUBSUB_REDIS_USE_CLUSTERS"),
description=(
"Enable Redis Cluster mode for pub/sub streaming. It's highly "
"recommended to enable this for large deployments."
"Enable Redis Cluster mode for pub/sub or streams transport. Recommended for large deployments. "
"Also accepts ENV: EVENT_BUS_REDIS_CLUSTERS."
),
default=False,
)
PUBSUB_REDIS_CHANNEL_TYPE: Literal["pubsub", "sharded"] = Field(
PUBSUB_REDIS_CHANNEL_TYPE: Literal["pubsub", "sharded", "streams"] = Field(
validation_alias=AliasChoices("EVENT_BUS_REDIS_CHANNEL_TYPE", "PUBSUB_REDIS_CHANNEL_TYPE"),
description=(
"Pub/sub channel type for streaming events. "
"Valid options are:\n"
"\n"
" - pubsub: for normal Pub/Sub\n"
" - sharded: for sharded Pub/Sub\n"
"\n"
"It's highly recommended to use sharded Pub/Sub AND redis cluster "
"for large deployments."
"Event transport type. Options are:\n\n"
" - pubsub: normal Pub/Sub (at-most-once)\n"
" - sharded: sharded Pub/Sub (at-most-once)\n"
" - streams: Redis Streams (at-least-once, recommended to avoid subscriber races)\n\n"
"Note: Before enabling 'streams' in production, estimate your expected event volume and retention needs.\n"
"Configure Redis memory limits and stream trimming appropriately (e.g., MAXLEN and key expiry) to reduce\n"
"the risk of data loss from Redis auto-eviction under memory pressure.\n"
"Also accepts ENV: EVENT_BUS_REDIS_CHANNEL_TYPE."
),
default="pubsub",
)
PUBSUB_STREAMS_RETENTION_SECONDS: int = Field(
validation_alias=AliasChoices("EVENT_BUS_STREAMS_RETENTION_SECONDS", "PUBSUB_STREAMS_RETENTION_SECONDS"),
description=(
"When using 'streams', expire each stream key this many seconds after the last event is published. "
"Also accepts ENV: EVENT_BUS_STREAMS_RETENTION_SECONDS."
),
default=600,
)
def _build_default_pubsub_url(self) -> str:
defaults = self._redis_defaults()
if not defaults.REDIS_HOST or not defaults.REDIS_PORT:

View File

@ -18,6 +18,7 @@ from dify_app import DifyApp
from libs.broadcast_channel.channel import BroadcastChannel as BroadcastChannelProtocol
from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel
from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel
from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel
if TYPE_CHECKING:
from redis.lock import Lock
@ -288,6 +289,11 @@ def get_pubsub_broadcast_channel() -> BroadcastChannelProtocol:
assert _pubsub_redis_client is not None, "PubSub redis Client should be initialized here."
if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "sharded":
return ShardedRedisBroadcastChannel(_pubsub_redis_client)
if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "streams":
return StreamsBroadcastChannel(
_pubsub_redis_client,
retention_seconds=dify_config.PUBSUB_STREAMS_RETENTION_SECONDS,
)
return RedisBroadcastChannel(_pubsub_redis_client)

View File

@ -0,0 +1,159 @@
from __future__ import annotations
import logging
import queue
import threading
from collections.abc import Iterator
from typing import Self
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
from libs.broadcast_channel.exc import SubscriptionClosedError
from redis import Redis, RedisCluster
logger = logging.getLogger(__name__)
class StreamsBroadcastChannel:
"""
Redis Streams based broadcast channel implementation.
Characteristics:
- At-least-once delivery for late subscribers within the stream retention window.
- Each topic is stored as a dedicated Redis Stream key.
- The stream key expires `retention_seconds` after the last event is published (to bound storage).
"""
def __init__(self, redis_client: Redis | RedisCluster, *, retention_seconds: int = 600):
self._client = redis_client
self._retention_seconds = max(int(retention_seconds or 0), 0)
def topic(self, topic: str) -> StreamsTopic:
return StreamsTopic(self._client, topic, retention_seconds=self._retention_seconds)
class StreamsTopic:
def __init__(self, redis_client: Redis | RedisCluster, topic: str, *, retention_seconds: int = 600):
self._client = redis_client
self._topic = topic
self._key = f"stream:{topic}"
self._retention_seconds = retention_seconds
self.max_length = 5000
def as_producer(self) -> Producer:
return self
def publish(self, payload: bytes) -> None:
self._client.xadd(self._key, {b"data": payload}, maxlen=self.max_length)
if self._retention_seconds > 0:
try:
self._client.expire(self._key, self._retention_seconds)
except Exception as e:
logger.warning("Failed to set expire for stream key %s: %s", self._key, e, exc_info=True)
def as_subscriber(self) -> Subscriber:
return self
def subscribe(self) -> Subscription:
return _StreamsSubscription(self._client, self._key)
class _StreamsSubscription(Subscription):
_SENTINEL = object()
def __init__(self, client: Redis | RedisCluster, key: str):
self._client = client
self._key = key
self._closed = threading.Event()
self._last_id = "0-0"
self._queue: queue.Queue[object] = queue.Queue()
self._start_lock = threading.Lock()
self._listener: threading.Thread | None = None
def _listen(self) -> None:
try:
while not self._closed.is_set():
streams = self._client.xread({self._key: self._last_id}, block=1000, count=100)
if not streams:
continue
for _key, entries in streams:
for entry_id, fields in entries:
data = None
if isinstance(fields, dict):
data = fields.get(b"data")
data_bytes: bytes | None = None
if isinstance(data, str):
data_bytes = data.encode()
elif isinstance(data, (bytes, bytearray)):
data_bytes = bytes(data)
if data_bytes is not None:
self._queue.put_nowait(data_bytes)
self._last_id = entry_id
finally:
self._queue.put_nowait(self._SENTINEL)
self._listener = None
def _start_if_needed(self) -> None:
if self._listener is not None:
return
# Ensure only one listener thread is created under concurrent calls
with self._start_lock:
if self._listener is not None or self._closed.is_set():
return
self._listener = threading.Thread(
target=self._listen,
name=f"redis-streams-sub-{self._key}",
daemon=True,
)
self._listener.start()
def __iter__(self) -> Iterator[bytes]:
# Iterator delegates to receive with timeout; stops on closure.
self._start_if_needed()
while not self._closed.is_set():
item = self.receive(timeout=1)
if item is not None:
yield item
def receive(self, timeout: float | None = 0.1) -> bytes | None:
if self._closed.is_set():
raise SubscriptionClosedError("The Redis streams subscription is closed")
self._start_if_needed()
try:
if timeout is None:
item = self._queue.get()
else:
item = self._queue.get(timeout=timeout)
except queue.Empty:
return None
if item is self._SENTINEL or self._closed.is_set():
raise SubscriptionClosedError("The Redis streams subscription is closed")
assert isinstance(item, (bytes, bytearray)), "Unexpected item type in stream queue"
return bytes(item)
def close(self) -> None:
if self._closed.is_set():
return
self._closed.set()
listener = self._listener
if listener is not None:
listener.join(timeout=2.0)
if listener.is_alive():
logger.warning(
"Streams subscription listener for key %s did not stop within timeout; keeping reference.",
self._key,
)
else:
self._listener = None
# Context manager helpers
def __enter__(self) -> Self:
self._start_if_needed()
return self
def __exit__(self, exc_type, exc_value, traceback) -> bool | None:
self.close()
return None

View File

@ -38,6 +38,13 @@ if TYPE_CHECKING:
class AppGenerateService:
@staticmethod
def _build_streaming_task_on_subscribe(start_task: Callable[[], None]) -> Callable[[], None]:
"""
Build a subscription callback that coordinates when the background task starts.
- streams transport: start immediately (events are durable; late subscribers can replay).
- pubsub/sharded transport: start on first subscribe, with a short fallback timer so the task
still runs if the client never connects.
"""
started = False
lock = threading.Lock()
@ -54,10 +61,18 @@ class AppGenerateService:
started = True
return True
# XXX(QuantumGhost): dirty hacks to avoid a race between publisher and SSE subscriber.
# The Celery task may publish the first event before the API side actually subscribes,
# causing an "at most once" drop with Redis Pub/Sub. We start the task on subscribe,
# but also use a short fallback timer so the task still runs if the client never consumes.
channel_type = dify_config.PUBSUB_REDIS_CHANNEL_TYPE
if channel_type == "streams":
# With Redis Streams, we can safely start right away; consumers can read past events.
_try_start()
# Keep return type Callable[[], None] consistent while allowing an extra (no-op) call.
def _on_subscribe_streams() -> None:
_try_start()
return _on_subscribe_streams
# Pub/Sub modes (at-most-once): subscribe-gated start with a tiny fallback.
timer = threading.Timer(SSE_TASK_START_FALLBACK_MS / 1000.0, _try_start)
timer.daemon = True
timer.start()

View File

@ -0,0 +1,145 @@
import time
import pytest
from libs.broadcast_channel.redis.streams_channel import (
StreamsBroadcastChannel,
StreamsTopic,
_StreamsSubscription,
)
class FakeStreamsRedis:
"""Minimal in-memory Redis Streams stub for unit tests.
- Stores entries per key as [(id, {b"data": bytes}), ...]
- xadd appends entries and returns an auto-increment id like "1-0"
- xread returns entries strictly greater than last_id
- expire is recorded but has no effect on behavior
"""
def __init__(self) -> None:
self._store: dict[str, list[tuple[str, dict]]] = {}
self._next_id: dict[str, int] = {}
self._expire_calls: dict[str, int] = {}
# Publisher API
def xadd(self, key: str, fields: dict, *, maxlen: int | None = None) -> str:
"""Append entry to stream; accept optional maxlen for API compatibility.
The test double ignores maxlen trimming semantics; only records the entry.
"""
n = self._next_id.get(key, 0) + 1
self._next_id[key] = n
entry_id = f"{n}-0"
self._store.setdefault(key, []).append((entry_id, fields))
return entry_id
def expire(self, key: str, seconds: int) -> None:
self._expire_calls[key] = self._expire_calls.get(key, 0) + 1
# Consumer API
def xread(self, streams: dict, block: int | None = None, count: int | None = None):
# Expect a single key
assert len(streams) == 1
key, last_id = next(iter(streams.items()))
entries = self._store.get(key, [])
# Find position strictly greater than last_id
start_idx = 0
if last_id != "0-0":
for i, (eid, _f) in enumerate(entries):
if eid == last_id:
start_idx = i + 1
break
if start_idx >= len(entries):
# Simulate blocking wait (bounded) if requested
if block and block > 0:
time.sleep(min(0.01, block / 1000.0))
return []
end_idx = len(entries) if count is None else min(len(entries), start_idx + count)
batch = entries[start_idx:end_idx]
return [(key, batch)]
@pytest.fixture
def fake_redis() -> FakeStreamsRedis:
return FakeStreamsRedis()
@pytest.fixture
def streams_channel(fake_redis: FakeStreamsRedis) -> StreamsBroadcastChannel:
return StreamsBroadcastChannel(fake_redis, retention_seconds=60)
class TestStreamsBroadcastChannel:
def test_topic_creation(self, streams_channel: StreamsBroadcastChannel, fake_redis: FakeStreamsRedis):
topic = streams_channel.topic("alpha")
assert isinstance(topic, StreamsTopic)
assert topic._client is fake_redis
assert topic._topic == "alpha"
assert topic._key == "stream:alpha"
def test_publish_calls_xadd_and_expire(
self,
streams_channel: StreamsBroadcastChannel,
fake_redis: FakeStreamsRedis,
):
topic = streams_channel.topic("beta")
payload = b"hello"
topic.publish(payload)
# One entry stored under stream key (bytes key for payload field)
assert fake_redis._store["stream:beta"][0][1] == {b"data": payload}
# Expire called after publish
assert fake_redis._expire_calls.get("stream:beta", 0) >= 1
class TestStreamsSubscription:
def test_subscribe_and_receive_from_beginning(self, streams_channel: StreamsBroadcastChannel):
topic = streams_channel.topic("gamma")
# Pre-publish events before subscribing (late subscriber)
topic.publish(b"e1")
topic.publish(b"e2")
sub = topic.subscribe()
assert isinstance(sub, _StreamsSubscription)
received: list[bytes] = []
with sub:
# Give listener thread a moment to xread
time.sleep(0.05)
# Drain using receive() to avoid indefinite iteration in tests
for _ in range(5):
msg = sub.receive(timeout=0.1)
if msg is None:
break
received.append(msg)
assert received == [b"e1", b"e2"]
def test_receive_timeout_returns_none(self, streams_channel: StreamsBroadcastChannel):
topic = streams_channel.topic("delta")
sub = topic.subscribe()
with sub:
# No messages yet
assert sub.receive(timeout=0.05) is None
def test_close_stops_listener(self, streams_channel: StreamsBroadcastChannel):
topic = streams_channel.topic("epsilon")
sub = topic.subscribe()
with sub:
# Listener running; now close and ensure no crash
sub.close()
# After close, receive should raise SubscriptionClosedError
from libs.broadcast_channel.exc import SubscriptionClosedError
with pytest.raises(SubscriptionClosedError):
sub.receive()
def test_no_expire_when_zero_retention(self, fake_redis: FakeStreamsRedis):
channel = StreamsBroadcastChannel(fake_redis, retention_seconds=0)
topic = channel.topic("zeta")
topic.publish(b"payload")
# No expire recorded when retention is disabled
assert fake_redis._expire_calls.get("stream:zeta") is None

View File

@ -0,0 +1,197 @@
import json
import uuid
from collections import defaultdict, deque
import pytest
from core.app.apps.message_generator import MessageGenerator
from models.model import AppMode
from services.app_generate_service import AppGenerateService
# -----------------------------
# Fakes for Redis Pub/Sub flow
# -----------------------------
class _FakePubSub:
def __init__(self, store: dict[str, deque[bytes]]):
self._store = store
self._subs: set[str] = set()
self._closed = False
def subscribe(self, topic: str) -> None:
self._subs.add(topic)
def unsubscribe(self, topic: str) -> None:
self._subs.discard(topic)
def close(self) -> None:
self._closed = True
def get_message(self, ignore_subscribe_messages: bool = True, timeout: int | float | None = 1):
# simulate a non-blocking poll; return first available
if self._closed:
return None
for t in list(self._subs):
q = self._store.get(t)
if q and len(q) > 0:
payload = q.popleft()
return {"type": "message", "channel": t, "data": payload}
# no message
return None
class _FakeRedisClient:
def __init__(self, store: dict[str, deque[bytes]]):
self._store = store
def pubsub(self):
return _FakePubSub(self._store)
def publish(self, topic: str, payload: bytes) -> None:
self._store.setdefault(topic, deque()).append(payload)
# ------------------------------------
# Fakes for Redis Streams (XADD/XREAD)
# ------------------------------------
class _FakeStreams:
def __init__(self) -> None:
# key -> list[(id, {field: value})]
self._data: dict[str, list[tuple[str, dict]]] = defaultdict(list)
self._seq: dict[str, int] = defaultdict(int)
def xadd(self, key: str, fields: dict, *, maxlen: int | None = None) -> str:
# maxlen is accepted for API compatibility with redis-py; ignored in this test double
self._seq[key] += 1
eid = f"{self._seq[key]}-0"
self._data[key].append((eid, fields))
return eid
def expire(self, key: str, seconds: int) -> None:
# no-op for tests
return None
def xread(self, streams: dict, block: int | None = None, count: int | None = None):
assert len(streams) == 1
key, last_id = next(iter(streams.items()))
entries = self._data.get(key, [])
start = 0
if last_id != "0-0":
for i, (eid, _f) in enumerate(entries):
if eid == last_id:
start = i + 1
break
if start >= len(entries):
return []
end = len(entries) if count is None else min(len(entries), start + count)
return [(key, entries[start:end])]
@pytest.fixture
def _patch_get_channel_streams(monkeypatch):
from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel
fake = _FakeStreams()
chan = StreamsBroadcastChannel(fake, retention_seconds=60)
def _get_channel():
return chan
# Patch both the source and the imported alias used by MessageGenerator
monkeypatch.setattr("extensions.ext_redis.get_pubsub_broadcast_channel", lambda: chan)
monkeypatch.setattr("core.app.apps.message_generator.get_pubsub_broadcast_channel", lambda: chan)
# Ensure AppGenerateService sees streams mode
import services.app_generate_service as ags
monkeypatch.setattr(ags.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "streams", raising=False)
@pytest.fixture
def _patch_get_channel_pubsub(monkeypatch):
from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel
store: dict[str, deque[bytes]] = defaultdict(deque)
client = _FakeRedisClient(store)
chan = RedisBroadcastChannel(client)
def _get_channel():
return chan
# Patch both the source and the imported alias used by MessageGenerator
monkeypatch.setattr("extensions.ext_redis.get_pubsub_broadcast_channel", lambda: chan)
monkeypatch.setattr("core.app.apps.message_generator.get_pubsub_broadcast_channel", lambda: chan)
# Ensure AppGenerateService sees pubsub mode
import services.app_generate_service as ags
monkeypatch.setattr(ags.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "pubsub", raising=False)
def _publish_events(app_mode: AppMode, run_id: str, events: list[dict]):
# Publish events to the same topic used by MessageGenerator
topic = MessageGenerator.get_response_topic(app_mode, run_id)
for ev in events:
topic.publish(json.dumps(ev).encode())
@pytest.mark.usefixtures("_patch_get_channel_streams")
def test_streams_full_flow_prepublish_and_replay():
app_mode = AppMode.WORKFLOW
run_id = str(uuid.uuid4())
# Build start_task that publishes two events immediately
events = [{"event": "workflow_started"}, {"event": "workflow_finished"}]
def start_task():
_publish_events(app_mode, run_id, events)
on_subscribe = AppGenerateService._build_streaming_task_on_subscribe(start_task)
# Start retrieving BEFORE subscription is established; in streams mode, we also started immediately
gen = MessageGenerator.retrieve_events(app_mode, run_id, idle_timeout=2.0, on_subscribe=on_subscribe)
received = []
for msg in gen:
if isinstance(msg, str):
# skip ping events
continue
received.append(msg)
if msg.get("event") == "workflow_finished":
break
assert [m.get("event") for m in received] == ["workflow_started", "workflow_finished"]
@pytest.mark.usefixtures("_patch_get_channel_pubsub")
def test_pubsub_full_flow_start_on_subscribe_gated(monkeypatch):
# Speed up any potential timer if it accidentally triggers
monkeypatch.setattr("services.app_generate_service.SSE_TASK_START_FALLBACK_MS", 50)
app_mode = AppMode.WORKFLOW
run_id = str(uuid.uuid4())
published_order: list[str] = []
def start_task():
# When called (on subscribe), publish both events
events = [{"event": "workflow_started"}, {"event": "workflow_finished"}]
_publish_events(app_mode, run_id, events)
published_order.extend([e["event"] for e in events])
on_subscribe = AppGenerateService._build_streaming_task_on_subscribe(start_task)
# Producer not started yet; only when subscribe happens
assert published_order == []
gen = MessageGenerator.retrieve_events(app_mode, run_id, idle_timeout=2.0, on_subscribe=on_subscribe)
received = []
for msg in gen:
if isinstance(msg, str):
continue
received.append(msg)
if msg.get("event") == "workflow_finished":
break
# Verify publish happened and consumer received in order
assert published_order == ["workflow_started", "workflow_finished"]
assert [m.get("event") for m in received] == ["workflow_started", "workflow_finished"]