From 27e484e7f83803d68a3ad47f69ab2925dd0b095b Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Thu, 9 Apr 2026 11:08:25 +0800 Subject: [PATCH] feat: redis add retry logic (#34566) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/.env.example | 7 + api/configs/middleware/cache/redis_config.py | 31 ++++ api/extensions/ext_redis.py | 67 +++++++-- api/tests/unit_tests/extensions/test_redis.py | 138 +++++++++++++----- docker/.env.example | 14 ++ docker/docker-compose.yaml | 6 + 6 files changed, 217 insertions(+), 46 deletions(-) diff --git a/api/.env.example b/api/.env.example index c6541731e6..2c1a755059 100644 --- a/api/.env.example +++ b/api/.env.example @@ -71,6 +71,13 @@ REDIS_USE_CLUSTERS=false REDIS_CLUSTERS= REDIS_CLUSTERS_PASSWORD= +REDIS_RETRY_RETRIES=3 +REDIS_RETRY_BACKOFF_BASE=1.0 +REDIS_RETRY_BACKOFF_CAP=10.0 +REDIS_SOCKET_TIMEOUT=5.0 +REDIS_SOCKET_CONNECT_TIMEOUT=5.0 +REDIS_HEALTH_CHECK_INTERVAL=30 + # celery configuration CELERY_BROKER_URL=redis://:difyai123456@localhost:${REDIS_PORT}/1 CELERY_BACKEND=redis diff --git a/api/configs/middleware/cache/redis_config.py b/api/configs/middleware/cache/redis_config.py index 3b91207545..b49275758a 100644 --- a/api/configs/middleware/cache/redis_config.py +++ b/api/configs/middleware/cache/redis_config.py @@ -117,6 +117,37 @@ class RedisConfig(BaseSettings): default=None, ) + REDIS_RETRY_RETRIES: NonNegativeInt = Field( + description="Maximum number of retries per Redis command on " + "transient failures (ConnectionError, TimeoutError, socket.timeout)", + default=3, + ) + + REDIS_RETRY_BACKOFF_BASE: PositiveFloat = Field( + description="Base delay in seconds for exponential backoff between retries", + default=1.0, + ) + + REDIS_RETRY_BACKOFF_CAP: PositiveFloat = Field( + description="Maximum backoff delay in seconds between retries", + default=10.0, + ) + + REDIS_SOCKET_TIMEOUT: PositiveFloat | None = Field( + description="Socket timeout in seconds for Redis read/write operations", + default=5.0, + ) + + REDIS_SOCKET_CONNECT_TIMEOUT: PositiveFloat | None = Field( + description="Socket timeout in seconds for Redis connection establishment", + default=5.0, + ) + + REDIS_HEALTH_CHECK_INTERVAL: NonNegativeInt = Field( + description="Interval in seconds between Redis connection health checks (0 to disable)", + default=30, + ) + @field_validator("REDIS_MAX_CONNECTIONS", mode="before") @classmethod def _empty_string_to_none_for_max_conns(cls, v): diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index 5f528dbf9e..b9e592cadb 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -7,10 +7,12 @@ from typing import TYPE_CHECKING, Any, Union import redis from redis import RedisError +from redis.backoff import ExponentialWithJitterBackoff # type: ignore from redis.cache import CacheConfig from redis.client import PubSub from redis.cluster import ClusterNode, RedisCluster from redis.connection import Connection, SSLConnection +from redis.retry import Retry from redis.sentinel import Sentinel from configs import dify_config @@ -158,8 +160,41 @@ def _get_cache_configuration() -> CacheConfig | None: return CacheConfig() +def _get_retry_policy() -> Retry: + """Build the shared retry policy for Redis connections.""" + return Retry( + backoff=ExponentialWithJitterBackoff( + base=dify_config.REDIS_RETRY_BACKOFF_BASE, + cap=dify_config.REDIS_RETRY_BACKOFF_CAP, + ), + retries=dify_config.REDIS_RETRY_RETRIES, + ) + + +def _get_connection_health_params() -> dict[str, Any]: + """Get connection health and retry parameters for standalone and Sentinel Redis clients.""" + return { + "retry": _get_retry_policy(), + "socket_timeout": dify_config.REDIS_SOCKET_TIMEOUT, + "socket_connect_timeout": dify_config.REDIS_SOCKET_CONNECT_TIMEOUT, + "health_check_interval": dify_config.REDIS_HEALTH_CHECK_INTERVAL, + } + + +def _get_cluster_connection_health_params() -> dict[str, Any]: + """Get retry and timeout parameters for Redis Cluster clients. + + RedisCluster does not support ``health_check_interval`` as a constructor + keyword (it is silently stripped by ``cleanup_kwargs``), so it is excluded + here. Only ``retry``, ``socket_timeout``, and ``socket_connect_timeout`` + are passed through. + """ + params = _get_connection_health_params() + return {k: v for k, v in params.items() if k != "health_check_interval"} + + def _get_base_redis_params() -> dict[str, Any]: - """Get base Redis connection parameters.""" + """Get base Redis connection parameters including retry and health policy.""" return { "username": dify_config.REDIS_USERNAME, "password": dify_config.REDIS_PASSWORD or None, @@ -169,6 +204,7 @@ def _get_base_redis_params() -> dict[str, Any]: "decode_responses": False, "protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL, "cache_config": _get_cache_configuration(), + **_get_connection_health_params(), } @@ -215,6 +251,7 @@ def _create_cluster_client() -> Union[redis.Redis, RedisCluster]: "password": dify_config.REDIS_CLUSTERS_PASSWORD, "protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL, "cache_config": _get_cache_configuration(), + **_get_cluster_connection_health_params(), } if dify_config.REDIS_MAX_CONNECTIONS: cluster_kwargs["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS @@ -226,7 +263,8 @@ def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis """Create standalone Redis client.""" connection_class, ssl_kwargs = _get_ssl_configuration() - redis_params.update( + params = {**redis_params} + params.update( { "host": dify_config.REDIS_HOST, "port": dify_config.REDIS_PORT, @@ -235,28 +273,31 @@ def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis ) if dify_config.REDIS_MAX_CONNECTIONS: - redis_params["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS + params["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS if ssl_kwargs: - redis_params.update(ssl_kwargs) + params.update(ssl_kwargs) - pool = redis.ConnectionPool(**redis_params) + pool = redis.ConnectionPool(**params) client: redis.Redis = redis.Redis(connection_pool=pool) return client def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> redis.Redis | RedisCluster: max_conns = dify_config.REDIS_MAX_CONNECTIONS - if use_clusters: - if max_conns: - return RedisCluster.from_url(pubsub_url, max_connections=max_conns) - else: - return RedisCluster.from_url(pubsub_url) + if use_clusters: + health_params = _get_cluster_connection_health_params() + kwargs: dict[str, Any] = {**health_params} + if max_conns: + kwargs["max_connections"] = max_conns + return RedisCluster.from_url(pubsub_url, **kwargs) + + health_params = _get_connection_health_params() + kwargs = {**health_params} if max_conns: - return redis.Redis.from_url(pubsub_url, max_connections=max_conns) - else: - return redis.Redis.from_url(pubsub_url) + kwargs["max_connections"] = max_conns + return redis.Redis.from_url(pubsub_url, **kwargs) def init_app(app: DifyApp): diff --git a/api/tests/unit_tests/extensions/test_redis.py b/api/tests/unit_tests/extensions/test_redis.py index 933fa32894..5e9be4ab9b 100644 --- a/api/tests/unit_tests/extensions/test_redis.py +++ b/api/tests/unit_tests/extensions/test_redis.py @@ -1,53 +1,125 @@ +from unittest.mock import patch + from redis import RedisError +from redis.retry import Retry -from extensions.ext_redis import redis_fallback +from extensions.ext_redis import ( + _get_base_redis_params, + _get_cluster_connection_health_params, + _get_connection_health_params, + redis_fallback, +) -def test_redis_fallback_success(): - @redis_fallback(default_return=None) - def test_func(): - return "success" +class TestGetConnectionHealthParams: + @patch("extensions.ext_redis.dify_config") + def test_includes_all_health_params(self, mock_config): + mock_config.REDIS_RETRY_RETRIES = 3 + mock_config.REDIS_RETRY_BACKOFF_BASE = 1.0 + mock_config.REDIS_RETRY_BACKOFF_CAP = 10.0 + mock_config.REDIS_SOCKET_TIMEOUT = 5.0 + mock_config.REDIS_SOCKET_CONNECT_TIMEOUT = 5.0 + mock_config.REDIS_HEALTH_CHECK_INTERVAL = 30 - assert test_func() == "success" + params = _get_connection_health_params() + + assert "retry" in params + assert "socket_timeout" in params + assert "socket_connect_timeout" in params + assert "health_check_interval" in params + assert isinstance(params["retry"], Retry) + assert params["retry"]._retries == 3 + assert params["socket_timeout"] == 5.0 + assert params["socket_connect_timeout"] == 5.0 + assert params["health_check_interval"] == 30 -def test_redis_fallback_error(): - @redis_fallback(default_return="fallback") - def test_func(): - raise RedisError("Redis error") +class TestGetClusterConnectionHealthParams: + @patch("extensions.ext_redis.dify_config") + def test_excludes_health_check_interval(self, mock_config): + mock_config.REDIS_RETRY_RETRIES = 3 + mock_config.REDIS_RETRY_BACKOFF_BASE = 1.0 + mock_config.REDIS_RETRY_BACKOFF_CAP = 10.0 + mock_config.REDIS_SOCKET_TIMEOUT = 5.0 + mock_config.REDIS_SOCKET_CONNECT_TIMEOUT = 5.0 + mock_config.REDIS_HEALTH_CHECK_INTERVAL = 30 - assert test_func() == "fallback" + params = _get_cluster_connection_health_params() + + assert "retry" in params + assert "socket_timeout" in params + assert "socket_connect_timeout" in params + assert "health_check_interval" not in params -def test_redis_fallback_none_default(): - @redis_fallback() - def test_func(): - raise RedisError("Redis error") +class TestGetBaseRedisParams: + @patch("extensions.ext_redis.dify_config") + def test_includes_retry_and_health_params(self, mock_config): + mock_config.REDIS_USERNAME = None + mock_config.REDIS_PASSWORD = None + mock_config.REDIS_DB = 0 + mock_config.REDIS_SERIALIZATION_PROTOCOL = 3 + mock_config.REDIS_ENABLE_CLIENT_SIDE_CACHE = False + mock_config.REDIS_RETRY_RETRIES = 3 + mock_config.REDIS_RETRY_BACKOFF_BASE = 1.0 + mock_config.REDIS_RETRY_BACKOFF_CAP = 10.0 + mock_config.REDIS_SOCKET_TIMEOUT = 5.0 + mock_config.REDIS_SOCKET_CONNECT_TIMEOUT = 5.0 + mock_config.REDIS_HEALTH_CHECK_INTERVAL = 30 - assert test_func() is None + params = _get_base_redis_params() + + assert "retry" in params + assert isinstance(params["retry"], Retry) + assert params["socket_timeout"] == 5.0 + assert params["socket_connect_timeout"] == 5.0 + assert params["health_check_interval"] == 30 + # Existing params still present + assert params["db"] == 0 + assert params["encoding"] == "utf-8" -def test_redis_fallback_with_args(): - @redis_fallback(default_return=0) - def test_func(x, y): - raise RedisError("Redis error") +class TestRedisFallback: + def test_redis_fallback_success(self): + @redis_fallback(default_return=None) + def test_func(): + return "success" - assert test_func(1, 2) == 0 + assert test_func() == "success" + def test_redis_fallback_error(self): + @redis_fallback(default_return="fallback") + def test_func(): + raise RedisError("Redis error") -def test_redis_fallback_with_kwargs(): - @redis_fallback(default_return={}) - def test_func(x=None, y=None): - raise RedisError("Redis error") + assert test_func() == "fallback" - assert test_func(x=1, y=2) == {} + def test_redis_fallback_none_default(self): + @redis_fallback() + def test_func(): + raise RedisError("Redis error") + assert test_func() is None -def test_redis_fallback_preserves_function_metadata(): - @redis_fallback(default_return=None) - def test_func(): - """Test function docstring""" - pass + def test_redis_fallback_with_args(self): + @redis_fallback(default_return=0) + def test_func(x, y): + raise RedisError("Redis error") - assert test_func.__name__ == "test_func" - assert test_func.__doc__ == "Test function docstring" + assert test_func(1, 2) == 0 + + def test_redis_fallback_with_kwargs(self): + @redis_fallback(default_return={}) + def test_func(x=None, y=None): + raise RedisError("Redis error") + + assert test_func(x=1, y=2) == {} + + def test_redis_fallback_preserves_function_metadata(self): + @redis_fallback(default_return=None) + def test_func(): + """Test function docstring""" + pass + + assert test_func.__name__ == "test_func" + assert test_func.__doc__ == "Test function docstring" diff --git a/docker/.env.example b/docker/.env.example index c046f6d378..f6da6c568d 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -373,6 +373,20 @@ REDIS_USE_CLUSTERS=false REDIS_CLUSTERS= REDIS_CLUSTERS_PASSWORD= +# Redis connection and retry configuration +# max redis retry +REDIS_RETRY_RETRIES=3 +# Base delay (in seconds) for exponential backoff on retries +REDIS_RETRY_BACKOFF_BASE=1.0 +# Cap (in seconds) for exponential backoff on retries +REDIS_RETRY_BACKOFF_CAP=10.0 +# Timeout (in seconds) for Redis socket operations +REDIS_SOCKET_TIMEOUT=5.0 +# Timeout (in seconds) for establishing a Redis connection +REDIS_SOCKET_CONNECT_TIMEOUT=5.0 +# Interval (in seconds) for Redis health checks +REDIS_HEALTH_CHECK_INTERVAL=30 + # ------------------------------ # Celery Configuration # ------------------------------ diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 3f6a13e78e..dbadc58f89 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -100,6 +100,12 @@ x-shared-env: &shared-api-worker-env REDIS_USE_CLUSTERS: ${REDIS_USE_CLUSTERS:-false} REDIS_CLUSTERS: ${REDIS_CLUSTERS:-} REDIS_CLUSTERS_PASSWORD: ${REDIS_CLUSTERS_PASSWORD:-} + REDIS_RETRY_RETRIES: ${REDIS_RETRY_RETRIES:-3} + REDIS_RETRY_BACKOFF_BASE: ${REDIS_RETRY_BACKOFF_BASE:-1.0} + REDIS_RETRY_BACKOFF_CAP: ${REDIS_RETRY_BACKOFF_CAP:-10.0} + REDIS_SOCKET_TIMEOUT: ${REDIS_SOCKET_TIMEOUT:-5.0} + REDIS_SOCKET_CONNECT_TIMEOUT: ${REDIS_SOCKET_CONNECT_TIMEOUT:-5.0} + REDIS_HEALTH_CHECK_INTERVAL: ${REDIS_HEALTH_CHECK_INTERVAL:-30} CELERY_BROKER_URL: ${CELERY_BROKER_URL:-redis://:difyai123456@redis:6379/1} CELERY_BACKEND: ${CELERY_BACKEND:-redis} BROKER_USE_SSL: ${BROKER_USE_SSL:-false}