mirror of https://github.com/langgenius/dify.git
perf(api): optimize tool provider list API with Redis caching (#29101)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
parent
05fe92a541
commit
71497954b8
|
|
@ -0,0 +1,56 @@
|
|||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
|
||||
from extensions.ext_redis import redis_client, redis_fallback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolProviderListCache:
|
||||
"""Cache for tool provider lists"""
|
||||
|
||||
CACHE_TTL = 300 # 5 minutes
|
||||
|
||||
@staticmethod
|
||||
def _generate_cache_key(tenant_id: str, typ: ToolProviderTypeApiLiteral = None) -> str:
|
||||
"""Generate cache key for tool providers list"""
|
||||
type_filter = typ or "all"
|
||||
return f"tool_providers:tenant_id:{tenant_id}:type:{type_filter}"
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=None)
|
||||
def get_cached_providers(tenant_id: str, typ: ToolProviderTypeApiLiteral = None) -> list[dict[str, Any]] | None:
|
||||
"""Get cached tool providers"""
|
||||
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
|
||||
cached_data = redis_client.get(cache_key)
|
||||
if cached_data:
|
||||
try:
|
||||
return json.loads(cached_data.decode("utf-8"))
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
logger.warning("Failed to decode cached tool providers data")
|
||||
return None
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback()
|
||||
def set_cached_providers(tenant_id: str, typ: ToolProviderTypeApiLiteral, providers: list[dict[str, Any]]):
|
||||
"""Cache tool providers"""
|
||||
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
|
||||
redis_client.setex(cache_key, ToolProviderListCache.CACHE_TTL, json.dumps(providers))
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback()
|
||||
def invalidate_cache(tenant_id: str, typ: ToolProviderTypeApiLiteral = None):
|
||||
"""Invalidate cache for tool providers"""
|
||||
if typ:
|
||||
# Invalidate specific type cache
|
||||
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
|
||||
redis_client.delete(cache_key)
|
||||
else:
|
||||
# Invalidate all caches for this tenant
|
||||
pattern = f"tool_providers:tenant_id:{tenant_id}:*"
|
||||
keys = list(redis_client.scan_iter(pattern))
|
||||
if keys:
|
||||
redis_client.delete(*keys)
|
||||
|
|
@ -5,7 +5,7 @@ import time
|
|||
from collections.abc import Generator, Mapping
|
||||
from os import listdir, path
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, TypedDict, Union, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import select
|
||||
|
|
@ -67,6 +67,11 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ApiProviderControllerItem(TypedDict):
|
||||
provider: ApiToolProvider
|
||||
controller: ApiToolProviderController
|
||||
|
||||
|
||||
class ToolManager:
|
||||
_builtin_provider_lock = Lock()
|
||||
_hardcoded_providers: dict[str, BuiltinToolProviderController] = {}
|
||||
|
|
@ -655,9 +660,10 @@ class ToolManager:
|
|||
else:
|
||||
filters.append(typ)
|
||||
|
||||
with db.session.no_autoflush:
|
||||
# Use a single session for all database operations to reduce connection overhead
|
||||
with Session(db.engine) as session:
|
||||
if "builtin" in filters:
|
||||
builtin_providers = cls.list_builtin_providers(tenant_id)
|
||||
builtin_providers = list(cls.list_builtin_providers(tenant_id))
|
||||
|
||||
# key: provider name, value: provider
|
||||
db_builtin_providers = {
|
||||
|
|
@ -688,57 +694,74 @@ class ToolManager:
|
|||
|
||||
# get db api providers
|
||||
if "api" in filters:
|
||||
db_api_providers = db.session.scalars(
|
||||
db_api_providers = session.scalars(
|
||||
select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)
|
||||
).all()
|
||||
|
||||
api_provider_controllers: list[dict[str, Any]] = [
|
||||
{"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
|
||||
for provider in db_api_providers
|
||||
]
|
||||
# Batch create controllers
|
||||
api_provider_controllers: list[ApiProviderControllerItem] = []
|
||||
for api_provider in db_api_providers:
|
||||
try:
|
||||
controller = ToolTransformService.api_provider_to_controller(api_provider)
|
||||
api_provider_controllers.append({"provider": api_provider, "controller": controller})
|
||||
except Exception:
|
||||
# Skip invalid providers but continue processing others
|
||||
logger.warning("Failed to create controller for API provider %s", api_provider.id)
|
||||
|
||||
# get labels
|
||||
labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers])
|
||||
|
||||
for api_provider_controller in api_provider_controllers:
|
||||
user_provider = ToolTransformService.api_provider_to_user_provider(
|
||||
provider_controller=api_provider_controller["controller"],
|
||||
db_provider=api_provider_controller["provider"],
|
||||
decrypt_credentials=False,
|
||||
labels=labels.get(api_provider_controller["controller"].provider_id, []),
|
||||
# Batch get labels for all API providers
|
||||
if api_provider_controllers:
|
||||
controllers = cast(
|
||||
list[ToolProviderController], [item["controller"] for item in api_provider_controllers]
|
||||
)
|
||||
result_providers[f"api_provider.{user_provider.name}"] = user_provider
|
||||
labels = ToolLabelManager.get_tools_labels(controllers)
|
||||
|
||||
for item in api_provider_controllers:
|
||||
provider_controller = item["controller"]
|
||||
db_provider = item["provider"]
|
||||
provider_labels = labels.get(provider_controller.provider_id, [])
|
||||
user_provider = ToolTransformService.api_provider_to_user_provider(
|
||||
provider_controller=provider_controller,
|
||||
db_provider=db_provider,
|
||||
decrypt_credentials=False,
|
||||
labels=provider_labels,
|
||||
)
|
||||
result_providers[f"api_provider.{user_provider.name}"] = user_provider
|
||||
|
||||
if "workflow" in filters:
|
||||
# get workflow providers
|
||||
workflow_providers = db.session.scalars(
|
||||
workflow_providers = session.scalars(
|
||||
select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id)
|
||||
).all()
|
||||
|
||||
workflow_provider_controllers: list[WorkflowToolProviderController] = []
|
||||
for workflow_provider in workflow_providers:
|
||||
try:
|
||||
workflow_provider_controllers.append(
|
||||
workflow_controller: WorkflowToolProviderController = (
|
||||
ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
|
||||
)
|
||||
workflow_provider_controllers.append(workflow_controller)
|
||||
except Exception:
|
||||
# app has been deleted
|
||||
logger.exception("Failed to transform workflow provider %s to controller", workflow_provider.id)
|
||||
continue
|
||||
# Batch get labels for workflow providers
|
||||
if workflow_provider_controllers:
|
||||
workflow_controllers: list[ToolProviderController] = [
|
||||
cast(ToolProviderController, controller) for controller in workflow_provider_controllers
|
||||
]
|
||||
labels = ToolLabelManager.get_tools_labels(workflow_controllers)
|
||||
|
||||
labels = ToolLabelManager.get_tools_labels(
|
||||
[cast(ToolProviderController, controller) for controller in workflow_provider_controllers]
|
||||
)
|
||||
for workflow_provider_controller in workflow_provider_controllers:
|
||||
provider_labels = labels.get(workflow_provider_controller.provider_id, [])
|
||||
user_provider = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=workflow_provider_controller,
|
||||
labels=provider_labels,
|
||||
)
|
||||
result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
|
||||
|
||||
for provider_controller in workflow_provider_controllers:
|
||||
user_provider = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=provider_controller,
|
||||
labels=labels.get(provider_controller.provider_id, []),
|
||||
)
|
||||
result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
|
||||
if "mcp" in filters:
|
||||
with Session(db.engine) as session:
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
mcp_providers = mcp_service.list_providers(tenant_id=tenant_id, for_list=True)
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
mcp_providers = mcp_service.list_providers(tenant_id=tenant_id, for_list=True)
|
||||
for mcp_provider in mcp_providers:
|
||||
result_providers[f"mcp_provider.{mcp_provider.name}"] = mcp_provider
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from httpx import get
|
|||
from sqlalchemy import select
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.helper.tool_provider_cache import ToolProviderListCache
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.custom_tool.provider import ApiToolProviderController
|
||||
|
|
@ -177,6 +178,9 @@ class ApiToolManageService:
|
|||
# update labels
|
||||
ToolLabelManager.update_tool_labels(provider_controller, labels)
|
||||
|
||||
# Invalidate tool providers cache
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -318,6 +322,9 @@ class ApiToolManageService:
|
|||
# update labels
|
||||
ToolLabelManager.update_tool_labels(provider_controller, labels)
|
||||
|
||||
# Invalidate tool providers cache
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -340,6 +347,9 @@ class ApiToolManageService:
|
|||
db.session.delete(provider)
|
||||
db.session.commit()
|
||||
|
||||
# Invalidate tool providers cache
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
|||
from core.helper.name_generator import generate_incremental_name
|
||||
from core.helper.position_helper import is_filtered
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
|
||||
from core.helper.tool_provider_cache import ToolProviderListCache
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
||||
|
|
@ -204,6 +205,9 @@ class BuiltinToolManageService:
|
|||
db_provider.name = name
|
||||
|
||||
session.commit()
|
||||
|
||||
# Invalidate tool providers cache
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
raise ValueError(str(e))
|
||||
|
|
@ -282,6 +286,9 @@ class BuiltinToolManageService:
|
|||
|
||||
session.add(db_provider)
|
||||
session.commit()
|
||||
|
||||
# Invalidate tool providers cache
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
raise ValueError(str(e))
|
||||
|
|
@ -402,6 +409,9 @@ class BuiltinToolManageService:
|
|||
)
|
||||
cache.delete()
|
||||
|
||||
# Invalidate tool providers cache
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -423,6 +433,9 @@ class BuiltinToolManageService:
|
|||
# set new default provider
|
||||
target_provider.is_default = True
|
||||
session.commit()
|
||||
|
||||
# Invalidate tool providers cache
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from sqlalchemy.orm import Session
|
|||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity
|
||||
from core.helper import encrypter
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.helper.tool_provider_cache import ToolProviderListCache
|
||||
from core.mcp.auth.auth_flow import auth
|
||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||
from core.mcp.error import MCPAuthError, MCPError
|
||||
|
|
@ -164,6 +165,10 @@ class MCPToolManageService:
|
|||
|
||||
self._session.add(mcp_tool)
|
||||
self._session.flush()
|
||||
|
||||
# Invalidate tool providers cache
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
mcp_providers = ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
|
||||
return mcp_providers
|
||||
|
||||
|
|
@ -245,6 +250,9 @@ class MCPToolManageService:
|
|||
|
||||
# Flush changes to database
|
||||
self._session.flush()
|
||||
|
||||
# Invalidate tool providers cache
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
except IntegrityError as e:
|
||||
self._handle_integrity_error(e, name, server_url, server_identifier)
|
||||
|
||||
|
|
@ -253,6 +261,9 @@ class MCPToolManageService:
|
|||
mcp_tool = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
self._session.delete(mcp_tool)
|
||||
|
||||
# Invalidate tool providers cache
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
def list_providers(
|
||||
self, *, tenant_id: str, for_list: bool = False, include_sensitive: bool = True
|
||||
) -> list[ToolProviderApiEntity]:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import logging
|
||||
|
||||
from core.helper.tool_provider_cache import ToolProviderListCache
|
||||
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
|
@ -15,6 +16,14 @@ class ToolCommonService:
|
|||
|
||||
:return: the list of tool providers
|
||||
"""
|
||||
# Try to get from cache first
|
||||
cached_result = ToolProviderListCache.get_cached_providers(tenant_id, typ)
|
||||
if cached_result is not None:
|
||||
logger.debug("Returning cached tool providers for tenant %s, type %s", tenant_id, typ)
|
||||
return cached_result
|
||||
|
||||
# Cache miss - fetch from database
|
||||
logger.debug("Cache miss for tool providers, fetching from database for tenant %s, type %s", tenant_id, typ)
|
||||
providers = ToolManager.list_providers_from_api(user_id, tenant_id, typ)
|
||||
|
||||
# add icon
|
||||
|
|
@ -23,4 +32,7 @@ class ToolCommonService:
|
|||
|
||||
result = [provider.to_dict() for provider in providers]
|
||||
|
||||
# Cache the result
|
||||
ToolProviderListCache.set_cached_providers(tenant_id, typ, result)
|
||||
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from typing import Any
|
|||
from sqlalchemy import or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.helper.tool_provider_cache import ToolProviderListCache
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
|
|
@ -91,6 +92,10 @@ class WorkflowToolManageService:
|
|||
ToolLabelManager.update_tool_labels(
|
||||
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
|
||||
)
|
||||
|
||||
# Invalidate tool providers cache
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
|
|
@ -178,6 +183,9 @@ class WorkflowToolManageService:
|
|||
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
|
||||
)
|
||||
|
||||
# Invalidate tool providers cache
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
|
|
@ -240,6 +248,9 @@ class WorkflowToolManageService:
|
|||
|
||||
db.session.commit()
|
||||
|
||||
# Invalidate tool providers cache
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -0,0 +1,129 @@
|
|||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from core.helper.tool_provider_cache import ToolProviderListCache
|
||||
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis_client():
|
||||
"""Fixture: Mock Redis client"""
|
||||
with patch("core.helper.tool_provider_cache.redis_client") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
class TestToolProviderListCache:
|
||||
"""Test class for ToolProviderListCache"""
|
||||
|
||||
def test_generate_cache_key(self):
|
||||
"""Test cache key generation logic"""
|
||||
# Scenario 1: Specify typ (valid literal value)
|
||||
tenant_id = "tenant_123"
|
||||
typ: ToolProviderTypeApiLiteral = "builtin"
|
||||
expected_key = f"tool_providers:tenant_id:{tenant_id}:type:{typ}"
|
||||
assert ToolProviderListCache._generate_cache_key(tenant_id, typ) == expected_key
|
||||
|
||||
# Scenario 2: typ is None (defaults to "all")
|
||||
expected_key_all = f"tool_providers:tenant_id:{tenant_id}:type:all"
|
||||
assert ToolProviderListCache._generate_cache_key(tenant_id) == expected_key_all
|
||||
|
||||
def test_get_cached_providers_hit(self, mock_redis_client):
|
||||
"""Test get cached providers - cache hit and successful decoding"""
|
||||
tenant_id = "tenant_123"
|
||||
typ: ToolProviderTypeApiLiteral = "api"
|
||||
mock_providers = [{"id": "tool", "name": "test_provider"}]
|
||||
mock_redis_client.get.return_value = json.dumps(mock_providers).encode("utf-8")
|
||||
|
||||
result = ToolProviderListCache.get_cached_providers(tenant_id, typ)
|
||||
|
||||
mock_redis_client.get.assert_called_once_with(ToolProviderListCache._generate_cache_key(tenant_id, typ))
|
||||
assert result == mock_providers
|
||||
|
||||
def test_get_cached_providers_decode_error(self, mock_redis_client):
|
||||
"""Test get cached providers - cache hit but decoding failed"""
|
||||
tenant_id = "tenant_123"
|
||||
mock_redis_client.get.return_value = b"invalid_json_data"
|
||||
|
||||
result = ToolProviderListCache.get_cached_providers(tenant_id)
|
||||
|
||||
assert result is None
|
||||
mock_redis_client.get.assert_called_once()
|
||||
|
||||
def test_get_cached_providers_miss(self, mock_redis_client):
|
||||
"""Test get cached providers - cache miss"""
|
||||
tenant_id = "tenant_123"
|
||||
mock_redis_client.get.return_value = None
|
||||
|
||||
result = ToolProviderListCache.get_cached_providers(tenant_id)
|
||||
|
||||
assert result is None
|
||||
mock_redis_client.get.assert_called_once()
|
||||
|
||||
def test_set_cached_providers(self, mock_redis_client):
|
||||
"""Test set cached providers"""
|
||||
tenant_id = "tenant_123"
|
||||
typ: ToolProviderTypeApiLiteral = "builtin"
|
||||
mock_providers = [{"id": "tool", "name": "test_provider"}]
|
||||
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
|
||||
|
||||
ToolProviderListCache.set_cached_providers(tenant_id, typ, mock_providers)
|
||||
|
||||
mock_redis_client.setex.assert_called_once_with(
|
||||
cache_key, ToolProviderListCache.CACHE_TTL, json.dumps(mock_providers)
|
||||
)
|
||||
|
||||
def test_invalidate_cache_specific_type(self, mock_redis_client):
|
||||
"""Test invalidate cache - specific type"""
|
||||
tenant_id = "tenant_123"
|
||||
typ: ToolProviderTypeApiLiteral = "workflow"
|
||||
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
|
||||
|
||||
ToolProviderListCache.invalidate_cache(tenant_id, typ)
|
||||
|
||||
mock_redis_client.delete.assert_called_once_with(cache_key)
|
||||
|
||||
def test_invalidate_cache_all_types(self, mock_redis_client):
|
||||
"""Test invalidate cache - clear all tenant cache"""
|
||||
tenant_id = "tenant_123"
|
||||
mock_keys = [
|
||||
b"tool_providers:tenant_id:tenant_123:type:all",
|
||||
b"tool_providers:tenant_id:tenant_123:type:builtin",
|
||||
]
|
||||
mock_redis_client.scan_iter.return_value = mock_keys
|
||||
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
mock_redis_client.scan_iter.assert_called_once_with(f"tool_providers:tenant_id:{tenant_id}:*")
|
||||
mock_redis_client.delete.assert_called_once_with(*mock_keys)
|
||||
|
||||
def test_invalidate_cache_no_keys(self, mock_redis_client):
|
||||
"""Test invalidate cache - no cache keys for tenant"""
|
||||
tenant_id = "tenant_123"
|
||||
mock_redis_client.scan_iter.return_value = []
|
||||
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
mock_redis_client.delete.assert_not_called()
|
||||
|
||||
def test_redis_fallback_default_return(self, mock_redis_client):
|
||||
"""Test redis_fallback decorator - default return value (Redis error)"""
|
||||
mock_redis_client.get.side_effect = RedisError("Redis connection error")
|
||||
|
||||
result = ToolProviderListCache.get_cached_providers("tenant_123")
|
||||
|
||||
assert result is None
|
||||
mock_redis_client.get.assert_called_once()
|
||||
|
||||
def test_redis_fallback_no_default(self, mock_redis_client):
|
||||
"""Test redis_fallback decorator - no default return value (Redis error)"""
|
||||
mock_redis_client.setex.side_effect = RedisError("Redis connection error")
|
||||
|
||||
try:
|
||||
ToolProviderListCache.set_cached_providers("tenant_123", "mcp", [])
|
||||
except RedisError:
|
||||
pytest.fail("set_cached_providers should not raise RedisError (handled by fallback)")
|
||||
|
||||
mock_redis_client.setex.assert_called_once()
|
||||
Loading…
Reference in New Issue