mirror of https://github.com/langgenius/dify.git
feat: implement RFC-compliant OAuth discovery with dynamic scope selection for MCP providers (#28294)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
014cbaf387
commit
6be013e072
|
|
@ -1086,7 +1086,13 @@ class ToolMCPAuthApi(Resource):
|
|||
return {"result": "success"}
|
||||
except MCPAuthError as e:
|
||||
try:
|
||||
auth_result = auth(provider_entity, args.get("authorization_code"))
|
||||
# Pass the extracted OAuth metadata hints to auth()
|
||||
auth_result = auth(
|
||||
provider_entity,
|
||||
args.get("authorization_code"),
|
||||
resource_metadata_url=e.resource_metadata_url,
|
||||
scope_hint=e.scope_hint,
|
||||
)
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
response = service.execute_auth_actions(auth_result)
|
||||
|
|
@ -1096,7 +1102,7 @@ class ToolMCPAuthApi(Resource):
|
|||
service = MCPToolManageService(session=session)
|
||||
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
||||
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
|
||||
except MCPError as e:
|
||||
except (MCPError, ValueError) as e:
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,8 @@ import secrets
|
|||
import urllib.parse
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
from httpx import ConnectError, HTTPStatusError, RequestError
|
||||
import httpx
|
||||
from httpx import RequestError
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.entities.mcp_provider import MCPProviderEntity, MCPSupportGrantType
|
||||
|
|
@ -20,6 +21,7 @@ from core.mcp.types import (
|
|||
OAuthClientMetadata,
|
||||
OAuthMetadata,
|
||||
OAuthTokens,
|
||||
ProtectedResourceMetadata,
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
|
@ -39,6 +41,131 @@ def generate_pkce_challenge() -> tuple[str, str]:
|
|||
return code_verifier, code_challenge
|
||||
|
||||
|
||||
def build_protected_resource_metadata_discovery_urls(
|
||||
www_auth_resource_metadata_url: str | None, server_url: str
|
||||
) -> list[str]:
|
||||
"""
|
||||
Build a list of URLs to try for Protected Resource Metadata discovery.
|
||||
|
||||
Per SEP-985, supports fallback when discovery fails at one URL.
|
||||
"""
|
||||
urls = []
|
||||
|
||||
# First priority: URL from WWW-Authenticate header
|
||||
if www_auth_resource_metadata_url:
|
||||
urls.append(www_auth_resource_metadata_url)
|
||||
|
||||
# Fallback: construct from server URL
|
||||
parsed = urlparse(server_url)
|
||||
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
||||
fallback_url = urljoin(base_url, "/.well-known/oauth-protected-resource")
|
||||
if fallback_url not in urls:
|
||||
urls.append(fallback_url)
|
||||
|
||||
return urls
|
||||
|
||||
|
||||
def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: str | None, server_url: str) -> list[str]:
|
||||
"""
|
||||
Build a list of URLs to try for OAuth Authorization Server Metadata discovery.
|
||||
|
||||
Supports both OAuth 2.0 (RFC 8414) and OpenID Connect discovery.
|
||||
|
||||
Per RFC 8414 section 3:
|
||||
- If issuer has no path: https://example.com/.well-known/oauth-authorization-server
|
||||
- If issuer has path: https://example.com/.well-known/oauth-authorization-server{path}
|
||||
|
||||
Example:
|
||||
- issuer: https://example.com/oauth
|
||||
- metadata: https://example.com/.well-known/oauth-authorization-server/oauth
|
||||
"""
|
||||
urls = []
|
||||
base_url = auth_server_url or server_url
|
||||
|
||||
parsed = urlparse(base_url)
|
||||
base = f"{parsed.scheme}://{parsed.netloc}"
|
||||
path = parsed.path.rstrip("/") # Remove trailing slash
|
||||
|
||||
# Try OpenID Connect discovery first (more common)
|
||||
urls.append(urljoin(base + "/", ".well-known/openid-configuration"))
|
||||
|
||||
# OAuth 2.0 Authorization Server Metadata (RFC 8414)
|
||||
# Include the path component if present in the issuer URL
|
||||
if path:
|
||||
urls.append(urljoin(base, f".well-known/oauth-authorization-server{path}"))
|
||||
else:
|
||||
urls.append(urljoin(base, ".well-known/oauth-authorization-server"))
|
||||
|
||||
return urls
|
||||
|
||||
|
||||
def discover_protected_resource_metadata(
|
||||
prm_url: str | None, server_url: str, protocol_version: str | None = None
|
||||
) -> ProtectedResourceMetadata | None:
|
||||
"""Discover OAuth 2.0 Protected Resource Metadata (RFC 9470)."""
|
||||
urls = build_protected_resource_metadata_discovery_urls(prm_url, server_url)
|
||||
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"}
|
||||
|
||||
for url in urls:
|
||||
try:
|
||||
response = ssrf_proxy.get(url, headers=headers)
|
||||
if response.status_code == 200:
|
||||
return ProtectedResourceMetadata.model_validate(response.json())
|
||||
elif response.status_code == 404:
|
||||
continue # Try next URL
|
||||
except (RequestError, ValidationError):
|
||||
continue # Try next URL
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def discover_oauth_authorization_server_metadata(
|
||||
auth_server_url: str | None, server_url: str, protocol_version: str | None = None
|
||||
) -> OAuthMetadata | None:
|
||||
"""Discover OAuth 2.0 Authorization Server Metadata (RFC 8414)."""
|
||||
urls = build_oauth_authorization_server_metadata_discovery_urls(auth_server_url, server_url)
|
||||
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"}
|
||||
|
||||
for url in urls:
|
||||
try:
|
||||
response = ssrf_proxy.get(url, headers=headers)
|
||||
if response.status_code == 200:
|
||||
return OAuthMetadata.model_validate(response.json())
|
||||
elif response.status_code == 404:
|
||||
continue # Try next URL
|
||||
except (RequestError, ValidationError):
|
||||
continue # Try next URL
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_effective_scope(
|
||||
scope_from_www_auth: str | None,
|
||||
prm: ProtectedResourceMetadata | None,
|
||||
asm: OAuthMetadata | None,
|
||||
client_scope: str | None,
|
||||
) -> str | None:
|
||||
"""
|
||||
Determine effective scope using priority-based selection strategy.
|
||||
|
||||
Priority order:
|
||||
1. WWW-Authenticate header scope (server explicit requirement)
|
||||
2. Protected Resource Metadata scopes
|
||||
3. OAuth Authorization Server Metadata scopes
|
||||
4. Client configured scope
|
||||
"""
|
||||
if scope_from_www_auth:
|
||||
return scope_from_www_auth
|
||||
|
||||
if prm and prm.scopes_supported:
|
||||
return " ".join(prm.scopes_supported)
|
||||
|
||||
if asm and asm.scopes_supported:
|
||||
return " ".join(asm.scopes_supported)
|
||||
|
||||
return client_scope
|
||||
|
||||
|
||||
def _create_secure_redis_state(state_data: OAuthCallbackState) -> str:
|
||||
"""Create a secure state parameter by storing state data in Redis and returning a random state key."""
|
||||
# Generate a secure random state key
|
||||
|
|
@ -121,42 +248,36 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
|
|||
return False, ""
|
||||
|
||||
|
||||
def discover_oauth_metadata(server_url: str, protocol_version: str | None = None) -> OAuthMetadata | None:
|
||||
"""Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata."""
|
||||
# First check if the server supports OAuth 2.0 Resource Discovery
|
||||
support_resource_discovery, oauth_discovery_url = check_support_resource_discovery(server_url)
|
||||
if support_resource_discovery:
|
||||
# The oauth_discovery_url is the authorization server base URL
|
||||
# Try OpenID Connect discovery first (more common), then OAuth 2.0
|
||||
urls_to_try = [
|
||||
urljoin(oauth_discovery_url + "/", ".well-known/oauth-authorization-server"),
|
||||
urljoin(oauth_discovery_url + "/", ".well-known/openid-configuration"),
|
||||
]
|
||||
else:
|
||||
urls_to_try = [urljoin(server_url, "/.well-known/oauth-authorization-server")]
|
||||
def discover_oauth_metadata(
|
||||
server_url: str,
|
||||
resource_metadata_url: str | None = None,
|
||||
scope_hint: str | None = None,
|
||||
protocol_version: str | None = None,
|
||||
) -> tuple[OAuthMetadata | None, ProtectedResourceMetadata | None, str | None]:
|
||||
"""
|
||||
Discover OAuth metadata using RFC 8414/9470 standards.
|
||||
|
||||
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION}
|
||||
Args:
|
||||
server_url: The MCP server URL
|
||||
resource_metadata_url: Protected Resource Metadata URL from WWW-Authenticate header
|
||||
scope_hint: Scope hint from WWW-Authenticate header
|
||||
protocol_version: MCP protocol version
|
||||
|
||||
for url in urls_to_try:
|
||||
try:
|
||||
response = ssrf_proxy.get(url, headers=headers)
|
||||
if response.status_code == 404:
|
||||
continue
|
||||
if not response.is_success:
|
||||
response.raise_for_status()
|
||||
return OAuthMetadata.model_validate(response.json())
|
||||
except (RequestError, HTTPStatusError) as e:
|
||||
if isinstance(e, ConnectError):
|
||||
response = ssrf_proxy.get(url)
|
||||
if response.status_code == 404:
|
||||
continue # Try next URL
|
||||
if not response.is_success:
|
||||
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
|
||||
return OAuthMetadata.model_validate(response.json())
|
||||
# For other errors, try next URL
|
||||
continue
|
||||
Returns:
|
||||
(oauth_metadata, protected_resource_metadata, scope_hint)
|
||||
"""
|
||||
# Discover Protected Resource Metadata
|
||||
prm = discover_protected_resource_metadata(resource_metadata_url, server_url, protocol_version)
|
||||
|
||||
return None # No metadata found
|
||||
# Get authorization server URL from PRM or use server URL
|
||||
auth_server_url = None
|
||||
if prm and prm.authorization_servers:
|
||||
auth_server_url = prm.authorization_servers[0]
|
||||
|
||||
# Discover OAuth Authorization Server Metadata
|
||||
asm = discover_oauth_authorization_server_metadata(auth_server_url, server_url, protocol_version)
|
||||
|
||||
return asm, prm, scope_hint
|
||||
|
||||
|
||||
def start_authorization(
|
||||
|
|
@ -166,6 +287,7 @@ def start_authorization(
|
|||
redirect_url: str,
|
||||
provider_id: str,
|
||||
tenant_id: str,
|
||||
scope: str | None = None,
|
||||
) -> tuple[str, str]:
|
||||
"""Begins the authorization flow with secure Redis state storage."""
|
||||
response_type = "code"
|
||||
|
|
@ -175,13 +297,6 @@ def start_authorization(
|
|||
authorization_url = metadata.authorization_endpoint
|
||||
if response_type not in metadata.response_types_supported:
|
||||
raise ValueError(f"Incompatible auth server: does not support response type {response_type}")
|
||||
if (
|
||||
not metadata.code_challenge_methods_supported
|
||||
or code_challenge_method not in metadata.code_challenge_methods_supported
|
||||
):
|
||||
raise ValueError(
|
||||
f"Incompatible auth server: does not support code challenge method {code_challenge_method}"
|
||||
)
|
||||
else:
|
||||
authorization_url = urljoin(server_url, "/authorize")
|
||||
|
||||
|
|
@ -210,10 +325,49 @@ def start_authorization(
|
|||
"state": state_key,
|
||||
}
|
||||
|
||||
# Add scope if provided
|
||||
if scope:
|
||||
params["scope"] = scope
|
||||
|
||||
authorization_url = f"{authorization_url}?{urllib.parse.urlencode(params)}"
|
||||
return authorization_url, code_verifier
|
||||
|
||||
|
||||
def _parse_token_response(response: httpx.Response) -> OAuthTokens:
|
||||
"""
|
||||
Parse OAuth token response supporting both JSON and form-urlencoded formats.
|
||||
|
||||
Per RFC 6749 Section 5.1, the standard format is JSON.
|
||||
However, some legacy OAuth providers (e.g., early GitHub OAuth Apps) return
|
||||
application/x-www-form-urlencoded format for backwards compatibility.
|
||||
|
||||
Args:
|
||||
response: The HTTP response from token endpoint
|
||||
|
||||
Returns:
|
||||
Parsed OAuth tokens
|
||||
|
||||
Raises:
|
||||
ValueError: If response cannot be parsed
|
||||
"""
|
||||
content_type = response.headers.get("content-type", "").lower()
|
||||
|
||||
if "application/json" in content_type:
|
||||
# Standard OAuth 2.0 JSON response (RFC 6749)
|
||||
return OAuthTokens.model_validate(response.json())
|
||||
elif "application/x-www-form-urlencoded" in content_type:
|
||||
# Legacy form-urlencoded response (non-standard but used by some providers)
|
||||
token_data = dict(urllib.parse.parse_qsl(response.text))
|
||||
return OAuthTokens.model_validate(token_data)
|
||||
else:
|
||||
# No content-type or unknown - try JSON first, fallback to form-urlencoded
|
||||
try:
|
||||
return OAuthTokens.model_validate(response.json())
|
||||
except (ValidationError, json.JSONDecodeError):
|
||||
token_data = dict(urllib.parse.parse_qsl(response.text))
|
||||
return OAuthTokens.model_validate(token_data)
|
||||
|
||||
|
||||
def exchange_authorization(
|
||||
server_url: str,
|
||||
metadata: OAuthMetadata | None,
|
||||
|
|
@ -246,7 +400,7 @@ def exchange_authorization(
|
|||
response = ssrf_proxy.post(token_url, data=params)
|
||||
if not response.is_success:
|
||||
raise ValueError(f"Token exchange failed: HTTP {response.status_code}")
|
||||
return OAuthTokens.model_validate(response.json())
|
||||
return _parse_token_response(response)
|
||||
|
||||
|
||||
def refresh_authorization(
|
||||
|
|
@ -279,7 +433,7 @@ def refresh_authorization(
|
|||
raise MCPRefreshTokenError(e) from e
|
||||
if not response.is_success:
|
||||
raise MCPRefreshTokenError(response.text)
|
||||
return OAuthTokens.model_validate(response.json())
|
||||
return _parse_token_response(response)
|
||||
|
||||
|
||||
def client_credentials_flow(
|
||||
|
|
@ -322,7 +476,7 @@ def client_credentials_flow(
|
|||
f"Client credentials token request failed: HTTP {response.status_code}, Response: {response.text}"
|
||||
)
|
||||
|
||||
return OAuthTokens.model_validate(response.json())
|
||||
return _parse_token_response(response)
|
||||
|
||||
|
||||
def register_client(
|
||||
|
|
@ -352,6 +506,8 @@ def auth(
|
|||
provider: MCPProviderEntity,
|
||||
authorization_code: str | None = None,
|
||||
state_param: str | None = None,
|
||||
resource_metadata_url: str | None = None,
|
||||
scope_hint: str | None = None,
|
||||
) -> AuthResult:
|
||||
"""
|
||||
Orchestrates the full auth flow with a server using secure Redis state storage.
|
||||
|
|
@ -363,18 +519,26 @@ def auth(
|
|||
provider: The MCP provider entity
|
||||
authorization_code: Optional authorization code from OAuth callback
|
||||
state_param: Optional state parameter from OAuth callback
|
||||
resource_metadata_url: Optional Protected Resource Metadata URL from WWW-Authenticate
|
||||
scope_hint: Optional scope hint from WWW-Authenticate header
|
||||
|
||||
Returns:
|
||||
AuthResult containing actions to be performed and response data
|
||||
"""
|
||||
actions: list[AuthAction] = []
|
||||
server_url = provider.decrypt_server_url()
|
||||
server_metadata = discover_oauth_metadata(server_url)
|
||||
|
||||
# Discover OAuth metadata using RFC 8414/9470 standards
|
||||
server_metadata, prm, scope_from_www_auth = discover_oauth_metadata(
|
||||
server_url, resource_metadata_url, scope_hint, LATEST_PROTOCOL_VERSION
|
||||
)
|
||||
|
||||
client_metadata = provider.client_metadata
|
||||
provider_id = provider.id
|
||||
tenant_id = provider.tenant_id
|
||||
client_information = provider.retrieve_client_information()
|
||||
redirect_url = provider.redirect_url
|
||||
credentials = provider.decrypt_credentials()
|
||||
|
||||
# Determine grant type based on server metadata
|
||||
if not server_metadata:
|
||||
|
|
@ -392,8 +556,8 @@ def auth(
|
|||
else:
|
||||
effective_grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS.value
|
||||
|
||||
# Get stored credentials
|
||||
credentials = provider.decrypt_credentials()
|
||||
# Determine effective scope using priority-based strategy
|
||||
effective_scope = get_effective_scope(scope_from_www_auth, prm, server_metadata, credentials.get("scope"))
|
||||
|
||||
if not client_information:
|
||||
if authorization_code is not None:
|
||||
|
|
@ -425,12 +589,11 @@ def auth(
|
|||
if effective_grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS.value:
|
||||
# Direct token request without user interaction
|
||||
try:
|
||||
scope = credentials.get("scope")
|
||||
tokens = client_credentials_flow(
|
||||
server_url,
|
||||
server_metadata,
|
||||
client_information,
|
||||
scope,
|
||||
effective_scope,
|
||||
)
|
||||
|
||||
# Return action to save tokens and grant type
|
||||
|
|
@ -526,6 +689,7 @@ def auth(
|
|||
redirect_url,
|
||||
provider_id,
|
||||
tenant_id,
|
||||
effective_scope,
|
||||
)
|
||||
|
||||
# Return action to save code verifier
|
||||
|
|
|
|||
|
|
@ -90,7 +90,13 @@ class MCPClientWithAuthRetry(MCPClient):
|
|||
mcp_service = MCPToolManageService(session=session)
|
||||
|
||||
# Perform authentication using the service's auth method
|
||||
mcp_service.auth_with_actions(self.provider_entity, self.authorization_code)
|
||||
# Extract OAuth metadata hints from the error
|
||||
mcp_service.auth_with_actions(
|
||||
self.provider_entity,
|
||||
self.authorization_code,
|
||||
resource_metadata_url=error.resource_metadata_url,
|
||||
scope_hint=error.scope_hint,
|
||||
)
|
||||
|
||||
# Retrieve new tokens
|
||||
self.provider_entity = mcp_service.get_provider_entity(
|
||||
|
|
|
|||
|
|
@ -290,7 +290,7 @@ def sse_client(
|
|||
|
||||
except httpx.HTTPStatusError as exc:
|
||||
if exc.response.status_code == 401:
|
||||
raise MCPAuthError()
|
||||
raise MCPAuthError(response=exc.response)
|
||||
raise MCPConnectionError()
|
||||
except Exception:
|
||||
logger.exception("Error connecting to SSE endpoint")
|
||||
|
|
|
|||
|
|
@ -1,3 +1,10 @@
|
|||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import httpx
|
||||
|
||||
|
||||
class MCPError(Exception):
|
||||
pass
|
||||
|
||||
|
|
@ -7,7 +14,49 @@ class MCPConnectionError(MCPError):
|
|||
|
||||
|
||||
class MCPAuthError(MCPConnectionError):
|
||||
pass
|
||||
def __init__(
|
||||
self,
|
||||
message: str | None = None,
|
||||
response: "httpx.Response | None" = None,
|
||||
www_authenticate_header: str | None = None,
|
||||
):
|
||||
"""
|
||||
MCP Authentication Error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
response: HTTP response object (will extract WWW-Authenticate header if provided)
|
||||
www_authenticate_header: Pre-extracted WWW-Authenticate header value
|
||||
"""
|
||||
super().__init__(message or "Authentication failed")
|
||||
|
||||
# Extract OAuth metadata hints from WWW-Authenticate header
|
||||
if response is not None:
|
||||
www_authenticate_header = response.headers.get("WWW-Authenticate")
|
||||
|
||||
self.resource_metadata_url: str | None = None
|
||||
self.scope_hint: str | None = None
|
||||
|
||||
if www_authenticate_header:
|
||||
self.resource_metadata_url = self._extract_field(www_authenticate_header, "resource_metadata")
|
||||
self.scope_hint = self._extract_field(www_authenticate_header, "scope")
|
||||
|
||||
@staticmethod
|
||||
def _extract_field(www_auth: str, field_name: str) -> str | None:
|
||||
"""Extract a specific field from the WWW-Authenticate header."""
|
||||
# Pattern to match field="value" or field=value
|
||||
pattern = rf'{field_name}="([^"]*)"'
|
||||
match = re.search(pattern, www_auth)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
# Try without quotes
|
||||
pattern = rf"{field_name}=([^\s,]+)"
|
||||
match = re.search(pattern, www_auth)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class MCPRefreshTokenError(MCPError):
|
||||
|
|
|
|||
|
|
@ -149,7 +149,7 @@ class BaseSession(
|
|||
messages when entered.
|
||||
"""
|
||||
|
||||
_response_streams: dict[RequestId, queue.Queue[JSONRPCResponse | JSONRPCError]]
|
||||
_response_streams: dict[RequestId, queue.Queue[JSONRPCResponse | JSONRPCError | HTTPStatusError]]
|
||||
_request_id: int
|
||||
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
|
||||
_receive_request_type: type[ReceiveRequestT]
|
||||
|
|
@ -230,7 +230,7 @@ class BaseSession(
|
|||
request_id = self._request_id
|
||||
self._request_id = request_id + 1
|
||||
|
||||
response_queue: queue.Queue[JSONRPCResponse | JSONRPCError] = queue.Queue()
|
||||
response_queue: queue.Queue[JSONRPCResponse | JSONRPCError | HTTPStatusError] = queue.Queue()
|
||||
self._response_streams[request_id] = response_queue
|
||||
|
||||
try:
|
||||
|
|
@ -261,11 +261,17 @@ class BaseSession(
|
|||
message="No response received",
|
||||
)
|
||||
)
|
||||
elif isinstance(response_or_error, HTTPStatusError):
|
||||
# HTTPStatusError from streamable_client with preserved response object
|
||||
if response_or_error.response.status_code == 401:
|
||||
raise MCPAuthError(response=response_or_error.response)
|
||||
else:
|
||||
raise MCPConnectionError(
|
||||
ErrorData(code=response_or_error.response.status_code, message=str(response_or_error))
|
||||
)
|
||||
elif isinstance(response_or_error, JSONRPCError):
|
||||
if response_or_error.error.code == 401:
|
||||
raise MCPAuthError(
|
||||
ErrorData(code=response_or_error.error.code, message=response_or_error.error.message)
|
||||
)
|
||||
raise MCPAuthError(message=response_or_error.error.message)
|
||||
else:
|
||||
raise MCPConnectionError(
|
||||
ErrorData(code=response_or_error.error.code, message=response_or_error.error.message)
|
||||
|
|
@ -327,13 +333,17 @@ class BaseSession(
|
|||
if isinstance(message, HTTPStatusError):
|
||||
response_queue = self._response_streams.get(self._request_id - 1)
|
||||
if response_queue is not None:
|
||||
response_queue.put(
|
||||
JSONRPCError(
|
||||
jsonrpc="2.0",
|
||||
id=self._request_id - 1,
|
||||
error=ErrorData(code=message.response.status_code, message=message.args[0]),
|
||||
# For 401 errors, pass the HTTPStatusError directly to preserve response object
|
||||
if message.response.status_code == 401:
|
||||
response_queue.put(message)
|
||||
else:
|
||||
response_queue.put(
|
||||
JSONRPCError(
|
||||
jsonrpc="2.0",
|
||||
id=self._request_id - 1,
|
||||
error=ErrorData(code=message.response.status_code, message=message.args[0]),
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}"))
|
||||
elif isinstance(message, Exception):
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ for reference.
|
|||
not separate types in the schema.
|
||||
"""
|
||||
# Client support both version, not support 2025-06-18 yet.
|
||||
LATEST_PROTOCOL_VERSION = "2025-03-26"
|
||||
LATEST_PROTOCOL_VERSION = "2025-06-18"
|
||||
# Server support 2024-11-05 to allow claude to use.
|
||||
SERVER_LATEST_PROTOCOL_VERSION = "2024-11-05"
|
||||
DEFAULT_NEGOTIATED_VERSION = "2025-03-26"
|
||||
|
|
@ -1330,3 +1330,13 @@ class OAuthMetadata(BaseModel):
|
|||
response_types_supported: list[str]
|
||||
grant_types_supported: list[str] | None = None
|
||||
code_challenge_methods_supported: list[str] | None = None
|
||||
scopes_supported: list[str] | None = None
|
||||
|
||||
|
||||
class ProtectedResourceMetadata(BaseModel):
|
||||
"""OAuth 2.0 Protected Resource Metadata (RFC 9470)."""
|
||||
|
||||
resource: str | None = None
|
||||
authorization_servers: list[str]
|
||||
scopes_supported: list[str] | None = None
|
||||
bearer_methods_supported: list[str] | None = None
|
||||
|
|
|
|||
|
|
@ -507,7 +507,11 @@ class MCPToolManageService:
|
|||
return auth_result.response
|
||||
|
||||
def auth_with_actions(
|
||||
self, provider_entity: MCPProviderEntity, authorization_code: str | None = None
|
||||
self,
|
||||
provider_entity: MCPProviderEntity,
|
||||
authorization_code: str | None = None,
|
||||
resource_metadata_url: str | None = None,
|
||||
scope_hint: str | None = None,
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Perform authentication and execute all resulting actions.
|
||||
|
|
@ -517,11 +521,18 @@ class MCPToolManageService:
|
|||
Args:
|
||||
provider_entity: The MCP provider entity
|
||||
authorization_code: Optional authorization code
|
||||
resource_metadata_url: Optional Protected Resource Metadata URL from WWW-Authenticate
|
||||
scope_hint: Optional scope hint from WWW-Authenticate header
|
||||
|
||||
Returns:
|
||||
Response dictionary from auth result
|
||||
"""
|
||||
auth_result = auth(provider_entity, authorization_code)
|
||||
auth_result = auth(
|
||||
provider_entity,
|
||||
authorization_code,
|
||||
resource_metadata_url=resource_metadata_url,
|
||||
scope_hint=scope_hint,
|
||||
)
|
||||
return self.execute_auth_actions(auth_result)
|
||||
|
||||
def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> ReconnectResult:
|
||||
|
|
|
|||
|
|
@ -23,11 +23,13 @@ from core.mcp.auth.auth_flow import (
|
|||
)
|
||||
from core.mcp.entities import AuthActionType, AuthResult
|
||||
from core.mcp.types import (
|
||||
LATEST_PROTOCOL_VERSION,
|
||||
OAuthClientInformation,
|
||||
OAuthClientInformationFull,
|
||||
OAuthClientMetadata,
|
||||
OAuthMetadata,
|
||||
OAuthTokens,
|
||||
ProtectedResourceMetadata,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -154,7 +156,7 @@ class TestOAuthDiscovery:
|
|||
assert auth_url == "https://auth.example.com"
|
||||
mock_get.assert_called_once_with(
|
||||
"https://api.example.com/.well-known/oauth-protected-resource",
|
||||
headers={"MCP-Protocol-Version": "2025-03-26", "User-Agent": "Dify"},
|
||||
headers={"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"},
|
||||
)
|
||||
|
||||
@patch("core.helper.ssrf_proxy.get")
|
||||
|
|
@ -183,59 +185,61 @@ class TestOAuthDiscovery:
|
|||
assert auth_url == "https://auth.example.com"
|
||||
mock_get.assert_called_once_with(
|
||||
"https://api.example.com/.well-known/oauth-protected-resource?query=1#fragment",
|
||||
headers={"MCP-Protocol-Version": "2025-03-26", "User-Agent": "Dify"},
|
||||
headers={"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"},
|
||||
)
|
||||
|
||||
@patch("core.helper.ssrf_proxy.get")
|
||||
def test_discover_oauth_metadata_with_resource_discovery(self, mock_get):
|
||||
def test_discover_oauth_metadata_with_resource_discovery(self):
|
||||
"""Test OAuth metadata discovery with resource discovery support."""
|
||||
with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check:
|
||||
mock_check.return_value = (True, "https://auth.example.com")
|
||||
with patch("core.mcp.auth.auth_flow.discover_protected_resource_metadata") as mock_prm:
|
||||
with patch("core.mcp.auth.auth_flow.discover_oauth_authorization_server_metadata") as mock_asm:
|
||||
# Mock protected resource metadata with auth server URL
|
||||
mock_prm.return_value = ProtectedResourceMetadata(
|
||||
resource="https://api.example.com",
|
||||
authorization_servers=["https://auth.example.com"],
|
||||
)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.is_success = True
|
||||
mock_response.json.return_value = {
|
||||
"authorization_endpoint": "https://auth.example.com/authorize",
|
||||
"token_endpoint": "https://auth.example.com/token",
|
||||
"response_types_supported": ["code"],
|
||||
}
|
||||
mock_get.return_value = mock_response
|
||||
# Mock OAuth authorization server metadata
|
||||
mock_asm.return_value = OAuthMetadata(
|
||||
authorization_endpoint="https://auth.example.com/authorize",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
response_types_supported=["code"],
|
||||
)
|
||||
|
||||
metadata = discover_oauth_metadata("https://api.example.com")
|
||||
oauth_metadata, prm, scope = discover_oauth_metadata("https://api.example.com")
|
||||
|
||||
assert metadata is not None
|
||||
assert metadata.authorization_endpoint == "https://auth.example.com/authorize"
|
||||
assert metadata.token_endpoint == "https://auth.example.com/token"
|
||||
mock_get.assert_called_once_with(
|
||||
"https://auth.example.com/.well-known/oauth-authorization-server",
|
||||
headers={"MCP-Protocol-Version": "2025-03-26"},
|
||||
)
|
||||
assert oauth_metadata is not None
|
||||
assert oauth_metadata.authorization_endpoint == "https://auth.example.com/authorize"
|
||||
assert oauth_metadata.token_endpoint == "https://auth.example.com/token"
|
||||
assert prm is not None
|
||||
assert prm.authorization_servers == ["https://auth.example.com"]
|
||||
|
||||
@patch("core.helper.ssrf_proxy.get")
|
||||
def test_discover_oauth_metadata_without_resource_discovery(self, mock_get):
|
||||
# Verify the discovery functions were called
|
||||
mock_prm.assert_called_once()
|
||||
mock_asm.assert_called_once()
|
||||
|
||||
def test_discover_oauth_metadata_without_resource_discovery(self):
|
||||
"""Test OAuth metadata discovery without resource discovery."""
|
||||
with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check:
|
||||
mock_check.return_value = (False, "")
|
||||
with patch("core.mcp.auth.auth_flow.discover_protected_resource_metadata") as mock_prm:
|
||||
with patch("core.mcp.auth.auth_flow.discover_oauth_authorization_server_metadata") as mock_asm:
|
||||
# Mock no protected resource metadata
|
||||
mock_prm.return_value = None
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.is_success = True
|
||||
mock_response.json.return_value = {
|
||||
"authorization_endpoint": "https://api.example.com/oauth/authorize",
|
||||
"token_endpoint": "https://api.example.com/oauth/token",
|
||||
"response_types_supported": ["code"],
|
||||
}
|
||||
mock_get.return_value = mock_response
|
||||
# Mock OAuth authorization server metadata
|
||||
mock_asm.return_value = OAuthMetadata(
|
||||
authorization_endpoint="https://api.example.com/oauth/authorize",
|
||||
token_endpoint="https://api.example.com/oauth/token",
|
||||
response_types_supported=["code"],
|
||||
)
|
||||
|
||||
metadata = discover_oauth_metadata("https://api.example.com")
|
||||
oauth_metadata, prm, scope = discover_oauth_metadata("https://api.example.com")
|
||||
|
||||
assert metadata is not None
|
||||
assert metadata.authorization_endpoint == "https://api.example.com/oauth/authorize"
|
||||
mock_get.assert_called_once_with(
|
||||
"https://api.example.com/.well-known/oauth-authorization-server",
|
||||
headers={"MCP-Protocol-Version": "2025-03-26"},
|
||||
)
|
||||
assert oauth_metadata is not None
|
||||
assert oauth_metadata.authorization_endpoint == "https://api.example.com/oauth/authorize"
|
||||
assert prm is None
|
||||
|
||||
# Verify the discovery functions were called
|
||||
mock_prm.assert_called_once()
|
||||
mock_asm.assert_called_once()
|
||||
|
||||
@patch("core.helper.ssrf_proxy.get")
|
||||
def test_discover_oauth_metadata_not_found(self, mock_get):
|
||||
|
|
@ -247,9 +251,9 @@ class TestOAuthDiscovery:
|
|||
mock_response.status_code = 404
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
metadata = discover_oauth_metadata("https://api.example.com")
|
||||
oauth_metadata, prm, scope = discover_oauth_metadata("https://api.example.com")
|
||||
|
||||
assert metadata is None
|
||||
assert oauth_metadata is None
|
||||
|
||||
|
||||
class TestAuthorizationFlow:
|
||||
|
|
@ -342,6 +346,7 @@ class TestAuthorizationFlow:
|
|||
"""Test successful authorization code exchange."""
|
||||
mock_response = Mock()
|
||||
mock_response.is_success = True
|
||||
mock_response.headers = {"content-type": "application/json"}
|
||||
mock_response.json.return_value = {
|
||||
"access_token": "new-access-token",
|
||||
"token_type": "Bearer",
|
||||
|
|
@ -412,6 +417,7 @@ class TestAuthorizationFlow:
|
|||
"""Test successful token refresh."""
|
||||
mock_response = Mock()
|
||||
mock_response.is_success = True
|
||||
mock_response.headers = {"content-type": "application/json"}
|
||||
mock_response.json.return_value = {
|
||||
"access_token": "refreshed-access-token",
|
||||
"token_type": "Bearer",
|
||||
|
|
@ -577,11 +583,15 @@ class TestAuthOrchestration:
|
|||
def test_auth_new_registration(self, mock_start_auth, mock_register, mock_discover, mock_provider, mock_service):
|
||||
"""Test auth flow for new client registration."""
|
||||
# Setup
|
||||
mock_discover.return_value = OAuthMetadata(
|
||||
authorization_endpoint="https://auth.example.com/authorize",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code"],
|
||||
mock_discover.return_value = (
|
||||
OAuthMetadata(
|
||||
authorization_endpoint="https://auth.example.com/authorize",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code"],
|
||||
),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
mock_register.return_value = OAuthClientInformationFull(
|
||||
client_id="new-client-id",
|
||||
|
|
@ -619,11 +629,15 @@ class TestAuthOrchestration:
|
|||
def test_auth_exchange_code(self, mock_exchange, mock_retrieve_state, mock_discover, mock_provider, mock_service):
|
||||
"""Test auth flow for exchanging authorization code."""
|
||||
# Setup metadata discovery
|
||||
mock_discover.return_value = OAuthMetadata(
|
||||
authorization_endpoint="https://auth.example.com/authorize",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code"],
|
||||
mock_discover.return_value = (
|
||||
OAuthMetadata(
|
||||
authorization_endpoint="https://auth.example.com/authorize",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code"],
|
||||
),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
# Setup existing client
|
||||
|
|
@ -662,11 +676,15 @@ class TestAuthOrchestration:
|
|||
def test_auth_exchange_code_without_state(self, mock_discover, mock_provider, mock_service):
|
||||
"""Test auth flow fails when exchanging code without state."""
|
||||
# Setup metadata discovery
|
||||
mock_discover.return_value = OAuthMetadata(
|
||||
authorization_endpoint="https://auth.example.com/authorize",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code"],
|
||||
mock_discover.return_value = (
|
||||
OAuthMetadata(
|
||||
authorization_endpoint="https://auth.example.com/authorize",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code"],
|
||||
),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
mock_provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="existing-client")
|
||||
|
|
@ -698,11 +716,15 @@ class TestAuthOrchestration:
|
|||
mock_refresh.return_value = new_tokens
|
||||
|
||||
with patch("core.mcp.auth.auth_flow.discover_oauth_metadata") as mock_discover:
|
||||
mock_discover.return_value = OAuthMetadata(
|
||||
authorization_endpoint="https://auth.example.com/authorize",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code"],
|
||||
mock_discover.return_value = (
|
||||
OAuthMetadata(
|
||||
authorization_endpoint="https://auth.example.com/authorize",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code"],
|
||||
),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
result = auth(mock_provider)
|
||||
|
|
@ -725,11 +747,15 @@ class TestAuthOrchestration:
|
|||
def test_auth_registration_fails_with_code(self, mock_discover, mock_provider, mock_service):
|
||||
"""Test auth fails when no client info exists but code is provided."""
|
||||
# Setup metadata discovery
|
||||
mock_discover.return_value = OAuthMetadata(
|
||||
authorization_endpoint="https://auth.example.com/authorize",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code"],
|
||||
mock_discover.return_value = (
|
||||
OAuthMetadata(
|
||||
authorization_endpoint="https://auth.example.com/authorize",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code"],
|
||||
),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
mock_provider.retrieve_client_information.return_value = None
|
||||
|
|
|
|||
|
|
@ -139,7 +139,9 @@ def test_sse_client_error_handling():
|
|||
with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
|
||||
with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sse_connect:
|
||||
# Mock 401 HTTP error
|
||||
mock_error = httpx.HTTPStatusError("Unauthorized", request=Mock(), response=Mock(status_code=401))
|
||||
mock_response = Mock(status_code=401)
|
||||
mock_response.headers = {"WWW-Authenticate": 'Bearer realm="example"'}
|
||||
mock_error = httpx.HTTPStatusError("Unauthorized", request=Mock(), response=mock_response)
|
||||
mock_sse_connect.side_effect = mock_error
|
||||
|
||||
with pytest.raises(MCPAuthError):
|
||||
|
|
@ -150,7 +152,9 @@ def test_sse_client_error_handling():
|
|||
with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
|
||||
with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sse_connect:
|
||||
# Mock other HTTP error
|
||||
mock_error = httpx.HTTPStatusError("Server Error", request=Mock(), response=Mock(status_code=500))
|
||||
mock_response = Mock(status_code=500)
|
||||
mock_response.headers = {}
|
||||
mock_error = httpx.HTTPStatusError("Server Error", request=Mock(), response=mock_response)
|
||||
mock_sse_connect.side_effect = mock_error
|
||||
|
||||
with pytest.raises(MCPConnectionError):
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ class TestConstants:
|
|||
|
||||
def test_protocol_versions(self):
|
||||
"""Test protocol version constants."""
|
||||
assert LATEST_PROTOCOL_VERSION == "2025-03-26"
|
||||
assert LATEST_PROTOCOL_VERSION == "2025-06-18"
|
||||
assert SERVER_LATEST_PROTOCOL_VERSION == "2024-11-05"
|
||||
|
||||
def test_error_codes(self):
|
||||
|
|
|
|||
|
|
@ -24,6 +24,8 @@ import { shouldUseMcpIconForAppIcon } from '@/utils/mcp'
|
|||
import TabSlider from '@/app/components/base/tab-slider'
|
||||
import { MCPAuthMethod } from '@/app/components/tools/types'
|
||||
import Switch from '@/app/components/base/switch'
|
||||
import AlertTriangle from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback/AlertTriangle'
|
||||
import { API_PREFIX } from '@/config'
|
||||
|
||||
export type DuplicateAppModalProps = {
|
||||
data?: ToolWithProvider
|
||||
|
|
@ -313,6 +315,17 @@ const MCPModal = ({
|
|||
/>
|
||||
<span className='system-sm-medium text-text-secondary'>{t('tools.mcp.modal.useDynamicClientRegistration')}</span>
|
||||
</div>
|
||||
{!isDynamicRegistration && (
|
||||
<div className='mt-2 flex gap-2 rounded-lg bg-state-warning-hover p-3'>
|
||||
<AlertTriangle className='mt-0.5 h-4 w-4 shrink-0 text-text-warning' />
|
||||
<div className='system-xs-regular text-text-secondary'>
|
||||
<div className='mb-1'>{t('tools.mcp.modal.redirectUrlWarning')}</div>
|
||||
<code className='system-xs-medium block break-all rounded bg-state-warning-active px-2 py-1 text-text-secondary'>
|
||||
{`${API_PREFIX}/mcp/oauth/callback`}
|
||||
</code>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<div>
|
||||
<div className={cn('mb-1 flex h-6 items-center', isDynamicRegistration && 'opacity-50')}>
|
||||
|
|
|
|||
|
|
@ -201,6 +201,7 @@ const translation = {
|
|||
timeoutPlaceholder: '30',
|
||||
authentication: 'Authentication',
|
||||
useDynamicClientRegistration: 'Use Dynamic Client Registration',
|
||||
redirectUrlWarning: 'Please configure your OAuth redirect URL to:',
|
||||
clientID: 'Client ID',
|
||||
clientSecret: 'Client Secret',
|
||||
clientSecretPlaceholder: 'Client Secret',
|
||||
|
|
|
|||
|
|
@ -201,6 +201,7 @@ const translation = {
|
|||
timeoutPlaceholder: '30',
|
||||
authentication: '认证',
|
||||
useDynamicClientRegistration: '使用动态客户端注册',
|
||||
redirectUrlWarning: '请将您的 OAuth 重定向 URL 配置为:',
|
||||
clientID: '客户端 ID',
|
||||
clientSecret: '客户端密钥',
|
||||
clientSecretPlaceholder: '客户端密钥',
|
||||
|
|
|
|||
Loading…
Reference in New Issue