From 8925606f33d7e4f329a315592da74e2ed3de1767 Mon Sep 17 00:00:00 2001 From: Novice Date: Fri, 5 Sep 2025 11:11:28 +0800 Subject: [PATCH] fix(mcp): prevent XSS attacks by validating OAuth endpoint URLs --- api/core/mcp/auth/auth_flow.py | 83 ++++++++++++++++++++++++++++++++-- 1 file changed, 80 insertions(+), 3 deletions(-) diff --git a/api/core/mcp/auth/auth_flow.py b/api/core/mcp/auth/auth_flow.py index 5626849edf..be4c40e908 100644 --- a/api/core/mcp/auth/auth_flow.py +++ b/api/core/mcp/auth/auth_flow.py @@ -25,6 +25,48 @@ OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:" +def _validate_url_security(url: str) -> None: + """Validate URL to prevent XSS attacks by ensuring only safe protocols are allowed.""" + if not url: + raise ValueError("URL cannot be empty") + + try: + parsed_url = urlparse(url) + except Exception as e: + raise ValueError(f"Invalid URL format: {e}") + + # Only allow http and https protocols + allowed_schemes = ["http", "https"] + if parsed_url.scheme.lower() not in allowed_schemes: + raise ValueError(f"Unsafe URL protocol '{parsed_url.scheme}'. Only {allowed_schemes} are allowed") + + # Ensure the URL has a valid netloc (domain) + if not parsed_url.netloc: + raise ValueError("URL must have a valid domain") + + # Additional check for suspicious patterns that could indicate XSS attempts + url_lower = url.lower() + dangerous_patterns = ["javascript:", "data:", "vbscript:", "file:", "ftp:"] + for pattern in dangerous_patterns: + if pattern in url_lower: + raise ValueError(f"URL contains dangerous pattern: {pattern}") + + +def _validate_oauth_metadata_urls(metadata: OAuthMetadata) -> None: + """Validate all URLs in OAuth metadata to prevent XSS attacks.""" + # Validate authorization endpoint + if metadata.authorization_endpoint: + _validate_url_security(metadata.authorization_endpoint) + + # Validate token endpoint + if metadata.token_endpoint: + _validate_url_security(metadata.token_endpoint) + + # Validate registration endpoint + if metadata.registration_endpoint: + _validate_url_security(metadata.registration_endpoint) + + class OAuthCallbackState(BaseModel): provider_id: str tenant_id: str @@ -113,7 +155,10 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]: if 200 <= response.status_code < 300: body = response.json() if "authorization_server_url" in body: - return True, body["authorization_server_url"][0] + auth_server_url = body["authorization_server_url"][0] + # Validate the authorization server URL to prevent XSS attacks + _validate_url_security(auth_server_url) + return True, auth_server_url else: return False, "" return False, "" @@ -124,10 +169,15 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]: def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = None) -> Optional[OAuthMetadata]: """Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata.""" + # Validate the server URL first + _validate_url_security(server_url) + # First check if the server supports OAuth 2.0 Resource Discovery support_resource_discovery, oauth_discovery_url = check_support_resource_discovery(server_url) if support_resource_discovery: url = oauth_discovery_url + # Validate the discovered OAuth URL + _validate_url_security(url) else: url = urljoin(server_url, "/.well-known/oauth-authorization-server") @@ -138,7 +188,11 @@ def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = N return None if not response.is_success: raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata") - return OAuthMetadata.model_validate(response.json()) + + metadata = OAuthMetadata.model_validate(response.json()) + # Validate all URLs in the metadata to prevent XSS attacks + _validate_oauth_metadata_urls(metadata) + return metadata except httpx.RequestError as e: if isinstance(e, httpx.ConnectError): response = httpx.get(url) @@ -146,7 +200,11 @@ def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = N return None if not response.is_success: raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata") - return OAuthMetadata.model_validate(response.json()) + + metadata = OAuthMetadata.model_validate(response.json()) + # Validate all URLs in the metadata to prevent XSS attacks + _validate_oauth_metadata_urls(metadata) + return metadata raise @@ -164,6 +222,9 @@ def start_authorization( if metadata: authorization_url = metadata.authorization_endpoint + # Validate the authorization endpoint URL to prevent XSS attacks + _validate_url_security(authorization_url) + if response_type not in metadata.response_types_supported: raise ValueError(f"Incompatible auth server: does not support response type {response_type}") if ( @@ -175,6 +236,8 @@ def start_authorization( ) else: authorization_url = urljoin(server_url, "/authorize") + # Validate the constructed authorization URL + _validate_url_security(authorization_url) code_verifier, code_challenge = generate_pkce_challenge() @@ -218,10 +281,15 @@ def exchange_authorization( if metadata: token_url = metadata.token_endpoint + # Validate the token endpoint URL to prevent XSS attacks + _validate_url_security(token_url) + if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported: raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}") else: token_url = urljoin(server_url, "/token") + # Validate the constructed token URL + _validate_url_security(token_url) params = { "grant_type": grant_type, @@ -251,10 +319,15 @@ def refresh_authorization( if metadata: token_url = metadata.token_endpoint + # Validate the token endpoint URL to prevent XSS attacks + _validate_url_security(token_url) + if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported: raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}") else: token_url = urljoin(server_url, "/token") + # Validate the constructed token URL + _validate_url_security(token_url) params = { "grant_type": grant_type, @@ -281,8 +354,12 @@ def register_client( if not metadata.registration_endpoint: raise ValueError("Incompatible auth server: does not support dynamic client registration") registration_url = metadata.registration_endpoint + # Validate the registration endpoint URL to prevent XSS attacks + _validate_url_security(registration_url) else: registration_url = urljoin(server_url, "/register") + # Validate the constructed registration URL + _validate_url_security(registration_url) response = httpx.post( registration_url,