fix(mcp): prevent XSS attacks by validating OAuth endpoint URLs

This commit is contained in:
Novice 2025-09-05 11:11:28 +08:00
parent f0561c0c3b
commit 8925606f33
1 changed files with 80 additions and 3 deletions

View File

@ -25,6 +25,48 @@ OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry
OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
def _validate_url_security(url: str) -> None:
"""Validate URL to prevent XSS attacks by ensuring only safe protocols are allowed."""
if not url:
raise ValueError("URL cannot be empty")
try:
parsed_url = urlparse(url)
except Exception as e:
raise ValueError(f"Invalid URL format: {e}")
# Only allow http and https protocols
allowed_schemes = ["http", "https"]
if parsed_url.scheme.lower() not in allowed_schemes:
raise ValueError(f"Unsafe URL protocol '{parsed_url.scheme}'. Only {allowed_schemes} are allowed")
# Ensure the URL has a valid netloc (domain)
if not parsed_url.netloc:
raise ValueError("URL must have a valid domain")
# Additional check for suspicious patterns that could indicate XSS attempts
url_lower = url.lower()
dangerous_patterns = ["javascript:", "data:", "vbscript:", "file:", "ftp:"]
for pattern in dangerous_patterns:
if pattern in url_lower:
raise ValueError(f"URL contains dangerous pattern: {pattern}")
def _validate_oauth_metadata_urls(metadata: OAuthMetadata) -> None:
"""Validate all URLs in OAuth metadata to prevent XSS attacks."""
# Validate authorization endpoint
if metadata.authorization_endpoint:
_validate_url_security(metadata.authorization_endpoint)
# Validate token endpoint
if metadata.token_endpoint:
_validate_url_security(metadata.token_endpoint)
# Validate registration endpoint
if metadata.registration_endpoint:
_validate_url_security(metadata.registration_endpoint)
class OAuthCallbackState(BaseModel):
provider_id: str
tenant_id: str
@ -113,7 +155,10 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
if 200 <= response.status_code < 300:
body = response.json()
if "authorization_server_url" in body:
return True, body["authorization_server_url"][0]
auth_server_url = body["authorization_server_url"][0]
# Validate the authorization server URL to prevent XSS attacks
_validate_url_security(auth_server_url)
return True, auth_server_url
else:
return False, ""
return False, ""
@ -124,10 +169,15 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = None) -> Optional[OAuthMetadata]:
"""Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata."""
# Validate the server URL first
_validate_url_security(server_url)
# First check if the server supports OAuth 2.0 Resource Discovery
support_resource_discovery, oauth_discovery_url = check_support_resource_discovery(server_url)
if support_resource_discovery:
url = oauth_discovery_url
# Validate the discovered OAuth URL
_validate_url_security(url)
else:
url = urljoin(server_url, "/.well-known/oauth-authorization-server")
@ -138,7 +188,11 @@ def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = N
return None
if not response.is_success:
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
return OAuthMetadata.model_validate(response.json())
metadata = OAuthMetadata.model_validate(response.json())
# Validate all URLs in the metadata to prevent XSS attacks
_validate_oauth_metadata_urls(metadata)
return metadata
except httpx.RequestError as e:
if isinstance(e, httpx.ConnectError):
response = httpx.get(url)
@ -146,7 +200,11 @@ def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = N
return None
if not response.is_success:
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
return OAuthMetadata.model_validate(response.json())
metadata = OAuthMetadata.model_validate(response.json())
# Validate all URLs in the metadata to prevent XSS attacks
_validate_oauth_metadata_urls(metadata)
return metadata
raise
@ -164,6 +222,9 @@ def start_authorization(
if metadata:
authorization_url = metadata.authorization_endpoint
# Validate the authorization endpoint URL to prevent XSS attacks
_validate_url_security(authorization_url)
if response_type not in metadata.response_types_supported:
raise ValueError(f"Incompatible auth server: does not support response type {response_type}")
if (
@ -175,6 +236,8 @@ def start_authorization(
)
else:
authorization_url = urljoin(server_url, "/authorize")
# Validate the constructed authorization URL
_validate_url_security(authorization_url)
code_verifier, code_challenge = generate_pkce_challenge()
@ -218,10 +281,15 @@ def exchange_authorization(
if metadata:
token_url = metadata.token_endpoint
# Validate the token endpoint URL to prevent XSS attacks
_validate_url_security(token_url)
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
else:
token_url = urljoin(server_url, "/token")
# Validate the constructed token URL
_validate_url_security(token_url)
params = {
"grant_type": grant_type,
@ -251,10 +319,15 @@ def refresh_authorization(
if metadata:
token_url = metadata.token_endpoint
# Validate the token endpoint URL to prevent XSS attacks
_validate_url_security(token_url)
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
else:
token_url = urljoin(server_url, "/token")
# Validate the constructed token URL
_validate_url_security(token_url)
params = {
"grant_type": grant_type,
@ -281,8 +354,12 @@ def register_client(
if not metadata.registration_endpoint:
raise ValueError("Incompatible auth server: does not support dynamic client registration")
registration_url = metadata.registration_endpoint
# Validate the registration endpoint URL to prevent XSS attacks
_validate_url_security(registration_url)
else:
registration_url = urljoin(server_url, "/register")
# Validate the constructed registration URL
_validate_url_security(registration_url)
response = httpx.post(
registration_url,