From 7501360663efe879d716d8a962ce80181a21596b Mon Sep 17 00:00:00 2001 From: Novice Date: Sun, 21 Dec 2025 09:19:11 +0800 Subject: [PATCH] fix: add RFC 9728 compliant well-known URL discovery with path insertion fallback (#29960) --- .../console/workspace/tool_providers.py | 39 ++++-- api/core/mcp/auth/auth_flow.py | 55 +++++--- api/core/mcp/mcp_client.py | 2 +- .../tools/mcp_tools_manage_service.py | 120 ++++++++++++------ .../tools/test_mcp_tools_manage_service.py | 41 +++--- 5 files changed, 170 insertions(+), 87 deletions(-) diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 2c54aa5a20..a2fc45c29c 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -18,6 +18,7 @@ from controllers.console.wraps import ( setup_required, ) from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration +from core.helper.tool_provider_cache import ToolProviderListCache from core.mcp.auth.auth_flow import auth, handle_callback from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError from core.mcp.mcp_client import MCPClient @@ -944,7 +945,7 @@ class ToolProviderMCPApi(Resource): configuration = MCPConfiguration.model_validate(args["configuration"]) authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None - # Create provider + # Create provider in transaction with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) result = service.create_provider( @@ -960,7 +961,11 @@ class ToolProviderMCPApi(Resource): configuration=configuration, authentication=authentication, ) - return jsonable_encoder(result) + + # Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations + ToolProviderListCache.invalidate_cache(tenant_id) + + return jsonable_encoder(result) @console_ns.expect(parser_mcp_put) @setup_required @@ -972,17 +977,23 @@ class ToolProviderMCPApi(Resource): authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None _, current_tenant_id = current_account_with_tenant() - # Step 1: Validate server URL change if needed (includes URL format validation and network operation) - validation_result = None + # Step 1: Get provider data for URL validation (short-lived session, no network I/O) + validation_data = None with Session(db.engine) as session: service = MCPToolManageService(session=session) - validation_result = service.validate_server_url_change( - tenant_id=current_tenant_id, provider_id=args["provider_id"], new_server_url=args["server_url"] + validation_data = service.get_provider_for_url_validation( + tenant_id=current_tenant_id, provider_id=args["provider_id"] ) - # No need to check for errors here, exceptions will be raised directly + # Step 2: Perform URL validation with network I/O OUTSIDE of any database session + # This prevents holding database locks during potentially slow network operations + validation_result = MCPToolManageService.validate_server_url_standalone( + tenant_id=current_tenant_id, + new_server_url=args["server_url"], + validation_data=validation_data, + ) - # Step 2: Perform database update in a transaction + # Step 3: Perform database update in a transaction with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) service.update_provider( @@ -999,7 +1010,11 @@ class ToolProviderMCPApi(Resource): authentication=authentication, validation_result=validation_result, ) - return {"result": "success"} + + # Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations + ToolProviderListCache.invalidate_cache(current_tenant_id) + + return {"result": "success"} @console_ns.expect(parser_mcp_delete) @setup_required @@ -1012,7 +1027,11 @@ class ToolProviderMCPApi(Resource): with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"]) - return {"result": "success"} + + # Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations + ToolProviderListCache.invalidate_cache(current_tenant_id) + + return {"result": "success"} parser_auth = ( diff --git a/api/core/mcp/auth/auth_flow.py b/api/core/mcp/auth/auth_flow.py index 92787b39dd..aef1afb235 100644 --- a/api/core/mcp/auth/auth_flow.py +++ b/api/core/mcp/auth/auth_flow.py @@ -47,7 +47,11 @@ def build_protected_resource_metadata_discovery_urls( """ Build a list of URLs to try for Protected Resource Metadata discovery. - Per SEP-985, supports fallback when discovery fails at one URL. + Per RFC 9728 Section 5.1, supports fallback when discovery fails at one URL. + Priority order: + 1. URL from WWW-Authenticate header (if provided) + 2. Well-known URI with path: https://example.com/.well-known/oauth-protected-resource/public/mcp + 3. Well-known URI at root: https://example.com/.well-known/oauth-protected-resource """ urls = [] @@ -58,9 +62,18 @@ def build_protected_resource_metadata_discovery_urls( # Fallback: construct from server URL parsed = urlparse(server_url) base_url = f"{parsed.scheme}://{parsed.netloc}" - fallback_url = urljoin(base_url, "/.well-known/oauth-protected-resource") - if fallback_url not in urls: - urls.append(fallback_url) + path = parsed.path.rstrip("/") + + # Priority 2: With path insertion (e.g., /.well-known/oauth-protected-resource/public/mcp) + if path: + path_url = f"{base_url}/.well-known/oauth-protected-resource{path}" + if path_url not in urls: + urls.append(path_url) + + # Priority 3: At root (e.g., /.well-known/oauth-protected-resource) + root_url = f"{base_url}/.well-known/oauth-protected-resource" + if root_url not in urls: + urls.append(root_url) return urls @@ -71,30 +84,34 @@ def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: st Supports both OAuth 2.0 (RFC 8414) and OpenID Connect discovery. - Per RFC 8414 section 3: - - If issuer has no path: https://example.com/.well-known/oauth-authorization-server - - If issuer has path: https://example.com/.well-known/oauth-authorization-server{path} - - Example: - - issuer: https://example.com/oauth - - metadata: https://example.com/.well-known/oauth-authorization-server/oauth + Per RFC 8414 section 3.1 and section 5, try all possible endpoints: + - OAuth 2.0 with path insertion: https://example.com/.well-known/oauth-authorization-server/tenant1 + - OpenID Connect with path insertion: https://example.com/.well-known/openid-configuration/tenant1 + - OpenID Connect path appending: https://example.com/tenant1/.well-known/openid-configuration + - OAuth 2.0 at root: https://example.com/.well-known/oauth-authorization-server + - OpenID Connect at root: https://example.com/.well-known/openid-configuration """ urls = [] base_url = auth_server_url or server_url parsed = urlparse(base_url) base = f"{parsed.scheme}://{parsed.netloc}" - path = parsed.path.rstrip("/") # Remove trailing slash + path = parsed.path.rstrip("/") + # OAuth 2.0 Authorization Server Metadata at root (MCP-03-26) + urls.append(f"{base}/.well-known/oauth-authorization-server") - # Try OpenID Connect discovery first (more common) - urls.append(urljoin(base + "/", ".well-known/openid-configuration")) + # OpenID Connect Discovery at root + urls.append(f"{base}/.well-known/openid-configuration") - # OAuth 2.0 Authorization Server Metadata (RFC 8414) - # Include the path component if present in the issuer URL if path: - urls.append(urljoin(base, f".well-known/oauth-authorization-server{path}")) - else: - urls.append(urljoin(base, ".well-known/oauth-authorization-server")) + # OpenID Connect Discovery with path insertion + urls.append(f"{base}/.well-known/openid-configuration{path}") + + # OpenID Connect Discovery path appending + urls.append(f"{base}{path}/.well-known/openid-configuration") + + # OAuth 2.0 Authorization Server Metadata with path insertion + urls.append(f"{base}/.well-known/oauth-authorization-server{path}") return urls diff --git a/api/core/mcp/mcp_client.py b/api/core/mcp/mcp_client.py index b0e0dab9be..2b0645b558 100644 --- a/api/core/mcp/mcp_client.py +++ b/api/core/mcp/mcp_client.py @@ -59,7 +59,7 @@ class MCPClient: try: logger.debug("Not supported method %s found in URL path, trying default 'mcp' method.", method_name) self.connect_server(sse_client, "sse") - except MCPConnectionError: + except (MCPConnectionError, ValueError): logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.") self.connect_server(streamablehttp_client, "mcp") diff --git a/api/services/tools/mcp_tools_manage_service.py b/api/services/tools/mcp_tools_manage_service.py index d641fe0315..252be77b27 100644 --- a/api/services/tools/mcp_tools_manage_service.py +++ b/api/services/tools/mcp_tools_manage_service.py @@ -15,7 +15,6 @@ from sqlalchemy.orm import Session from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity from core.helper import encrypter from core.helper.provider_cache import NoOpProviderCredentialCache -from core.helper.tool_provider_cache import ToolProviderListCache from core.mcp.auth.auth_flow import auth from core.mcp.auth_client import MCPClientWithAuthRetry from core.mcp.error import MCPAuthError, MCPError @@ -65,6 +64,15 @@ class ServerUrlValidationResult(BaseModel): return self.needs_validation and self.validation_passed and self.reconnect_result is not None +class ProviderUrlValidationData(BaseModel): + """Data required for URL validation, extracted from database to perform network operations outside of session""" + + current_server_url_hash: str + headers: dict[str, str] + timeout: float | None + sse_read_timeout: float | None + + class MCPToolManageService: """Service class for managing MCP tools and providers.""" @@ -166,9 +174,6 @@ class MCPToolManageService: self._session.add(mcp_tool) self._session.flush() - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id) - mcp_providers = ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True) return mcp_providers @@ -192,7 +197,7 @@ class MCPToolManageService: Update an MCP provider. Args: - validation_result: Pre-validation result from validate_server_url_change. + validation_result: Pre-validation result from validate_server_url_standalone. If provided and contains reconnect_result, it will be used instead of performing network operations. """ @@ -251,8 +256,6 @@ class MCPToolManageService: # Flush changes to database self._session.flush() - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id) except IntegrityError as e: self._handle_integrity_error(e, name, server_url, server_identifier) @@ -261,9 +264,6 @@ class MCPToolManageService: mcp_tool = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) self._session.delete(mcp_tool) - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id) - def list_providers( self, *, tenant_id: str, for_list: bool = False, include_sensitive: bool = True ) -> list[ToolProviderApiEntity]: @@ -546,30 +546,39 @@ class MCPToolManageService: ) return self.execute_auth_actions(auth_result) - def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> ReconnectResult: - """Attempt to reconnect to MCP provider with new server URL.""" + def get_provider_for_url_validation(self, *, tenant_id: str, provider_id: str) -> ProviderUrlValidationData: + """ + Get provider data required for URL validation. + This method performs database read and should be called within a session. + + Returns: + ProviderUrlValidationData: Data needed for standalone URL validation + """ + provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) provider_entity = provider.to_entity() - headers = provider_entity.headers + return ProviderUrlValidationData( + current_server_url_hash=provider.server_url_hash, + headers=provider_entity.headers, + timeout=provider_entity.timeout, + sse_read_timeout=provider_entity.sse_read_timeout, + ) - try: - tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity) - return ReconnectResult( - authed=True, - tools=json.dumps([tool.model_dump() for tool in tools]), - encrypted_credentials=EMPTY_CREDENTIALS_JSON, - ) - except MCPAuthError: - return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON) - except MCPError as e: - raise ValueError(f"Failed to re-connect MCP server: {e}") from e - - def validate_server_url_change( - self, *, tenant_id: str, provider_id: str, new_server_url: str + @staticmethod + def validate_server_url_standalone( + *, + tenant_id: str, + new_server_url: str, + validation_data: ProviderUrlValidationData, ) -> ServerUrlValidationResult: """ Validate server URL change by attempting to connect to the new server. - This method should be called BEFORE update_provider to perform network operations - outside of the database transaction. + This method performs network operations and MUST be called OUTSIDE of any database session + to avoid holding locks during network I/O. + + Args: + tenant_id: Tenant ID for encryption + new_server_url: The new server URL to validate + validation_data: Provider data obtained from get_provider_for_url_validation Returns: ServerUrlValidationResult: Validation result with connection status and tools if successful @@ -579,25 +588,30 @@ class MCPToolManageService: return ServerUrlValidationResult(needs_validation=False) # Validate URL format - if not self._is_valid_url(new_server_url): + parsed = urlparse(new_server_url) + if not all([parsed.scheme, parsed.netloc]) or parsed.scheme not in ["http", "https"]: raise ValueError("Server URL is not valid.") # Always encrypt and hash the URL encrypted_server_url = encrypter.encrypt_token(tenant_id, new_server_url) new_server_url_hash = hashlib.sha256(new_server_url.encode()).hexdigest() - # Get current provider - provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) - # Check if URL is actually different - if new_server_url_hash == provider.server_url_hash: + if new_server_url_hash == validation_data.current_server_url_hash: # URL hasn't changed, but still return the encrypted data return ServerUrlValidationResult( - needs_validation=False, encrypted_server_url=encrypted_server_url, server_url_hash=new_server_url_hash + needs_validation=False, + encrypted_server_url=encrypted_server_url, + server_url_hash=new_server_url_hash, ) - # Perform validation by attempting to connect - reconnect_result = self._reconnect_provider(server_url=new_server_url, provider=provider) + # Perform network validation - this is the expensive operation that should be outside session + reconnect_result = MCPToolManageService._reconnect_with_url( + server_url=new_server_url, + headers=validation_data.headers, + timeout=validation_data.timeout, + sse_read_timeout=validation_data.sse_read_timeout, + ) return ServerUrlValidationResult( needs_validation=True, validation_passed=True, @@ -606,6 +620,38 @@ class MCPToolManageService: server_url_hash=new_server_url_hash, ) + @staticmethod + def _reconnect_with_url( + *, + server_url: str, + headers: dict[str, str], + timeout: float | None, + sse_read_timeout: float | None, + ) -> ReconnectResult: + """ + Attempt to connect to MCP server with given URL. + This is a static method that performs network I/O without database access. + """ + from core.mcp.mcp_client import MCPClient + + try: + with MCPClient( + server_url=server_url, + headers=headers, + timeout=timeout, + sse_read_timeout=sse_read_timeout, + ) as mcp_client: + tools = mcp_client.list_tools() + return ReconnectResult( + authed=True, + tools=json.dumps([tool.model_dump() for tool in tools]), + encrypted_credentials=EMPTY_CREDENTIALS_JSON, + ) + except MCPAuthError: + return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON) + except MCPError as e: + raise ValueError(f"Failed to re-connect MCP server: {e}") from e + def _build_tool_provider_response( self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list ) -> ToolProviderApiEntity: diff --git a/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py index 8c190762cf..6cae83ac37 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py @@ -1308,18 +1308,17 @@ class TestMCPToolManageService: type("MockTool", (), {"model_dump": lambda self: {"name": "test_tool_2", "description": "Test tool 2"}})(), ] - with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client: + with patch("core.mcp.mcp_client.MCPClient") as mock_mcp_client: # Setup mock client mock_client_instance = mock_mcp_client.return_value.__enter__.return_value mock_client_instance.list_tools.return_value = mock_tools # Act: Execute the method under test - from extensions.ext_database import db - - service = MCPToolManageService(db.session()) - result = service._reconnect_provider( + result = MCPToolManageService._reconnect_with_url( server_url="https://example.com/mcp", - provider=mcp_provider, + headers={"X-Test": "1"}, + timeout=mcp_provider.timeout, + sse_read_timeout=mcp_provider.sse_read_timeout, ) # Assert: Verify the expected outcomes @@ -1337,8 +1336,12 @@ class TestMCPToolManageService: assert tools_data[1]["name"] == "test_tool_2" # Verify mock interactions - provider_entity = mcp_provider.to_entity() - mock_mcp_client.assert_called_once() + mock_mcp_client.assert_called_once_with( + server_url="https://example.com/mcp", + headers={"X-Test": "1"}, + timeout=mcp_provider.timeout, + sse_read_timeout=mcp_provider.sse_read_timeout, + ) def test_re_connect_mcp_provider_auth_error(self, db_session_with_containers, mock_external_service_dependencies): """ @@ -1361,19 +1364,18 @@ class TestMCPToolManageService: ) # Mock MCPClient to raise authentication error - with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client: + with patch("core.mcp.mcp_client.MCPClient") as mock_mcp_client: from core.mcp.error import MCPAuthError mock_client_instance = mock_mcp_client.return_value.__enter__.return_value mock_client_instance.list_tools.side_effect = MCPAuthError("Authentication required") # Act: Execute the method under test - from extensions.ext_database import db - - service = MCPToolManageService(db.session()) - result = service._reconnect_provider( + result = MCPToolManageService._reconnect_with_url( server_url="https://example.com/mcp", - provider=mcp_provider, + headers={}, + timeout=mcp_provider.timeout, + sse_read_timeout=mcp_provider.sse_read_timeout, ) # Assert: Verify the expected outcomes @@ -1404,18 +1406,17 @@ class TestMCPToolManageService: ) # Mock MCPClient to raise connection error - with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client: + with patch("core.mcp.mcp_client.MCPClient") as mock_mcp_client: from core.mcp.error import MCPError mock_client_instance = mock_mcp_client.return_value.__enter__.return_value mock_client_instance.list_tools.side_effect = MCPError("Connection failed") # Act & Assert: Verify proper error handling - from extensions.ext_database import db - - service = MCPToolManageService(db.session()) with pytest.raises(ValueError, match="Failed to re-connect MCP server: Connection failed"): - service._reconnect_provider( + MCPToolManageService._reconnect_with_url( server_url="https://example.com/mcp", - provider=mcp_provider, + headers={"X-Test": "1"}, + timeout=mcp_provider.timeout, + sse_read_timeout=mcp_provider.sse_read_timeout, )