From 5ab315aeafeb702264d15190fb3465b9ddb83c96 Mon Sep 17 00:00:00 2001 From: Vivec <72788785+Vivecccccc@users.noreply.github.com> Date: Wed, 29 Oct 2025 15:11:45 +0800 Subject: [PATCH] fix: set conditional capabilities upon MCP client session initialization (#26234) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Novice --- api/core/entities/mcp_provider.py | 3 ++- api/core/mcp/session/client_session.py | 16 ++++++++++------ api/services/tools/tools_transform_service.py | 3 ++- .../unit_tests/core/mcp/client/test_session.py | 3 --- 4 files changed, 14 insertions(+), 11 deletions(-) diff --git a/api/core/entities/mcp_provider.py b/api/core/entities/mcp_provider.py index 4ac39cef02..7484cea04a 100644 --- a/api/core/entities/mcp_provider.py +++ b/api/core/entities/mcp_provider.py @@ -14,7 +14,6 @@ from core.helper.provider_cache import NoOpProviderCredentialCache from core.mcp.types import OAuthClientInformation, OAuthClientMetadata, OAuthTokens from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderType -from core.tools.utils.encryption import create_provider_encrypter if TYPE_CHECKING: from models.tools import MCPToolProvider @@ -272,6 +271,8 @@ class MCPProviderEntity(BaseModel): def _decrypt_dict(self, data: dict[str, Any]) -> dict[str, Any]: """Generic method to decrypt dictionary fields""" + from core.tools.utils.encryption import create_provider_encrypter + if not data: return {} diff --git a/api/core/mcp/session/client_session.py b/api/core/mcp/session/client_session.py index 77d35cca19..d684fe0dd7 100644 --- a/api/core/mcp/session/client_session.py +++ b/api/core/mcp/session/client_session.py @@ -109,12 +109,16 @@ class ClientSession( self._message_handler = message_handler or _default_message_handler def initialize(self) -> types.InitializeResult: - sampling = types.SamplingCapability() - roots = types.RootsCapability( - # TODO: Should this be based on whether we - # _will_ send notifications, or only whether - # they're supported? - listChanged=True, + # Only set capabilities if non-default callbacks are provided + # This prevents servers from attempting callbacks when we don't actually support them + sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None + roots = ( + types.RootsCapability( + # Only enable listChanged if we have a custom callback + listChanged=True, + ) + if self._list_roots_callback is not _default_list_roots_callback + else None ) result = self.send_request( diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 6e95513318..ab80af7a8d 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -7,7 +7,6 @@ from pydantic import ValidationError from yarl import URL from configs import dify_config -from core.entities.mcp_provider import MCPConfiguration from core.helper.provider_cache import ToolProviderCredentialsCache from core.mcp.types import Tool as MCPTool from core.plugin.entities.plugin_daemon import PluginDatasourceProviderEntity @@ -240,6 +239,8 @@ class ToolTransformService: user_name: str | None = None, include_sensitive: bool = True, ) -> ToolProviderApiEntity: + from core.entities.mcp_provider import MCPConfiguration + # Use provided user_name to avoid N+1 query, fallback to load_user() if not provided if user_name is None: user = db_provider.load_user() diff --git a/api/tests/unit_tests/core/mcp/client/test_session.py b/api/tests/unit_tests/core/mcp/client/test_session.py index 08d5b7d21c..8b24c8ce75 100644 --- a/api/tests/unit_tests/core/mcp/client/test_session.py +++ b/api/tests/unit_tests/core/mcp/client/test_session.py @@ -395,9 +395,6 @@ def test_client_capabilities_default(): # Assert default capabilities assert received_capabilities is not None - assert received_capabilities.sampling is not None - assert received_capabilities.roots is not None - assert received_capabilities.roots.listChanged is True def test_client_capabilities_with_custom_callbacks():