mirror of
https://github.com/langgenius/dify.git
synced 2026-04-16 02:16:57 +08:00
feat: support configurable redis key prefix (#35139)
This commit is contained in:
parent
bd7a9b5fcf
commit
736880e046
@ -57,6 +57,9 @@ REDIS_SSL_CERTFILE=
|
||||
REDIS_SSL_KEYFILE=
|
||||
# Path to client private key file for SSL authentication
|
||||
REDIS_DB=0
|
||||
# Optional global prefix for Redis keys, topics, streams, and Celery Redis transport artifacts.
|
||||
# Leave empty to preserve current unprefixed behavior.
|
||||
REDIS_KEY_PREFIX=
|
||||
|
||||
# redis Sentinel configuration.
|
||||
REDIS_USE_SENTINEL=false
|
||||
|
||||
5
api/configs/middleware/cache/redis_config.py
vendored
5
api/configs/middleware/cache/redis_config.py
vendored
@ -32,6 +32,11 @@ class RedisConfig(BaseSettings):
|
||||
default=0,
|
||||
)
|
||||
|
||||
REDIS_KEY_PREFIX: str = Field(
|
||||
description="Optional global prefix for Redis keys, topics, and transport artifacts",
|
||||
default="",
|
||||
)
|
||||
|
||||
REDIS_USE_SSL: bool = Field(
|
||||
description="Enable SSL/TLS for the Redis connection",
|
||||
default=False,
|
||||
|
||||
@ -9,6 +9,7 @@ from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
from extensions.redis_names import normalize_redis_key_prefix
|
||||
|
||||
|
||||
class _CelerySentinelKwargsDict(TypedDict):
|
||||
@ -16,9 +17,10 @@ class _CelerySentinelKwargsDict(TypedDict):
|
||||
password: str | None
|
||||
|
||||
|
||||
class CelerySentinelTransportDict(TypedDict):
|
||||
class CelerySentinelTransportDict(TypedDict, total=False):
|
||||
master_name: str | None
|
||||
sentinel_kwargs: _CelerySentinelKwargsDict
|
||||
global_keyprefix: str
|
||||
|
||||
|
||||
class CelerySSLOptionsDict(TypedDict):
|
||||
@ -61,15 +63,31 @@ def get_celery_ssl_options() -> CelerySSLOptionsDict | None:
|
||||
|
||||
def get_celery_broker_transport_options() -> CelerySentinelTransportDict | dict[str, Any]:
|
||||
"""Get broker transport options (e.g. Redis Sentinel) for Celery connections."""
|
||||
transport_options: CelerySentinelTransportDict | dict[str, Any]
|
||||
if dify_config.CELERY_USE_SENTINEL:
|
||||
return CelerySentinelTransportDict(
|
||||
transport_options = CelerySentinelTransportDict(
|
||||
master_name=dify_config.CELERY_SENTINEL_MASTER_NAME,
|
||||
sentinel_kwargs=_CelerySentinelKwargsDict(
|
||||
socket_timeout=dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT,
|
||||
password=dify_config.CELERY_SENTINEL_PASSWORD,
|
||||
),
|
||||
)
|
||||
return {}
|
||||
else:
|
||||
transport_options = {}
|
||||
|
||||
global_keyprefix = get_celery_redis_global_keyprefix()
|
||||
if global_keyprefix:
|
||||
transport_options["global_keyprefix"] = global_keyprefix
|
||||
|
||||
return transport_options
|
||||
|
||||
|
||||
def get_celery_redis_global_keyprefix() -> str | None:
|
||||
"""Return the Redis transport prefix for Celery when namespace isolation is enabled."""
|
||||
normalized_prefix = normalize_redis_key_prefix(dify_config.REDIS_KEY_PREFIX)
|
||||
if not normalized_prefix:
|
||||
return None
|
||||
return f"{normalized_prefix}:"
|
||||
|
||||
|
||||
def init_app(app: DifyApp) -> Celery:
|
||||
|
||||
@ -3,7 +3,7 @@ import logging
|
||||
import ssl
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
from typing import Any, Union, cast
|
||||
|
||||
import redis
|
||||
from redis import RedisError
|
||||
@ -18,17 +18,26 @@ from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
from extensions.redis_names import (
|
||||
normalize_redis_key_prefix,
|
||||
serialize_redis_name,
|
||||
serialize_redis_name_arg,
|
||||
serialize_redis_name_args,
|
||||
)
|
||||
from libs.broadcast_channel.channel import BroadcastChannel as BroadcastChannelProtocol
|
||||
from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel
|
||||
from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel
|
||||
from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.lock import Lock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_normalize_redis_key_prefix = normalize_redis_key_prefix
|
||||
_serialize_redis_name = serialize_redis_name
|
||||
_serialize_redis_name_arg = serialize_redis_name_arg
|
||||
_serialize_redis_name_args = serialize_redis_name_args
|
||||
|
||||
|
||||
class RedisClientWrapper:
|
||||
"""
|
||||
A wrapper class for the Redis client that addresses the issue where the global
|
||||
@ -59,68 +68,148 @@ class RedisClientWrapper:
|
||||
if self._client is None:
|
||||
self._client = client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Type hints for IDE support and static analysis
|
||||
# These are not executed at runtime but provide type information
|
||||
def get(self, name: str | bytes) -> Any: ...
|
||||
|
||||
def set(
|
||||
self,
|
||||
name: str | bytes,
|
||||
value: Any,
|
||||
ex: int | None = None,
|
||||
px: int | None = None,
|
||||
nx: bool = False,
|
||||
xx: bool = False,
|
||||
keepttl: bool = False,
|
||||
get: bool = False,
|
||||
exat: int | None = None,
|
||||
pxat: int | None = None,
|
||||
) -> Any: ...
|
||||
|
||||
def setex(self, name: str | bytes, time: int | timedelta, value: Any) -> Any: ...
|
||||
def setnx(self, name: str | bytes, value: Any) -> Any: ...
|
||||
def delete(self, *names: str | bytes) -> Any: ...
|
||||
def incr(self, name: str | bytes, amount: int = 1) -> Any: ...
|
||||
def expire(
|
||||
self,
|
||||
name: str | bytes,
|
||||
time: int | timedelta,
|
||||
nx: bool = False,
|
||||
xx: bool = False,
|
||||
gt: bool = False,
|
||||
lt: bool = False,
|
||||
) -> Any: ...
|
||||
def lock(
|
||||
self,
|
||||
name: str,
|
||||
timeout: float | None = None,
|
||||
sleep: float = 0.1,
|
||||
blocking: bool = True,
|
||||
blocking_timeout: float | None = None,
|
||||
thread_local: bool = True,
|
||||
) -> Lock: ...
|
||||
def zadd(
|
||||
self,
|
||||
name: str | bytes,
|
||||
mapping: dict[str | bytes | int | float, float | int | str | bytes],
|
||||
nx: bool = False,
|
||||
xx: bool = False,
|
||||
ch: bool = False,
|
||||
incr: bool = False,
|
||||
gt: bool = False,
|
||||
lt: bool = False,
|
||||
) -> Any: ...
|
||||
def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any: ...
|
||||
def zcard(self, name: str | bytes) -> Any: ...
|
||||
def getdel(self, name: str | bytes) -> Any: ...
|
||||
def pubsub(self) -> PubSub: ...
|
||||
def pipeline(self, transaction: bool = True, shard_hint: str | None = None) -> Any: ...
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
def _require_client(self) -> redis.Redis | RedisCluster:
|
||||
if self._client is None:
|
||||
raise RuntimeError("Redis client is not initialized. Call init_app first.")
|
||||
return getattr(self._client, item)
|
||||
return self._client
|
||||
|
||||
def _get_prefix(self) -> str:
|
||||
return dify_config.REDIS_KEY_PREFIX
|
||||
|
||||
def get(self, name: str | bytes) -> Any:
|
||||
return self._require_client().get(_serialize_redis_name_arg(name, self._get_prefix()))
|
||||
|
||||
def set(
|
||||
self,
|
||||
name: str | bytes,
|
||||
value: Any,
|
||||
ex: int | None = None,
|
||||
px: int | None = None,
|
||||
nx: bool = False,
|
||||
xx: bool = False,
|
||||
keepttl: bool = False,
|
||||
get: bool = False,
|
||||
exat: int | None = None,
|
||||
pxat: int | None = None,
|
||||
) -> Any:
|
||||
return self._require_client().set(
|
||||
_serialize_redis_name_arg(name, self._get_prefix()),
|
||||
value,
|
||||
ex=ex,
|
||||
px=px,
|
||||
nx=nx,
|
||||
xx=xx,
|
||||
keepttl=keepttl,
|
||||
get=get,
|
||||
exat=exat,
|
||||
pxat=pxat,
|
||||
)
|
||||
|
||||
def setex(self, name: str | bytes, time: int | timedelta, value: Any) -> Any:
|
||||
return self._require_client().setex(_serialize_redis_name_arg(name, self._get_prefix()), time, value)
|
||||
|
||||
def setnx(self, name: str | bytes, value: Any) -> Any:
|
||||
return self._require_client().setnx(_serialize_redis_name_arg(name, self._get_prefix()), value)
|
||||
|
||||
def delete(self, *names: str | bytes) -> Any:
|
||||
return self._require_client().delete(*_serialize_redis_name_args(names, self._get_prefix()))
|
||||
|
||||
def incr(self, name: str | bytes, amount: int = 1) -> Any:
|
||||
return self._require_client().incr(_serialize_redis_name_arg(name, self._get_prefix()), amount)
|
||||
|
||||
def expire(
|
||||
self,
|
||||
name: str | bytes,
|
||||
time: int | timedelta,
|
||||
nx: bool = False,
|
||||
xx: bool = False,
|
||||
gt: bool = False,
|
||||
lt: bool = False,
|
||||
) -> Any:
|
||||
return self._require_client().expire(
|
||||
_serialize_redis_name_arg(name, self._get_prefix()),
|
||||
time,
|
||||
nx=nx,
|
||||
xx=xx,
|
||||
gt=gt,
|
||||
lt=lt,
|
||||
)
|
||||
|
||||
def exists(self, *names: str | bytes) -> Any:
|
||||
return self._require_client().exists(*_serialize_redis_name_args(names, self._get_prefix()))
|
||||
|
||||
def ttl(self, name: str | bytes) -> Any:
|
||||
return self._require_client().ttl(_serialize_redis_name_arg(name, self._get_prefix()))
|
||||
|
||||
def getdel(self, name: str | bytes) -> Any:
|
||||
return self._require_client().getdel(_serialize_redis_name_arg(name, self._get_prefix()))
|
||||
|
||||
def lock(
|
||||
self,
|
||||
name: str,
|
||||
timeout: float | None = None,
|
||||
sleep: float = 0.1,
|
||||
blocking: bool = True,
|
||||
blocking_timeout: float | None = None,
|
||||
thread_local: bool = True,
|
||||
) -> Any:
|
||||
return self._require_client().lock(
|
||||
_serialize_redis_name(name, self._get_prefix()),
|
||||
timeout=timeout,
|
||||
sleep=sleep,
|
||||
blocking=blocking,
|
||||
blocking_timeout=blocking_timeout,
|
||||
thread_local=thread_local,
|
||||
)
|
||||
|
||||
def hset(self, name: str | bytes, *args: Any, **kwargs: Any) -> Any:
|
||||
return self._require_client().hset(_serialize_redis_name_arg(name, self._get_prefix()), *args, **kwargs)
|
||||
|
||||
def hgetall(self, name: str | bytes) -> Any:
|
||||
return self._require_client().hgetall(_serialize_redis_name_arg(name, self._get_prefix()))
|
||||
|
||||
def hdel(self, name: str | bytes, *keys: str | bytes) -> Any:
|
||||
return self._require_client().hdel(_serialize_redis_name_arg(name, self._get_prefix()), *keys)
|
||||
|
||||
def hlen(self, name: str | bytes) -> Any:
|
||||
return self._require_client().hlen(_serialize_redis_name_arg(name, self._get_prefix()))
|
||||
|
||||
def zadd(
|
||||
self,
|
||||
name: str | bytes,
|
||||
mapping: dict[str | bytes | int | float, float | int | str | bytes],
|
||||
nx: bool = False,
|
||||
xx: bool = False,
|
||||
ch: bool = False,
|
||||
incr: bool = False,
|
||||
gt: bool = False,
|
||||
lt: bool = False,
|
||||
) -> Any:
|
||||
return self._require_client().zadd(
|
||||
_serialize_redis_name_arg(name, self._get_prefix()),
|
||||
cast(Any, mapping),
|
||||
nx=nx,
|
||||
xx=xx,
|
||||
ch=ch,
|
||||
incr=incr,
|
||||
gt=gt,
|
||||
lt=lt,
|
||||
)
|
||||
|
||||
def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any:
|
||||
return self._require_client().zremrangebyscore(_serialize_redis_name_arg(name, self._get_prefix()), min, max)
|
||||
|
||||
def zcard(self, name: str | bytes) -> Any:
|
||||
return self._require_client().zcard(_serialize_redis_name_arg(name, self._get_prefix()))
|
||||
|
||||
def pubsub(self) -> PubSub:
|
||||
return self._require_client().pubsub()
|
||||
|
||||
def pipeline(self, transaction: bool = True, shard_hint: str | None = None) -> Any:
|
||||
return self._require_client().pipeline(transaction=transaction, shard_hint=shard_hint)
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
return getattr(self._require_client(), item)
|
||||
|
||||
|
||||
redis_client: RedisClientWrapper = RedisClientWrapper()
|
||||
|
||||
32
api/extensions/redis_names.py
Normal file
32
api/extensions/redis_names.py
Normal file
@ -0,0 +1,32 @@
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
def normalize_redis_key_prefix(prefix: str | None) -> str:
|
||||
"""Normalize the configured Redis key prefix for consistent runtime use."""
|
||||
if prefix is None:
|
||||
return ""
|
||||
return prefix.strip()
|
||||
|
||||
|
||||
def get_redis_key_prefix() -> str:
|
||||
"""Read and normalize the current Redis key prefix from config."""
|
||||
return normalize_redis_key_prefix(dify_config.REDIS_KEY_PREFIX)
|
||||
|
||||
|
||||
def serialize_redis_name(name: str, prefix: str | None = None) -> str:
|
||||
"""Convert a logical Redis name into the physical name used in Redis."""
|
||||
normalized_prefix = get_redis_key_prefix() if prefix is None else normalize_redis_key_prefix(prefix)
|
||||
if not normalized_prefix:
|
||||
return name
|
||||
return f"{normalized_prefix}:{name}"
|
||||
|
||||
|
||||
def serialize_redis_name_arg(name: str | bytes, prefix: str | None = None) -> str | bytes:
|
||||
"""Prefix string Redis names while preserving bytes inputs unchanged."""
|
||||
if isinstance(name, bytes):
|
||||
return name
|
||||
return serialize_redis_name(name, prefix)
|
||||
|
||||
|
||||
def serialize_redis_name_args(names: tuple[str | bytes, ...], prefix: str | None = None) -> tuple[str | bytes, ...]:
|
||||
return tuple(serialize_redis_name_arg(name, prefix) for name in names)
|
||||
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from extensions.redis_names import serialize_redis_name
|
||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||
from redis import Redis, RedisCluster
|
||||
|
||||
@ -32,12 +33,13 @@ class Topic:
|
||||
def __init__(self, redis_client: Redis | RedisCluster, topic: str):
|
||||
self._client = redis_client
|
||||
self._topic = topic
|
||||
self._redis_topic = serialize_redis_name(topic)
|
||||
|
||||
def as_producer(self) -> Producer:
|
||||
return self
|
||||
|
||||
def publish(self, payload: bytes) -> None:
|
||||
self._client.publish(self._topic, payload)
|
||||
self._client.publish(self._redis_topic, payload)
|
||||
|
||||
def as_subscriber(self) -> Subscriber:
|
||||
return self
|
||||
@ -46,7 +48,7 @@ class Topic:
|
||||
return _RedisSubscription(
|
||||
client=self._client,
|
||||
pubsub=self._client.pubsub(),
|
||||
topic=self._topic,
|
||||
topic=self._redis_topic,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from extensions.redis_names import serialize_redis_name
|
||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||
from redis import Redis, RedisCluster
|
||||
|
||||
@ -30,12 +31,13 @@ class ShardedTopic:
|
||||
def __init__(self, redis_client: Redis | RedisCluster, topic: str):
|
||||
self._client = redis_client
|
||||
self._topic = topic
|
||||
self._redis_topic = serialize_redis_name(topic)
|
||||
|
||||
def as_producer(self) -> Producer:
|
||||
return self
|
||||
|
||||
def publish(self, payload: bytes) -> None:
|
||||
self._client.spublish(self._topic, payload) # type: ignore[attr-defined,union-attr]
|
||||
self._client.spublish(self._redis_topic, payload) # type: ignore[attr-defined,union-attr]
|
||||
|
||||
def as_subscriber(self) -> Subscriber:
|
||||
return self
|
||||
@ -44,7 +46,7 @@ class ShardedTopic:
|
||||
return _RedisShardedSubscription(
|
||||
client=self._client,
|
||||
pubsub=self._client.pubsub(),
|
||||
topic=self._topic,
|
||||
topic=self._redis_topic,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@ import threading
|
||||
from collections.abc import Iterator
|
||||
from typing import Self
|
||||
|
||||
from extensions.redis_names import serialize_redis_name
|
||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||
from redis import Redis, RedisCluster
|
||||
@ -35,7 +36,7 @@ class StreamsTopic:
|
||||
def __init__(self, redis_client: Redis | RedisCluster, topic: str, *, retention_seconds: int = 600):
|
||||
self._client = redis_client
|
||||
self._topic = topic
|
||||
self._key = f"stream:{topic}"
|
||||
self._key = serialize_redis_name(f"stream:{topic}")
|
||||
self._retention_seconds = retention_seconds
|
||||
self.max_length = 5000
|
||||
|
||||
|
||||
@ -103,7 +103,10 @@ class DbMigrationAutoRenewLock:
|
||||
timeout=self._ttl_seconds,
|
||||
thread_local=False,
|
||||
)
|
||||
acquired = bool(self._lock.acquire(*args, **kwargs))
|
||||
lock = self._lock
|
||||
if lock is None:
|
||||
raise RuntimeError("Redis lock initialization failed.")
|
||||
acquired = bool(lock.acquire(*args, **kwargs))
|
||||
self._acquired = acquired
|
||||
if acquired:
|
||||
self._start_heartbeat()
|
||||
|
||||
@ -33,6 +33,7 @@ REDIS_USERNAME=
|
||||
REDIS_PASSWORD=difyai123456
|
||||
REDIS_USE_SSL=false
|
||||
REDIS_DB=0
|
||||
REDIS_KEY_PREFIX=
|
||||
|
||||
# PostgreSQL database configuration
|
||||
DB_USERNAME=postgres
|
||||
|
||||
@ -236,6 +236,41 @@ def test_pubsub_redis_url_required_when_default_unavailable(monkeypatch: pytest.
|
||||
_ = DifyConfig().normalized_pubsub_redis_url
|
||||
|
||||
|
||||
def test_dify_config_exposes_redis_key_prefix_default(monkeypatch: pytest.MonkeyPatch):
|
||||
os.environ.clear()
|
||||
|
||||
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
|
||||
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
|
||||
monkeypatch.setenv("DB_TYPE", "postgresql")
|
||||
monkeypatch.setenv("DB_USERNAME", "postgres")
|
||||
monkeypatch.setenv("DB_PASSWORD", "postgres")
|
||||
monkeypatch.setenv("DB_HOST", "localhost")
|
||||
monkeypatch.setenv("DB_PORT", "5432")
|
||||
monkeypatch.setenv("DB_DATABASE", "dify")
|
||||
|
||||
config = DifyConfig(_env_file=None)
|
||||
|
||||
assert config.REDIS_KEY_PREFIX == ""
|
||||
|
||||
|
||||
def test_dify_config_reads_redis_key_prefix_from_env(monkeypatch: pytest.MonkeyPatch):
|
||||
os.environ.clear()
|
||||
|
||||
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
|
||||
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
|
||||
monkeypatch.setenv("DB_TYPE", "postgresql")
|
||||
monkeypatch.setenv("DB_USERNAME", "postgres")
|
||||
monkeypatch.setenv("DB_PASSWORD", "postgres")
|
||||
monkeypatch.setenv("DB_HOST", "localhost")
|
||||
monkeypatch.setenv("DB_PORT", "5432")
|
||||
monkeypatch.setenv("DB_DATABASE", "dify")
|
||||
monkeypatch.setenv("REDIS_KEY_PREFIX", "enterprise-a")
|
||||
|
||||
config = DifyConfig(_env_file=None)
|
||||
|
||||
assert config.REDIS_KEY_PREFIX == "enterprise-a"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("broker_url", "expected_host", "expected_port", "expected_username", "expected_password", "expected_db"),
|
||||
[
|
||||
|
||||
@ -7,6 +7,47 @@ from unittest.mock import MagicMock, patch
|
||||
class TestCelerySSLConfiguration:
|
||||
"""Test suite for Celery SSL configuration."""
|
||||
|
||||
def test_get_celery_broker_transport_options_includes_global_keyprefix_for_redis(self):
|
||||
mock_config = MagicMock()
|
||||
mock_config.CELERY_USE_SENTINEL = False
|
||||
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
|
||||
|
||||
with patch("extensions.ext_celery.dify_config", mock_config):
|
||||
from extensions.ext_celery import get_celery_broker_transport_options
|
||||
|
||||
result = get_celery_broker_transport_options()
|
||||
|
||||
assert result["global_keyprefix"] == "enterprise-a:"
|
||||
|
||||
def test_get_celery_broker_transport_options_omits_global_keyprefix_when_prefix_empty(self):
|
||||
mock_config = MagicMock()
|
||||
mock_config.CELERY_USE_SENTINEL = False
|
||||
mock_config.REDIS_KEY_PREFIX = " "
|
||||
|
||||
with patch("extensions.ext_celery.dify_config", mock_config):
|
||||
from extensions.ext_celery import get_celery_broker_transport_options
|
||||
|
||||
result = get_celery_broker_transport_options()
|
||||
|
||||
assert "global_keyprefix" not in result
|
||||
|
||||
def test_get_celery_broker_transport_options_keeps_sentinel_and_adds_global_keyprefix(self):
|
||||
mock_config = MagicMock()
|
||||
mock_config.CELERY_USE_SENTINEL = True
|
||||
mock_config.CELERY_SENTINEL_MASTER_NAME = "mymaster"
|
||||
mock_config.CELERY_SENTINEL_SOCKET_TIMEOUT = 0.1
|
||||
mock_config.CELERY_SENTINEL_PASSWORD = "secret"
|
||||
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
|
||||
|
||||
with patch("extensions.ext_celery.dify_config", mock_config):
|
||||
from extensions.ext_celery import get_celery_broker_transport_options
|
||||
|
||||
result = get_celery_broker_transport_options()
|
||||
|
||||
assert result["master_name"] == "mymaster"
|
||||
assert result["sentinel_kwargs"]["password"] == "secret"
|
||||
assert result["global_keyprefix"] == "enterprise-a:"
|
||||
|
||||
def test_get_celery_ssl_options_when_ssl_disabled(self):
|
||||
"""Test SSL options when BROKER_USE_SSL is False."""
|
||||
from configs import DifyConfig
|
||||
@ -151,3 +192,49 @@ class TestCelerySSLConfiguration:
|
||||
# Check that SSL is also applied to Redis backend
|
||||
assert "redis_backend_use_ssl" in celery_app.conf
|
||||
assert celery_app.conf["redis_backend_use_ssl"] is not None
|
||||
|
||||
def test_celery_init_applies_global_keyprefix_to_broker_and_backend_transport(self):
|
||||
mock_config = MagicMock()
|
||||
mock_config.BROKER_USE_SSL = False
|
||||
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
|
||||
mock_config.HUMAN_INPUT_TIMEOUT_TASK_INTERVAL = 1
|
||||
mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0"
|
||||
mock_config.CELERY_BACKEND = "redis"
|
||||
mock_config.CELERY_RESULT_BACKEND = "redis://localhost:6379/0"
|
||||
mock_config.CELERY_USE_SENTINEL = False
|
||||
mock_config.LOG_FORMAT = "%(message)s"
|
||||
mock_config.LOG_TZ = "UTC"
|
||||
mock_config.LOG_FILE = None
|
||||
mock_config.CELERY_TASK_ANNOTATIONS = {}
|
||||
|
||||
mock_config.CELERY_BEAT_SCHEDULER_TIME = 1
|
||||
mock_config.ENABLE_CLEAN_EMBEDDING_CACHE_TASK = False
|
||||
mock_config.ENABLE_CLEAN_UNUSED_DATASETS_TASK = False
|
||||
mock_config.ENABLE_CREATE_TIDB_SERVERLESS_TASK = False
|
||||
mock_config.ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK = False
|
||||
mock_config.ENABLE_CLEAN_MESSAGES = False
|
||||
mock_config.ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK = False
|
||||
mock_config.ENABLE_DATASETS_QUEUE_MONITOR = False
|
||||
mock_config.ENABLE_HUMAN_INPUT_TIMEOUT_TASK = False
|
||||
mock_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK = False
|
||||
mock_config.MARKETPLACE_ENABLED = False
|
||||
mock_config.WORKFLOW_LOG_CLEANUP_ENABLED = False
|
||||
mock_config.ENABLE_WORKFLOW_RUN_CLEANUP_TASK = False
|
||||
mock_config.ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK = False
|
||||
mock_config.WORKFLOW_SCHEDULE_POLLER_INTERVAL = 1
|
||||
mock_config.ENABLE_TRIGGER_PROVIDER_REFRESH_TASK = False
|
||||
mock_config.TRIGGER_PROVIDER_REFRESH_INTERVAL = 15
|
||||
mock_config.ENABLE_API_TOKEN_LAST_USED_UPDATE_TASK = False
|
||||
mock_config.API_TOKEN_LAST_USED_UPDATE_INTERVAL = 30
|
||||
mock_config.ENTERPRISE_ENABLED = False
|
||||
mock_config.ENTERPRISE_TELEMETRY_ENABLED = False
|
||||
|
||||
with patch("extensions.ext_celery.dify_config", mock_config):
|
||||
from dify_app import DifyApp
|
||||
from extensions.ext_celery import init_app
|
||||
|
||||
app = DifyApp(__name__)
|
||||
celery_app = init_app(app)
|
||||
|
||||
assert celery_app.conf["broker_transport_options"]["global_keyprefix"] == "enterprise-a:"
|
||||
assert celery_app.conf["result_backend_transport_options"]["global_keyprefix"] == "enterprise-a:"
|
||||
|
||||
@ -6,6 +6,7 @@ from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastCh
|
||||
|
||||
def test_get_pubsub_broadcast_channel_defaults_to_pubsub(monkeypatch):
|
||||
monkeypatch.setattr(dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "pubsub")
|
||||
monkeypatch.setattr(ext_redis, "_pubsub_redis_client", object())
|
||||
|
||||
channel = ext_redis.get_pubsub_broadcast_channel()
|
||||
|
||||
@ -14,6 +15,7 @@ def test_get_pubsub_broadcast_channel_defaults_to_pubsub(monkeypatch):
|
||||
|
||||
def test_get_pubsub_broadcast_channel_sharded(monkeypatch):
|
||||
monkeypatch.setattr(dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "sharded")
|
||||
monkeypatch.setattr(ext_redis, "_pubsub_redis_client", object())
|
||||
|
||||
channel = ext_redis.get_pubsub_broadcast_channel()
|
||||
|
||||
|
||||
@ -1,12 +1,15 @@
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from redis import RedisError
|
||||
from redis.retry import Retry
|
||||
|
||||
from extensions.ext_redis import (
|
||||
RedisClientWrapper,
|
||||
_get_base_redis_params,
|
||||
_get_cluster_connection_health_params,
|
||||
_get_connection_health_params,
|
||||
_normalize_redis_key_prefix,
|
||||
_serialize_redis_name,
|
||||
redis_fallback,
|
||||
)
|
||||
|
||||
@ -123,3 +126,99 @@ class TestRedisFallback:
|
||||
|
||||
assert test_func.__name__ == "test_func"
|
||||
assert test_func.__doc__ == "Test function docstring"
|
||||
|
||||
|
||||
class TestRedisKeyPrefixHelpers:
|
||||
def test_normalize_redis_key_prefix_trims_whitespace(self):
|
||||
assert _normalize_redis_key_prefix(" enterprise-a ") == "enterprise-a"
|
||||
|
||||
def test_normalize_redis_key_prefix_treats_whitespace_only_as_empty(self):
|
||||
assert _normalize_redis_key_prefix(" ") == ""
|
||||
|
||||
def test_serialize_redis_name_returns_original_when_prefix_empty(self):
|
||||
assert _serialize_redis_name("model_lb_index:test", "") == "model_lb_index:test"
|
||||
|
||||
def test_serialize_redis_name_adds_single_colon_separator(self):
|
||||
assert _serialize_redis_name("model_lb_index:test", "enterprise-a") == "enterprise-a:model_lb_index:test"
|
||||
|
||||
|
||||
class TestRedisClientWrapperKeyPrefix:
|
||||
def test_wrapper_get_prefixes_string_keys(self):
|
||||
mock_client = MagicMock()
|
||||
wrapper = RedisClientWrapper()
|
||||
wrapper.initialize(mock_client)
|
||||
|
||||
with patch("extensions.ext_redis.dify_config") as mock_config:
|
||||
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
|
||||
|
||||
wrapper.get("oauth_state:abc")
|
||||
|
||||
mock_client.get.assert_called_once_with("enterprise-a:oauth_state:abc")
|
||||
|
||||
def test_wrapper_delete_prefixes_multiple_keys(self):
|
||||
mock_client = MagicMock()
|
||||
wrapper = RedisClientWrapper()
|
||||
wrapper.initialize(mock_client)
|
||||
|
||||
with patch("extensions.ext_redis.dify_config") as mock_config:
|
||||
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
|
||||
|
||||
wrapper.delete("key:a", "key:b")
|
||||
|
||||
mock_client.delete.assert_called_once_with("enterprise-a:key:a", "enterprise-a:key:b")
|
||||
|
||||
def test_wrapper_lock_prefixes_lock_name(self):
|
||||
mock_client = MagicMock()
|
||||
wrapper = RedisClientWrapper()
|
||||
wrapper.initialize(mock_client)
|
||||
|
||||
with patch("extensions.ext_redis.dify_config") as mock_config:
|
||||
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
|
||||
|
||||
wrapper.lock("resource-lock", timeout=10)
|
||||
|
||||
mock_client.lock.assert_called_once()
|
||||
args, kwargs = mock_client.lock.call_args
|
||||
assert args == ("enterprise-a:resource-lock",)
|
||||
assert kwargs["timeout"] == 10
|
||||
|
||||
def test_wrapper_hash_operations_prefix_key_name(self):
|
||||
mock_client = MagicMock()
|
||||
wrapper = RedisClientWrapper()
|
||||
wrapper.initialize(mock_client)
|
||||
|
||||
with patch("extensions.ext_redis.dify_config") as mock_config:
|
||||
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
|
||||
|
||||
wrapper.hset("hash:key", "field", "value")
|
||||
wrapper.hgetall("hash:key")
|
||||
|
||||
mock_client.hset.assert_called_once_with("enterprise-a:hash:key", "field", "value")
|
||||
mock_client.hgetall.assert_called_once_with("enterprise-a:hash:key")
|
||||
|
||||
def test_wrapper_zadd_prefixes_sorted_set_name(self):
|
||||
mock_client = MagicMock()
|
||||
wrapper = RedisClientWrapper()
|
||||
wrapper.initialize(mock_client)
|
||||
|
||||
with patch("extensions.ext_redis.dify_config") as mock_config:
|
||||
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
|
||||
|
||||
wrapper.zadd("zset:key", {"member": 1})
|
||||
|
||||
mock_client.zadd.assert_called_once()
|
||||
args, kwargs = mock_client.zadd.call_args
|
||||
assert args == ("enterprise-a:zset:key", {"member": 1})
|
||||
assert kwargs["nx"] is False
|
||||
|
||||
def test_wrapper_preserves_keys_when_prefix_is_empty(self):
|
||||
mock_client = MagicMock()
|
||||
wrapper = RedisClientWrapper()
|
||||
wrapper.initialize(mock_client)
|
||||
|
||||
with patch("extensions.ext_redis.dify_config") as mock_config:
|
||||
mock_config.REDIS_KEY_PREFIX = " "
|
||||
|
||||
wrapper.get("plain:key")
|
||||
|
||||
mock_client.get.assert_called_once_with("plain:key")
|
||||
|
||||
@ -139,6 +139,28 @@ class TestTopic:
|
||||
|
||||
mock_redis_client.publish.assert_called_once_with("test-topic", payload)
|
||||
|
||||
def test_publish_prefixes_regular_topic(self, mock_redis_client: MagicMock):
|
||||
with patch("extensions.redis_names.dify_config") as mock_config:
|
||||
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
|
||||
topic = Topic(mock_redis_client, "test-topic")
|
||||
|
||||
topic.publish(b"test message")
|
||||
|
||||
mock_redis_client.publish.assert_called_once_with("enterprise-a:test-topic", b"test message")
|
||||
|
||||
def test_subscribe_prefixes_regular_topic(self, mock_redis_client: MagicMock):
|
||||
with patch("extensions.redis_names.dify_config") as mock_config:
|
||||
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
|
||||
topic = Topic(mock_redis_client, "test-topic")
|
||||
|
||||
subscription = topic.subscribe()
|
||||
try:
|
||||
subscription._start_if_needed()
|
||||
finally:
|
||||
subscription.close()
|
||||
|
||||
mock_redis_client.pubsub.return_value.subscribe.assert_called_once_with("enterprise-a:test-topic")
|
||||
|
||||
|
||||
class TestShardedTopic:
|
||||
"""Test cases for the ShardedTopic class."""
|
||||
@ -176,6 +198,15 @@ class TestShardedTopic:
|
||||
|
||||
mock_redis_client.spublish.assert_called_once_with("test-sharded-topic", payload)
|
||||
|
||||
def test_publish_prefixes_sharded_topic(self, mock_redis_client: MagicMock):
|
||||
with patch("extensions.redis_names.dify_config") as mock_config:
|
||||
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
|
||||
sharded_topic = ShardedTopic(mock_redis_client, "test-sharded-topic")
|
||||
|
||||
sharded_topic.publish(b"test sharded message")
|
||||
|
||||
mock_redis_client.spublish.assert_called_once_with("enterprise-a:test-sharded-topic", b"test sharded message")
|
||||
|
||||
def test_subscribe_returns_sharded_subscription(self, sharded_topic: ShardedTopic, mock_redis_client: MagicMock):
|
||||
"""Test that subscribe() returns a _RedisShardedSubscription instance."""
|
||||
subscription = sharded_topic.subscribe()
|
||||
@ -185,6 +216,19 @@ class TestShardedTopic:
|
||||
assert subscription._pubsub is mock_redis_client.pubsub.return_value
|
||||
assert subscription._topic == "test-sharded-topic"
|
||||
|
||||
def test_subscribe_prefixes_sharded_topic(self, mock_redis_client: MagicMock):
|
||||
with patch("extensions.redis_names.dify_config") as mock_config:
|
||||
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
|
||||
sharded_topic = ShardedTopic(mock_redis_client, "test-sharded-topic")
|
||||
|
||||
subscription = sharded_topic.subscribe()
|
||||
try:
|
||||
subscription._start_if_needed()
|
||||
finally:
|
||||
subscription.close()
|
||||
|
||||
mock_redis_client.pubsub.return_value.ssubscribe.assert_called_once_with("enterprise-a:test-sharded-topic")
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class SubscriptionTestCase:
|
||||
|
||||
@ -2,6 +2,7 @@ import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import cast
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
@ -150,6 +151,25 @@ class TestStreamsBroadcastChannel:
|
||||
# Expire called after publish
|
||||
assert fake_redis._expire_calls.get("stream:beta", 0) >= 1
|
||||
|
||||
def test_topic_uses_prefixed_stream_key(self, fake_redis: FakeStreamsRedis):
|
||||
with patch("extensions.redis_names.dify_config") as mock_config:
|
||||
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
|
||||
|
||||
topic = StreamsBroadcastChannel(fake_redis, retention_seconds=60).topic("alpha")
|
||||
|
||||
assert topic._topic == "alpha"
|
||||
assert topic._key == "enterprise-a:stream:alpha"
|
||||
|
||||
def test_publish_uses_prefixed_stream_key(self, fake_redis: FakeStreamsRedis):
|
||||
with patch("extensions.redis_names.dify_config") as mock_config:
|
||||
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
|
||||
topic = StreamsBroadcastChannel(fake_redis, retention_seconds=60).topic("beta")
|
||||
|
||||
topic.publish(b"hello")
|
||||
|
||||
assert fake_redis._store["enterprise-a:stream:beta"][0][1] == {b"data": b"hello"}
|
||||
assert fake_redis._expire_calls.get("enterprise-a:stream:beta", 0) >= 1
|
||||
|
||||
def test_topic_exposes_self_as_producer_and_subscriber(self, streams_channel: StreamsBroadcastChannel):
|
||||
topic = streams_channel.topic("producer-subscriber")
|
||||
|
||||
|
||||
@ -351,6 +351,9 @@ REDIS_SSL_CERTFILE=
|
||||
REDIS_SSL_KEYFILE=
|
||||
# Path to client private key file for SSL authentication
|
||||
REDIS_DB=0
|
||||
# Optional global prefix for Redis keys, topics, streams, and Celery Redis transport artifacts.
|
||||
# Leave empty to preserve current unprefixed behavior.
|
||||
REDIS_KEY_PREFIX=
|
||||
# Optional: limit total Redis connections used by API/Worker (unset for default)
|
||||
# Align with API's REDIS_MAX_CONNECTIONS in configs
|
||||
REDIS_MAX_CONNECTIONS=
|
||||
|
||||
@ -88,6 +88,7 @@ The `.env.example` file provided in the Docker setup is extensive and covers a w
|
||||
1. **Redis Configuration**:
|
||||
|
||||
- `REDIS_HOST`, `REDIS_PORT`, `REDIS_PASSWORD`: Redis server connection settings.
|
||||
- `REDIS_KEY_PREFIX`: Optional global namespace prefix for Redis keys, topics, streams, and Celery Redis transport artifacts.
|
||||
|
||||
1. **Celery Configuration**:
|
||||
|
||||
|
||||
@ -90,6 +90,7 @@ x-shared-env: &shared-api-worker-env
|
||||
REDIS_SSL_CERTFILE: ${REDIS_SSL_CERTFILE:-}
|
||||
REDIS_SSL_KEYFILE: ${REDIS_SSL_KEYFILE:-}
|
||||
REDIS_DB: ${REDIS_DB:-0}
|
||||
REDIS_KEY_PREFIX: ${REDIS_KEY_PREFIX:-}
|
||||
REDIS_MAX_CONNECTIONS: ${REDIS_MAX_CONNECTIONS:-}
|
||||
REDIS_USE_SENTINEL: ${REDIS_USE_SENTINEL:-false}
|
||||
REDIS_SENTINELS: ${REDIS_SENTINELS:-}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user