feat: redis add retry logic (#34566)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
wangxiaolei 2026-04-09 11:08:25 +08:00 committed by GitHub
parent 9308287fea
commit 27e484e7f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 217 additions and 46 deletions

View File

@ -71,6 +71,13 @@ REDIS_USE_CLUSTERS=false
REDIS_CLUSTERS=
REDIS_CLUSTERS_PASSWORD=
REDIS_RETRY_RETRIES=3
REDIS_RETRY_BACKOFF_BASE=1.0
REDIS_RETRY_BACKOFF_CAP=10.0
REDIS_SOCKET_TIMEOUT=5.0
REDIS_SOCKET_CONNECT_TIMEOUT=5.0
REDIS_HEALTH_CHECK_INTERVAL=30
# celery configuration
CELERY_BROKER_URL=redis://:difyai123456@localhost:${REDIS_PORT}/1
CELERY_BACKEND=redis

View File

@ -117,6 +117,37 @@ class RedisConfig(BaseSettings):
default=None,
)
REDIS_RETRY_RETRIES: NonNegativeInt = Field(
description="Maximum number of retries per Redis command on "
"transient failures (ConnectionError, TimeoutError, socket.timeout)",
default=3,
)
REDIS_RETRY_BACKOFF_BASE: PositiveFloat = Field(
description="Base delay in seconds for exponential backoff between retries",
default=1.0,
)
REDIS_RETRY_BACKOFF_CAP: PositiveFloat = Field(
description="Maximum backoff delay in seconds between retries",
default=10.0,
)
REDIS_SOCKET_TIMEOUT: PositiveFloat | None = Field(
description="Socket timeout in seconds for Redis read/write operations",
default=5.0,
)
REDIS_SOCKET_CONNECT_TIMEOUT: PositiveFloat | None = Field(
description="Socket timeout in seconds for Redis connection establishment",
default=5.0,
)
REDIS_HEALTH_CHECK_INTERVAL: NonNegativeInt = Field(
description="Interval in seconds between Redis connection health checks (0 to disable)",
default=30,
)
@field_validator("REDIS_MAX_CONNECTIONS", mode="before")
@classmethod
def _empty_string_to_none_for_max_conns(cls, v):

View File

@ -7,10 +7,12 @@ from typing import TYPE_CHECKING, Any, Union
import redis
from redis import RedisError
from redis.backoff import ExponentialWithJitterBackoff # type: ignore
from redis.cache import CacheConfig
from redis.client import PubSub
from redis.cluster import ClusterNode, RedisCluster
from redis.connection import Connection, SSLConnection
from redis.retry import Retry
from redis.sentinel import Sentinel
from configs import dify_config
@ -158,8 +160,41 @@ def _get_cache_configuration() -> CacheConfig | None:
return CacheConfig()
def _get_retry_policy() -> Retry:
"""Build the shared retry policy for Redis connections."""
return Retry(
backoff=ExponentialWithJitterBackoff(
base=dify_config.REDIS_RETRY_BACKOFF_BASE,
cap=dify_config.REDIS_RETRY_BACKOFF_CAP,
),
retries=dify_config.REDIS_RETRY_RETRIES,
)
def _get_connection_health_params() -> dict[str, Any]:
"""Get connection health and retry parameters for standalone and Sentinel Redis clients."""
return {
"retry": _get_retry_policy(),
"socket_timeout": dify_config.REDIS_SOCKET_TIMEOUT,
"socket_connect_timeout": dify_config.REDIS_SOCKET_CONNECT_TIMEOUT,
"health_check_interval": dify_config.REDIS_HEALTH_CHECK_INTERVAL,
}
def _get_cluster_connection_health_params() -> dict[str, Any]:
"""Get retry and timeout parameters for Redis Cluster clients.
RedisCluster does not support ``health_check_interval`` as a constructor
keyword (it is silently stripped by ``cleanup_kwargs``), so it is excluded
here. Only ``retry``, ``socket_timeout``, and ``socket_connect_timeout``
are passed through.
"""
params = _get_connection_health_params()
return {k: v for k, v in params.items() if k != "health_check_interval"}
def _get_base_redis_params() -> dict[str, Any]:
"""Get base Redis connection parameters."""
"""Get base Redis connection parameters including retry and health policy."""
return {
"username": dify_config.REDIS_USERNAME,
"password": dify_config.REDIS_PASSWORD or None,
@ -169,6 +204,7 @@ def _get_base_redis_params() -> dict[str, Any]:
"decode_responses": False,
"protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL,
"cache_config": _get_cache_configuration(),
**_get_connection_health_params(),
}
@ -215,6 +251,7 @@ def _create_cluster_client() -> Union[redis.Redis, RedisCluster]:
"password": dify_config.REDIS_CLUSTERS_PASSWORD,
"protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL,
"cache_config": _get_cache_configuration(),
**_get_cluster_connection_health_params(),
}
if dify_config.REDIS_MAX_CONNECTIONS:
cluster_kwargs["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS
@ -226,7 +263,8 @@ def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis
"""Create standalone Redis client."""
connection_class, ssl_kwargs = _get_ssl_configuration()
redis_params.update(
params = {**redis_params}
params.update(
{
"host": dify_config.REDIS_HOST,
"port": dify_config.REDIS_PORT,
@ -235,28 +273,31 @@ def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis
)
if dify_config.REDIS_MAX_CONNECTIONS:
redis_params["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS
params["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS
if ssl_kwargs:
redis_params.update(ssl_kwargs)
params.update(ssl_kwargs)
pool = redis.ConnectionPool(**redis_params)
pool = redis.ConnectionPool(**params)
client: redis.Redis = redis.Redis(connection_pool=pool)
return client
def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> redis.Redis | RedisCluster:
max_conns = dify_config.REDIS_MAX_CONNECTIONS
if use_clusters:
if max_conns:
return RedisCluster.from_url(pubsub_url, max_connections=max_conns)
else:
return RedisCluster.from_url(pubsub_url)
if use_clusters:
health_params = _get_cluster_connection_health_params()
kwargs: dict[str, Any] = {**health_params}
if max_conns:
kwargs["max_connections"] = max_conns
return RedisCluster.from_url(pubsub_url, **kwargs)
health_params = _get_connection_health_params()
kwargs = {**health_params}
if max_conns:
return redis.Redis.from_url(pubsub_url, max_connections=max_conns)
else:
return redis.Redis.from_url(pubsub_url)
kwargs["max_connections"] = max_conns
return redis.Redis.from_url(pubsub_url, **kwargs)
def init_app(app: DifyApp):

View File

@ -1,53 +1,125 @@
from unittest.mock import patch
from redis import RedisError
from redis.retry import Retry
from extensions.ext_redis import redis_fallback
from extensions.ext_redis import (
_get_base_redis_params,
_get_cluster_connection_health_params,
_get_connection_health_params,
redis_fallback,
)
def test_redis_fallback_success():
@redis_fallback(default_return=None)
def test_func():
return "success"
class TestGetConnectionHealthParams:
@patch("extensions.ext_redis.dify_config")
def test_includes_all_health_params(self, mock_config):
mock_config.REDIS_RETRY_RETRIES = 3
mock_config.REDIS_RETRY_BACKOFF_BASE = 1.0
mock_config.REDIS_RETRY_BACKOFF_CAP = 10.0
mock_config.REDIS_SOCKET_TIMEOUT = 5.0
mock_config.REDIS_SOCKET_CONNECT_TIMEOUT = 5.0
mock_config.REDIS_HEALTH_CHECK_INTERVAL = 30
assert test_func() == "success"
params = _get_connection_health_params()
assert "retry" in params
assert "socket_timeout" in params
assert "socket_connect_timeout" in params
assert "health_check_interval" in params
assert isinstance(params["retry"], Retry)
assert params["retry"]._retries == 3
assert params["socket_timeout"] == 5.0
assert params["socket_connect_timeout"] == 5.0
assert params["health_check_interval"] == 30
def test_redis_fallback_error():
@redis_fallback(default_return="fallback")
def test_func():
raise RedisError("Redis error")
class TestGetClusterConnectionHealthParams:
@patch("extensions.ext_redis.dify_config")
def test_excludes_health_check_interval(self, mock_config):
mock_config.REDIS_RETRY_RETRIES = 3
mock_config.REDIS_RETRY_BACKOFF_BASE = 1.0
mock_config.REDIS_RETRY_BACKOFF_CAP = 10.0
mock_config.REDIS_SOCKET_TIMEOUT = 5.0
mock_config.REDIS_SOCKET_CONNECT_TIMEOUT = 5.0
mock_config.REDIS_HEALTH_CHECK_INTERVAL = 30
assert test_func() == "fallback"
params = _get_cluster_connection_health_params()
assert "retry" in params
assert "socket_timeout" in params
assert "socket_connect_timeout" in params
assert "health_check_interval" not in params
def test_redis_fallback_none_default():
@redis_fallback()
def test_func():
raise RedisError("Redis error")
class TestGetBaseRedisParams:
@patch("extensions.ext_redis.dify_config")
def test_includes_retry_and_health_params(self, mock_config):
mock_config.REDIS_USERNAME = None
mock_config.REDIS_PASSWORD = None
mock_config.REDIS_DB = 0
mock_config.REDIS_SERIALIZATION_PROTOCOL = 3
mock_config.REDIS_ENABLE_CLIENT_SIDE_CACHE = False
mock_config.REDIS_RETRY_RETRIES = 3
mock_config.REDIS_RETRY_BACKOFF_BASE = 1.0
mock_config.REDIS_RETRY_BACKOFF_CAP = 10.0
mock_config.REDIS_SOCKET_TIMEOUT = 5.0
mock_config.REDIS_SOCKET_CONNECT_TIMEOUT = 5.0
mock_config.REDIS_HEALTH_CHECK_INTERVAL = 30
assert test_func() is None
params = _get_base_redis_params()
assert "retry" in params
assert isinstance(params["retry"], Retry)
assert params["socket_timeout"] == 5.0
assert params["socket_connect_timeout"] == 5.0
assert params["health_check_interval"] == 30
# Existing params still present
assert params["db"] == 0
assert params["encoding"] == "utf-8"
def test_redis_fallback_with_args():
@redis_fallback(default_return=0)
def test_func(x, y):
raise RedisError("Redis error")
class TestRedisFallback:
def test_redis_fallback_success(self):
@redis_fallback(default_return=None)
def test_func():
return "success"
assert test_func(1, 2) == 0
assert test_func() == "success"
def test_redis_fallback_error(self):
@redis_fallback(default_return="fallback")
def test_func():
raise RedisError("Redis error")
def test_redis_fallback_with_kwargs():
@redis_fallback(default_return={})
def test_func(x=None, y=None):
raise RedisError("Redis error")
assert test_func() == "fallback"
assert test_func(x=1, y=2) == {}
def test_redis_fallback_none_default(self):
@redis_fallback()
def test_func():
raise RedisError("Redis error")
assert test_func() is None
def test_redis_fallback_preserves_function_metadata():
@redis_fallback(default_return=None)
def test_func():
"""Test function docstring"""
pass
def test_redis_fallback_with_args(self):
@redis_fallback(default_return=0)
def test_func(x, y):
raise RedisError("Redis error")
assert test_func.__name__ == "test_func"
assert test_func.__doc__ == "Test function docstring"
assert test_func(1, 2) == 0
def test_redis_fallback_with_kwargs(self):
@redis_fallback(default_return={})
def test_func(x=None, y=None):
raise RedisError("Redis error")
assert test_func(x=1, y=2) == {}
def test_redis_fallback_preserves_function_metadata(self):
@redis_fallback(default_return=None)
def test_func():
"""Test function docstring"""
pass
assert test_func.__name__ == "test_func"
assert test_func.__doc__ == "Test function docstring"

View File

@ -373,6 +373,20 @@ REDIS_USE_CLUSTERS=false
REDIS_CLUSTERS=
REDIS_CLUSTERS_PASSWORD=
# Redis connection and retry configuration
# max redis retry
REDIS_RETRY_RETRIES=3
# Base delay (in seconds) for exponential backoff on retries
REDIS_RETRY_BACKOFF_BASE=1.0
# Cap (in seconds) for exponential backoff on retries
REDIS_RETRY_BACKOFF_CAP=10.0
# Timeout (in seconds) for Redis socket operations
REDIS_SOCKET_TIMEOUT=5.0
# Timeout (in seconds) for establishing a Redis connection
REDIS_SOCKET_CONNECT_TIMEOUT=5.0
# Interval (in seconds) for Redis health checks
REDIS_HEALTH_CHECK_INTERVAL=30
# ------------------------------
# Celery Configuration
# ------------------------------

View File

@ -100,6 +100,12 @@ x-shared-env: &shared-api-worker-env
REDIS_USE_CLUSTERS: ${REDIS_USE_CLUSTERS:-false}
REDIS_CLUSTERS: ${REDIS_CLUSTERS:-}
REDIS_CLUSTERS_PASSWORD: ${REDIS_CLUSTERS_PASSWORD:-}
REDIS_RETRY_RETRIES: ${REDIS_RETRY_RETRIES:-3}
REDIS_RETRY_BACKOFF_BASE: ${REDIS_RETRY_BACKOFF_BASE:-1.0}
REDIS_RETRY_BACKOFF_CAP: ${REDIS_RETRY_BACKOFF_CAP:-10.0}
REDIS_SOCKET_TIMEOUT: ${REDIS_SOCKET_TIMEOUT:-5.0}
REDIS_SOCKET_CONNECT_TIMEOUT: ${REDIS_SOCKET_CONNECT_TIMEOUT:-5.0}
REDIS_HEALTH_CHECK_INTERVAL: ${REDIS_HEALTH_CHECK_INTERVAL:-30}
CELERY_BROKER_URL: ${CELERY_BROKER_URL:-redis://:difyai123456@redis:6379/1}
CELERY_BACKEND: ${CELERY_BACKEND:-redis}
BROKER_USE_SSL: ${BROKER_USE_SSL:-false}