fix: add RFC 9728 compliant well-known URL discovery with path insertion fallback (#29960)

This commit is contained in:
Novice 2025-12-21 09:19:11 +08:00 committed by GitHub
parent 7b60ff3d2d
commit 7501360663
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 170 additions and 87 deletions

View File

@ -18,6 +18,7 @@ from controllers.console.wraps import (
setup_required, setup_required,
) )
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
from core.helper.tool_provider_cache import ToolProviderListCache
from core.mcp.auth.auth_flow import auth, handle_callback from core.mcp.auth.auth_flow import auth, handle_callback
from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError
from core.mcp.mcp_client import MCPClient from core.mcp.mcp_client import MCPClient
@ -944,7 +945,7 @@ class ToolProviderMCPApi(Resource):
configuration = MCPConfiguration.model_validate(args["configuration"]) configuration = MCPConfiguration.model_validate(args["configuration"])
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
# Create provider # Create provider in transaction
with Session(db.engine) as session, session.begin(): with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session) service = MCPToolManageService(session=session)
result = service.create_provider( result = service.create_provider(
@ -960,6 +961,10 @@ class ToolProviderMCPApi(Resource):
configuration=configuration, configuration=configuration,
authentication=authentication, authentication=authentication,
) )
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
ToolProviderListCache.invalidate_cache(tenant_id)
return jsonable_encoder(result) return jsonable_encoder(result)
@console_ns.expect(parser_mcp_put) @console_ns.expect(parser_mcp_put)
@ -972,17 +977,23 @@ class ToolProviderMCPApi(Resource):
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
# Step 1: Validate server URL change if needed (includes URL format validation and network operation) # Step 1: Get provider data for URL validation (short-lived session, no network I/O)
validation_result = None validation_data = None
with Session(db.engine) as session: with Session(db.engine) as session:
service = MCPToolManageService(session=session) service = MCPToolManageService(session=session)
validation_result = service.validate_server_url_change( validation_data = service.get_provider_for_url_validation(
tenant_id=current_tenant_id, provider_id=args["provider_id"], new_server_url=args["server_url"] tenant_id=current_tenant_id, provider_id=args["provider_id"]
) )
# No need to check for errors here, exceptions will be raised directly # Step 2: Perform URL validation with network I/O OUTSIDE of any database session
# This prevents holding database locks during potentially slow network operations
validation_result = MCPToolManageService.validate_server_url_standalone(
tenant_id=current_tenant_id,
new_server_url=args["server_url"],
validation_data=validation_data,
)
# Step 2: Perform database update in a transaction # Step 3: Perform database update in a transaction
with Session(db.engine) as session, session.begin(): with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session) service = MCPToolManageService(session=session)
service.update_provider( service.update_provider(
@ -999,6 +1010,10 @@ class ToolProviderMCPApi(Resource):
authentication=authentication, authentication=authentication,
validation_result=validation_result, validation_result=validation_result,
) )
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
ToolProviderListCache.invalidate_cache(current_tenant_id)
return {"result": "success"} return {"result": "success"}
@console_ns.expect(parser_mcp_delete) @console_ns.expect(parser_mcp_delete)
@ -1012,6 +1027,10 @@ class ToolProviderMCPApi(Resource):
with Session(db.engine) as session, session.begin(): with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session) service = MCPToolManageService(session=session)
service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"]) service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"])
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
ToolProviderListCache.invalidate_cache(current_tenant_id)
return {"result": "success"} return {"result": "success"}

View File

@ -47,7 +47,11 @@ def build_protected_resource_metadata_discovery_urls(
""" """
Build a list of URLs to try for Protected Resource Metadata discovery. Build a list of URLs to try for Protected Resource Metadata discovery.
Per SEP-985, supports fallback when discovery fails at one URL. Per RFC 9728 Section 5.1, supports fallback when discovery fails at one URL.
Priority order:
1. URL from WWW-Authenticate header (if provided)
2. Well-known URI with path: https://example.com/.well-known/oauth-protected-resource/public/mcp
3. Well-known URI at root: https://example.com/.well-known/oauth-protected-resource
""" """
urls = [] urls = []
@ -58,9 +62,18 @@ def build_protected_resource_metadata_discovery_urls(
# Fallback: construct from server URL # Fallback: construct from server URL
parsed = urlparse(server_url) parsed = urlparse(server_url)
base_url = f"{parsed.scheme}://{parsed.netloc}" base_url = f"{parsed.scheme}://{parsed.netloc}"
fallback_url = urljoin(base_url, "/.well-known/oauth-protected-resource") path = parsed.path.rstrip("/")
if fallback_url not in urls:
urls.append(fallback_url) # Priority 2: With path insertion (e.g., /.well-known/oauth-protected-resource/public/mcp)
if path:
path_url = f"{base_url}/.well-known/oauth-protected-resource{path}"
if path_url not in urls:
urls.append(path_url)
# Priority 3: At root (e.g., /.well-known/oauth-protected-resource)
root_url = f"{base_url}/.well-known/oauth-protected-resource"
if root_url not in urls:
urls.append(root_url)
return urls return urls
@ -71,30 +84,34 @@ def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: st
Supports both OAuth 2.0 (RFC 8414) and OpenID Connect discovery. Supports both OAuth 2.0 (RFC 8414) and OpenID Connect discovery.
Per RFC 8414 section 3: Per RFC 8414 section 3.1 and section 5, try all possible endpoints:
- If issuer has no path: https://example.com/.well-known/oauth-authorization-server - OAuth 2.0 with path insertion: https://example.com/.well-known/oauth-authorization-server/tenant1
- If issuer has path: https://example.com/.well-known/oauth-authorization-server{path} - OpenID Connect with path insertion: https://example.com/.well-known/openid-configuration/tenant1
- OpenID Connect path appending: https://example.com/tenant1/.well-known/openid-configuration
Example: - OAuth 2.0 at root: https://example.com/.well-known/oauth-authorization-server
- issuer: https://example.com/oauth - OpenID Connect at root: https://example.com/.well-known/openid-configuration
- metadata: https://example.com/.well-known/oauth-authorization-server/oauth
""" """
urls = [] urls = []
base_url = auth_server_url or server_url base_url = auth_server_url or server_url
parsed = urlparse(base_url) parsed = urlparse(base_url)
base = f"{parsed.scheme}://{parsed.netloc}" base = f"{parsed.scheme}://{parsed.netloc}"
path = parsed.path.rstrip("/") # Remove trailing slash path = parsed.path.rstrip("/")
# OAuth 2.0 Authorization Server Metadata at root (MCP-03-26)
urls.append(f"{base}/.well-known/oauth-authorization-server")
# Try OpenID Connect discovery first (more common) # OpenID Connect Discovery at root
urls.append(urljoin(base + "/", ".well-known/openid-configuration")) urls.append(f"{base}/.well-known/openid-configuration")
# OAuth 2.0 Authorization Server Metadata (RFC 8414)
# Include the path component if present in the issuer URL
if path: if path:
urls.append(urljoin(base, f".well-known/oauth-authorization-server{path}")) # OpenID Connect Discovery with path insertion
else: urls.append(f"{base}/.well-known/openid-configuration{path}")
urls.append(urljoin(base, ".well-known/oauth-authorization-server"))
# OpenID Connect Discovery path appending
urls.append(f"{base}{path}/.well-known/openid-configuration")
# OAuth 2.0 Authorization Server Metadata with path insertion
urls.append(f"{base}/.well-known/oauth-authorization-server{path}")
return urls return urls

View File

@ -59,7 +59,7 @@ class MCPClient:
try: try:
logger.debug("Not supported method %s found in URL path, trying default 'mcp' method.", method_name) logger.debug("Not supported method %s found in URL path, trying default 'mcp' method.", method_name)
self.connect_server(sse_client, "sse") self.connect_server(sse_client, "sse")
except MCPConnectionError: except (MCPConnectionError, ValueError):
logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.") logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.")
self.connect_server(streamablehttp_client, "mcp") self.connect_server(streamablehttp_client, "mcp")

View File

@ -15,7 +15,6 @@ from sqlalchemy.orm import Session
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity
from core.helper import encrypter from core.helper import encrypter
from core.helper.provider_cache import NoOpProviderCredentialCache from core.helper.provider_cache import NoOpProviderCredentialCache
from core.helper.tool_provider_cache import ToolProviderListCache
from core.mcp.auth.auth_flow import auth from core.mcp.auth.auth_flow import auth
from core.mcp.auth_client import MCPClientWithAuthRetry from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPAuthError, MCPError from core.mcp.error import MCPAuthError, MCPError
@ -65,6 +64,15 @@ class ServerUrlValidationResult(BaseModel):
return self.needs_validation and self.validation_passed and self.reconnect_result is not None return self.needs_validation and self.validation_passed and self.reconnect_result is not None
class ProviderUrlValidationData(BaseModel):
"""Data required for URL validation, extracted from database to perform network operations outside of session"""
current_server_url_hash: str
headers: dict[str, str]
timeout: float | None
sse_read_timeout: float | None
class MCPToolManageService: class MCPToolManageService:
"""Service class for managing MCP tools and providers.""" """Service class for managing MCP tools and providers."""
@ -166,9 +174,6 @@ class MCPToolManageService:
self._session.add(mcp_tool) self._session.add(mcp_tool)
self._session.flush() self._session.flush()
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
mcp_providers = ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True) mcp_providers = ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
return mcp_providers return mcp_providers
@ -192,7 +197,7 @@ class MCPToolManageService:
Update an MCP provider. Update an MCP provider.
Args: Args:
validation_result: Pre-validation result from validate_server_url_change. validation_result: Pre-validation result from validate_server_url_standalone.
If provided and contains reconnect_result, it will be used If provided and contains reconnect_result, it will be used
instead of performing network operations. instead of performing network operations.
""" """
@ -251,8 +256,6 @@ class MCPToolManageService:
# Flush changes to database # Flush changes to database
self._session.flush() self._session.flush()
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
except IntegrityError as e: except IntegrityError as e:
self._handle_integrity_error(e, name, server_url, server_identifier) self._handle_integrity_error(e, name, server_url, server_identifier)
@ -261,9 +264,6 @@ class MCPToolManageService:
mcp_tool = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) mcp_tool = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
self._session.delete(mcp_tool) self._session.delete(mcp_tool)
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
def list_providers( def list_providers(
self, *, tenant_id: str, for_list: bool = False, include_sensitive: bool = True self, *, tenant_id: str, for_list: bool = False, include_sensitive: bool = True
) -> list[ToolProviderApiEntity]: ) -> list[ToolProviderApiEntity]:
@ -546,30 +546,39 @@ class MCPToolManageService:
) )
return self.execute_auth_actions(auth_result) return self.execute_auth_actions(auth_result)
def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> ReconnectResult: def get_provider_for_url_validation(self, *, tenant_id: str, provider_id: str) -> ProviderUrlValidationData:
"""Attempt to reconnect to MCP provider with new server URL.""" """
Get provider data required for URL validation.
This method performs database read and should be called within a session.
Returns:
ProviderUrlValidationData: Data needed for standalone URL validation
"""
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
provider_entity = provider.to_entity() provider_entity = provider.to_entity()
headers = provider_entity.headers return ProviderUrlValidationData(
current_server_url_hash=provider.server_url_hash,
try: headers=provider_entity.headers,
tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity) timeout=provider_entity.timeout,
return ReconnectResult( sse_read_timeout=provider_entity.sse_read_timeout,
authed=True,
tools=json.dumps([tool.model_dump() for tool in tools]),
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
) )
except MCPAuthError:
return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON)
except MCPError as e:
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
def validate_server_url_change( @staticmethod
self, *, tenant_id: str, provider_id: str, new_server_url: str def validate_server_url_standalone(
*,
tenant_id: str,
new_server_url: str,
validation_data: ProviderUrlValidationData,
) -> ServerUrlValidationResult: ) -> ServerUrlValidationResult:
""" """
Validate server URL change by attempting to connect to the new server. Validate server URL change by attempting to connect to the new server.
This method should be called BEFORE update_provider to perform network operations This method performs network operations and MUST be called OUTSIDE of any database session
outside of the database transaction. to avoid holding locks during network I/O.
Args:
tenant_id: Tenant ID for encryption
new_server_url: The new server URL to validate
validation_data: Provider data obtained from get_provider_for_url_validation
Returns: Returns:
ServerUrlValidationResult: Validation result with connection status and tools if successful ServerUrlValidationResult: Validation result with connection status and tools if successful
@ -579,25 +588,30 @@ class MCPToolManageService:
return ServerUrlValidationResult(needs_validation=False) return ServerUrlValidationResult(needs_validation=False)
# Validate URL format # Validate URL format
if not self._is_valid_url(new_server_url): parsed = urlparse(new_server_url)
if not all([parsed.scheme, parsed.netloc]) or parsed.scheme not in ["http", "https"]:
raise ValueError("Server URL is not valid.") raise ValueError("Server URL is not valid.")
# Always encrypt and hash the URL # Always encrypt and hash the URL
encrypted_server_url = encrypter.encrypt_token(tenant_id, new_server_url) encrypted_server_url = encrypter.encrypt_token(tenant_id, new_server_url)
new_server_url_hash = hashlib.sha256(new_server_url.encode()).hexdigest() new_server_url_hash = hashlib.sha256(new_server_url.encode()).hexdigest()
# Get current provider
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
# Check if URL is actually different # Check if URL is actually different
if new_server_url_hash == provider.server_url_hash: if new_server_url_hash == validation_data.current_server_url_hash:
# URL hasn't changed, but still return the encrypted data # URL hasn't changed, but still return the encrypted data
return ServerUrlValidationResult( return ServerUrlValidationResult(
needs_validation=False, encrypted_server_url=encrypted_server_url, server_url_hash=new_server_url_hash needs_validation=False,
encrypted_server_url=encrypted_server_url,
server_url_hash=new_server_url_hash,
) )
# Perform validation by attempting to connect # Perform network validation - this is the expensive operation that should be outside session
reconnect_result = self._reconnect_provider(server_url=new_server_url, provider=provider) reconnect_result = MCPToolManageService._reconnect_with_url(
server_url=new_server_url,
headers=validation_data.headers,
timeout=validation_data.timeout,
sse_read_timeout=validation_data.sse_read_timeout,
)
return ServerUrlValidationResult( return ServerUrlValidationResult(
needs_validation=True, needs_validation=True,
validation_passed=True, validation_passed=True,
@ -606,6 +620,38 @@ class MCPToolManageService:
server_url_hash=new_server_url_hash, server_url_hash=new_server_url_hash,
) )
@staticmethod
def _reconnect_with_url(
*,
server_url: str,
headers: dict[str, str],
timeout: float | None,
sse_read_timeout: float | None,
) -> ReconnectResult:
"""
Attempt to connect to MCP server with given URL.
This is a static method that performs network I/O without database access.
"""
from core.mcp.mcp_client import MCPClient
try:
with MCPClient(
server_url=server_url,
headers=headers,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
) as mcp_client:
tools = mcp_client.list_tools()
return ReconnectResult(
authed=True,
tools=json.dumps([tool.model_dump() for tool in tools]),
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
)
except MCPAuthError:
return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON)
except MCPError as e:
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
def _build_tool_provider_response( def _build_tool_provider_response(
self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list
) -> ToolProviderApiEntity: ) -> ToolProviderApiEntity:

View File

@ -1308,18 +1308,17 @@ class TestMCPToolManageService:
type("MockTool", (), {"model_dump": lambda self: {"name": "test_tool_2", "description": "Test tool 2"}})(), type("MockTool", (), {"model_dump": lambda self: {"name": "test_tool_2", "description": "Test tool 2"}})(),
] ]
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client: with patch("core.mcp.mcp_client.MCPClient") as mock_mcp_client:
# Setup mock client # Setup mock client
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
mock_client_instance.list_tools.return_value = mock_tools mock_client_instance.list_tools.return_value = mock_tools
# Act: Execute the method under test # Act: Execute the method under test
from extensions.ext_database import db result = MCPToolManageService._reconnect_with_url(
service = MCPToolManageService(db.session())
result = service._reconnect_provider(
server_url="https://example.com/mcp", server_url="https://example.com/mcp",
provider=mcp_provider, headers={"X-Test": "1"},
timeout=mcp_provider.timeout,
sse_read_timeout=mcp_provider.sse_read_timeout,
) )
# Assert: Verify the expected outcomes # Assert: Verify the expected outcomes
@ -1337,8 +1336,12 @@ class TestMCPToolManageService:
assert tools_data[1]["name"] == "test_tool_2" assert tools_data[1]["name"] == "test_tool_2"
# Verify mock interactions # Verify mock interactions
provider_entity = mcp_provider.to_entity() mock_mcp_client.assert_called_once_with(
mock_mcp_client.assert_called_once() server_url="https://example.com/mcp",
headers={"X-Test": "1"},
timeout=mcp_provider.timeout,
sse_read_timeout=mcp_provider.sse_read_timeout,
)
def test_re_connect_mcp_provider_auth_error(self, db_session_with_containers, mock_external_service_dependencies): def test_re_connect_mcp_provider_auth_error(self, db_session_with_containers, mock_external_service_dependencies):
""" """
@ -1361,19 +1364,18 @@ class TestMCPToolManageService:
) )
# Mock MCPClient to raise authentication error # Mock MCPClient to raise authentication error
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client: with patch("core.mcp.mcp_client.MCPClient") as mock_mcp_client:
from core.mcp.error import MCPAuthError from core.mcp.error import MCPAuthError
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
mock_client_instance.list_tools.side_effect = MCPAuthError("Authentication required") mock_client_instance.list_tools.side_effect = MCPAuthError("Authentication required")
# Act: Execute the method under test # Act: Execute the method under test
from extensions.ext_database import db result = MCPToolManageService._reconnect_with_url(
service = MCPToolManageService(db.session())
result = service._reconnect_provider(
server_url="https://example.com/mcp", server_url="https://example.com/mcp",
provider=mcp_provider, headers={},
timeout=mcp_provider.timeout,
sse_read_timeout=mcp_provider.sse_read_timeout,
) )
# Assert: Verify the expected outcomes # Assert: Verify the expected outcomes
@ -1404,18 +1406,17 @@ class TestMCPToolManageService:
) )
# Mock MCPClient to raise connection error # Mock MCPClient to raise connection error
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client: with patch("core.mcp.mcp_client.MCPClient") as mock_mcp_client:
from core.mcp.error import MCPError from core.mcp.error import MCPError
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
mock_client_instance.list_tools.side_effect = MCPError("Connection failed") mock_client_instance.list_tools.side_effect = MCPError("Connection failed")
# Act & Assert: Verify proper error handling # Act & Assert: Verify proper error handling
from extensions.ext_database import db
service = MCPToolManageService(db.session())
with pytest.raises(ValueError, match="Failed to re-connect MCP server: Connection failed"): with pytest.raises(ValueError, match="Failed to re-connect MCP server: Connection failed"):
service._reconnect_provider( MCPToolManageService._reconnect_with_url(
server_url="https://example.com/mcp", server_url="https://example.com/mcp",
provider=mcp_provider, headers={"X-Test": "1"},
timeout=mcp_provider.timeout,
sse_read_timeout=mcp_provider.sse_read_timeout,
) )