From 736880e046581ca5b0b65441cee1c1e96cc5caac Mon Sep 17 00:00:00 2001 From: Blackoutta <37723456+Blackoutta@users.noreply.github.com> Date: Tue, 14 Apr 2026 17:31:41 +0800 Subject: [PATCH] feat: support configurable redis key prefix (#35139) --- api/.env.example | 3 + api/configs/middleware/cache/redis_config.py | 5 + api/extensions/ext_celery.py | 24 +- api/extensions/ext_redis.py | 217 ++++++++++++------ api/extensions/redis_names.py | 32 +++ api/libs/broadcast_channel/redis/channel.py | 6 +- .../redis/sharded_channel.py | 6 +- .../redis/streams_channel.py | 3 +- api/libs/db_migration_lock.py | 5 +- api/tests/integration_tests/.env.example | 1 + .../unit_tests/configs/test_dify_config.py | 35 +++ .../unit_tests/extensions/test_celery_ssl.py | 87 +++++++ .../extensions/test_pubsub_channel.py | 2 + api/tests/unit_tests/extensions/test_redis.py | 101 +++++++- .../redis/test_channel_unit_tests.py | 44 ++++ .../redis/test_streams_channel_unit_tests.py | 20 ++ docker/.env.example | 3 + docker/README.md | 1 + docker/docker-compose.yaml | 1 + 19 files changed, 522 insertions(+), 74 deletions(-) create mode 100644 api/extensions/redis_names.py diff --git a/api/.env.example b/api/.env.example index a04a18944a..beb820e797 100644 --- a/api/.env.example +++ b/api/.env.example @@ -57,6 +57,9 @@ REDIS_SSL_CERTFILE= REDIS_SSL_KEYFILE= # Path to client private key file for SSL authentication REDIS_DB=0 +# Optional global prefix for Redis keys, topics, streams, and Celery Redis transport artifacts. +# Leave empty to preserve current unprefixed behavior. +REDIS_KEY_PREFIX= # redis Sentinel configuration. REDIS_USE_SENTINEL=false diff --git a/api/configs/middleware/cache/redis_config.py b/api/configs/middleware/cache/redis_config.py index b49275758a..2def0a0d4e 100644 --- a/api/configs/middleware/cache/redis_config.py +++ b/api/configs/middleware/cache/redis_config.py @@ -32,6 +32,11 @@ class RedisConfig(BaseSettings): default=0, ) + REDIS_KEY_PREFIX: str = Field( + description="Optional global prefix for Redis keys, topics, and transport artifacts", + default="", + ) + REDIS_USE_SSL: bool = Field( description="Enable SSL/TLS for the Redis connection", default=False, diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 86b0550187..340f514fcc 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -9,6 +9,7 @@ from typing_extensions import TypedDict from configs import dify_config from dify_app import DifyApp +from extensions.redis_names import normalize_redis_key_prefix class _CelerySentinelKwargsDict(TypedDict): @@ -16,9 +17,10 @@ class _CelerySentinelKwargsDict(TypedDict): password: str | None -class CelerySentinelTransportDict(TypedDict): +class CelerySentinelTransportDict(TypedDict, total=False): master_name: str | None sentinel_kwargs: _CelerySentinelKwargsDict + global_keyprefix: str class CelerySSLOptionsDict(TypedDict): @@ -61,15 +63,31 @@ def get_celery_ssl_options() -> CelerySSLOptionsDict | None: def get_celery_broker_transport_options() -> CelerySentinelTransportDict | dict[str, Any]: """Get broker transport options (e.g. Redis Sentinel) for Celery connections.""" + transport_options: CelerySentinelTransportDict | dict[str, Any] if dify_config.CELERY_USE_SENTINEL: - return CelerySentinelTransportDict( + transport_options = CelerySentinelTransportDict( master_name=dify_config.CELERY_SENTINEL_MASTER_NAME, sentinel_kwargs=_CelerySentinelKwargsDict( socket_timeout=dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT, password=dify_config.CELERY_SENTINEL_PASSWORD, ), ) - return {} + else: + transport_options = {} + + global_keyprefix = get_celery_redis_global_keyprefix() + if global_keyprefix: + transport_options["global_keyprefix"] = global_keyprefix + + return transport_options + + +def get_celery_redis_global_keyprefix() -> str | None: + """Return the Redis transport prefix for Celery when namespace isolation is enabled.""" + normalized_prefix = normalize_redis_key_prefix(dify_config.REDIS_KEY_PREFIX) + if not normalized_prefix: + return None + return f"{normalized_prefix}:" def init_app(app: DifyApp) -> Celery: diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index 20f05b8b9e..9f7f73765e 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -3,7 +3,7 @@ import logging import ssl from collections.abc import Callable from datetime import timedelta -from typing import TYPE_CHECKING, Any, Union +from typing import Any, Union, cast import redis from redis import RedisError @@ -18,17 +18,26 @@ from typing_extensions import TypedDict from configs import dify_config from dify_app import DifyApp +from extensions.redis_names import ( + normalize_redis_key_prefix, + serialize_redis_name, + serialize_redis_name_arg, + serialize_redis_name_args, +) from libs.broadcast_channel.channel import BroadcastChannel as BroadcastChannelProtocol from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel -if TYPE_CHECKING: - from redis.lock import Lock - logger = logging.getLogger(__name__) +_normalize_redis_key_prefix = normalize_redis_key_prefix +_serialize_redis_name = serialize_redis_name +_serialize_redis_name_arg = serialize_redis_name_arg +_serialize_redis_name_args = serialize_redis_name_args + + class RedisClientWrapper: """ A wrapper class for the Redis client that addresses the issue where the global @@ -59,68 +68,148 @@ class RedisClientWrapper: if self._client is None: self._client = client - if TYPE_CHECKING: - # Type hints for IDE support and static analysis - # These are not executed at runtime but provide type information - def get(self, name: str | bytes) -> Any: ... - - def set( - self, - name: str | bytes, - value: Any, - ex: int | None = None, - px: int | None = None, - nx: bool = False, - xx: bool = False, - keepttl: bool = False, - get: bool = False, - exat: int | None = None, - pxat: int | None = None, - ) -> Any: ... - - def setex(self, name: str | bytes, time: int | timedelta, value: Any) -> Any: ... - def setnx(self, name: str | bytes, value: Any) -> Any: ... - def delete(self, *names: str | bytes) -> Any: ... - def incr(self, name: str | bytes, amount: int = 1) -> Any: ... - def expire( - self, - name: str | bytes, - time: int | timedelta, - nx: bool = False, - xx: bool = False, - gt: bool = False, - lt: bool = False, - ) -> Any: ... - def lock( - self, - name: str, - timeout: float | None = None, - sleep: float = 0.1, - blocking: bool = True, - blocking_timeout: float | None = None, - thread_local: bool = True, - ) -> Lock: ... - def zadd( - self, - name: str | bytes, - mapping: dict[str | bytes | int | float, float | int | str | bytes], - nx: bool = False, - xx: bool = False, - ch: bool = False, - incr: bool = False, - gt: bool = False, - lt: bool = False, - ) -> Any: ... - def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any: ... - def zcard(self, name: str | bytes) -> Any: ... - def getdel(self, name: str | bytes) -> Any: ... - def pubsub(self) -> PubSub: ... - def pipeline(self, transaction: bool = True, shard_hint: str | None = None) -> Any: ... - - def __getattr__(self, item: str) -> Any: + def _require_client(self) -> redis.Redis | RedisCluster: if self._client is None: raise RuntimeError("Redis client is not initialized. Call init_app first.") - return getattr(self._client, item) + return self._client + + def _get_prefix(self) -> str: + return dify_config.REDIS_KEY_PREFIX + + def get(self, name: str | bytes) -> Any: + return self._require_client().get(_serialize_redis_name_arg(name, self._get_prefix())) + + def set( + self, + name: str | bytes, + value: Any, + ex: int | None = None, + px: int | None = None, + nx: bool = False, + xx: bool = False, + keepttl: bool = False, + get: bool = False, + exat: int | None = None, + pxat: int | None = None, + ) -> Any: + return self._require_client().set( + _serialize_redis_name_arg(name, self._get_prefix()), + value, + ex=ex, + px=px, + nx=nx, + xx=xx, + keepttl=keepttl, + get=get, + exat=exat, + pxat=pxat, + ) + + def setex(self, name: str | bytes, time: int | timedelta, value: Any) -> Any: + return self._require_client().setex(_serialize_redis_name_arg(name, self._get_prefix()), time, value) + + def setnx(self, name: str | bytes, value: Any) -> Any: + return self._require_client().setnx(_serialize_redis_name_arg(name, self._get_prefix()), value) + + def delete(self, *names: str | bytes) -> Any: + return self._require_client().delete(*_serialize_redis_name_args(names, self._get_prefix())) + + def incr(self, name: str | bytes, amount: int = 1) -> Any: + return self._require_client().incr(_serialize_redis_name_arg(name, self._get_prefix()), amount) + + def expire( + self, + name: str | bytes, + time: int | timedelta, + nx: bool = False, + xx: bool = False, + gt: bool = False, + lt: bool = False, + ) -> Any: + return self._require_client().expire( + _serialize_redis_name_arg(name, self._get_prefix()), + time, + nx=nx, + xx=xx, + gt=gt, + lt=lt, + ) + + def exists(self, *names: str | bytes) -> Any: + return self._require_client().exists(*_serialize_redis_name_args(names, self._get_prefix())) + + def ttl(self, name: str | bytes) -> Any: + return self._require_client().ttl(_serialize_redis_name_arg(name, self._get_prefix())) + + def getdel(self, name: str | bytes) -> Any: + return self._require_client().getdel(_serialize_redis_name_arg(name, self._get_prefix())) + + def lock( + self, + name: str, + timeout: float | None = None, + sleep: float = 0.1, + blocking: bool = True, + blocking_timeout: float | None = None, + thread_local: bool = True, + ) -> Any: + return self._require_client().lock( + _serialize_redis_name(name, self._get_prefix()), + timeout=timeout, + sleep=sleep, + blocking=blocking, + blocking_timeout=blocking_timeout, + thread_local=thread_local, + ) + + def hset(self, name: str | bytes, *args: Any, **kwargs: Any) -> Any: + return self._require_client().hset(_serialize_redis_name_arg(name, self._get_prefix()), *args, **kwargs) + + def hgetall(self, name: str | bytes) -> Any: + return self._require_client().hgetall(_serialize_redis_name_arg(name, self._get_prefix())) + + def hdel(self, name: str | bytes, *keys: str | bytes) -> Any: + return self._require_client().hdel(_serialize_redis_name_arg(name, self._get_prefix()), *keys) + + def hlen(self, name: str | bytes) -> Any: + return self._require_client().hlen(_serialize_redis_name_arg(name, self._get_prefix())) + + def zadd( + self, + name: str | bytes, + mapping: dict[str | bytes | int | float, float | int | str | bytes], + nx: bool = False, + xx: bool = False, + ch: bool = False, + incr: bool = False, + gt: bool = False, + lt: bool = False, + ) -> Any: + return self._require_client().zadd( + _serialize_redis_name_arg(name, self._get_prefix()), + cast(Any, mapping), + nx=nx, + xx=xx, + ch=ch, + incr=incr, + gt=gt, + lt=lt, + ) + + def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any: + return self._require_client().zremrangebyscore(_serialize_redis_name_arg(name, self._get_prefix()), min, max) + + def zcard(self, name: str | bytes) -> Any: + return self._require_client().zcard(_serialize_redis_name_arg(name, self._get_prefix())) + + def pubsub(self) -> PubSub: + return self._require_client().pubsub() + + def pipeline(self, transaction: bool = True, shard_hint: str | None = None) -> Any: + return self._require_client().pipeline(transaction=transaction, shard_hint=shard_hint) + + def __getattr__(self, item: str) -> Any: + return getattr(self._require_client(), item) redis_client: RedisClientWrapper = RedisClientWrapper() diff --git a/api/extensions/redis_names.py b/api/extensions/redis_names.py new file mode 100644 index 0000000000..9e63416daf --- /dev/null +++ b/api/extensions/redis_names.py @@ -0,0 +1,32 @@ +from configs import dify_config + + +def normalize_redis_key_prefix(prefix: str | None) -> str: + """Normalize the configured Redis key prefix for consistent runtime use.""" + if prefix is None: + return "" + return prefix.strip() + + +def get_redis_key_prefix() -> str: + """Read and normalize the current Redis key prefix from config.""" + return normalize_redis_key_prefix(dify_config.REDIS_KEY_PREFIX) + + +def serialize_redis_name(name: str, prefix: str | None = None) -> str: + """Convert a logical Redis name into the physical name used in Redis.""" + normalized_prefix = get_redis_key_prefix() if prefix is None else normalize_redis_key_prefix(prefix) + if not normalized_prefix: + return name + return f"{normalized_prefix}:{name}" + + +def serialize_redis_name_arg(name: str | bytes, prefix: str | None = None) -> str | bytes: + """Prefix string Redis names while preserving bytes inputs unchanged.""" + if isinstance(name, bytes): + return name + return serialize_redis_name(name, prefix) + + +def serialize_redis_name_args(names: tuple[str | bytes, ...], prefix: str | None = None) -> tuple[str | bytes, ...]: + return tuple(serialize_redis_name_arg(name, prefix) for name in names) diff --git a/api/libs/broadcast_channel/redis/channel.py b/api/libs/broadcast_channel/redis/channel.py index 36aa1cd3e8..b76a23eb3c 100644 --- a/api/libs/broadcast_channel/redis/channel.py +++ b/api/libs/broadcast_channel/redis/channel.py @@ -2,6 +2,7 @@ from __future__ import annotations from typing import Any +from extensions.redis_names import serialize_redis_name from libs.broadcast_channel.channel import Producer, Subscriber, Subscription from redis import Redis, RedisCluster @@ -32,12 +33,13 @@ class Topic: def __init__(self, redis_client: Redis | RedisCluster, topic: str): self._client = redis_client self._topic = topic + self._redis_topic = serialize_redis_name(topic) def as_producer(self) -> Producer: return self def publish(self, payload: bytes) -> None: - self._client.publish(self._topic, payload) + self._client.publish(self._redis_topic, payload) def as_subscriber(self) -> Subscriber: return self @@ -46,7 +48,7 @@ class Topic: return _RedisSubscription( client=self._client, pubsub=self._client.pubsub(), - topic=self._topic, + topic=self._redis_topic, ) diff --git a/api/libs/broadcast_channel/redis/sharded_channel.py b/api/libs/broadcast_channel/redis/sharded_channel.py index dddc92d099..919d8d622e 100644 --- a/api/libs/broadcast_channel/redis/sharded_channel.py +++ b/api/libs/broadcast_channel/redis/sharded_channel.py @@ -2,6 +2,7 @@ from __future__ import annotations from typing import Any +from extensions.redis_names import serialize_redis_name from libs.broadcast_channel.channel import Producer, Subscriber, Subscription from redis import Redis, RedisCluster @@ -30,12 +31,13 @@ class ShardedTopic: def __init__(self, redis_client: Redis | RedisCluster, topic: str): self._client = redis_client self._topic = topic + self._redis_topic = serialize_redis_name(topic) def as_producer(self) -> Producer: return self def publish(self, payload: bytes) -> None: - self._client.spublish(self._topic, payload) # type: ignore[attr-defined,union-attr] + self._client.spublish(self._redis_topic, payload) # type: ignore[attr-defined,union-attr] def as_subscriber(self) -> Subscriber: return self @@ -44,7 +46,7 @@ class ShardedTopic: return _RedisShardedSubscription( client=self._client, pubsub=self._client.pubsub(), - topic=self._topic, + topic=self._redis_topic, ) diff --git a/api/libs/broadcast_channel/redis/streams_channel.py b/api/libs/broadcast_channel/redis/streams_channel.py index 983f785027..55ff6cd4f9 100644 --- a/api/libs/broadcast_channel/redis/streams_channel.py +++ b/api/libs/broadcast_channel/redis/streams_channel.py @@ -6,6 +6,7 @@ import threading from collections.abc import Iterator from typing import Self +from extensions.redis_names import serialize_redis_name from libs.broadcast_channel.channel import Producer, Subscriber, Subscription from libs.broadcast_channel.exc import SubscriptionClosedError from redis import Redis, RedisCluster @@ -35,7 +36,7 @@ class StreamsTopic: def __init__(self, redis_client: Redis | RedisCluster, topic: str, *, retention_seconds: int = 600): self._client = redis_client self._topic = topic - self._key = f"stream:{topic}" + self._key = serialize_redis_name(f"stream:{topic}") self._retention_seconds = retention_seconds self.max_length = 5000 diff --git a/api/libs/db_migration_lock.py b/api/libs/db_migration_lock.py index ca8956e397..b5fe38342a 100644 --- a/api/libs/db_migration_lock.py +++ b/api/libs/db_migration_lock.py @@ -103,7 +103,10 @@ class DbMigrationAutoRenewLock: timeout=self._ttl_seconds, thread_local=False, ) - acquired = bool(self._lock.acquire(*args, **kwargs)) + lock = self._lock + if lock is None: + raise RuntimeError("Redis lock initialization failed.") + acquired = bool(lock.acquire(*args, **kwargs)) self._acquired = acquired if acquired: self._start_heartbeat() diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index f84d39aeb5..c07ab6d6bf 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -33,6 +33,7 @@ REDIS_USERNAME= REDIS_PASSWORD=difyai123456 REDIS_USE_SSL=false REDIS_DB=0 +REDIS_KEY_PREFIX= # PostgreSQL database configuration DB_USERNAME=postgres diff --git a/api/tests/unit_tests/configs/test_dify_config.py b/api/tests/unit_tests/configs/test_dify_config.py index 3089750c3e..bad246a4bb 100644 --- a/api/tests/unit_tests/configs/test_dify_config.py +++ b/api/tests/unit_tests/configs/test_dify_config.py @@ -236,6 +236,41 @@ def test_pubsub_redis_url_required_when_default_unavailable(monkeypatch: pytest. _ = DifyConfig().normalized_pubsub_redis_url +def test_dify_config_exposes_redis_key_prefix_default(monkeypatch: pytest.MonkeyPatch): + os.environ.clear() + + monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") + monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") + monkeypatch.setenv("DB_TYPE", "postgresql") + monkeypatch.setenv("DB_USERNAME", "postgres") + monkeypatch.setenv("DB_PASSWORD", "postgres") + monkeypatch.setenv("DB_HOST", "localhost") + monkeypatch.setenv("DB_PORT", "5432") + monkeypatch.setenv("DB_DATABASE", "dify") + + config = DifyConfig(_env_file=None) + + assert config.REDIS_KEY_PREFIX == "" + + +def test_dify_config_reads_redis_key_prefix_from_env(monkeypatch: pytest.MonkeyPatch): + os.environ.clear() + + monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") + monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") + monkeypatch.setenv("DB_TYPE", "postgresql") + monkeypatch.setenv("DB_USERNAME", "postgres") + monkeypatch.setenv("DB_PASSWORD", "postgres") + monkeypatch.setenv("DB_HOST", "localhost") + monkeypatch.setenv("DB_PORT", "5432") + monkeypatch.setenv("DB_DATABASE", "dify") + monkeypatch.setenv("REDIS_KEY_PREFIX", "enterprise-a") + + config = DifyConfig(_env_file=None) + + assert config.REDIS_KEY_PREFIX == "enterprise-a" + + @pytest.mark.parametrize( ("broker_url", "expected_host", "expected_port", "expected_username", "expected_password", "expected_db"), [ diff --git a/api/tests/unit_tests/extensions/test_celery_ssl.py b/api/tests/unit_tests/extensions/test_celery_ssl.py index 81687ce5f8..366e45d86d 100644 --- a/api/tests/unit_tests/extensions/test_celery_ssl.py +++ b/api/tests/unit_tests/extensions/test_celery_ssl.py @@ -7,6 +7,47 @@ from unittest.mock import MagicMock, patch class TestCelerySSLConfiguration: """Test suite for Celery SSL configuration.""" + def test_get_celery_broker_transport_options_includes_global_keyprefix_for_redis(self): + mock_config = MagicMock() + mock_config.CELERY_USE_SENTINEL = False + mock_config.REDIS_KEY_PREFIX = "enterprise-a" + + with patch("extensions.ext_celery.dify_config", mock_config): + from extensions.ext_celery import get_celery_broker_transport_options + + result = get_celery_broker_transport_options() + + assert result["global_keyprefix"] == "enterprise-a:" + + def test_get_celery_broker_transport_options_omits_global_keyprefix_when_prefix_empty(self): + mock_config = MagicMock() + mock_config.CELERY_USE_SENTINEL = False + mock_config.REDIS_KEY_PREFIX = " " + + with patch("extensions.ext_celery.dify_config", mock_config): + from extensions.ext_celery import get_celery_broker_transport_options + + result = get_celery_broker_transport_options() + + assert "global_keyprefix" not in result + + def test_get_celery_broker_transport_options_keeps_sentinel_and_adds_global_keyprefix(self): + mock_config = MagicMock() + mock_config.CELERY_USE_SENTINEL = True + mock_config.CELERY_SENTINEL_MASTER_NAME = "mymaster" + mock_config.CELERY_SENTINEL_SOCKET_TIMEOUT = 0.1 + mock_config.CELERY_SENTINEL_PASSWORD = "secret" + mock_config.REDIS_KEY_PREFIX = "enterprise-a" + + with patch("extensions.ext_celery.dify_config", mock_config): + from extensions.ext_celery import get_celery_broker_transport_options + + result = get_celery_broker_transport_options() + + assert result["master_name"] == "mymaster" + assert result["sentinel_kwargs"]["password"] == "secret" + assert result["global_keyprefix"] == "enterprise-a:" + def test_get_celery_ssl_options_when_ssl_disabled(self): """Test SSL options when BROKER_USE_SSL is False.""" from configs import DifyConfig @@ -151,3 +192,49 @@ class TestCelerySSLConfiguration: # Check that SSL is also applied to Redis backend assert "redis_backend_use_ssl" in celery_app.conf assert celery_app.conf["redis_backend_use_ssl"] is not None + + def test_celery_init_applies_global_keyprefix_to_broker_and_backend_transport(self): + mock_config = MagicMock() + mock_config.BROKER_USE_SSL = False + mock_config.REDIS_KEY_PREFIX = "enterprise-a" + mock_config.HUMAN_INPUT_TIMEOUT_TASK_INTERVAL = 1 + mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0" + mock_config.CELERY_BACKEND = "redis" + mock_config.CELERY_RESULT_BACKEND = "redis://localhost:6379/0" + mock_config.CELERY_USE_SENTINEL = False + mock_config.LOG_FORMAT = "%(message)s" + mock_config.LOG_TZ = "UTC" + mock_config.LOG_FILE = None + mock_config.CELERY_TASK_ANNOTATIONS = {} + + mock_config.CELERY_BEAT_SCHEDULER_TIME = 1 + mock_config.ENABLE_CLEAN_EMBEDDING_CACHE_TASK = False + mock_config.ENABLE_CLEAN_UNUSED_DATASETS_TASK = False + mock_config.ENABLE_CREATE_TIDB_SERVERLESS_TASK = False + mock_config.ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK = False + mock_config.ENABLE_CLEAN_MESSAGES = False + mock_config.ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK = False + mock_config.ENABLE_DATASETS_QUEUE_MONITOR = False + mock_config.ENABLE_HUMAN_INPUT_TIMEOUT_TASK = False + mock_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK = False + mock_config.MARKETPLACE_ENABLED = False + mock_config.WORKFLOW_LOG_CLEANUP_ENABLED = False + mock_config.ENABLE_WORKFLOW_RUN_CLEANUP_TASK = False + mock_config.ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK = False + mock_config.WORKFLOW_SCHEDULE_POLLER_INTERVAL = 1 + mock_config.ENABLE_TRIGGER_PROVIDER_REFRESH_TASK = False + mock_config.TRIGGER_PROVIDER_REFRESH_INTERVAL = 15 + mock_config.ENABLE_API_TOKEN_LAST_USED_UPDATE_TASK = False + mock_config.API_TOKEN_LAST_USED_UPDATE_INTERVAL = 30 + mock_config.ENTERPRISE_ENABLED = False + mock_config.ENTERPRISE_TELEMETRY_ENABLED = False + + with patch("extensions.ext_celery.dify_config", mock_config): + from dify_app import DifyApp + from extensions.ext_celery import init_app + + app = DifyApp(__name__) + celery_app = init_app(app) + + assert celery_app.conf["broker_transport_options"]["global_keyprefix"] == "enterprise-a:" + assert celery_app.conf["result_backend_transport_options"]["global_keyprefix"] == "enterprise-a:" diff --git a/api/tests/unit_tests/extensions/test_pubsub_channel.py b/api/tests/unit_tests/extensions/test_pubsub_channel.py index a5b41a7266..926c406ad4 100644 --- a/api/tests/unit_tests/extensions/test_pubsub_channel.py +++ b/api/tests/unit_tests/extensions/test_pubsub_channel.py @@ -6,6 +6,7 @@ from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastCh def test_get_pubsub_broadcast_channel_defaults_to_pubsub(monkeypatch): monkeypatch.setattr(dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "pubsub") + monkeypatch.setattr(ext_redis, "_pubsub_redis_client", object()) channel = ext_redis.get_pubsub_broadcast_channel() @@ -14,6 +15,7 @@ def test_get_pubsub_broadcast_channel_defaults_to_pubsub(monkeypatch): def test_get_pubsub_broadcast_channel_sharded(monkeypatch): monkeypatch.setattr(dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "sharded") + monkeypatch.setattr(ext_redis, "_pubsub_redis_client", object()) channel = ext_redis.get_pubsub_broadcast_channel() diff --git a/api/tests/unit_tests/extensions/test_redis.py b/api/tests/unit_tests/extensions/test_redis.py index 5e9be4ab9b..21248439bf 100644 --- a/api/tests/unit_tests/extensions/test_redis.py +++ b/api/tests/unit_tests/extensions/test_redis.py @@ -1,12 +1,15 @@ -from unittest.mock import patch +from unittest.mock import MagicMock, patch from redis import RedisError from redis.retry import Retry from extensions.ext_redis import ( + RedisClientWrapper, _get_base_redis_params, _get_cluster_connection_health_params, _get_connection_health_params, + _normalize_redis_key_prefix, + _serialize_redis_name, redis_fallback, ) @@ -123,3 +126,99 @@ class TestRedisFallback: assert test_func.__name__ == "test_func" assert test_func.__doc__ == "Test function docstring" + + +class TestRedisKeyPrefixHelpers: + def test_normalize_redis_key_prefix_trims_whitespace(self): + assert _normalize_redis_key_prefix(" enterprise-a ") == "enterprise-a" + + def test_normalize_redis_key_prefix_treats_whitespace_only_as_empty(self): + assert _normalize_redis_key_prefix(" ") == "" + + def test_serialize_redis_name_returns_original_when_prefix_empty(self): + assert _serialize_redis_name("model_lb_index:test", "") == "model_lb_index:test" + + def test_serialize_redis_name_adds_single_colon_separator(self): + assert _serialize_redis_name("model_lb_index:test", "enterprise-a") == "enterprise-a:model_lb_index:test" + + +class TestRedisClientWrapperKeyPrefix: + def test_wrapper_get_prefixes_string_keys(self): + mock_client = MagicMock() + wrapper = RedisClientWrapper() + wrapper.initialize(mock_client) + + with patch("extensions.ext_redis.dify_config") as mock_config: + mock_config.REDIS_KEY_PREFIX = "enterprise-a" + + wrapper.get("oauth_state:abc") + + mock_client.get.assert_called_once_with("enterprise-a:oauth_state:abc") + + def test_wrapper_delete_prefixes_multiple_keys(self): + mock_client = MagicMock() + wrapper = RedisClientWrapper() + wrapper.initialize(mock_client) + + with patch("extensions.ext_redis.dify_config") as mock_config: + mock_config.REDIS_KEY_PREFIX = "enterprise-a" + + wrapper.delete("key:a", "key:b") + + mock_client.delete.assert_called_once_with("enterprise-a:key:a", "enterprise-a:key:b") + + def test_wrapper_lock_prefixes_lock_name(self): + mock_client = MagicMock() + wrapper = RedisClientWrapper() + wrapper.initialize(mock_client) + + with patch("extensions.ext_redis.dify_config") as mock_config: + mock_config.REDIS_KEY_PREFIX = "enterprise-a" + + wrapper.lock("resource-lock", timeout=10) + + mock_client.lock.assert_called_once() + args, kwargs = mock_client.lock.call_args + assert args == ("enterprise-a:resource-lock",) + assert kwargs["timeout"] == 10 + + def test_wrapper_hash_operations_prefix_key_name(self): + mock_client = MagicMock() + wrapper = RedisClientWrapper() + wrapper.initialize(mock_client) + + with patch("extensions.ext_redis.dify_config") as mock_config: + mock_config.REDIS_KEY_PREFIX = "enterprise-a" + + wrapper.hset("hash:key", "field", "value") + wrapper.hgetall("hash:key") + + mock_client.hset.assert_called_once_with("enterprise-a:hash:key", "field", "value") + mock_client.hgetall.assert_called_once_with("enterprise-a:hash:key") + + def test_wrapper_zadd_prefixes_sorted_set_name(self): + mock_client = MagicMock() + wrapper = RedisClientWrapper() + wrapper.initialize(mock_client) + + with patch("extensions.ext_redis.dify_config") as mock_config: + mock_config.REDIS_KEY_PREFIX = "enterprise-a" + + wrapper.zadd("zset:key", {"member": 1}) + + mock_client.zadd.assert_called_once() + args, kwargs = mock_client.zadd.call_args + assert args == ("enterprise-a:zset:key", {"member": 1}) + assert kwargs["nx"] is False + + def test_wrapper_preserves_keys_when_prefix_is_empty(self): + mock_client = MagicMock() + wrapper = RedisClientWrapper() + wrapper.initialize(mock_client) + + with patch("extensions.ext_redis.dify_config") as mock_config: + mock_config.REDIS_KEY_PREFIX = " " + + wrapper.get("plain:key") + + mock_client.get.assert_called_once_with("plain:key") diff --git a/api/tests/unit_tests/libs/broadcast_channel/redis/test_channel_unit_tests.py b/api/tests/unit_tests/libs/broadcast_channel/redis/test_channel_unit_tests.py index 460374b6f6..8bef01c1ed 100644 --- a/api/tests/unit_tests/libs/broadcast_channel/redis/test_channel_unit_tests.py +++ b/api/tests/unit_tests/libs/broadcast_channel/redis/test_channel_unit_tests.py @@ -139,6 +139,28 @@ class TestTopic: mock_redis_client.publish.assert_called_once_with("test-topic", payload) + def test_publish_prefixes_regular_topic(self, mock_redis_client: MagicMock): + with patch("extensions.redis_names.dify_config") as mock_config: + mock_config.REDIS_KEY_PREFIX = "enterprise-a" + topic = Topic(mock_redis_client, "test-topic") + + topic.publish(b"test message") + + mock_redis_client.publish.assert_called_once_with("enterprise-a:test-topic", b"test message") + + def test_subscribe_prefixes_regular_topic(self, mock_redis_client: MagicMock): + with patch("extensions.redis_names.dify_config") as mock_config: + mock_config.REDIS_KEY_PREFIX = "enterprise-a" + topic = Topic(mock_redis_client, "test-topic") + + subscription = topic.subscribe() + try: + subscription._start_if_needed() + finally: + subscription.close() + + mock_redis_client.pubsub.return_value.subscribe.assert_called_once_with("enterprise-a:test-topic") + class TestShardedTopic: """Test cases for the ShardedTopic class.""" @@ -176,6 +198,15 @@ class TestShardedTopic: mock_redis_client.spublish.assert_called_once_with("test-sharded-topic", payload) + def test_publish_prefixes_sharded_topic(self, mock_redis_client: MagicMock): + with patch("extensions.redis_names.dify_config") as mock_config: + mock_config.REDIS_KEY_PREFIX = "enterprise-a" + sharded_topic = ShardedTopic(mock_redis_client, "test-sharded-topic") + + sharded_topic.publish(b"test sharded message") + + mock_redis_client.spublish.assert_called_once_with("enterprise-a:test-sharded-topic", b"test sharded message") + def test_subscribe_returns_sharded_subscription(self, sharded_topic: ShardedTopic, mock_redis_client: MagicMock): """Test that subscribe() returns a _RedisShardedSubscription instance.""" subscription = sharded_topic.subscribe() @@ -185,6 +216,19 @@ class TestShardedTopic: assert subscription._pubsub is mock_redis_client.pubsub.return_value assert subscription._topic == "test-sharded-topic" + def test_subscribe_prefixes_sharded_topic(self, mock_redis_client: MagicMock): + with patch("extensions.redis_names.dify_config") as mock_config: + mock_config.REDIS_KEY_PREFIX = "enterprise-a" + sharded_topic = ShardedTopic(mock_redis_client, "test-sharded-topic") + + subscription = sharded_topic.subscribe() + try: + subscription._start_if_needed() + finally: + subscription.close() + + mock_redis_client.pubsub.return_value.ssubscribe.assert_called_once_with("enterprise-a:test-sharded-topic") + @dataclasses.dataclass(frozen=True) class SubscriptionTestCase: diff --git a/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py b/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py index 0886b70ee5..fd9e5ca5b3 100644 --- a/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py +++ b/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py @@ -2,6 +2,7 @@ import threading import time from dataclasses import dataclass from typing import cast +from unittest.mock import patch import pytest @@ -150,6 +151,25 @@ class TestStreamsBroadcastChannel: # Expire called after publish assert fake_redis._expire_calls.get("stream:beta", 0) >= 1 + def test_topic_uses_prefixed_stream_key(self, fake_redis: FakeStreamsRedis): + with patch("extensions.redis_names.dify_config") as mock_config: + mock_config.REDIS_KEY_PREFIX = "enterprise-a" + + topic = StreamsBroadcastChannel(fake_redis, retention_seconds=60).topic("alpha") + + assert topic._topic == "alpha" + assert topic._key == "enterprise-a:stream:alpha" + + def test_publish_uses_prefixed_stream_key(self, fake_redis: FakeStreamsRedis): + with patch("extensions.redis_names.dify_config") as mock_config: + mock_config.REDIS_KEY_PREFIX = "enterprise-a" + topic = StreamsBroadcastChannel(fake_redis, retention_seconds=60).topic("beta") + + topic.publish(b"hello") + + assert fake_redis._store["enterprise-a:stream:beta"][0][1] == {b"data": b"hello"} + assert fake_redis._expire_calls.get("enterprise-a:stream:beta", 0) >= 1 + def test_topic_exposes_self_as_producer_and_subscriber(self, streams_channel: StreamsBroadcastChannel): topic = streams_channel.topic("producer-subscriber") diff --git a/docker/.env.example b/docker/.env.example index 4426a882f1..856b04a3df 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -351,6 +351,9 @@ REDIS_SSL_CERTFILE= REDIS_SSL_KEYFILE= # Path to client private key file for SSL authentication REDIS_DB=0 +# Optional global prefix for Redis keys, topics, streams, and Celery Redis transport artifacts. +# Leave empty to preserve current unprefixed behavior. +REDIS_KEY_PREFIX= # Optional: limit total Redis connections used by API/Worker (unset for default) # Align with API's REDIS_MAX_CONNECTIONS in configs REDIS_MAX_CONNECTIONS= diff --git a/docker/README.md b/docker/README.md index 4c40317f37..3130fa9886 100644 --- a/docker/README.md +++ b/docker/README.md @@ -88,6 +88,7 @@ The `.env.example` file provided in the Docker setup is extensive and covers a w 1. **Redis Configuration**: - `REDIS_HOST`, `REDIS_PORT`, `REDIS_PASSWORD`: Redis server connection settings. + - `REDIS_KEY_PREFIX`: Optional global namespace prefix for Redis keys, topics, streams, and Celery Redis transport artifacts. 1. **Celery Configuration**: diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 1fc1cfdf9e..c1ddba4f80 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -90,6 +90,7 @@ x-shared-env: &shared-api-worker-env REDIS_SSL_CERTFILE: ${REDIS_SSL_CERTFILE:-} REDIS_SSL_KEYFILE: ${REDIS_SSL_KEYFILE:-} REDIS_DB: ${REDIS_DB:-0} + REDIS_KEY_PREFIX: ${REDIS_KEY_PREFIX:-} REDIS_MAX_CONNECTIONS: ${REDIS_MAX_CONNECTIONS:-} REDIS_USE_SENTINEL: ${REDIS_USE_SENTINEL:-false} REDIS_SENTINELS: ${REDIS_SENTINELS:-}