feat: support configurable redis key prefix (#35139)

This commit is contained in:
Blackoutta 2026-04-14 17:31:41 +08:00 committed by GitHub
parent bd7a9b5fcf
commit 736880e046
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 522 additions and 74 deletions

View File

@ -57,6 +57,9 @@ REDIS_SSL_CERTFILE=
REDIS_SSL_KEYFILE=
# Path to client private key file for SSL authentication
REDIS_DB=0
# Optional global prefix for Redis keys, topics, streams, and Celery Redis transport artifacts.
# Leave empty to preserve current unprefixed behavior.
REDIS_KEY_PREFIX=
# redis Sentinel configuration.
REDIS_USE_SENTINEL=false

View File

@ -32,6 +32,11 @@ class RedisConfig(BaseSettings):
default=0,
)
REDIS_KEY_PREFIX: str = Field(
description="Optional global prefix for Redis keys, topics, and transport artifacts",
default="",
)
REDIS_USE_SSL: bool = Field(
description="Enable SSL/TLS for the Redis connection",
default=False,

View File

@ -9,6 +9,7 @@ from typing_extensions import TypedDict
from configs import dify_config
from dify_app import DifyApp
from extensions.redis_names import normalize_redis_key_prefix
class _CelerySentinelKwargsDict(TypedDict):
@ -16,9 +17,10 @@ class _CelerySentinelKwargsDict(TypedDict):
password: str | None
class CelerySentinelTransportDict(TypedDict):
class CelerySentinelTransportDict(TypedDict, total=False):
master_name: str | None
sentinel_kwargs: _CelerySentinelKwargsDict
global_keyprefix: str
class CelerySSLOptionsDict(TypedDict):
@ -61,15 +63,31 @@ def get_celery_ssl_options() -> CelerySSLOptionsDict | None:
def get_celery_broker_transport_options() -> CelerySentinelTransportDict | dict[str, Any]:
"""Get broker transport options (e.g. Redis Sentinel) for Celery connections."""
transport_options: CelerySentinelTransportDict | dict[str, Any]
if dify_config.CELERY_USE_SENTINEL:
return CelerySentinelTransportDict(
transport_options = CelerySentinelTransportDict(
master_name=dify_config.CELERY_SENTINEL_MASTER_NAME,
sentinel_kwargs=_CelerySentinelKwargsDict(
socket_timeout=dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT,
password=dify_config.CELERY_SENTINEL_PASSWORD,
),
)
return {}
else:
transport_options = {}
global_keyprefix = get_celery_redis_global_keyprefix()
if global_keyprefix:
transport_options["global_keyprefix"] = global_keyprefix
return transport_options
def get_celery_redis_global_keyprefix() -> str | None:
"""Return the Redis transport prefix for Celery when namespace isolation is enabled."""
normalized_prefix = normalize_redis_key_prefix(dify_config.REDIS_KEY_PREFIX)
if not normalized_prefix:
return None
return f"{normalized_prefix}:"
def init_app(app: DifyApp) -> Celery:

View File

@ -3,7 +3,7 @@ import logging
import ssl
from collections.abc import Callable
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Union
from typing import Any, Union, cast
import redis
from redis import RedisError
@ -18,17 +18,26 @@ from typing_extensions import TypedDict
from configs import dify_config
from dify_app import DifyApp
from extensions.redis_names import (
normalize_redis_key_prefix,
serialize_redis_name,
serialize_redis_name_arg,
serialize_redis_name_args,
)
from libs.broadcast_channel.channel import BroadcastChannel as BroadcastChannelProtocol
from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel
from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel
from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel
if TYPE_CHECKING:
from redis.lock import Lock
logger = logging.getLogger(__name__)
_normalize_redis_key_prefix = normalize_redis_key_prefix
_serialize_redis_name = serialize_redis_name
_serialize_redis_name_arg = serialize_redis_name_arg
_serialize_redis_name_args = serialize_redis_name_args
class RedisClientWrapper:
"""
A wrapper class for the Redis client that addresses the issue where the global
@ -59,68 +68,148 @@ class RedisClientWrapper:
if self._client is None:
self._client = client
if TYPE_CHECKING:
# Type hints for IDE support and static analysis
# These are not executed at runtime but provide type information
def get(self, name: str | bytes) -> Any: ...
def set(
self,
name: str | bytes,
value: Any,
ex: int | None = None,
px: int | None = None,
nx: bool = False,
xx: bool = False,
keepttl: bool = False,
get: bool = False,
exat: int | None = None,
pxat: int | None = None,
) -> Any: ...
def setex(self, name: str | bytes, time: int | timedelta, value: Any) -> Any: ...
def setnx(self, name: str | bytes, value: Any) -> Any: ...
def delete(self, *names: str | bytes) -> Any: ...
def incr(self, name: str | bytes, amount: int = 1) -> Any: ...
def expire(
self,
name: str | bytes,
time: int | timedelta,
nx: bool = False,
xx: bool = False,
gt: bool = False,
lt: bool = False,
) -> Any: ...
def lock(
self,
name: str,
timeout: float | None = None,
sleep: float = 0.1,
blocking: bool = True,
blocking_timeout: float | None = None,
thread_local: bool = True,
) -> Lock: ...
def zadd(
self,
name: str | bytes,
mapping: dict[str | bytes | int | float, float | int | str | bytes],
nx: bool = False,
xx: bool = False,
ch: bool = False,
incr: bool = False,
gt: bool = False,
lt: bool = False,
) -> Any: ...
def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any: ...
def zcard(self, name: str | bytes) -> Any: ...
def getdel(self, name: str | bytes) -> Any: ...
def pubsub(self) -> PubSub: ...
def pipeline(self, transaction: bool = True, shard_hint: str | None = None) -> Any: ...
def __getattr__(self, item: str) -> Any:
def _require_client(self) -> redis.Redis | RedisCluster:
if self._client is None:
raise RuntimeError("Redis client is not initialized. Call init_app first.")
return getattr(self._client, item)
return self._client
def _get_prefix(self) -> str:
return dify_config.REDIS_KEY_PREFIX
def get(self, name: str | bytes) -> Any:
return self._require_client().get(_serialize_redis_name_arg(name, self._get_prefix()))
def set(
self,
name: str | bytes,
value: Any,
ex: int | None = None,
px: int | None = None,
nx: bool = False,
xx: bool = False,
keepttl: bool = False,
get: bool = False,
exat: int | None = None,
pxat: int | None = None,
) -> Any:
return self._require_client().set(
_serialize_redis_name_arg(name, self._get_prefix()),
value,
ex=ex,
px=px,
nx=nx,
xx=xx,
keepttl=keepttl,
get=get,
exat=exat,
pxat=pxat,
)
def setex(self, name: str | bytes, time: int | timedelta, value: Any) -> Any:
return self._require_client().setex(_serialize_redis_name_arg(name, self._get_prefix()), time, value)
def setnx(self, name: str | bytes, value: Any) -> Any:
return self._require_client().setnx(_serialize_redis_name_arg(name, self._get_prefix()), value)
def delete(self, *names: str | bytes) -> Any:
return self._require_client().delete(*_serialize_redis_name_args(names, self._get_prefix()))
def incr(self, name: str | bytes, amount: int = 1) -> Any:
return self._require_client().incr(_serialize_redis_name_arg(name, self._get_prefix()), amount)
def expire(
self,
name: str | bytes,
time: int | timedelta,
nx: bool = False,
xx: bool = False,
gt: bool = False,
lt: bool = False,
) -> Any:
return self._require_client().expire(
_serialize_redis_name_arg(name, self._get_prefix()),
time,
nx=nx,
xx=xx,
gt=gt,
lt=lt,
)
def exists(self, *names: str | bytes) -> Any:
return self._require_client().exists(*_serialize_redis_name_args(names, self._get_prefix()))
def ttl(self, name: str | bytes) -> Any:
return self._require_client().ttl(_serialize_redis_name_arg(name, self._get_prefix()))
def getdel(self, name: str | bytes) -> Any:
return self._require_client().getdel(_serialize_redis_name_arg(name, self._get_prefix()))
def lock(
self,
name: str,
timeout: float | None = None,
sleep: float = 0.1,
blocking: bool = True,
blocking_timeout: float | None = None,
thread_local: bool = True,
) -> Any:
return self._require_client().lock(
_serialize_redis_name(name, self._get_prefix()),
timeout=timeout,
sleep=sleep,
blocking=blocking,
blocking_timeout=blocking_timeout,
thread_local=thread_local,
)
def hset(self, name: str | bytes, *args: Any, **kwargs: Any) -> Any:
return self._require_client().hset(_serialize_redis_name_arg(name, self._get_prefix()), *args, **kwargs)
def hgetall(self, name: str | bytes) -> Any:
return self._require_client().hgetall(_serialize_redis_name_arg(name, self._get_prefix()))
def hdel(self, name: str | bytes, *keys: str | bytes) -> Any:
return self._require_client().hdel(_serialize_redis_name_arg(name, self._get_prefix()), *keys)
def hlen(self, name: str | bytes) -> Any:
return self._require_client().hlen(_serialize_redis_name_arg(name, self._get_prefix()))
def zadd(
self,
name: str | bytes,
mapping: dict[str | bytes | int | float, float | int | str | bytes],
nx: bool = False,
xx: bool = False,
ch: bool = False,
incr: bool = False,
gt: bool = False,
lt: bool = False,
) -> Any:
return self._require_client().zadd(
_serialize_redis_name_arg(name, self._get_prefix()),
cast(Any, mapping),
nx=nx,
xx=xx,
ch=ch,
incr=incr,
gt=gt,
lt=lt,
)
def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any:
return self._require_client().zremrangebyscore(_serialize_redis_name_arg(name, self._get_prefix()), min, max)
def zcard(self, name: str | bytes) -> Any:
return self._require_client().zcard(_serialize_redis_name_arg(name, self._get_prefix()))
def pubsub(self) -> PubSub:
return self._require_client().pubsub()
def pipeline(self, transaction: bool = True, shard_hint: str | None = None) -> Any:
return self._require_client().pipeline(transaction=transaction, shard_hint=shard_hint)
def __getattr__(self, item: str) -> Any:
return getattr(self._require_client(), item)
redis_client: RedisClientWrapper = RedisClientWrapper()

View File

@ -0,0 +1,32 @@
from configs import dify_config
def normalize_redis_key_prefix(prefix: str | None) -> str:
"""Normalize the configured Redis key prefix for consistent runtime use."""
if prefix is None:
return ""
return prefix.strip()
def get_redis_key_prefix() -> str:
"""Read and normalize the current Redis key prefix from config."""
return normalize_redis_key_prefix(dify_config.REDIS_KEY_PREFIX)
def serialize_redis_name(name: str, prefix: str | None = None) -> str:
"""Convert a logical Redis name into the physical name used in Redis."""
normalized_prefix = get_redis_key_prefix() if prefix is None else normalize_redis_key_prefix(prefix)
if not normalized_prefix:
return name
return f"{normalized_prefix}:{name}"
def serialize_redis_name_arg(name: str | bytes, prefix: str | None = None) -> str | bytes:
"""Prefix string Redis names while preserving bytes inputs unchanged."""
if isinstance(name, bytes):
return name
return serialize_redis_name(name, prefix)
def serialize_redis_name_args(names: tuple[str | bytes, ...], prefix: str | None = None) -> tuple[str | bytes, ...]:
return tuple(serialize_redis_name_arg(name, prefix) for name in names)

View File

@ -2,6 +2,7 @@ from __future__ import annotations
from typing import Any
from extensions.redis_names import serialize_redis_name
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
from redis import Redis, RedisCluster
@ -32,12 +33,13 @@ class Topic:
def __init__(self, redis_client: Redis | RedisCluster, topic: str):
self._client = redis_client
self._topic = topic
self._redis_topic = serialize_redis_name(topic)
def as_producer(self) -> Producer:
return self
def publish(self, payload: bytes) -> None:
self._client.publish(self._topic, payload)
self._client.publish(self._redis_topic, payload)
def as_subscriber(self) -> Subscriber:
return self
@ -46,7 +48,7 @@ class Topic:
return _RedisSubscription(
client=self._client,
pubsub=self._client.pubsub(),
topic=self._topic,
topic=self._redis_topic,
)

View File

@ -2,6 +2,7 @@ from __future__ import annotations
from typing import Any
from extensions.redis_names import serialize_redis_name
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
from redis import Redis, RedisCluster
@ -30,12 +31,13 @@ class ShardedTopic:
def __init__(self, redis_client: Redis | RedisCluster, topic: str):
self._client = redis_client
self._topic = topic
self._redis_topic = serialize_redis_name(topic)
def as_producer(self) -> Producer:
return self
def publish(self, payload: bytes) -> None:
self._client.spublish(self._topic, payload) # type: ignore[attr-defined,union-attr]
self._client.spublish(self._redis_topic, payload) # type: ignore[attr-defined,union-attr]
def as_subscriber(self) -> Subscriber:
return self
@ -44,7 +46,7 @@ class ShardedTopic:
return _RedisShardedSubscription(
client=self._client,
pubsub=self._client.pubsub(),
topic=self._topic,
topic=self._redis_topic,
)

View File

@ -6,6 +6,7 @@ import threading
from collections.abc import Iterator
from typing import Self
from extensions.redis_names import serialize_redis_name
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
from libs.broadcast_channel.exc import SubscriptionClosedError
from redis import Redis, RedisCluster
@ -35,7 +36,7 @@ class StreamsTopic:
def __init__(self, redis_client: Redis | RedisCluster, topic: str, *, retention_seconds: int = 600):
self._client = redis_client
self._topic = topic
self._key = f"stream:{topic}"
self._key = serialize_redis_name(f"stream:{topic}")
self._retention_seconds = retention_seconds
self.max_length = 5000

View File

@ -103,7 +103,10 @@ class DbMigrationAutoRenewLock:
timeout=self._ttl_seconds,
thread_local=False,
)
acquired = bool(self._lock.acquire(*args, **kwargs))
lock = self._lock
if lock is None:
raise RuntimeError("Redis lock initialization failed.")
acquired = bool(lock.acquire(*args, **kwargs))
self._acquired = acquired
if acquired:
self._start_heartbeat()

View File

@ -33,6 +33,7 @@ REDIS_USERNAME=
REDIS_PASSWORD=difyai123456
REDIS_USE_SSL=false
REDIS_DB=0
REDIS_KEY_PREFIX=
# PostgreSQL database configuration
DB_USERNAME=postgres

View File

@ -236,6 +236,41 @@ def test_pubsub_redis_url_required_when_default_unavailable(monkeypatch: pytest.
_ = DifyConfig().normalized_pubsub_redis_url
def test_dify_config_exposes_redis_key_prefix_default(monkeypatch: pytest.MonkeyPatch):
os.environ.clear()
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
monkeypatch.setenv("DB_TYPE", "postgresql")
monkeypatch.setenv("DB_USERNAME", "postgres")
monkeypatch.setenv("DB_PASSWORD", "postgres")
monkeypatch.setenv("DB_HOST", "localhost")
monkeypatch.setenv("DB_PORT", "5432")
monkeypatch.setenv("DB_DATABASE", "dify")
config = DifyConfig(_env_file=None)
assert config.REDIS_KEY_PREFIX == ""
def test_dify_config_reads_redis_key_prefix_from_env(monkeypatch: pytest.MonkeyPatch):
os.environ.clear()
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
monkeypatch.setenv("DB_TYPE", "postgresql")
monkeypatch.setenv("DB_USERNAME", "postgres")
monkeypatch.setenv("DB_PASSWORD", "postgres")
monkeypatch.setenv("DB_HOST", "localhost")
monkeypatch.setenv("DB_PORT", "5432")
monkeypatch.setenv("DB_DATABASE", "dify")
monkeypatch.setenv("REDIS_KEY_PREFIX", "enterprise-a")
config = DifyConfig(_env_file=None)
assert config.REDIS_KEY_PREFIX == "enterprise-a"
@pytest.mark.parametrize(
("broker_url", "expected_host", "expected_port", "expected_username", "expected_password", "expected_db"),
[

View File

@ -7,6 +7,47 @@ from unittest.mock import MagicMock, patch
class TestCelerySSLConfiguration:
"""Test suite for Celery SSL configuration."""
def test_get_celery_broker_transport_options_includes_global_keyprefix_for_redis(self):
mock_config = MagicMock()
mock_config.CELERY_USE_SENTINEL = False
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
with patch("extensions.ext_celery.dify_config", mock_config):
from extensions.ext_celery import get_celery_broker_transport_options
result = get_celery_broker_transport_options()
assert result["global_keyprefix"] == "enterprise-a:"
def test_get_celery_broker_transport_options_omits_global_keyprefix_when_prefix_empty(self):
mock_config = MagicMock()
mock_config.CELERY_USE_SENTINEL = False
mock_config.REDIS_KEY_PREFIX = " "
with patch("extensions.ext_celery.dify_config", mock_config):
from extensions.ext_celery import get_celery_broker_transport_options
result = get_celery_broker_transport_options()
assert "global_keyprefix" not in result
def test_get_celery_broker_transport_options_keeps_sentinel_and_adds_global_keyprefix(self):
mock_config = MagicMock()
mock_config.CELERY_USE_SENTINEL = True
mock_config.CELERY_SENTINEL_MASTER_NAME = "mymaster"
mock_config.CELERY_SENTINEL_SOCKET_TIMEOUT = 0.1
mock_config.CELERY_SENTINEL_PASSWORD = "secret"
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
with patch("extensions.ext_celery.dify_config", mock_config):
from extensions.ext_celery import get_celery_broker_transport_options
result = get_celery_broker_transport_options()
assert result["master_name"] == "mymaster"
assert result["sentinel_kwargs"]["password"] == "secret"
assert result["global_keyprefix"] == "enterprise-a:"
def test_get_celery_ssl_options_when_ssl_disabled(self):
"""Test SSL options when BROKER_USE_SSL is False."""
from configs import DifyConfig
@ -151,3 +192,49 @@ class TestCelerySSLConfiguration:
# Check that SSL is also applied to Redis backend
assert "redis_backend_use_ssl" in celery_app.conf
assert celery_app.conf["redis_backend_use_ssl"] is not None
def test_celery_init_applies_global_keyprefix_to_broker_and_backend_transport(self):
mock_config = MagicMock()
mock_config.BROKER_USE_SSL = False
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
mock_config.HUMAN_INPUT_TIMEOUT_TASK_INTERVAL = 1
mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0"
mock_config.CELERY_BACKEND = "redis"
mock_config.CELERY_RESULT_BACKEND = "redis://localhost:6379/0"
mock_config.CELERY_USE_SENTINEL = False
mock_config.LOG_FORMAT = "%(message)s"
mock_config.LOG_TZ = "UTC"
mock_config.LOG_FILE = None
mock_config.CELERY_TASK_ANNOTATIONS = {}
mock_config.CELERY_BEAT_SCHEDULER_TIME = 1
mock_config.ENABLE_CLEAN_EMBEDDING_CACHE_TASK = False
mock_config.ENABLE_CLEAN_UNUSED_DATASETS_TASK = False
mock_config.ENABLE_CREATE_TIDB_SERVERLESS_TASK = False
mock_config.ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK = False
mock_config.ENABLE_CLEAN_MESSAGES = False
mock_config.ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK = False
mock_config.ENABLE_DATASETS_QUEUE_MONITOR = False
mock_config.ENABLE_HUMAN_INPUT_TIMEOUT_TASK = False
mock_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK = False
mock_config.MARKETPLACE_ENABLED = False
mock_config.WORKFLOW_LOG_CLEANUP_ENABLED = False
mock_config.ENABLE_WORKFLOW_RUN_CLEANUP_TASK = False
mock_config.ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK = False
mock_config.WORKFLOW_SCHEDULE_POLLER_INTERVAL = 1
mock_config.ENABLE_TRIGGER_PROVIDER_REFRESH_TASK = False
mock_config.TRIGGER_PROVIDER_REFRESH_INTERVAL = 15
mock_config.ENABLE_API_TOKEN_LAST_USED_UPDATE_TASK = False
mock_config.API_TOKEN_LAST_USED_UPDATE_INTERVAL = 30
mock_config.ENTERPRISE_ENABLED = False
mock_config.ENTERPRISE_TELEMETRY_ENABLED = False
with patch("extensions.ext_celery.dify_config", mock_config):
from dify_app import DifyApp
from extensions.ext_celery import init_app
app = DifyApp(__name__)
celery_app = init_app(app)
assert celery_app.conf["broker_transport_options"]["global_keyprefix"] == "enterprise-a:"
assert celery_app.conf["result_backend_transport_options"]["global_keyprefix"] == "enterprise-a:"

View File

@ -6,6 +6,7 @@ from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastCh
def test_get_pubsub_broadcast_channel_defaults_to_pubsub(monkeypatch):
monkeypatch.setattr(dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "pubsub")
monkeypatch.setattr(ext_redis, "_pubsub_redis_client", object())
channel = ext_redis.get_pubsub_broadcast_channel()
@ -14,6 +15,7 @@ def test_get_pubsub_broadcast_channel_defaults_to_pubsub(monkeypatch):
def test_get_pubsub_broadcast_channel_sharded(monkeypatch):
monkeypatch.setattr(dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "sharded")
monkeypatch.setattr(ext_redis, "_pubsub_redis_client", object())
channel = ext_redis.get_pubsub_broadcast_channel()

View File

@ -1,12 +1,15 @@
from unittest.mock import patch
from unittest.mock import MagicMock, patch
from redis import RedisError
from redis.retry import Retry
from extensions.ext_redis import (
RedisClientWrapper,
_get_base_redis_params,
_get_cluster_connection_health_params,
_get_connection_health_params,
_normalize_redis_key_prefix,
_serialize_redis_name,
redis_fallback,
)
@ -123,3 +126,99 @@ class TestRedisFallback:
assert test_func.__name__ == "test_func"
assert test_func.__doc__ == "Test function docstring"
class TestRedisKeyPrefixHelpers:
def test_normalize_redis_key_prefix_trims_whitespace(self):
assert _normalize_redis_key_prefix(" enterprise-a ") == "enterprise-a"
def test_normalize_redis_key_prefix_treats_whitespace_only_as_empty(self):
assert _normalize_redis_key_prefix(" ") == ""
def test_serialize_redis_name_returns_original_when_prefix_empty(self):
assert _serialize_redis_name("model_lb_index:test", "") == "model_lb_index:test"
def test_serialize_redis_name_adds_single_colon_separator(self):
assert _serialize_redis_name("model_lb_index:test", "enterprise-a") == "enterprise-a:model_lb_index:test"
class TestRedisClientWrapperKeyPrefix:
def test_wrapper_get_prefixes_string_keys(self):
mock_client = MagicMock()
wrapper = RedisClientWrapper()
wrapper.initialize(mock_client)
with patch("extensions.ext_redis.dify_config") as mock_config:
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
wrapper.get("oauth_state:abc")
mock_client.get.assert_called_once_with("enterprise-a:oauth_state:abc")
def test_wrapper_delete_prefixes_multiple_keys(self):
mock_client = MagicMock()
wrapper = RedisClientWrapper()
wrapper.initialize(mock_client)
with patch("extensions.ext_redis.dify_config") as mock_config:
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
wrapper.delete("key:a", "key:b")
mock_client.delete.assert_called_once_with("enterprise-a:key:a", "enterprise-a:key:b")
def test_wrapper_lock_prefixes_lock_name(self):
mock_client = MagicMock()
wrapper = RedisClientWrapper()
wrapper.initialize(mock_client)
with patch("extensions.ext_redis.dify_config") as mock_config:
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
wrapper.lock("resource-lock", timeout=10)
mock_client.lock.assert_called_once()
args, kwargs = mock_client.lock.call_args
assert args == ("enterprise-a:resource-lock",)
assert kwargs["timeout"] == 10
def test_wrapper_hash_operations_prefix_key_name(self):
mock_client = MagicMock()
wrapper = RedisClientWrapper()
wrapper.initialize(mock_client)
with patch("extensions.ext_redis.dify_config") as mock_config:
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
wrapper.hset("hash:key", "field", "value")
wrapper.hgetall("hash:key")
mock_client.hset.assert_called_once_with("enterprise-a:hash:key", "field", "value")
mock_client.hgetall.assert_called_once_with("enterprise-a:hash:key")
def test_wrapper_zadd_prefixes_sorted_set_name(self):
mock_client = MagicMock()
wrapper = RedisClientWrapper()
wrapper.initialize(mock_client)
with patch("extensions.ext_redis.dify_config") as mock_config:
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
wrapper.zadd("zset:key", {"member": 1})
mock_client.zadd.assert_called_once()
args, kwargs = mock_client.zadd.call_args
assert args == ("enterprise-a:zset:key", {"member": 1})
assert kwargs["nx"] is False
def test_wrapper_preserves_keys_when_prefix_is_empty(self):
mock_client = MagicMock()
wrapper = RedisClientWrapper()
wrapper.initialize(mock_client)
with patch("extensions.ext_redis.dify_config") as mock_config:
mock_config.REDIS_KEY_PREFIX = " "
wrapper.get("plain:key")
mock_client.get.assert_called_once_with("plain:key")

View File

@ -139,6 +139,28 @@ class TestTopic:
mock_redis_client.publish.assert_called_once_with("test-topic", payload)
def test_publish_prefixes_regular_topic(self, mock_redis_client: MagicMock):
with patch("extensions.redis_names.dify_config") as mock_config:
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
topic = Topic(mock_redis_client, "test-topic")
topic.publish(b"test message")
mock_redis_client.publish.assert_called_once_with("enterprise-a:test-topic", b"test message")
def test_subscribe_prefixes_regular_topic(self, mock_redis_client: MagicMock):
with patch("extensions.redis_names.dify_config") as mock_config:
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
topic = Topic(mock_redis_client, "test-topic")
subscription = topic.subscribe()
try:
subscription._start_if_needed()
finally:
subscription.close()
mock_redis_client.pubsub.return_value.subscribe.assert_called_once_with("enterprise-a:test-topic")
class TestShardedTopic:
"""Test cases for the ShardedTopic class."""
@ -176,6 +198,15 @@ class TestShardedTopic:
mock_redis_client.spublish.assert_called_once_with("test-sharded-topic", payload)
def test_publish_prefixes_sharded_topic(self, mock_redis_client: MagicMock):
with patch("extensions.redis_names.dify_config") as mock_config:
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
sharded_topic = ShardedTopic(mock_redis_client, "test-sharded-topic")
sharded_topic.publish(b"test sharded message")
mock_redis_client.spublish.assert_called_once_with("enterprise-a:test-sharded-topic", b"test sharded message")
def test_subscribe_returns_sharded_subscription(self, sharded_topic: ShardedTopic, mock_redis_client: MagicMock):
"""Test that subscribe() returns a _RedisShardedSubscription instance."""
subscription = sharded_topic.subscribe()
@ -185,6 +216,19 @@ class TestShardedTopic:
assert subscription._pubsub is mock_redis_client.pubsub.return_value
assert subscription._topic == "test-sharded-topic"
def test_subscribe_prefixes_sharded_topic(self, mock_redis_client: MagicMock):
with patch("extensions.redis_names.dify_config") as mock_config:
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
sharded_topic = ShardedTopic(mock_redis_client, "test-sharded-topic")
subscription = sharded_topic.subscribe()
try:
subscription._start_if_needed()
finally:
subscription.close()
mock_redis_client.pubsub.return_value.ssubscribe.assert_called_once_with("enterprise-a:test-sharded-topic")
@dataclasses.dataclass(frozen=True)
class SubscriptionTestCase:

View File

@ -2,6 +2,7 @@ import threading
import time
from dataclasses import dataclass
from typing import cast
from unittest.mock import patch
import pytest
@ -150,6 +151,25 @@ class TestStreamsBroadcastChannel:
# Expire called after publish
assert fake_redis._expire_calls.get("stream:beta", 0) >= 1
def test_topic_uses_prefixed_stream_key(self, fake_redis: FakeStreamsRedis):
with patch("extensions.redis_names.dify_config") as mock_config:
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
topic = StreamsBroadcastChannel(fake_redis, retention_seconds=60).topic("alpha")
assert topic._topic == "alpha"
assert topic._key == "enterprise-a:stream:alpha"
def test_publish_uses_prefixed_stream_key(self, fake_redis: FakeStreamsRedis):
with patch("extensions.redis_names.dify_config") as mock_config:
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
topic = StreamsBroadcastChannel(fake_redis, retention_seconds=60).topic("beta")
topic.publish(b"hello")
assert fake_redis._store["enterprise-a:stream:beta"][0][1] == {b"data": b"hello"}
assert fake_redis._expire_calls.get("enterprise-a:stream:beta", 0) >= 1
def test_topic_exposes_self_as_producer_and_subscriber(self, streams_channel: StreamsBroadcastChannel):
topic = streams_channel.topic("producer-subscriber")

View File

@ -351,6 +351,9 @@ REDIS_SSL_CERTFILE=
REDIS_SSL_KEYFILE=
# Path to client private key file for SSL authentication
REDIS_DB=0
# Optional global prefix for Redis keys, topics, streams, and Celery Redis transport artifacts.
# Leave empty to preserve current unprefixed behavior.
REDIS_KEY_PREFIX=
# Optional: limit total Redis connections used by API/Worker (unset for default)
# Align with API's REDIS_MAX_CONNECTIONS in configs
REDIS_MAX_CONNECTIONS=

View File

@ -88,6 +88,7 @@ The `.env.example` file provided in the Docker setup is extensive and covers a w
1. **Redis Configuration**:
- `REDIS_HOST`, `REDIS_PORT`, `REDIS_PASSWORD`: Redis server connection settings.
- `REDIS_KEY_PREFIX`: Optional global namespace prefix for Redis keys, topics, streams, and Celery Redis transport artifacts.
1. **Celery Configuration**:

View File

@ -90,6 +90,7 @@ x-shared-env: &shared-api-worker-env
REDIS_SSL_CERTFILE: ${REDIS_SSL_CERTFILE:-}
REDIS_SSL_KEYFILE: ${REDIS_SSL_KEYFILE:-}
REDIS_DB: ${REDIS_DB:-0}
REDIS_KEY_PREFIX: ${REDIS_KEY_PREFIX:-}
REDIS_MAX_CONNECTIONS: ${REDIS_MAX_CONNECTIONS:-}
REDIS_USE_SENTINEL: ${REDIS_USE_SENTINEL:-false}
REDIS_SENTINELS: ${REDIS_SENTINELS:-}