perf(api): optimize tool provider list API with Redis caching (#29101)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
yangzheli 2025-12-08 14:34:03 +08:00 committed by GitHub
parent 05fe92a541
commit 71497954b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 297 additions and 32 deletions

View File

@ -0,0 +1,56 @@
import json
import logging
from typing import Any
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
from extensions.ext_redis import redis_client, redis_fallback
logger = logging.getLogger(__name__)
class ToolProviderListCache:
"""Cache for tool provider lists"""
CACHE_TTL = 300 # 5 minutes
@staticmethod
def _generate_cache_key(tenant_id: str, typ: ToolProviderTypeApiLiteral = None) -> str:
"""Generate cache key for tool providers list"""
type_filter = typ or "all"
return f"tool_providers:tenant_id:{tenant_id}:type:{type_filter}"
@staticmethod
@redis_fallback(default_return=None)
def get_cached_providers(tenant_id: str, typ: ToolProviderTypeApiLiteral = None) -> list[dict[str, Any]] | None:
"""Get cached tool providers"""
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
cached_data = redis_client.get(cache_key)
if cached_data:
try:
return json.loads(cached_data.decode("utf-8"))
except (json.JSONDecodeError, UnicodeDecodeError):
logger.warning("Failed to decode cached tool providers data")
return None
return None
@staticmethod
@redis_fallback()
def set_cached_providers(tenant_id: str, typ: ToolProviderTypeApiLiteral, providers: list[dict[str, Any]]):
"""Cache tool providers"""
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
redis_client.setex(cache_key, ToolProviderListCache.CACHE_TTL, json.dumps(providers))
@staticmethod
@redis_fallback()
def invalidate_cache(tenant_id: str, typ: ToolProviderTypeApiLiteral = None):
"""Invalidate cache for tool providers"""
if typ:
# Invalidate specific type cache
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
redis_client.delete(cache_key)
else:
# Invalidate all caches for this tenant
pattern = f"tool_providers:tenant_id:{tenant_id}:*"
keys = list(redis_client.scan_iter(pattern))
if keys:
redis_client.delete(*keys)

View File

@ -5,7 +5,7 @@ import time
from collections.abc import Generator, Mapping
from os import listdir, path
from threading import Lock
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Literal, Optional, TypedDict, Union, cast
import sqlalchemy as sa
from sqlalchemy import select
@ -67,6 +67,11 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class ApiProviderControllerItem(TypedDict):
provider: ApiToolProvider
controller: ApiToolProviderController
class ToolManager:
_builtin_provider_lock = Lock()
_hardcoded_providers: dict[str, BuiltinToolProviderController] = {}
@ -655,9 +660,10 @@ class ToolManager:
else:
filters.append(typ)
with db.session.no_autoflush:
# Use a single session for all database operations to reduce connection overhead
with Session(db.engine) as session:
if "builtin" in filters:
builtin_providers = cls.list_builtin_providers(tenant_id)
builtin_providers = list(cls.list_builtin_providers(tenant_id))
# key: provider name, value: provider
db_builtin_providers = {
@ -688,57 +694,74 @@ class ToolManager:
# get db api providers
if "api" in filters:
db_api_providers = db.session.scalars(
db_api_providers = session.scalars(
select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)
).all()
api_provider_controllers: list[dict[str, Any]] = [
{"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
for provider in db_api_providers
]
# Batch create controllers
api_provider_controllers: list[ApiProviderControllerItem] = []
for api_provider in db_api_providers:
try:
controller = ToolTransformService.api_provider_to_controller(api_provider)
api_provider_controllers.append({"provider": api_provider, "controller": controller})
except Exception:
# Skip invalid providers but continue processing others
logger.warning("Failed to create controller for API provider %s", api_provider.id)
# get labels
labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers])
for api_provider_controller in api_provider_controllers:
user_provider = ToolTransformService.api_provider_to_user_provider(
provider_controller=api_provider_controller["controller"],
db_provider=api_provider_controller["provider"],
decrypt_credentials=False,
labels=labels.get(api_provider_controller["controller"].provider_id, []),
# Batch get labels for all API providers
if api_provider_controllers:
controllers = cast(
list[ToolProviderController], [item["controller"] for item in api_provider_controllers]
)
result_providers[f"api_provider.{user_provider.name}"] = user_provider
labels = ToolLabelManager.get_tools_labels(controllers)
for item in api_provider_controllers:
provider_controller = item["controller"]
db_provider = item["provider"]
provider_labels = labels.get(provider_controller.provider_id, [])
user_provider = ToolTransformService.api_provider_to_user_provider(
provider_controller=provider_controller,
db_provider=db_provider,
decrypt_credentials=False,
labels=provider_labels,
)
result_providers[f"api_provider.{user_provider.name}"] = user_provider
if "workflow" in filters:
# get workflow providers
workflow_providers = db.session.scalars(
workflow_providers = session.scalars(
select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id)
).all()
workflow_provider_controllers: list[WorkflowToolProviderController] = []
for workflow_provider in workflow_providers:
try:
workflow_provider_controllers.append(
workflow_controller: WorkflowToolProviderController = (
ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
)
workflow_provider_controllers.append(workflow_controller)
except Exception:
# app has been deleted
logger.exception("Failed to transform workflow provider %s to controller", workflow_provider.id)
continue
# Batch get labels for workflow providers
if workflow_provider_controllers:
workflow_controllers: list[ToolProviderController] = [
cast(ToolProviderController, controller) for controller in workflow_provider_controllers
]
labels = ToolLabelManager.get_tools_labels(workflow_controllers)
labels = ToolLabelManager.get_tools_labels(
[cast(ToolProviderController, controller) for controller in workflow_provider_controllers]
)
for workflow_provider_controller in workflow_provider_controllers:
provider_labels = labels.get(workflow_provider_controller.provider_id, [])
user_provider = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=workflow_provider_controller,
labels=provider_labels,
)
result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
for provider_controller in workflow_provider_controllers:
user_provider = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=provider_controller,
labels=labels.get(provider_controller.provider_id, []),
)
result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
if "mcp" in filters:
with Session(db.engine) as session:
mcp_service = MCPToolManageService(session=session)
mcp_providers = mcp_service.list_providers(tenant_id=tenant_id, for_list=True)
mcp_service = MCPToolManageService(session=session)
mcp_providers = mcp_service.list_providers(tenant_id=tenant_id, for_list=True)
for mcp_provider in mcp_providers:
result_providers[f"mcp_provider.{mcp_provider.name}"] = mcp_provider

View File

@ -7,6 +7,7 @@ from httpx import get
from sqlalchemy import select
from core.entities.provider_entities import ProviderConfig
from core.helper.tool_provider_cache import ToolProviderListCache
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.custom_tool.provider import ApiToolProviderController
@ -177,6 +178,9 @@ class ApiToolManageService:
# update labels
ToolLabelManager.update_tool_labels(provider_controller, labels)
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
return {"result": "success"}
@staticmethod
@ -318,6 +322,9 @@ class ApiToolManageService:
# update labels
ToolLabelManager.update_tool_labels(provider_controller, labels)
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
return {"result": "success"}
@staticmethod
@ -340,6 +347,9 @@ class ApiToolManageService:
db.session.delete(provider)
db.session.commit()
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
return {"result": "success"}
@staticmethod

View File

@ -12,6 +12,7 @@ from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from core.helper.name_generator import generate_incremental_name
from core.helper.position_helper import is_filtered
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
from core.helper.tool_provider_cache import ToolProviderListCache
from core.plugin.entities.plugin_daemon import CredentialType
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
@ -204,6 +205,9 @@ class BuiltinToolManageService:
db_provider.name = name
session.commit()
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
except Exception as e:
session.rollback()
raise ValueError(str(e))
@ -282,6 +286,9 @@ class BuiltinToolManageService:
session.add(db_provider)
session.commit()
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
except Exception as e:
session.rollback()
raise ValueError(str(e))
@ -402,6 +409,9 @@ class BuiltinToolManageService:
)
cache.delete()
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
return {"result": "success"}
@staticmethod
@ -423,6 +433,9 @@ class BuiltinToolManageService:
# set new default provider
target_provider.is_default = True
session.commit()
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
return {"result": "success"}
@staticmethod

View File

@ -15,6 +15,7 @@ from sqlalchemy.orm import Session
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity
from core.helper import encrypter
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.helper.tool_provider_cache import ToolProviderListCache
from core.mcp.auth.auth_flow import auth
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPAuthError, MCPError
@ -164,6 +165,10 @@ class MCPToolManageService:
self._session.add(mcp_tool)
self._session.flush()
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
mcp_providers = ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
return mcp_providers
@ -245,6 +250,9 @@ class MCPToolManageService:
# Flush changes to database
self._session.flush()
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
except IntegrityError as e:
self._handle_integrity_error(e, name, server_url, server_identifier)
@ -253,6 +261,9 @@ class MCPToolManageService:
mcp_tool = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
self._session.delete(mcp_tool)
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
def list_providers(
self, *, tenant_id: str, for_list: bool = False, include_sensitive: bool = True
) -> list[ToolProviderApiEntity]:

View File

@ -1,5 +1,6 @@
import logging
from core.helper.tool_provider_cache import ToolProviderListCache
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
from core.tools.tool_manager import ToolManager
from services.tools.tools_transform_service import ToolTransformService
@ -15,6 +16,14 @@ class ToolCommonService:
:return: the list of tool providers
"""
# Try to get from cache first
cached_result = ToolProviderListCache.get_cached_providers(tenant_id, typ)
if cached_result is not None:
logger.debug("Returning cached tool providers for tenant %s, type %s", tenant_id, typ)
return cached_result
# Cache miss - fetch from database
logger.debug("Cache miss for tool providers, fetching from database for tenant %s, type %s", tenant_id, typ)
providers = ToolManager.list_providers_from_api(user_id, tenant_id, typ)
# add icon
@ -23,4 +32,7 @@ class ToolCommonService:
result = [provider.to_dict() for provider in providers]
# Cache the result
ToolProviderListCache.set_cached_providers(tenant_id, typ, result)
return result

View File

@ -7,6 +7,7 @@ from typing import Any
from sqlalchemy import or_, select
from sqlalchemy.orm import Session
from core.helper.tool_provider_cache import ToolProviderListCache
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
@ -91,6 +92,10 @@ class WorkflowToolManageService:
ToolLabelManager.update_tool_labels(
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
)
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
return {"result": "success"}
@classmethod
@ -178,6 +183,9 @@ class WorkflowToolManageService:
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
)
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
return {"result": "success"}
@classmethod
@ -240,6 +248,9 @@ class WorkflowToolManageService:
db.session.commit()
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
return {"result": "success"}
@classmethod

View File

@ -0,0 +1,129 @@
import json
from unittest.mock import patch
import pytest
from redis.exceptions import RedisError
from core.helper.tool_provider_cache import ToolProviderListCache
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
@pytest.fixture
def mock_redis_client():
"""Fixture: Mock Redis client"""
with patch("core.helper.tool_provider_cache.redis_client") as mock:
yield mock
class TestToolProviderListCache:
"""Test class for ToolProviderListCache"""
def test_generate_cache_key(self):
"""Test cache key generation logic"""
# Scenario 1: Specify typ (valid literal value)
tenant_id = "tenant_123"
typ: ToolProviderTypeApiLiteral = "builtin"
expected_key = f"tool_providers:tenant_id:{tenant_id}:type:{typ}"
assert ToolProviderListCache._generate_cache_key(tenant_id, typ) == expected_key
# Scenario 2: typ is None (defaults to "all")
expected_key_all = f"tool_providers:tenant_id:{tenant_id}:type:all"
assert ToolProviderListCache._generate_cache_key(tenant_id) == expected_key_all
def test_get_cached_providers_hit(self, mock_redis_client):
"""Test get cached providers - cache hit and successful decoding"""
tenant_id = "tenant_123"
typ: ToolProviderTypeApiLiteral = "api"
mock_providers = [{"id": "tool", "name": "test_provider"}]
mock_redis_client.get.return_value = json.dumps(mock_providers).encode("utf-8")
result = ToolProviderListCache.get_cached_providers(tenant_id, typ)
mock_redis_client.get.assert_called_once_with(ToolProviderListCache._generate_cache_key(tenant_id, typ))
assert result == mock_providers
def test_get_cached_providers_decode_error(self, mock_redis_client):
"""Test get cached providers - cache hit but decoding failed"""
tenant_id = "tenant_123"
mock_redis_client.get.return_value = b"invalid_json_data"
result = ToolProviderListCache.get_cached_providers(tenant_id)
assert result is None
mock_redis_client.get.assert_called_once()
def test_get_cached_providers_miss(self, mock_redis_client):
"""Test get cached providers - cache miss"""
tenant_id = "tenant_123"
mock_redis_client.get.return_value = None
result = ToolProviderListCache.get_cached_providers(tenant_id)
assert result is None
mock_redis_client.get.assert_called_once()
def test_set_cached_providers(self, mock_redis_client):
"""Test set cached providers"""
tenant_id = "tenant_123"
typ: ToolProviderTypeApiLiteral = "builtin"
mock_providers = [{"id": "tool", "name": "test_provider"}]
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
ToolProviderListCache.set_cached_providers(tenant_id, typ, mock_providers)
mock_redis_client.setex.assert_called_once_with(
cache_key, ToolProviderListCache.CACHE_TTL, json.dumps(mock_providers)
)
def test_invalidate_cache_specific_type(self, mock_redis_client):
"""Test invalidate cache - specific type"""
tenant_id = "tenant_123"
typ: ToolProviderTypeApiLiteral = "workflow"
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
ToolProviderListCache.invalidate_cache(tenant_id, typ)
mock_redis_client.delete.assert_called_once_with(cache_key)
def test_invalidate_cache_all_types(self, mock_redis_client):
"""Test invalidate cache - clear all tenant cache"""
tenant_id = "tenant_123"
mock_keys = [
b"tool_providers:tenant_id:tenant_123:type:all",
b"tool_providers:tenant_id:tenant_123:type:builtin",
]
mock_redis_client.scan_iter.return_value = mock_keys
ToolProviderListCache.invalidate_cache(tenant_id)
mock_redis_client.scan_iter.assert_called_once_with(f"tool_providers:tenant_id:{tenant_id}:*")
mock_redis_client.delete.assert_called_once_with(*mock_keys)
def test_invalidate_cache_no_keys(self, mock_redis_client):
"""Test invalidate cache - no cache keys for tenant"""
tenant_id = "tenant_123"
mock_redis_client.scan_iter.return_value = []
ToolProviderListCache.invalidate_cache(tenant_id)
mock_redis_client.delete.assert_not_called()
def test_redis_fallback_default_return(self, mock_redis_client):
"""Test redis_fallback decorator - default return value (Redis error)"""
mock_redis_client.get.side_effect = RedisError("Redis connection error")
result = ToolProviderListCache.get_cached_providers("tenant_123")
assert result is None
mock_redis_client.get.assert_called_once()
def test_redis_fallback_no_default(self, mock_redis_client):
"""Test redis_fallback decorator - no default return value (Redis error)"""
mock_redis_client.setex.side_effect = RedisError("Redis connection error")
try:
ToolProviderListCache.set_cached_providers("tenant_123", "mcp", [])
except RedisError:
pytest.fail("set_cached_providers should not raise RedisError (handled by fallback)")
mock_redis_client.setex.assert_called_once()