dify/api/tests/unit_tests/extensions/test_redis.py

225 lines
7.8 KiB
Python

from unittest.mock import MagicMock, patch
from redis import RedisError
from redis.retry import Retry
from extensions.ext_redis import (
RedisClientWrapper,
_get_base_redis_params,
_get_cluster_connection_health_params,
_get_connection_health_params,
_normalize_redis_key_prefix,
_serialize_redis_name,
redis_fallback,
)
class TestGetConnectionHealthParams:
@patch("extensions.ext_redis.dify_config")
def test_includes_all_health_params(self, mock_config):
mock_config.REDIS_RETRY_RETRIES = 3
mock_config.REDIS_RETRY_BACKOFF_BASE = 1.0
mock_config.REDIS_RETRY_BACKOFF_CAP = 10.0
mock_config.REDIS_SOCKET_TIMEOUT = 5.0
mock_config.REDIS_SOCKET_CONNECT_TIMEOUT = 5.0
mock_config.REDIS_HEALTH_CHECK_INTERVAL = 30
params = _get_connection_health_params()
assert "retry" in params
assert "socket_timeout" in params
assert "socket_connect_timeout" in params
assert "health_check_interval" in params
assert isinstance(params["retry"], Retry)
assert params["retry"]._retries == 3
assert params["socket_timeout"] == 5.0
assert params["socket_connect_timeout"] == 5.0
assert params["health_check_interval"] == 30
class TestGetClusterConnectionHealthParams:
@patch("extensions.ext_redis.dify_config")
def test_excludes_health_check_interval(self, mock_config):
mock_config.REDIS_RETRY_RETRIES = 3
mock_config.REDIS_RETRY_BACKOFF_BASE = 1.0
mock_config.REDIS_RETRY_BACKOFF_CAP = 10.0
mock_config.REDIS_SOCKET_TIMEOUT = 5.0
mock_config.REDIS_SOCKET_CONNECT_TIMEOUT = 5.0
mock_config.REDIS_HEALTH_CHECK_INTERVAL = 30
params = _get_cluster_connection_health_params()
assert "retry" in params
assert "socket_timeout" in params
assert "socket_connect_timeout" in params
assert "health_check_interval" not in params
class TestGetBaseRedisParams:
@patch("extensions.ext_redis.dify_config")
def test_includes_retry_and_health_params(self, mock_config):
mock_config.REDIS_USERNAME = None
mock_config.REDIS_PASSWORD = None
mock_config.REDIS_DB = 0
mock_config.REDIS_SERIALIZATION_PROTOCOL = 3
mock_config.REDIS_ENABLE_CLIENT_SIDE_CACHE = False
mock_config.REDIS_RETRY_RETRIES = 3
mock_config.REDIS_RETRY_BACKOFF_BASE = 1.0
mock_config.REDIS_RETRY_BACKOFF_CAP = 10.0
mock_config.REDIS_SOCKET_TIMEOUT = 5.0
mock_config.REDIS_SOCKET_CONNECT_TIMEOUT = 5.0
mock_config.REDIS_HEALTH_CHECK_INTERVAL = 30
params = _get_base_redis_params()
assert "retry" in params
assert isinstance(params["retry"], Retry)
assert params["socket_timeout"] == 5.0
assert params["socket_connect_timeout"] == 5.0
assert params["health_check_interval"] == 30
# Existing params still present
assert params["db"] == 0
assert params["encoding"] == "utf-8"
class TestRedisFallback:
def test_redis_fallback_success(self):
@redis_fallback(default_return=None)
def test_func():
return "success"
assert test_func() == "success"
def test_redis_fallback_error(self):
@redis_fallback(default_return="fallback")
def test_func():
raise RedisError("Redis error")
assert test_func() == "fallback"
def test_redis_fallback_none_default(self):
@redis_fallback()
def test_func():
raise RedisError("Redis error")
assert test_func() is None
def test_redis_fallback_with_args(self):
@redis_fallback(default_return=0)
def test_func(x, y):
raise RedisError("Redis error")
assert test_func(1, 2) == 0
def test_redis_fallback_with_kwargs(self):
@redis_fallback(default_return={})
def test_func(x=None, y=None):
raise RedisError("Redis error")
assert test_func(x=1, y=2) == {}
def test_redis_fallback_preserves_function_metadata(self):
@redis_fallback(default_return=None)
def test_func():
"""Test function docstring"""
pass
assert test_func.__name__ == "test_func"
assert test_func.__doc__ == "Test function docstring"
class TestRedisKeyPrefixHelpers:
def test_normalize_redis_key_prefix_trims_whitespace(self):
assert _normalize_redis_key_prefix(" enterprise-a ") == "enterprise-a"
def test_normalize_redis_key_prefix_treats_whitespace_only_as_empty(self):
assert _normalize_redis_key_prefix(" ") == ""
def test_serialize_redis_name_returns_original_when_prefix_empty(self):
assert _serialize_redis_name("model_lb_index:test", "") == "model_lb_index:test"
def test_serialize_redis_name_adds_single_colon_separator(self):
assert _serialize_redis_name("model_lb_index:test", "enterprise-a") == "enterprise-a:model_lb_index:test"
class TestRedisClientWrapperKeyPrefix:
def test_wrapper_get_prefixes_string_keys(self):
mock_client = MagicMock()
wrapper = RedisClientWrapper()
wrapper.initialize(mock_client)
with patch("extensions.ext_redis.dify_config") as mock_config:
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
wrapper.get("oauth_state:abc")
mock_client.get.assert_called_once_with("enterprise-a:oauth_state:abc")
def test_wrapper_delete_prefixes_multiple_keys(self):
mock_client = MagicMock()
wrapper = RedisClientWrapper()
wrapper.initialize(mock_client)
with patch("extensions.ext_redis.dify_config") as mock_config:
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
wrapper.delete("key:a", "key:b")
mock_client.delete.assert_called_once_with("enterprise-a:key:a", "enterprise-a:key:b")
def test_wrapper_lock_prefixes_lock_name(self):
mock_client = MagicMock()
wrapper = RedisClientWrapper()
wrapper.initialize(mock_client)
with patch("extensions.ext_redis.dify_config") as mock_config:
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
wrapper.lock("resource-lock", timeout=10)
mock_client.lock.assert_called_once()
args, kwargs = mock_client.lock.call_args
assert args == ("enterprise-a:resource-lock",)
assert kwargs["timeout"] == 10
def test_wrapper_hash_operations_prefix_key_name(self):
mock_client = MagicMock()
wrapper = RedisClientWrapper()
wrapper.initialize(mock_client)
with patch("extensions.ext_redis.dify_config") as mock_config:
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
wrapper.hset("hash:key", "field", "value")
wrapper.hgetall("hash:key")
mock_client.hset.assert_called_once_with("enterprise-a:hash:key", "field", "value")
mock_client.hgetall.assert_called_once_with("enterprise-a:hash:key")
def test_wrapper_zadd_prefixes_sorted_set_name(self):
mock_client = MagicMock()
wrapper = RedisClientWrapper()
wrapper.initialize(mock_client)
with patch("extensions.ext_redis.dify_config") as mock_config:
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
wrapper.zadd("zset:key", {"member": 1})
mock_client.zadd.assert_called_once()
args, kwargs = mock_client.zadd.call_args
assert args == ("enterprise-a:zset:key", {"member": 1})
assert kwargs["nx"] is False
def test_wrapper_preserves_keys_when_prefix_is_empty(self):
mock_client = MagicMock()
wrapper = RedisClientWrapper()
wrapper.initialize(mock_client)
with patch("extensions.ext_redis.dify_config") as mock_config:
mock_config.REDIS_KEY_PREFIX = " "
wrapper.get("plain:key")
mock_client.get.assert_called_once_with("plain:key")