diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index d51b37a9cd..e9e7b72718 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -20,7 +20,6 @@ from controllers.console.wraps import ( ) from core.db.session_factory import session_factory from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration -from core.helper.tool_provider_cache import ToolProviderListCache from core.mcp.auth.auth_flow import auth, handle_callback from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError from core.mcp.mcp_client import MCPClient @@ -987,9 +986,6 @@ class ToolProviderMCPApi(Resource): # Best-effort: if initial fetch fails (e.g., auth required), return created provider as-is logger.warning("Failed to fetch MCP tools after creation", exc_info=True) - # Final cache invalidation to ensure list views are up to date - ToolProviderListCache.invalidate_cache(tenant_id) - return jsonable_encoder(result) @console_ns.expect(parser_mcp_put) @@ -1036,9 +1032,6 @@ class ToolProviderMCPApi(Resource): validation_result=validation_result, ) - # Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations - ToolProviderListCache.invalidate_cache(current_tenant_id) - return {"result": "success"} @console_ns.expect(parser_mcp_delete) @@ -1053,9 +1046,6 @@ class ToolProviderMCPApi(Resource): service = MCPToolManageService(session=session) service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"]) - # Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations - ToolProviderListCache.invalidate_cache(current_tenant_id) - return {"result": "success"} @@ -1106,8 +1096,6 @@ class ToolMCPAuthApi(Resource): credentials=provider_entity.credentials, authed=True, ) - # Invalidate cache after updating credentials - ToolProviderListCache.invalidate_cache(tenant_id) return {"result": "success"} except MCPAuthError as e: try: @@ -1121,22 +1109,16 @@ class ToolMCPAuthApi(Resource): with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) response = service.execute_auth_actions(auth_result) - # Invalidate cache after auth actions may have updated provider state - ToolProviderListCache.invalidate_cache(tenant_id) return response except MCPRefreshTokenError as e: with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id) - # Invalidate cache after clearing credentials - ToolProviderListCache.invalidate_cache(tenant_id) raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e except (MCPError, ValueError) as e: with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id) - # Invalidate cache after clearing credentials - ToolProviderListCache.invalidate_cache(tenant_id) raise ValueError(f"Failed to connect to MCP server: {e}") from e diff --git a/api/core/helper/tool_provider_cache.py b/api/core/helper/tool_provider_cache.py deleted file mode 100644 index c5447c2b3f..0000000000 --- a/api/core/helper/tool_provider_cache.py +++ /dev/null @@ -1,58 +0,0 @@ -import json -import logging -from typing import Any, cast - -from core.tools.entities.api_entities import ToolProviderTypeApiLiteral -from extensions.ext_redis import redis_client, redis_fallback - -logger = logging.getLogger(__name__) - - -class ToolProviderListCache: - """Cache for tool provider lists""" - - CACHE_TTL = 300 # 5 minutes - - @staticmethod - def _generate_cache_key(tenant_id: str, typ: ToolProviderTypeApiLiteral = None) -> str: - """Generate cache key for tool providers list""" - type_filter = typ or "all" - return f"tool_providers:tenant_id:{tenant_id}:type:{type_filter}" - - @staticmethod - @redis_fallback(default_return=None) - def get_cached_providers(tenant_id: str, typ: ToolProviderTypeApiLiteral = None) -> list[dict[str, Any]] | None: - """Get cached tool providers""" - cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ) - cached_data = redis_client.get(cache_key) - if cached_data: - try: - return json.loads(cached_data.decode("utf-8")) - except (json.JSONDecodeError, UnicodeDecodeError): - logger.warning("Failed to decode cached tool providers data") - return None - return None - - @staticmethod - @redis_fallback() - def set_cached_providers(tenant_id: str, typ: ToolProviderTypeApiLiteral, providers: list[dict[str, Any]]): - """Cache tool providers""" - cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ) - redis_client.setex(cache_key, ToolProviderListCache.CACHE_TTL, json.dumps(providers)) - - @staticmethod - @redis_fallback() - def invalidate_cache(tenant_id: str, typ: ToolProviderTypeApiLiteral = None): - """Invalidate cache for tool providers""" - if typ: - # Invalidate specific type cache - cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ) - redis_client.delete(cache_key) - else: - # Invalidate all caches for this tenant - keys = ["builtin", "model", "api", "workflow", "mcp"] - pipeline = redis_client.pipeline() - for key in keys: - cache_key = ToolProviderListCache._generate_cache_key(tenant_id, cast(ToolProviderTypeApiLiteral, key)) - pipeline.delete(cache_key) - pipeline.execute() diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index b3b6e36346..250d29f335 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -7,7 +7,6 @@ from httpx import get from sqlalchemy import select from core.entities.provider_entities import ProviderConfig -from core.helper.tool_provider_cache import ToolProviderListCache from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.__base.tool_runtime import ToolRuntime from core.tools.custom_tool.provider import ApiToolProviderController @@ -178,9 +177,6 @@ class ApiToolManageService: # update labels ToolLabelManager.update_tool_labels(provider_controller, labels) - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id) - return {"result": "success"} @staticmethod @@ -322,9 +318,6 @@ class ApiToolManageService: # update labels ToolLabelManager.update_tool_labels(provider_controller, labels) - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id) - return {"result": "success"} @staticmethod @@ -347,9 +340,6 @@ class ApiToolManageService: db.session.delete(provider) db.session.commit() - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id) - return {"result": "success"} @staticmethod diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 87951d53e6..6797a67dde 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -12,7 +12,6 @@ from constants import HIDDEN_VALUE, UNKNOWN_VALUE from core.helper.name_generator import generate_incremental_name from core.helper.position_helper import is_filtered from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache -from core.helper.tool_provider_cache import ToolProviderListCache from core.plugin.entities.plugin_daemon import CredentialType from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort @@ -205,9 +204,6 @@ class BuiltinToolManageService: db_provider.name = name session.commit() - - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id) except Exception as e: session.rollback() raise ValueError(str(e)) @@ -290,8 +286,6 @@ class BuiltinToolManageService: session.rollback() raise ValueError(str(e)) - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id, "builtin") return {"result": "success"} @staticmethod @@ -409,9 +403,6 @@ class BuiltinToolManageService: ) cache.delete() - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id) - return {"result": "success"} @staticmethod @@ -434,8 +425,6 @@ class BuiltinToolManageService: target_provider.is_default = True session.commit() - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id) return {"result": "success"} @staticmethod diff --git a/api/services/tools/tools_manage_service.py b/api/services/tools/tools_manage_service.py index 038c462f15..51e9120b8d 100644 --- a/api/services/tools/tools_manage_service.py +++ b/api/services/tools/tools_manage_service.py @@ -1,6 +1,5 @@ import logging -from core.helper.tool_provider_cache import ToolProviderListCache from core.tools.entities.api_entities import ToolProviderTypeApiLiteral from core.tools.tool_manager import ToolManager from services.tools.tools_transform_service import ToolTransformService @@ -16,14 +15,6 @@ class ToolCommonService: :return: the list of tool providers """ - # Try to get from cache first - cached_result = ToolProviderListCache.get_cached_providers(tenant_id, typ) - if cached_result is not None: - logger.debug("Returning cached tool providers for tenant %s, type %s", tenant_id, typ) - return cached_result - - # Cache miss - fetch from database - logger.debug("Cache miss for tool providers, fetching from database for tenant %s, type %s", tenant_id, typ) providers = ToolManager.list_providers_from_api(user_id, tenant_id, typ) # add icon @@ -32,7 +23,4 @@ class ToolCommonService: result = [provider.to_dict() for provider in providers] - # Cache the result - ToolProviderListCache.set_cached_providers(tenant_id, typ, result) - return result diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index 714a651839..ab5d5480df 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -5,9 +5,8 @@ from datetime import datetime from typing import Any from sqlalchemy import or_, select +from sqlalchemy.orm import Session -from core.db.session_factory import session_factory -from core.helper.tool_provider_cache import ToolProviderListCache from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.__base.tool_provider import ToolProviderController from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity @@ -86,17 +85,13 @@ class WorkflowToolManageService: except Exception as e: raise ValueError(str(e)) - with session_factory.create_session() as session, session.begin(): + with Session(db.engine, expire_on_commit=False) as session, session.begin(): session.add(workflow_tool_provider) if labels is not None: ToolLabelManager.update_tool_labels( ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels ) - - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id) - return {"result": "success"} @classmethod @@ -184,9 +179,6 @@ class WorkflowToolManageService: ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels ) - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id) - return {"result": "success"} @classmethod @@ -249,9 +241,6 @@ class WorkflowToolManageService: db.session.commit() - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id) - return {"result": "success"} @classmethod diff --git a/api/tests/unit_tests/controllers/console/workspace/test_tool_provider.py b/api/tests/unit_tests/controllers/console/workspace/test_tool_provider.py index 2b03813ef4..c608f731c5 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_tool_provider.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_tool_provider.py @@ -41,13 +41,10 @@ def client(): @patch( "controllers.console.workspace.tool_providers.current_account_with_tenant", return_value=(MagicMock(id="u1"), "t1") ) -@patch("controllers.console.workspace.tool_providers.ToolProviderListCache.invalidate_cache", return_value=None) @patch("controllers.console.workspace.tool_providers.Session") @patch("controllers.console.workspace.tool_providers.MCPToolManageService._reconnect_with_url") @pytest.mark.usefixtures("_mock_cache", "_mock_user_tenant") -def test_create_mcp_provider_populates_tools( - mock_reconnect, mock_session, mock_invalidate_cache, mock_current_account_with_tenant, client -): +def test_create_mcp_provider_populates_tools(mock_reconnect, mock_session, mock_current_account_with_tenant, client): # Arrange: reconnect returns tools immediately mock_reconnect.return_value = ReconnectResult( authed=True, diff --git a/api/tests/unit_tests/core/helper/test_tool_provider_cache.py b/api/tests/unit_tests/core/helper/test_tool_provider_cache.py deleted file mode 100644 index d237c68f35..0000000000 --- a/api/tests/unit_tests/core/helper/test_tool_provider_cache.py +++ /dev/null @@ -1,126 +0,0 @@ -import json -from unittest.mock import patch - -import pytest -from redis.exceptions import RedisError - -from core.helper.tool_provider_cache import ToolProviderListCache -from core.tools.entities.api_entities import ToolProviderTypeApiLiteral - - -@pytest.fixture -def mock_redis_client(): - """Fixture: Mock Redis client""" - with patch("core.helper.tool_provider_cache.redis_client") as mock: - yield mock - - -class TestToolProviderListCache: - """Test class for ToolProviderListCache""" - - def test_generate_cache_key(self): - """Test cache key generation logic""" - # Scenario 1: Specify typ (valid literal value) - tenant_id = "tenant_123" - typ: ToolProviderTypeApiLiteral = "builtin" - expected_key = f"tool_providers:tenant_id:{tenant_id}:type:{typ}" - assert ToolProviderListCache._generate_cache_key(tenant_id, typ) == expected_key - - # Scenario 2: typ is None (defaults to "all") - expected_key_all = f"tool_providers:tenant_id:{tenant_id}:type:all" - assert ToolProviderListCache._generate_cache_key(tenant_id) == expected_key_all - - def test_get_cached_providers_hit(self, mock_redis_client): - """Test get cached providers - cache hit and successful decoding""" - tenant_id = "tenant_123" - typ: ToolProviderTypeApiLiteral = "api" - mock_providers = [{"id": "tool", "name": "test_provider"}] - mock_redis_client.get.return_value = json.dumps(mock_providers).encode("utf-8") - - result = ToolProviderListCache.get_cached_providers(tenant_id, typ) - - mock_redis_client.get.assert_called_once_with(ToolProviderListCache._generate_cache_key(tenant_id, typ)) - assert result == mock_providers - - def test_get_cached_providers_decode_error(self, mock_redis_client): - """Test get cached providers - cache hit but decoding failed""" - tenant_id = "tenant_123" - mock_redis_client.get.return_value = b"invalid_json_data" - - result = ToolProviderListCache.get_cached_providers(tenant_id) - - assert result is None - mock_redis_client.get.assert_called_once() - - def test_get_cached_providers_miss(self, mock_redis_client): - """Test get cached providers - cache miss""" - tenant_id = "tenant_123" - mock_redis_client.get.return_value = None - - result = ToolProviderListCache.get_cached_providers(tenant_id) - - assert result is None - mock_redis_client.get.assert_called_once() - - def test_set_cached_providers(self, mock_redis_client): - """Test set cached providers""" - tenant_id = "tenant_123" - typ: ToolProviderTypeApiLiteral = "builtin" - mock_providers = [{"id": "tool", "name": "test_provider"}] - cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ) - - ToolProviderListCache.set_cached_providers(tenant_id, typ, mock_providers) - - mock_redis_client.setex.assert_called_once_with( - cache_key, ToolProviderListCache.CACHE_TTL, json.dumps(mock_providers) - ) - - def test_invalidate_cache_specific_type(self, mock_redis_client): - """Test invalidate cache - specific type""" - tenant_id = "tenant_123" - typ: ToolProviderTypeApiLiteral = "workflow" - cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ) - - ToolProviderListCache.invalidate_cache(tenant_id, typ) - - mock_redis_client.delete.assert_called_once_with(cache_key) - - def test_invalidate_cache_all_types(self, mock_redis_client): - """Test invalidate cache - clear all tenant cache""" - tenant_id = "tenant_123" - mock_keys = [ - b"tool_providers:tenant_id:tenant_123:type:all", - b"tool_providers:tenant_id:tenant_123:type:builtin", - ] - mock_redis_client.scan_iter.return_value = mock_keys - - ToolProviderListCache.invalidate_cache(tenant_id) - - def test_invalidate_cache_no_keys(self, mock_redis_client): - """Test invalidate cache - no cache keys for tenant""" - tenant_id = "tenant_123" - mock_redis_client.scan_iter.return_value = [] - - ToolProviderListCache.invalidate_cache(tenant_id) - - mock_redis_client.delete.assert_not_called() - - def test_redis_fallback_default_return(self, mock_redis_client): - """Test redis_fallback decorator - default return value (Redis error)""" - mock_redis_client.get.side_effect = RedisError("Redis connection error") - - result = ToolProviderListCache.get_cached_providers("tenant_123") - - assert result is None - mock_redis_client.get.assert_called_once() - - def test_redis_fallback_no_default(self, mock_redis_client): - """Test redis_fallback decorator - no default return value (Redis error)""" - mock_redis_client.setex.side_effect = RedisError("Redis connection error") - - try: - ToolProviderListCache.set_cached_providers("tenant_123", "mcp", []) - except RedisError: - pytest.fail("set_cached_providers should not raise RedisError (handled by fallback)") - - mock_redis_client.setex.assert_called_once()