perf(api): optimize tool provider list API with Redis caching (#29101)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
yangzheli 2025-12-08 14:34:03 +08:00 committed by GitHub
parent 05fe92a541
commit 71497954b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 297 additions and 32 deletions

View File

@ -0,0 +1,56 @@
import json
import logging
from typing import Any
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
from extensions.ext_redis import redis_client, redis_fallback
logger = logging.getLogger(__name__)
class ToolProviderListCache:
"""Cache for tool provider lists"""
CACHE_TTL = 300 # 5 minutes
@staticmethod
def _generate_cache_key(tenant_id: str, typ: ToolProviderTypeApiLiteral = None) -> str:
"""Generate cache key for tool providers list"""
type_filter = typ or "all"
return f"tool_providers:tenant_id:{tenant_id}:type:{type_filter}"
@staticmethod
@redis_fallback(default_return=None)
def get_cached_providers(tenant_id: str, typ: ToolProviderTypeApiLiteral = None) -> list[dict[str, Any]] | None:
"""Get cached tool providers"""
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
cached_data = redis_client.get(cache_key)
if cached_data:
try:
return json.loads(cached_data.decode("utf-8"))
except (json.JSONDecodeError, UnicodeDecodeError):
logger.warning("Failed to decode cached tool providers data")
return None
return None
@staticmethod
@redis_fallback()
def set_cached_providers(tenant_id: str, typ: ToolProviderTypeApiLiteral, providers: list[dict[str, Any]]):
"""Cache tool providers"""
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
redis_client.setex(cache_key, ToolProviderListCache.CACHE_TTL, json.dumps(providers))
@staticmethod
@redis_fallback()
def invalidate_cache(tenant_id: str, typ: ToolProviderTypeApiLiteral = None):
"""Invalidate cache for tool providers"""
if typ:
# Invalidate specific type cache
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
redis_client.delete(cache_key)
else:
# Invalidate all caches for this tenant
pattern = f"tool_providers:tenant_id:{tenant_id}:*"
keys = list(redis_client.scan_iter(pattern))
if keys:
redis_client.delete(*keys)

View File

@ -5,7 +5,7 @@ import time
from collections.abc import Generator, Mapping from collections.abc import Generator, Mapping
from os import listdir, path from os import listdir, path
from threading import Lock from threading import Lock
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast from typing import TYPE_CHECKING, Any, Literal, Optional, TypedDict, Union, cast
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy import select from sqlalchemy import select
@ -67,6 +67,11 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ApiProviderControllerItem(TypedDict):
provider: ApiToolProvider
controller: ApiToolProviderController
class ToolManager: class ToolManager:
_builtin_provider_lock = Lock() _builtin_provider_lock = Lock()
_hardcoded_providers: dict[str, BuiltinToolProviderController] = {} _hardcoded_providers: dict[str, BuiltinToolProviderController] = {}
@ -655,9 +660,10 @@ class ToolManager:
else: else:
filters.append(typ) filters.append(typ)
with db.session.no_autoflush: # Use a single session for all database operations to reduce connection overhead
with Session(db.engine) as session:
if "builtin" in filters: if "builtin" in filters:
builtin_providers = cls.list_builtin_providers(tenant_id) builtin_providers = list(cls.list_builtin_providers(tenant_id))
# key: provider name, value: provider # key: provider name, value: provider
db_builtin_providers = { db_builtin_providers = {
@ -688,57 +694,74 @@ class ToolManager:
# get db api providers # get db api providers
if "api" in filters: if "api" in filters:
db_api_providers = db.session.scalars( db_api_providers = session.scalars(
select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id) select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)
).all() ).all()
api_provider_controllers: list[dict[str, Any]] = [ # Batch create controllers
{"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)} api_provider_controllers: list[ApiProviderControllerItem] = []
for provider in db_api_providers for api_provider in db_api_providers:
] try:
controller = ToolTransformService.api_provider_to_controller(api_provider)
api_provider_controllers.append({"provider": api_provider, "controller": controller})
except Exception:
# Skip invalid providers but continue processing others
logger.warning("Failed to create controller for API provider %s", api_provider.id)
# get labels # Batch get labels for all API providers
labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers]) if api_provider_controllers:
controllers = cast(
for api_provider_controller in api_provider_controllers: list[ToolProviderController], [item["controller"] for item in api_provider_controllers]
user_provider = ToolTransformService.api_provider_to_user_provider(
provider_controller=api_provider_controller["controller"],
db_provider=api_provider_controller["provider"],
decrypt_credentials=False,
labels=labels.get(api_provider_controller["controller"].provider_id, []),
) )
result_providers[f"api_provider.{user_provider.name}"] = user_provider labels = ToolLabelManager.get_tools_labels(controllers)
for item in api_provider_controllers:
provider_controller = item["controller"]
db_provider = item["provider"]
provider_labels = labels.get(provider_controller.provider_id, [])
user_provider = ToolTransformService.api_provider_to_user_provider(
provider_controller=provider_controller,
db_provider=db_provider,
decrypt_credentials=False,
labels=provider_labels,
)
result_providers[f"api_provider.{user_provider.name}"] = user_provider
if "workflow" in filters: if "workflow" in filters:
# get workflow providers # get workflow providers
workflow_providers = db.session.scalars( workflow_providers = session.scalars(
select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id) select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id)
).all() ).all()
workflow_provider_controllers: list[WorkflowToolProviderController] = [] workflow_provider_controllers: list[WorkflowToolProviderController] = []
for workflow_provider in workflow_providers: for workflow_provider in workflow_providers:
try: try:
workflow_provider_controllers.append( workflow_controller: WorkflowToolProviderController = (
ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider) ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
) )
workflow_provider_controllers.append(workflow_controller)
except Exception: except Exception:
# app has been deleted # app has been deleted
logger.exception("Failed to transform workflow provider %s to controller", workflow_provider.id) logger.exception("Failed to transform workflow provider %s to controller", workflow_provider.id)
continue
# Batch get labels for workflow providers
if workflow_provider_controllers:
workflow_controllers: list[ToolProviderController] = [
cast(ToolProviderController, controller) for controller in workflow_provider_controllers
]
labels = ToolLabelManager.get_tools_labels(workflow_controllers)
labels = ToolLabelManager.get_tools_labels( for workflow_provider_controller in workflow_provider_controllers:
[cast(ToolProviderController, controller) for controller in workflow_provider_controllers] provider_labels = labels.get(workflow_provider_controller.provider_id, [])
) user_provider = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=workflow_provider_controller,
labels=provider_labels,
)
result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
for provider_controller in workflow_provider_controllers:
user_provider = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=provider_controller,
labels=labels.get(provider_controller.provider_id, []),
)
result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
if "mcp" in filters: if "mcp" in filters:
with Session(db.engine) as session: mcp_service = MCPToolManageService(session=session)
mcp_service = MCPToolManageService(session=session) mcp_providers = mcp_service.list_providers(tenant_id=tenant_id, for_list=True)
mcp_providers = mcp_service.list_providers(tenant_id=tenant_id, for_list=True)
for mcp_provider in mcp_providers: for mcp_provider in mcp_providers:
result_providers[f"mcp_provider.{mcp_provider.name}"] = mcp_provider result_providers[f"mcp_provider.{mcp_provider.name}"] = mcp_provider

View File

@ -7,6 +7,7 @@ from httpx import get
from sqlalchemy import select from sqlalchemy import select
from core.entities.provider_entities import ProviderConfig from core.entities.provider_entities import ProviderConfig
from core.helper.tool_provider_cache import ToolProviderListCache
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool_runtime import ToolRuntime from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.custom_tool.provider import ApiToolProviderController from core.tools.custom_tool.provider import ApiToolProviderController
@ -177,6 +178,9 @@ class ApiToolManageService:
# update labels # update labels
ToolLabelManager.update_tool_labels(provider_controller, labels) ToolLabelManager.update_tool_labels(provider_controller, labels)
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
return {"result": "success"} return {"result": "success"}
@staticmethod @staticmethod
@ -318,6 +322,9 @@ class ApiToolManageService:
# update labels # update labels
ToolLabelManager.update_tool_labels(provider_controller, labels) ToolLabelManager.update_tool_labels(provider_controller, labels)
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
return {"result": "success"} return {"result": "success"}
@staticmethod @staticmethod
@ -340,6 +347,9 @@ class ApiToolManageService:
db.session.delete(provider) db.session.delete(provider)
db.session.commit() db.session.commit()
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
return {"result": "success"} return {"result": "success"}
@staticmethod @staticmethod

View File

@ -12,6 +12,7 @@ from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from core.helper.name_generator import generate_incremental_name from core.helper.name_generator import generate_incremental_name
from core.helper.position_helper import is_filtered from core.helper.position_helper import is_filtered
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
from core.helper.tool_provider_cache import ToolProviderListCache
from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.entities.plugin_daemon import CredentialType
from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
@ -204,6 +205,9 @@ class BuiltinToolManageService:
db_provider.name = name db_provider.name = name
session.commit() session.commit()
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
except Exception as e: except Exception as e:
session.rollback() session.rollback()
raise ValueError(str(e)) raise ValueError(str(e))
@ -282,6 +286,9 @@ class BuiltinToolManageService:
session.add(db_provider) session.add(db_provider)
session.commit() session.commit()
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
except Exception as e: except Exception as e:
session.rollback() session.rollback()
raise ValueError(str(e)) raise ValueError(str(e))
@ -402,6 +409,9 @@ class BuiltinToolManageService:
) )
cache.delete() cache.delete()
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
return {"result": "success"} return {"result": "success"}
@staticmethod @staticmethod
@ -423,6 +433,9 @@ class BuiltinToolManageService:
# set new default provider # set new default provider
target_provider.is_default = True target_provider.is_default = True
session.commit() session.commit()
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
return {"result": "success"} return {"result": "success"}
@staticmethod @staticmethod

View File

@ -15,6 +15,7 @@ from sqlalchemy.orm import Session
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity
from core.helper import encrypter from core.helper import encrypter
from core.helper.provider_cache import NoOpProviderCredentialCache from core.helper.provider_cache import NoOpProviderCredentialCache
from core.helper.tool_provider_cache import ToolProviderListCache
from core.mcp.auth.auth_flow import auth from core.mcp.auth.auth_flow import auth
from core.mcp.auth_client import MCPClientWithAuthRetry from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPAuthError, MCPError from core.mcp.error import MCPAuthError, MCPError
@ -164,6 +165,10 @@ class MCPToolManageService:
self._session.add(mcp_tool) self._session.add(mcp_tool)
self._session.flush() self._session.flush()
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
mcp_providers = ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True) mcp_providers = ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
return mcp_providers return mcp_providers
@ -245,6 +250,9 @@ class MCPToolManageService:
# Flush changes to database # Flush changes to database
self._session.flush() self._session.flush()
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
except IntegrityError as e: except IntegrityError as e:
self._handle_integrity_error(e, name, server_url, server_identifier) self._handle_integrity_error(e, name, server_url, server_identifier)
@ -253,6 +261,9 @@ class MCPToolManageService:
mcp_tool = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) mcp_tool = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
self._session.delete(mcp_tool) self._session.delete(mcp_tool)
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
def list_providers( def list_providers(
self, *, tenant_id: str, for_list: bool = False, include_sensitive: bool = True self, *, tenant_id: str, for_list: bool = False, include_sensitive: bool = True
) -> list[ToolProviderApiEntity]: ) -> list[ToolProviderApiEntity]:

View File

@ -1,5 +1,6 @@
import logging import logging
from core.helper.tool_provider_cache import ToolProviderListCache
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from services.tools.tools_transform_service import ToolTransformService from services.tools.tools_transform_service import ToolTransformService
@ -15,6 +16,14 @@ class ToolCommonService:
:return: the list of tool providers :return: the list of tool providers
""" """
# Try to get from cache first
cached_result = ToolProviderListCache.get_cached_providers(tenant_id, typ)
if cached_result is not None:
logger.debug("Returning cached tool providers for tenant %s, type %s", tenant_id, typ)
return cached_result
# Cache miss - fetch from database
logger.debug("Cache miss for tool providers, fetching from database for tenant %s, type %s", tenant_id, typ)
providers = ToolManager.list_providers_from_api(user_id, tenant_id, typ) providers = ToolManager.list_providers_from_api(user_id, tenant_id, typ)
# add icon # add icon
@ -23,4 +32,7 @@ class ToolCommonService:
result = [provider.to_dict() for provider in providers] result = [provider.to_dict() for provider in providers]
# Cache the result
ToolProviderListCache.set_cached_providers(tenant_id, typ, result)
return result return result

View File

@ -7,6 +7,7 @@ from typing import Any
from sqlalchemy import or_, select from sqlalchemy import or_, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from core.helper.tool_provider_cache import ToolProviderListCache
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_provider import ToolProviderController
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
@ -91,6 +92,10 @@ class WorkflowToolManageService:
ToolLabelManager.update_tool_labels( ToolLabelManager.update_tool_labels(
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
) )
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
return {"result": "success"} return {"result": "success"}
@classmethod @classmethod
@ -178,6 +183,9 @@ class WorkflowToolManageService:
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
) )
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
return {"result": "success"} return {"result": "success"}
@classmethod @classmethod
@ -240,6 +248,9 @@ class WorkflowToolManageService:
db.session.commit() db.session.commit()
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
return {"result": "success"} return {"result": "success"}
@classmethod @classmethod

View File

@ -0,0 +1,129 @@
import json
from unittest.mock import patch
import pytest
from redis.exceptions import RedisError
from core.helper.tool_provider_cache import ToolProviderListCache
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
@pytest.fixture
def mock_redis_client():
"""Fixture: Mock Redis client"""
with patch("core.helper.tool_provider_cache.redis_client") as mock:
yield mock
class TestToolProviderListCache:
"""Test class for ToolProviderListCache"""
def test_generate_cache_key(self):
"""Test cache key generation logic"""
# Scenario 1: Specify typ (valid literal value)
tenant_id = "tenant_123"
typ: ToolProviderTypeApiLiteral = "builtin"
expected_key = f"tool_providers:tenant_id:{tenant_id}:type:{typ}"
assert ToolProviderListCache._generate_cache_key(tenant_id, typ) == expected_key
# Scenario 2: typ is None (defaults to "all")
expected_key_all = f"tool_providers:tenant_id:{tenant_id}:type:all"
assert ToolProviderListCache._generate_cache_key(tenant_id) == expected_key_all
def test_get_cached_providers_hit(self, mock_redis_client):
"""Test get cached providers - cache hit and successful decoding"""
tenant_id = "tenant_123"
typ: ToolProviderTypeApiLiteral = "api"
mock_providers = [{"id": "tool", "name": "test_provider"}]
mock_redis_client.get.return_value = json.dumps(mock_providers).encode("utf-8")
result = ToolProviderListCache.get_cached_providers(tenant_id, typ)
mock_redis_client.get.assert_called_once_with(ToolProviderListCache._generate_cache_key(tenant_id, typ))
assert result == mock_providers
def test_get_cached_providers_decode_error(self, mock_redis_client):
"""Test get cached providers - cache hit but decoding failed"""
tenant_id = "tenant_123"
mock_redis_client.get.return_value = b"invalid_json_data"
result = ToolProviderListCache.get_cached_providers(tenant_id)
assert result is None
mock_redis_client.get.assert_called_once()
def test_get_cached_providers_miss(self, mock_redis_client):
"""Test get cached providers - cache miss"""
tenant_id = "tenant_123"
mock_redis_client.get.return_value = None
result = ToolProviderListCache.get_cached_providers(tenant_id)
assert result is None
mock_redis_client.get.assert_called_once()
def test_set_cached_providers(self, mock_redis_client):
"""Test set cached providers"""
tenant_id = "tenant_123"
typ: ToolProviderTypeApiLiteral = "builtin"
mock_providers = [{"id": "tool", "name": "test_provider"}]
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
ToolProviderListCache.set_cached_providers(tenant_id, typ, mock_providers)
mock_redis_client.setex.assert_called_once_with(
cache_key, ToolProviderListCache.CACHE_TTL, json.dumps(mock_providers)
)
def test_invalidate_cache_specific_type(self, mock_redis_client):
"""Test invalidate cache - specific type"""
tenant_id = "tenant_123"
typ: ToolProviderTypeApiLiteral = "workflow"
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
ToolProviderListCache.invalidate_cache(tenant_id, typ)
mock_redis_client.delete.assert_called_once_with(cache_key)
def test_invalidate_cache_all_types(self, mock_redis_client):
"""Test invalidate cache - clear all tenant cache"""
tenant_id = "tenant_123"
mock_keys = [
b"tool_providers:tenant_id:tenant_123:type:all",
b"tool_providers:tenant_id:tenant_123:type:builtin",
]
mock_redis_client.scan_iter.return_value = mock_keys
ToolProviderListCache.invalidate_cache(tenant_id)
mock_redis_client.scan_iter.assert_called_once_with(f"tool_providers:tenant_id:{tenant_id}:*")
mock_redis_client.delete.assert_called_once_with(*mock_keys)
def test_invalidate_cache_no_keys(self, mock_redis_client):
"""Test invalidate cache - no cache keys for tenant"""
tenant_id = "tenant_123"
mock_redis_client.scan_iter.return_value = []
ToolProviderListCache.invalidate_cache(tenant_id)
mock_redis_client.delete.assert_not_called()
def test_redis_fallback_default_return(self, mock_redis_client):
"""Test redis_fallback decorator - default return value (Redis error)"""
mock_redis_client.get.side_effect = RedisError("Redis connection error")
result = ToolProviderListCache.get_cached_providers("tenant_123")
assert result is None
mock_redis_client.get.assert_called_once()
def test_redis_fallback_no_default(self, mock_redis_client):
"""Test redis_fallback decorator - no default return value (Redis error)"""
mock_redis_client.setex.side_effect = RedisError("Redis connection error")
try:
ToolProviderListCache.set_cached_providers("tenant_123", "mcp", [])
except RedisError:
pytest.fail("set_cached_providers should not raise RedisError (handled by fallback)")
mock_redis_client.setex.assert_called_once()