From 71497954b8b0d2fcd74e33554d8c6b03804a8a84 Mon Sep 17 00:00:00 2001 From: yangzheli <43645580+yangzheli@users.noreply.github.com> Date: Mon, 8 Dec 2025 14:34:03 +0800 Subject: [PATCH] perf(api): optimize tool provider list API with Redis caching (#29101) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> --- api/core/helper/tool_provider_cache.py | 56 ++++++++ api/core/tools/tool_manager.py | 87 +++++++----- .../tools/api_tools_manage_service.py | 10 ++ .../tools/builtin_tools_manage_service.py | 13 ++ .../tools/mcp_tools_manage_service.py | 11 ++ api/services/tools/tools_manage_service.py | 12 ++ .../tools/workflow_tools_manage_service.py | 11 ++ .../core/helper/test_tool_provider_cache.py | 129 ++++++++++++++++++ 8 files changed, 297 insertions(+), 32 deletions(-) create mode 100644 api/core/helper/tool_provider_cache.py create mode 100644 api/tests/unit_tests/core/helper/test_tool_provider_cache.py diff --git a/api/core/helper/tool_provider_cache.py b/api/core/helper/tool_provider_cache.py new file mode 100644 index 0000000000..eef5937407 --- /dev/null +++ b/api/core/helper/tool_provider_cache.py @@ -0,0 +1,56 @@ +import json +import logging +from typing import Any + +from core.tools.entities.api_entities import ToolProviderTypeApiLiteral +from extensions.ext_redis import redis_client, redis_fallback + +logger = logging.getLogger(__name__) + + +class ToolProviderListCache: + """Cache for tool provider lists""" + + CACHE_TTL = 300 # 5 minutes + + @staticmethod + def _generate_cache_key(tenant_id: str, typ: ToolProviderTypeApiLiteral = None) -> str: + """Generate cache key for tool providers list""" + type_filter = typ or "all" + return f"tool_providers:tenant_id:{tenant_id}:type:{type_filter}" + + @staticmethod + @redis_fallback(default_return=None) + def get_cached_providers(tenant_id: str, typ: ToolProviderTypeApiLiteral = None) -> list[dict[str, Any]] | None: + """Get cached tool providers""" + cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ) + cached_data = redis_client.get(cache_key) + if cached_data: + try: + return json.loads(cached_data.decode("utf-8")) + except (json.JSONDecodeError, UnicodeDecodeError): + logger.warning("Failed to decode cached tool providers data") + return None + return None + + @staticmethod + @redis_fallback() + def set_cached_providers(tenant_id: str, typ: ToolProviderTypeApiLiteral, providers: list[dict[str, Any]]): + """Cache tool providers""" + cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ) + redis_client.setex(cache_key, ToolProviderListCache.CACHE_TTL, json.dumps(providers)) + + @staticmethod + @redis_fallback() + def invalidate_cache(tenant_id: str, typ: ToolProviderTypeApiLiteral = None): + """Invalidate cache for tool providers""" + if typ: + # Invalidate specific type cache + cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ) + redis_client.delete(cache_key) + else: + # Invalidate all caches for this tenant + pattern = f"tool_providers:tenant_id:{tenant_id}:*" + keys = list(redis_client.scan_iter(pattern)) + if keys: + redis_client.delete(*keys) diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index dd751b8c8d..f8213d9fd7 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -5,7 +5,7 @@ import time from collections.abc import Generator, Mapping from os import listdir, path from threading import Lock -from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Literal, Optional, TypedDict, Union, cast import sqlalchemy as sa from sqlalchemy import select @@ -67,6 +67,11 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +class ApiProviderControllerItem(TypedDict): + provider: ApiToolProvider + controller: ApiToolProviderController + + class ToolManager: _builtin_provider_lock = Lock() _hardcoded_providers: dict[str, BuiltinToolProviderController] = {} @@ -655,9 +660,10 @@ class ToolManager: else: filters.append(typ) - with db.session.no_autoflush: + # Use a single session for all database operations to reduce connection overhead + with Session(db.engine) as session: if "builtin" in filters: - builtin_providers = cls.list_builtin_providers(tenant_id) + builtin_providers = list(cls.list_builtin_providers(tenant_id)) # key: provider name, value: provider db_builtin_providers = { @@ -688,57 +694,74 @@ class ToolManager: # get db api providers if "api" in filters: - db_api_providers = db.session.scalars( + db_api_providers = session.scalars( select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id) ).all() - api_provider_controllers: list[dict[str, Any]] = [ - {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)} - for provider in db_api_providers - ] + # Batch create controllers + api_provider_controllers: list[ApiProviderControllerItem] = [] + for api_provider in db_api_providers: + try: + controller = ToolTransformService.api_provider_to_controller(api_provider) + api_provider_controllers.append({"provider": api_provider, "controller": controller}) + except Exception: + # Skip invalid providers but continue processing others + logger.warning("Failed to create controller for API provider %s", api_provider.id) - # get labels - labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers]) - - for api_provider_controller in api_provider_controllers: - user_provider = ToolTransformService.api_provider_to_user_provider( - provider_controller=api_provider_controller["controller"], - db_provider=api_provider_controller["provider"], - decrypt_credentials=False, - labels=labels.get(api_provider_controller["controller"].provider_id, []), + # Batch get labels for all API providers + if api_provider_controllers: + controllers = cast( + list[ToolProviderController], [item["controller"] for item in api_provider_controllers] ) - result_providers[f"api_provider.{user_provider.name}"] = user_provider + labels = ToolLabelManager.get_tools_labels(controllers) + + for item in api_provider_controllers: + provider_controller = item["controller"] + db_provider = item["provider"] + provider_labels = labels.get(provider_controller.provider_id, []) + user_provider = ToolTransformService.api_provider_to_user_provider( + provider_controller=provider_controller, + db_provider=db_provider, + decrypt_credentials=False, + labels=provider_labels, + ) + result_providers[f"api_provider.{user_provider.name}"] = user_provider if "workflow" in filters: # get workflow providers - workflow_providers = db.session.scalars( + workflow_providers = session.scalars( select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id) ).all() workflow_provider_controllers: list[WorkflowToolProviderController] = [] for workflow_provider in workflow_providers: try: - workflow_provider_controllers.append( + workflow_controller: WorkflowToolProviderController = ( ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider) ) + workflow_provider_controllers.append(workflow_controller) except Exception: # app has been deleted logger.exception("Failed to transform workflow provider %s to controller", workflow_provider.id) + continue + # Batch get labels for workflow providers + if workflow_provider_controllers: + workflow_controllers: list[ToolProviderController] = [ + cast(ToolProviderController, controller) for controller in workflow_provider_controllers + ] + labels = ToolLabelManager.get_tools_labels(workflow_controllers) - labels = ToolLabelManager.get_tools_labels( - [cast(ToolProviderController, controller) for controller in workflow_provider_controllers] - ) + for workflow_provider_controller in workflow_provider_controllers: + provider_labels = labels.get(workflow_provider_controller.provider_id, []) + user_provider = ToolTransformService.workflow_provider_to_user_provider( + provider_controller=workflow_provider_controller, + labels=provider_labels, + ) + result_providers[f"workflow_provider.{user_provider.name}"] = user_provider - for provider_controller in workflow_provider_controllers: - user_provider = ToolTransformService.workflow_provider_to_user_provider( - provider_controller=provider_controller, - labels=labels.get(provider_controller.provider_id, []), - ) - result_providers[f"workflow_provider.{user_provider.name}"] = user_provider if "mcp" in filters: - with Session(db.engine) as session: - mcp_service = MCPToolManageService(session=session) - mcp_providers = mcp_service.list_providers(tenant_id=tenant_id, for_list=True) + mcp_service = MCPToolManageService(session=session) + mcp_providers = mcp_service.list_providers(tenant_id=tenant_id, for_list=True) for mcp_provider in mcp_providers: result_providers[f"mcp_provider.{mcp_provider.name}"] = mcp_provider diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 250d29f335..b3b6e36346 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -7,6 +7,7 @@ from httpx import get from sqlalchemy import select from core.entities.provider_entities import ProviderConfig +from core.helper.tool_provider_cache import ToolProviderListCache from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.__base.tool_runtime import ToolRuntime from core.tools.custom_tool.provider import ApiToolProviderController @@ -177,6 +178,9 @@ class ApiToolManageService: # update labels ToolLabelManager.update_tool_labels(provider_controller, labels) + # Invalidate tool providers cache + ToolProviderListCache.invalidate_cache(tenant_id) + return {"result": "success"} @staticmethod @@ -318,6 +322,9 @@ class ApiToolManageService: # update labels ToolLabelManager.update_tool_labels(provider_controller, labels) + # Invalidate tool providers cache + ToolProviderListCache.invalidate_cache(tenant_id) + return {"result": "success"} @staticmethod @@ -340,6 +347,9 @@ class ApiToolManageService: db.session.delete(provider) db.session.commit() + # Invalidate tool providers cache + ToolProviderListCache.invalidate_cache(tenant_id) + return {"result": "success"} @staticmethod diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 783f2f0d21..cf1d39fa25 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -12,6 +12,7 @@ from constants import HIDDEN_VALUE, UNKNOWN_VALUE from core.helper.name_generator import generate_incremental_name from core.helper.position_helper import is_filtered from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache +from core.helper.tool_provider_cache import ToolProviderListCache from core.plugin.entities.plugin_daemon import CredentialType from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort @@ -204,6 +205,9 @@ class BuiltinToolManageService: db_provider.name = name session.commit() + + # Invalidate tool providers cache + ToolProviderListCache.invalidate_cache(tenant_id) except Exception as e: session.rollback() raise ValueError(str(e)) @@ -282,6 +286,9 @@ class BuiltinToolManageService: session.add(db_provider) session.commit() + + # Invalidate tool providers cache + ToolProviderListCache.invalidate_cache(tenant_id) except Exception as e: session.rollback() raise ValueError(str(e)) @@ -402,6 +409,9 @@ class BuiltinToolManageService: ) cache.delete() + # Invalidate tool providers cache + ToolProviderListCache.invalidate_cache(tenant_id) + return {"result": "success"} @staticmethod @@ -423,6 +433,9 @@ class BuiltinToolManageService: # set new default provider target_provider.is_default = True session.commit() + + # Invalidate tool providers cache + ToolProviderListCache.invalidate_cache(tenant_id) return {"result": "success"} @staticmethod diff --git a/api/services/tools/mcp_tools_manage_service.py b/api/services/tools/mcp_tools_manage_service.py index 7eedf76aed..d641fe0315 100644 --- a/api/services/tools/mcp_tools_manage_service.py +++ b/api/services/tools/mcp_tools_manage_service.py @@ -15,6 +15,7 @@ from sqlalchemy.orm import Session from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity from core.helper import encrypter from core.helper.provider_cache import NoOpProviderCredentialCache +from core.helper.tool_provider_cache import ToolProviderListCache from core.mcp.auth.auth_flow import auth from core.mcp.auth_client import MCPClientWithAuthRetry from core.mcp.error import MCPAuthError, MCPError @@ -164,6 +165,10 @@ class MCPToolManageService: self._session.add(mcp_tool) self._session.flush() + + # Invalidate tool providers cache + ToolProviderListCache.invalidate_cache(tenant_id) + mcp_providers = ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True) return mcp_providers @@ -245,6 +250,9 @@ class MCPToolManageService: # Flush changes to database self._session.flush() + + # Invalidate tool providers cache + ToolProviderListCache.invalidate_cache(tenant_id) except IntegrityError as e: self._handle_integrity_error(e, name, server_url, server_identifier) @@ -253,6 +261,9 @@ class MCPToolManageService: mcp_tool = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) self._session.delete(mcp_tool) + # Invalidate tool providers cache + ToolProviderListCache.invalidate_cache(tenant_id) + def list_providers( self, *, tenant_id: str, for_list: bool = False, include_sensitive: bool = True ) -> list[ToolProviderApiEntity]: diff --git a/api/services/tools/tools_manage_service.py b/api/services/tools/tools_manage_service.py index 51e9120b8d..038c462f15 100644 --- a/api/services/tools/tools_manage_service.py +++ b/api/services/tools/tools_manage_service.py @@ -1,5 +1,6 @@ import logging +from core.helper.tool_provider_cache import ToolProviderListCache from core.tools.entities.api_entities import ToolProviderTypeApiLiteral from core.tools.tool_manager import ToolManager from services.tools.tools_transform_service import ToolTransformService @@ -15,6 +16,14 @@ class ToolCommonService: :return: the list of tool providers """ + # Try to get from cache first + cached_result = ToolProviderListCache.get_cached_providers(tenant_id, typ) + if cached_result is not None: + logger.debug("Returning cached tool providers for tenant %s, type %s", tenant_id, typ) + return cached_result + + # Cache miss - fetch from database + logger.debug("Cache miss for tool providers, fetching from database for tenant %s, type %s", tenant_id, typ) providers = ToolManager.list_providers_from_api(user_id, tenant_id, typ) # add icon @@ -23,4 +32,7 @@ class ToolCommonService: result = [provider.to_dict() for provider in providers] + # Cache the result + ToolProviderListCache.set_cached_providers(tenant_id, typ, result) + return result diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index d89b38d563..fe77ff2dc5 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -7,6 +7,7 @@ from typing import Any from sqlalchemy import or_, select from sqlalchemy.orm import Session +from core.helper.tool_provider_cache import ToolProviderListCache from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.__base.tool_provider import ToolProviderController from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity @@ -91,6 +92,10 @@ class WorkflowToolManageService: ToolLabelManager.update_tool_labels( ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels ) + + # Invalidate tool providers cache + ToolProviderListCache.invalidate_cache(tenant_id) + return {"result": "success"} @classmethod @@ -178,6 +183,9 @@ class WorkflowToolManageService: ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels ) + # Invalidate tool providers cache + ToolProviderListCache.invalidate_cache(tenant_id) + return {"result": "success"} @classmethod @@ -240,6 +248,9 @@ class WorkflowToolManageService: db.session.commit() + # Invalidate tool providers cache + ToolProviderListCache.invalidate_cache(tenant_id) + return {"result": "success"} @classmethod diff --git a/api/tests/unit_tests/core/helper/test_tool_provider_cache.py b/api/tests/unit_tests/core/helper/test_tool_provider_cache.py new file mode 100644 index 0000000000..00f7c9d7e9 --- /dev/null +++ b/api/tests/unit_tests/core/helper/test_tool_provider_cache.py @@ -0,0 +1,129 @@ +import json +from unittest.mock import patch + +import pytest +from redis.exceptions import RedisError + +from core.helper.tool_provider_cache import ToolProviderListCache +from core.tools.entities.api_entities import ToolProviderTypeApiLiteral + + +@pytest.fixture +def mock_redis_client(): + """Fixture: Mock Redis client""" + with patch("core.helper.tool_provider_cache.redis_client") as mock: + yield mock + + +class TestToolProviderListCache: + """Test class for ToolProviderListCache""" + + def test_generate_cache_key(self): + """Test cache key generation logic""" + # Scenario 1: Specify typ (valid literal value) + tenant_id = "tenant_123" + typ: ToolProviderTypeApiLiteral = "builtin" + expected_key = f"tool_providers:tenant_id:{tenant_id}:type:{typ}" + assert ToolProviderListCache._generate_cache_key(tenant_id, typ) == expected_key + + # Scenario 2: typ is None (defaults to "all") + expected_key_all = f"tool_providers:tenant_id:{tenant_id}:type:all" + assert ToolProviderListCache._generate_cache_key(tenant_id) == expected_key_all + + def test_get_cached_providers_hit(self, mock_redis_client): + """Test get cached providers - cache hit and successful decoding""" + tenant_id = "tenant_123" + typ: ToolProviderTypeApiLiteral = "api" + mock_providers = [{"id": "tool", "name": "test_provider"}] + mock_redis_client.get.return_value = json.dumps(mock_providers).encode("utf-8") + + result = ToolProviderListCache.get_cached_providers(tenant_id, typ) + + mock_redis_client.get.assert_called_once_with(ToolProviderListCache._generate_cache_key(tenant_id, typ)) + assert result == mock_providers + + def test_get_cached_providers_decode_error(self, mock_redis_client): + """Test get cached providers - cache hit but decoding failed""" + tenant_id = "tenant_123" + mock_redis_client.get.return_value = b"invalid_json_data" + + result = ToolProviderListCache.get_cached_providers(tenant_id) + + assert result is None + mock_redis_client.get.assert_called_once() + + def test_get_cached_providers_miss(self, mock_redis_client): + """Test get cached providers - cache miss""" + tenant_id = "tenant_123" + mock_redis_client.get.return_value = None + + result = ToolProviderListCache.get_cached_providers(tenant_id) + + assert result is None + mock_redis_client.get.assert_called_once() + + def test_set_cached_providers(self, mock_redis_client): + """Test set cached providers""" + tenant_id = "tenant_123" + typ: ToolProviderTypeApiLiteral = "builtin" + mock_providers = [{"id": "tool", "name": "test_provider"}] + cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ) + + ToolProviderListCache.set_cached_providers(tenant_id, typ, mock_providers) + + mock_redis_client.setex.assert_called_once_with( + cache_key, ToolProviderListCache.CACHE_TTL, json.dumps(mock_providers) + ) + + def test_invalidate_cache_specific_type(self, mock_redis_client): + """Test invalidate cache - specific type""" + tenant_id = "tenant_123" + typ: ToolProviderTypeApiLiteral = "workflow" + cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ) + + ToolProviderListCache.invalidate_cache(tenant_id, typ) + + mock_redis_client.delete.assert_called_once_with(cache_key) + + def test_invalidate_cache_all_types(self, mock_redis_client): + """Test invalidate cache - clear all tenant cache""" + tenant_id = "tenant_123" + mock_keys = [ + b"tool_providers:tenant_id:tenant_123:type:all", + b"tool_providers:tenant_id:tenant_123:type:builtin", + ] + mock_redis_client.scan_iter.return_value = mock_keys + + ToolProviderListCache.invalidate_cache(tenant_id) + + mock_redis_client.scan_iter.assert_called_once_with(f"tool_providers:tenant_id:{tenant_id}:*") + mock_redis_client.delete.assert_called_once_with(*mock_keys) + + def test_invalidate_cache_no_keys(self, mock_redis_client): + """Test invalidate cache - no cache keys for tenant""" + tenant_id = "tenant_123" + mock_redis_client.scan_iter.return_value = [] + + ToolProviderListCache.invalidate_cache(tenant_id) + + mock_redis_client.delete.assert_not_called() + + def test_redis_fallback_default_return(self, mock_redis_client): + """Test redis_fallback decorator - default return value (Redis error)""" + mock_redis_client.get.side_effect = RedisError("Redis connection error") + + result = ToolProviderListCache.get_cached_providers("tenant_123") + + assert result is None + mock_redis_client.get.assert_called_once() + + def test_redis_fallback_no_default(self, mock_redis_client): + """Test redis_fallback decorator - no default return value (Redis error)""" + mock_redis_client.setex.side_effect = RedisError("Redis connection error") + + try: + ToolProviderListCache.set_cached_providers("tenant_123", "mcp", []) + except RedisError: + pytest.fail("set_cached_providers should not raise RedisError (handled by fallback)") + + mock_redis_client.setex.assert_called_once()