feat: implement RFC-compliant OAuth discovery with dynamic scope selection for MCP providers (#28294)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Novice 2025-11-20 11:18:16 +08:00 committed by GitHub
parent 014cbaf387
commit 6be013e072
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 442 additions and 141 deletions

View File

@ -1086,7 +1086,13 @@ class ToolMCPAuthApi(Resource):
return {"result": "success"}
except MCPAuthError as e:
try:
auth_result = auth(provider_entity, args.get("authorization_code"))
# Pass the extracted OAuth metadata hints to auth()
auth_result = auth(
provider_entity,
args.get("authorization_code"),
resource_metadata_url=e.resource_metadata_url,
scope_hint=e.scope_hint,
)
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
response = service.execute_auth_actions(auth_result)
@ -1096,7 +1102,7 @@ class ToolMCPAuthApi(Resource):
service = MCPToolManageService(session=session)
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
except MCPError as e:
except (MCPError, ValueError) as e:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)

View File

@ -6,7 +6,8 @@ import secrets
import urllib.parse
from urllib.parse import urljoin, urlparse
from httpx import ConnectError, HTTPStatusError, RequestError
import httpx
from httpx import RequestError
from pydantic import ValidationError
from core.entities.mcp_provider import MCPProviderEntity, MCPSupportGrantType
@ -20,6 +21,7 @@ from core.mcp.types import (
OAuthClientMetadata,
OAuthMetadata,
OAuthTokens,
ProtectedResourceMetadata,
)
from extensions.ext_redis import redis_client
@ -39,6 +41,131 @@ def generate_pkce_challenge() -> tuple[str, str]:
return code_verifier, code_challenge
def build_protected_resource_metadata_discovery_urls(
www_auth_resource_metadata_url: str | None, server_url: str
) -> list[str]:
"""
Build a list of URLs to try for Protected Resource Metadata discovery.
Per SEP-985, supports fallback when discovery fails at one URL.
"""
urls = []
# First priority: URL from WWW-Authenticate header
if www_auth_resource_metadata_url:
urls.append(www_auth_resource_metadata_url)
# Fallback: construct from server URL
parsed = urlparse(server_url)
base_url = f"{parsed.scheme}://{parsed.netloc}"
fallback_url = urljoin(base_url, "/.well-known/oauth-protected-resource")
if fallback_url not in urls:
urls.append(fallback_url)
return urls
def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: str | None, server_url: str) -> list[str]:
"""
Build a list of URLs to try for OAuth Authorization Server Metadata discovery.
Supports both OAuth 2.0 (RFC 8414) and OpenID Connect discovery.
Per RFC 8414 section 3:
- If issuer has no path: https://example.com/.well-known/oauth-authorization-server
- If issuer has path: https://example.com/.well-known/oauth-authorization-server{path}
Example:
- issuer: https://example.com/oauth
- metadata: https://example.com/.well-known/oauth-authorization-server/oauth
"""
urls = []
base_url = auth_server_url or server_url
parsed = urlparse(base_url)
base = f"{parsed.scheme}://{parsed.netloc}"
path = parsed.path.rstrip("/") # Remove trailing slash
# Try OpenID Connect discovery first (more common)
urls.append(urljoin(base + "/", ".well-known/openid-configuration"))
# OAuth 2.0 Authorization Server Metadata (RFC 8414)
# Include the path component if present in the issuer URL
if path:
urls.append(urljoin(base, f".well-known/oauth-authorization-server{path}"))
else:
urls.append(urljoin(base, ".well-known/oauth-authorization-server"))
return urls
def discover_protected_resource_metadata(
prm_url: str | None, server_url: str, protocol_version: str | None = None
) -> ProtectedResourceMetadata | None:
"""Discover OAuth 2.0 Protected Resource Metadata (RFC 9470)."""
urls = build_protected_resource_metadata_discovery_urls(prm_url, server_url)
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"}
for url in urls:
try:
response = ssrf_proxy.get(url, headers=headers)
if response.status_code == 200:
return ProtectedResourceMetadata.model_validate(response.json())
elif response.status_code == 404:
continue # Try next URL
except (RequestError, ValidationError):
continue # Try next URL
return None
def discover_oauth_authorization_server_metadata(
auth_server_url: str | None, server_url: str, protocol_version: str | None = None
) -> OAuthMetadata | None:
"""Discover OAuth 2.0 Authorization Server Metadata (RFC 8414)."""
urls = build_oauth_authorization_server_metadata_discovery_urls(auth_server_url, server_url)
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"}
for url in urls:
try:
response = ssrf_proxy.get(url, headers=headers)
if response.status_code == 200:
return OAuthMetadata.model_validate(response.json())
elif response.status_code == 404:
continue # Try next URL
except (RequestError, ValidationError):
continue # Try next URL
return None
def get_effective_scope(
scope_from_www_auth: str | None,
prm: ProtectedResourceMetadata | None,
asm: OAuthMetadata | None,
client_scope: str | None,
) -> str | None:
"""
Determine effective scope using priority-based selection strategy.
Priority order:
1. WWW-Authenticate header scope (server explicit requirement)
2. Protected Resource Metadata scopes
3. OAuth Authorization Server Metadata scopes
4. Client configured scope
"""
if scope_from_www_auth:
return scope_from_www_auth
if prm and prm.scopes_supported:
return " ".join(prm.scopes_supported)
if asm and asm.scopes_supported:
return " ".join(asm.scopes_supported)
return client_scope
def _create_secure_redis_state(state_data: OAuthCallbackState) -> str:
"""Create a secure state parameter by storing state data in Redis and returning a random state key."""
# Generate a secure random state key
@ -121,42 +248,36 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
return False, ""
def discover_oauth_metadata(server_url: str, protocol_version: str | None = None) -> OAuthMetadata | None:
"""Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata."""
# First check if the server supports OAuth 2.0 Resource Discovery
support_resource_discovery, oauth_discovery_url = check_support_resource_discovery(server_url)
if support_resource_discovery:
# The oauth_discovery_url is the authorization server base URL
# Try OpenID Connect discovery first (more common), then OAuth 2.0
urls_to_try = [
urljoin(oauth_discovery_url + "/", ".well-known/oauth-authorization-server"),
urljoin(oauth_discovery_url + "/", ".well-known/openid-configuration"),
]
else:
urls_to_try = [urljoin(server_url, "/.well-known/oauth-authorization-server")]
def discover_oauth_metadata(
server_url: str,
resource_metadata_url: str | None = None,
scope_hint: str | None = None,
protocol_version: str | None = None,
) -> tuple[OAuthMetadata | None, ProtectedResourceMetadata | None, str | None]:
"""
Discover OAuth metadata using RFC 8414/9470 standards.
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION}
Args:
server_url: The MCP server URL
resource_metadata_url: Protected Resource Metadata URL from WWW-Authenticate header
scope_hint: Scope hint from WWW-Authenticate header
protocol_version: MCP protocol version
for url in urls_to_try:
try:
response = ssrf_proxy.get(url, headers=headers)
if response.status_code == 404:
continue
if not response.is_success:
response.raise_for_status()
return OAuthMetadata.model_validate(response.json())
except (RequestError, HTTPStatusError) as e:
if isinstance(e, ConnectError):
response = ssrf_proxy.get(url)
if response.status_code == 404:
continue # Try next URL
if not response.is_success:
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
return OAuthMetadata.model_validate(response.json())
# For other errors, try next URL
continue
Returns:
(oauth_metadata, protected_resource_metadata, scope_hint)
"""
# Discover Protected Resource Metadata
prm = discover_protected_resource_metadata(resource_metadata_url, server_url, protocol_version)
return None # No metadata found
# Get authorization server URL from PRM or use server URL
auth_server_url = None
if prm and prm.authorization_servers:
auth_server_url = prm.authorization_servers[0]
# Discover OAuth Authorization Server Metadata
asm = discover_oauth_authorization_server_metadata(auth_server_url, server_url, protocol_version)
return asm, prm, scope_hint
def start_authorization(
@ -166,6 +287,7 @@ def start_authorization(
redirect_url: str,
provider_id: str,
tenant_id: str,
scope: str | None = None,
) -> tuple[str, str]:
"""Begins the authorization flow with secure Redis state storage."""
response_type = "code"
@ -175,13 +297,6 @@ def start_authorization(
authorization_url = metadata.authorization_endpoint
if response_type not in metadata.response_types_supported:
raise ValueError(f"Incompatible auth server: does not support response type {response_type}")
if (
not metadata.code_challenge_methods_supported
or code_challenge_method not in metadata.code_challenge_methods_supported
):
raise ValueError(
f"Incompatible auth server: does not support code challenge method {code_challenge_method}"
)
else:
authorization_url = urljoin(server_url, "/authorize")
@ -210,10 +325,49 @@ def start_authorization(
"state": state_key,
}
# Add scope if provided
if scope:
params["scope"] = scope
authorization_url = f"{authorization_url}?{urllib.parse.urlencode(params)}"
return authorization_url, code_verifier
def _parse_token_response(response: httpx.Response) -> OAuthTokens:
"""
Parse OAuth token response supporting both JSON and form-urlencoded formats.
Per RFC 6749 Section 5.1, the standard format is JSON.
However, some legacy OAuth providers (e.g., early GitHub OAuth Apps) return
application/x-www-form-urlencoded format for backwards compatibility.
Args:
response: The HTTP response from token endpoint
Returns:
Parsed OAuth tokens
Raises:
ValueError: If response cannot be parsed
"""
content_type = response.headers.get("content-type", "").lower()
if "application/json" in content_type:
# Standard OAuth 2.0 JSON response (RFC 6749)
return OAuthTokens.model_validate(response.json())
elif "application/x-www-form-urlencoded" in content_type:
# Legacy form-urlencoded response (non-standard but used by some providers)
token_data = dict(urllib.parse.parse_qsl(response.text))
return OAuthTokens.model_validate(token_data)
else:
# No content-type or unknown - try JSON first, fallback to form-urlencoded
try:
return OAuthTokens.model_validate(response.json())
except (ValidationError, json.JSONDecodeError):
token_data = dict(urllib.parse.parse_qsl(response.text))
return OAuthTokens.model_validate(token_data)
def exchange_authorization(
server_url: str,
metadata: OAuthMetadata | None,
@ -246,7 +400,7 @@ def exchange_authorization(
response = ssrf_proxy.post(token_url, data=params)
if not response.is_success:
raise ValueError(f"Token exchange failed: HTTP {response.status_code}")
return OAuthTokens.model_validate(response.json())
return _parse_token_response(response)
def refresh_authorization(
@ -279,7 +433,7 @@ def refresh_authorization(
raise MCPRefreshTokenError(e) from e
if not response.is_success:
raise MCPRefreshTokenError(response.text)
return OAuthTokens.model_validate(response.json())
return _parse_token_response(response)
def client_credentials_flow(
@ -322,7 +476,7 @@ def client_credentials_flow(
f"Client credentials token request failed: HTTP {response.status_code}, Response: {response.text}"
)
return OAuthTokens.model_validate(response.json())
return _parse_token_response(response)
def register_client(
@ -352,6 +506,8 @@ def auth(
provider: MCPProviderEntity,
authorization_code: str | None = None,
state_param: str | None = None,
resource_metadata_url: str | None = None,
scope_hint: str | None = None,
) -> AuthResult:
"""
Orchestrates the full auth flow with a server using secure Redis state storage.
@ -363,18 +519,26 @@ def auth(
provider: The MCP provider entity
authorization_code: Optional authorization code from OAuth callback
state_param: Optional state parameter from OAuth callback
resource_metadata_url: Optional Protected Resource Metadata URL from WWW-Authenticate
scope_hint: Optional scope hint from WWW-Authenticate header
Returns:
AuthResult containing actions to be performed and response data
"""
actions: list[AuthAction] = []
server_url = provider.decrypt_server_url()
server_metadata = discover_oauth_metadata(server_url)
# Discover OAuth metadata using RFC 8414/9470 standards
server_metadata, prm, scope_from_www_auth = discover_oauth_metadata(
server_url, resource_metadata_url, scope_hint, LATEST_PROTOCOL_VERSION
)
client_metadata = provider.client_metadata
provider_id = provider.id
tenant_id = provider.tenant_id
client_information = provider.retrieve_client_information()
redirect_url = provider.redirect_url
credentials = provider.decrypt_credentials()
# Determine grant type based on server metadata
if not server_metadata:
@ -392,8 +556,8 @@ def auth(
else:
effective_grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS.value
# Get stored credentials
credentials = provider.decrypt_credentials()
# Determine effective scope using priority-based strategy
effective_scope = get_effective_scope(scope_from_www_auth, prm, server_metadata, credentials.get("scope"))
if not client_information:
if authorization_code is not None:
@ -425,12 +589,11 @@ def auth(
if effective_grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS.value:
# Direct token request without user interaction
try:
scope = credentials.get("scope")
tokens = client_credentials_flow(
server_url,
server_metadata,
client_information,
scope,
effective_scope,
)
# Return action to save tokens and grant type
@ -526,6 +689,7 @@ def auth(
redirect_url,
provider_id,
tenant_id,
effective_scope,
)
# Return action to save code verifier

View File

@ -90,7 +90,13 @@ class MCPClientWithAuthRetry(MCPClient):
mcp_service = MCPToolManageService(session=session)
# Perform authentication using the service's auth method
mcp_service.auth_with_actions(self.provider_entity, self.authorization_code)
# Extract OAuth metadata hints from the error
mcp_service.auth_with_actions(
self.provider_entity,
self.authorization_code,
resource_metadata_url=error.resource_metadata_url,
scope_hint=error.scope_hint,
)
# Retrieve new tokens
self.provider_entity = mcp_service.get_provider_entity(

View File

@ -290,7 +290,7 @@ def sse_client(
except httpx.HTTPStatusError as exc:
if exc.response.status_code == 401:
raise MCPAuthError()
raise MCPAuthError(response=exc.response)
raise MCPConnectionError()
except Exception:
logger.exception("Error connecting to SSE endpoint")

View File

@ -1,3 +1,10 @@
import re
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import httpx
class MCPError(Exception):
pass
@ -7,7 +14,49 @@ class MCPConnectionError(MCPError):
class MCPAuthError(MCPConnectionError):
pass
def __init__(
self,
message: str | None = None,
response: "httpx.Response | None" = None,
www_authenticate_header: str | None = None,
):
"""
MCP Authentication Error.
Args:
message: Error message
response: HTTP response object (will extract WWW-Authenticate header if provided)
www_authenticate_header: Pre-extracted WWW-Authenticate header value
"""
super().__init__(message or "Authentication failed")
# Extract OAuth metadata hints from WWW-Authenticate header
if response is not None:
www_authenticate_header = response.headers.get("WWW-Authenticate")
self.resource_metadata_url: str | None = None
self.scope_hint: str | None = None
if www_authenticate_header:
self.resource_metadata_url = self._extract_field(www_authenticate_header, "resource_metadata")
self.scope_hint = self._extract_field(www_authenticate_header, "scope")
@staticmethod
def _extract_field(www_auth: str, field_name: str) -> str | None:
"""Extract a specific field from the WWW-Authenticate header."""
# Pattern to match field="value" or field=value
pattern = rf'{field_name}="([^"]*)"'
match = re.search(pattern, www_auth)
if match:
return match.group(1)
# Try without quotes
pattern = rf"{field_name}=([^\s,]+)"
match = re.search(pattern, www_auth)
if match:
return match.group(1)
return None
class MCPRefreshTokenError(MCPError):

View File

@ -149,7 +149,7 @@ class BaseSession(
messages when entered.
"""
_response_streams: dict[RequestId, queue.Queue[JSONRPCResponse | JSONRPCError]]
_response_streams: dict[RequestId, queue.Queue[JSONRPCResponse | JSONRPCError | HTTPStatusError]]
_request_id: int
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
_receive_request_type: type[ReceiveRequestT]
@ -230,7 +230,7 @@ class BaseSession(
request_id = self._request_id
self._request_id = request_id + 1
response_queue: queue.Queue[JSONRPCResponse | JSONRPCError] = queue.Queue()
response_queue: queue.Queue[JSONRPCResponse | JSONRPCError | HTTPStatusError] = queue.Queue()
self._response_streams[request_id] = response_queue
try:
@ -261,11 +261,17 @@ class BaseSession(
message="No response received",
)
)
elif isinstance(response_or_error, HTTPStatusError):
# HTTPStatusError from streamable_client with preserved response object
if response_or_error.response.status_code == 401:
raise MCPAuthError(response=response_or_error.response)
else:
raise MCPConnectionError(
ErrorData(code=response_or_error.response.status_code, message=str(response_or_error))
)
elif isinstance(response_or_error, JSONRPCError):
if response_or_error.error.code == 401:
raise MCPAuthError(
ErrorData(code=response_or_error.error.code, message=response_or_error.error.message)
)
raise MCPAuthError(message=response_or_error.error.message)
else:
raise MCPConnectionError(
ErrorData(code=response_or_error.error.code, message=response_or_error.error.message)
@ -327,13 +333,17 @@ class BaseSession(
if isinstance(message, HTTPStatusError):
response_queue = self._response_streams.get(self._request_id - 1)
if response_queue is not None:
response_queue.put(
JSONRPCError(
jsonrpc="2.0",
id=self._request_id - 1,
error=ErrorData(code=message.response.status_code, message=message.args[0]),
# For 401 errors, pass the HTTPStatusError directly to preserve response object
if message.response.status_code == 401:
response_queue.put(message)
else:
response_queue.put(
JSONRPCError(
jsonrpc="2.0",
id=self._request_id - 1,
error=ErrorData(code=message.response.status_code, message=message.args[0]),
)
)
)
else:
self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}"))
elif isinstance(message, Exception):

View File

@ -23,7 +23,7 @@ for reference.
not separate types in the schema.
"""
# Client support both version, not support 2025-06-18 yet.
LATEST_PROTOCOL_VERSION = "2025-03-26"
LATEST_PROTOCOL_VERSION = "2025-06-18"
# Server support 2024-11-05 to allow claude to use.
SERVER_LATEST_PROTOCOL_VERSION = "2024-11-05"
DEFAULT_NEGOTIATED_VERSION = "2025-03-26"
@ -1330,3 +1330,13 @@ class OAuthMetadata(BaseModel):
response_types_supported: list[str]
grant_types_supported: list[str] | None = None
code_challenge_methods_supported: list[str] | None = None
scopes_supported: list[str] | None = None
class ProtectedResourceMetadata(BaseModel):
"""OAuth 2.0 Protected Resource Metadata (RFC 9470)."""
resource: str | None = None
authorization_servers: list[str]
scopes_supported: list[str] | None = None
bearer_methods_supported: list[str] | None = None

View File

@ -507,7 +507,11 @@ class MCPToolManageService:
return auth_result.response
def auth_with_actions(
self, provider_entity: MCPProviderEntity, authorization_code: str | None = None
self,
provider_entity: MCPProviderEntity,
authorization_code: str | None = None,
resource_metadata_url: str | None = None,
scope_hint: str | None = None,
) -> dict[str, str]:
"""
Perform authentication and execute all resulting actions.
@ -517,11 +521,18 @@ class MCPToolManageService:
Args:
provider_entity: The MCP provider entity
authorization_code: Optional authorization code
resource_metadata_url: Optional Protected Resource Metadata URL from WWW-Authenticate
scope_hint: Optional scope hint from WWW-Authenticate header
Returns:
Response dictionary from auth result
"""
auth_result = auth(provider_entity, authorization_code)
auth_result = auth(
provider_entity,
authorization_code,
resource_metadata_url=resource_metadata_url,
scope_hint=scope_hint,
)
return self.execute_auth_actions(auth_result)
def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> ReconnectResult:

View File

@ -23,11 +23,13 @@ from core.mcp.auth.auth_flow import (
)
from core.mcp.entities import AuthActionType, AuthResult
from core.mcp.types import (
LATEST_PROTOCOL_VERSION,
OAuthClientInformation,
OAuthClientInformationFull,
OAuthClientMetadata,
OAuthMetadata,
OAuthTokens,
ProtectedResourceMetadata,
)
@ -154,7 +156,7 @@ class TestOAuthDiscovery:
assert auth_url == "https://auth.example.com"
mock_get.assert_called_once_with(
"https://api.example.com/.well-known/oauth-protected-resource",
headers={"MCP-Protocol-Version": "2025-03-26", "User-Agent": "Dify"},
headers={"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"},
)
@patch("core.helper.ssrf_proxy.get")
@ -183,59 +185,61 @@ class TestOAuthDiscovery:
assert auth_url == "https://auth.example.com"
mock_get.assert_called_once_with(
"https://api.example.com/.well-known/oauth-protected-resource?query=1#fragment",
headers={"MCP-Protocol-Version": "2025-03-26", "User-Agent": "Dify"},
headers={"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"},
)
@patch("core.helper.ssrf_proxy.get")
def test_discover_oauth_metadata_with_resource_discovery(self, mock_get):
def test_discover_oauth_metadata_with_resource_discovery(self):
"""Test OAuth metadata discovery with resource discovery support."""
with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check:
mock_check.return_value = (True, "https://auth.example.com")
with patch("core.mcp.auth.auth_flow.discover_protected_resource_metadata") as mock_prm:
with patch("core.mcp.auth.auth_flow.discover_oauth_authorization_server_metadata") as mock_asm:
# Mock protected resource metadata with auth server URL
mock_prm.return_value = ProtectedResourceMetadata(
resource="https://api.example.com",
authorization_servers=["https://auth.example.com"],
)
mock_response = Mock()
mock_response.status_code = 200
mock_response.is_success = True
mock_response.json.return_value = {
"authorization_endpoint": "https://auth.example.com/authorize",
"token_endpoint": "https://auth.example.com/token",
"response_types_supported": ["code"],
}
mock_get.return_value = mock_response
# Mock OAuth authorization server metadata
mock_asm.return_value = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
)
metadata = discover_oauth_metadata("https://api.example.com")
oauth_metadata, prm, scope = discover_oauth_metadata("https://api.example.com")
assert metadata is not None
assert metadata.authorization_endpoint == "https://auth.example.com/authorize"
assert metadata.token_endpoint == "https://auth.example.com/token"
mock_get.assert_called_once_with(
"https://auth.example.com/.well-known/oauth-authorization-server",
headers={"MCP-Protocol-Version": "2025-03-26"},
)
assert oauth_metadata is not None
assert oauth_metadata.authorization_endpoint == "https://auth.example.com/authorize"
assert oauth_metadata.token_endpoint == "https://auth.example.com/token"
assert prm is not None
assert prm.authorization_servers == ["https://auth.example.com"]
@patch("core.helper.ssrf_proxy.get")
def test_discover_oauth_metadata_without_resource_discovery(self, mock_get):
# Verify the discovery functions were called
mock_prm.assert_called_once()
mock_asm.assert_called_once()
def test_discover_oauth_metadata_without_resource_discovery(self):
"""Test OAuth metadata discovery without resource discovery."""
with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check:
mock_check.return_value = (False, "")
with patch("core.mcp.auth.auth_flow.discover_protected_resource_metadata") as mock_prm:
with patch("core.mcp.auth.auth_flow.discover_oauth_authorization_server_metadata") as mock_asm:
# Mock no protected resource metadata
mock_prm.return_value = None
mock_response = Mock()
mock_response.status_code = 200
mock_response.is_success = True
mock_response.json.return_value = {
"authorization_endpoint": "https://api.example.com/oauth/authorize",
"token_endpoint": "https://api.example.com/oauth/token",
"response_types_supported": ["code"],
}
mock_get.return_value = mock_response
# Mock OAuth authorization server metadata
mock_asm.return_value = OAuthMetadata(
authorization_endpoint="https://api.example.com/oauth/authorize",
token_endpoint="https://api.example.com/oauth/token",
response_types_supported=["code"],
)
metadata = discover_oauth_metadata("https://api.example.com")
oauth_metadata, prm, scope = discover_oauth_metadata("https://api.example.com")
assert metadata is not None
assert metadata.authorization_endpoint == "https://api.example.com/oauth/authorize"
mock_get.assert_called_once_with(
"https://api.example.com/.well-known/oauth-authorization-server",
headers={"MCP-Protocol-Version": "2025-03-26"},
)
assert oauth_metadata is not None
assert oauth_metadata.authorization_endpoint == "https://api.example.com/oauth/authorize"
assert prm is None
# Verify the discovery functions were called
mock_prm.assert_called_once()
mock_asm.assert_called_once()
@patch("core.helper.ssrf_proxy.get")
def test_discover_oauth_metadata_not_found(self, mock_get):
@ -247,9 +251,9 @@ class TestOAuthDiscovery:
mock_response.status_code = 404
mock_get.return_value = mock_response
metadata = discover_oauth_metadata("https://api.example.com")
oauth_metadata, prm, scope = discover_oauth_metadata("https://api.example.com")
assert metadata is None
assert oauth_metadata is None
class TestAuthorizationFlow:
@ -342,6 +346,7 @@ class TestAuthorizationFlow:
"""Test successful authorization code exchange."""
mock_response = Mock()
mock_response.is_success = True
mock_response.headers = {"content-type": "application/json"}
mock_response.json.return_value = {
"access_token": "new-access-token",
"token_type": "Bearer",
@ -412,6 +417,7 @@ class TestAuthorizationFlow:
"""Test successful token refresh."""
mock_response = Mock()
mock_response.is_success = True
mock_response.headers = {"content-type": "application/json"}
mock_response.json.return_value = {
"access_token": "refreshed-access-token",
"token_type": "Bearer",
@ -577,11 +583,15 @@ class TestAuthOrchestration:
def test_auth_new_registration(self, mock_start_auth, mock_register, mock_discover, mock_provider, mock_service):
"""Test auth flow for new client registration."""
# Setup
mock_discover.return_value = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
mock_discover.return_value = (
OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
),
None,
None,
)
mock_register.return_value = OAuthClientInformationFull(
client_id="new-client-id",
@ -619,11 +629,15 @@ class TestAuthOrchestration:
def test_auth_exchange_code(self, mock_exchange, mock_retrieve_state, mock_discover, mock_provider, mock_service):
"""Test auth flow for exchanging authorization code."""
# Setup metadata discovery
mock_discover.return_value = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
mock_discover.return_value = (
OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
),
None,
None,
)
# Setup existing client
@ -662,11 +676,15 @@ class TestAuthOrchestration:
def test_auth_exchange_code_without_state(self, mock_discover, mock_provider, mock_service):
"""Test auth flow fails when exchanging code without state."""
# Setup metadata discovery
mock_discover.return_value = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
mock_discover.return_value = (
OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
),
None,
None,
)
mock_provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="existing-client")
@ -698,11 +716,15 @@ class TestAuthOrchestration:
mock_refresh.return_value = new_tokens
with patch("core.mcp.auth.auth_flow.discover_oauth_metadata") as mock_discover:
mock_discover.return_value = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
mock_discover.return_value = (
OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
),
None,
None,
)
result = auth(mock_provider)
@ -725,11 +747,15 @@ class TestAuthOrchestration:
def test_auth_registration_fails_with_code(self, mock_discover, mock_provider, mock_service):
"""Test auth fails when no client info exists but code is provided."""
# Setup metadata discovery
mock_discover.return_value = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
mock_discover.return_value = (
OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
),
None,
None,
)
mock_provider.retrieve_client_information.return_value = None

View File

@ -139,7 +139,9 @@ def test_sse_client_error_handling():
with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sse_connect:
# Mock 401 HTTP error
mock_error = httpx.HTTPStatusError("Unauthorized", request=Mock(), response=Mock(status_code=401))
mock_response = Mock(status_code=401)
mock_response.headers = {"WWW-Authenticate": 'Bearer realm="example"'}
mock_error = httpx.HTTPStatusError("Unauthorized", request=Mock(), response=mock_response)
mock_sse_connect.side_effect = mock_error
with pytest.raises(MCPAuthError):
@ -150,7 +152,9 @@ def test_sse_client_error_handling():
with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sse_connect:
# Mock other HTTP error
mock_error = httpx.HTTPStatusError("Server Error", request=Mock(), response=Mock(status_code=500))
mock_response = Mock(status_code=500)
mock_response.headers = {}
mock_error = httpx.HTTPStatusError("Server Error", request=Mock(), response=mock_response)
mock_sse_connect.side_effect = mock_error
with pytest.raises(MCPConnectionError):

View File

@ -58,7 +58,7 @@ class TestConstants:
def test_protocol_versions(self):
"""Test protocol version constants."""
assert LATEST_PROTOCOL_VERSION == "2025-03-26"
assert LATEST_PROTOCOL_VERSION == "2025-06-18"
assert SERVER_LATEST_PROTOCOL_VERSION == "2024-11-05"
def test_error_codes(self):

View File

@ -24,6 +24,8 @@ import { shouldUseMcpIconForAppIcon } from '@/utils/mcp'
import TabSlider from '@/app/components/base/tab-slider'
import { MCPAuthMethod } from '@/app/components/tools/types'
import Switch from '@/app/components/base/switch'
import AlertTriangle from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback/AlertTriangle'
import { API_PREFIX } from '@/config'
export type DuplicateAppModalProps = {
data?: ToolWithProvider
@ -313,6 +315,17 @@ const MCPModal = ({
/>
<span className='system-sm-medium text-text-secondary'>{t('tools.mcp.modal.useDynamicClientRegistration')}</span>
</div>
{!isDynamicRegistration && (
<div className='mt-2 flex gap-2 rounded-lg bg-state-warning-hover p-3'>
<AlertTriangle className='mt-0.5 h-4 w-4 shrink-0 text-text-warning' />
<div className='system-xs-regular text-text-secondary'>
<div className='mb-1'>{t('tools.mcp.modal.redirectUrlWarning')}</div>
<code className='system-xs-medium block break-all rounded bg-state-warning-active px-2 py-1 text-text-secondary'>
{`${API_PREFIX}/mcp/oauth/callback`}
</code>
</div>
</div>
)}
</div>
<div>
<div className={cn('mb-1 flex h-6 items-center', isDynamicRegistration && 'opacity-50')}>

View File

@ -201,6 +201,7 @@ const translation = {
timeoutPlaceholder: '30',
authentication: 'Authentication',
useDynamicClientRegistration: 'Use Dynamic Client Registration',
redirectUrlWarning: 'Please configure your OAuth redirect URL to:',
clientID: 'Client ID',
clientSecret: 'Client Secret',
clientSecretPlaceholder: 'Client Secret',

View File

@ -201,6 +201,7 @@ const translation = {
timeoutPlaceholder: '30',
authentication: '认证',
useDynamicClientRegistration: '使用动态客户端注册',
redirectUrlWarning: '请将您的 OAuth 重定向 URL 配置为:',
clientID: '客户端 ID',
clientSecret: '客户端密钥',
clientSecretPlaceholder: '客户端密钥',