diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 1c9d438ca6..cc7fa0fc3d 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -1086,7 +1086,13 @@ class ToolMCPAuthApi(Resource): return {"result": "success"} except MCPAuthError as e: try: - auth_result = auth(provider_entity, args.get("authorization_code")) + # Pass the extracted OAuth metadata hints to auth() + auth_result = auth( + provider_entity, + args.get("authorization_code"), + resource_metadata_url=e.resource_metadata_url, + scope_hint=e.scope_hint, + ) with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) response = service.execute_auth_actions(auth_result) @@ -1096,7 +1102,7 @@ class ToolMCPAuthApi(Resource): service = MCPToolManageService(session=session) service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id) raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e - except MCPError as e: + except (MCPError, ValueError) as e: with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id) diff --git a/api/core/mcp/auth/auth_flow.py b/api/core/mcp/auth/auth_flow.py index 951c22f6dd..92787b39dd 100644 --- a/api/core/mcp/auth/auth_flow.py +++ b/api/core/mcp/auth/auth_flow.py @@ -6,7 +6,8 @@ import secrets import urllib.parse from urllib.parse import urljoin, urlparse -from httpx import ConnectError, HTTPStatusError, RequestError +import httpx +from httpx import RequestError from pydantic import ValidationError from core.entities.mcp_provider import MCPProviderEntity, MCPSupportGrantType @@ -20,6 +21,7 @@ from core.mcp.types import ( OAuthClientMetadata, OAuthMetadata, OAuthTokens, + ProtectedResourceMetadata, ) from extensions.ext_redis import redis_client @@ -39,6 +41,131 @@ def generate_pkce_challenge() -> tuple[str, str]: return code_verifier, code_challenge +def build_protected_resource_metadata_discovery_urls( + www_auth_resource_metadata_url: str | None, server_url: str +) -> list[str]: + """ + Build a list of URLs to try for Protected Resource Metadata discovery. + + Per SEP-985, supports fallback when discovery fails at one URL. + """ + urls = [] + + # First priority: URL from WWW-Authenticate header + if www_auth_resource_metadata_url: + urls.append(www_auth_resource_metadata_url) + + # Fallback: construct from server URL + parsed = urlparse(server_url) + base_url = f"{parsed.scheme}://{parsed.netloc}" + fallback_url = urljoin(base_url, "/.well-known/oauth-protected-resource") + if fallback_url not in urls: + urls.append(fallback_url) + + return urls + + +def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: str | None, server_url: str) -> list[str]: + """ + Build a list of URLs to try for OAuth Authorization Server Metadata discovery. + + Supports both OAuth 2.0 (RFC 8414) and OpenID Connect discovery. + + Per RFC 8414 section 3: + - If issuer has no path: https://example.com/.well-known/oauth-authorization-server + - If issuer has path: https://example.com/.well-known/oauth-authorization-server{path} + + Example: + - issuer: https://example.com/oauth + - metadata: https://example.com/.well-known/oauth-authorization-server/oauth + """ + urls = [] + base_url = auth_server_url or server_url + + parsed = urlparse(base_url) + base = f"{parsed.scheme}://{parsed.netloc}" + path = parsed.path.rstrip("/") # Remove trailing slash + + # Try OpenID Connect discovery first (more common) + urls.append(urljoin(base + "/", ".well-known/openid-configuration")) + + # OAuth 2.0 Authorization Server Metadata (RFC 8414) + # Include the path component if present in the issuer URL + if path: + urls.append(urljoin(base, f".well-known/oauth-authorization-server{path}")) + else: + urls.append(urljoin(base, ".well-known/oauth-authorization-server")) + + return urls + + +def discover_protected_resource_metadata( + prm_url: str | None, server_url: str, protocol_version: str | None = None +) -> ProtectedResourceMetadata | None: + """Discover OAuth 2.0 Protected Resource Metadata (RFC 9470).""" + urls = build_protected_resource_metadata_discovery_urls(prm_url, server_url) + headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"} + + for url in urls: + try: + response = ssrf_proxy.get(url, headers=headers) + if response.status_code == 200: + return ProtectedResourceMetadata.model_validate(response.json()) + elif response.status_code == 404: + continue # Try next URL + except (RequestError, ValidationError): + continue # Try next URL + + return None + + +def discover_oauth_authorization_server_metadata( + auth_server_url: str | None, server_url: str, protocol_version: str | None = None +) -> OAuthMetadata | None: + """Discover OAuth 2.0 Authorization Server Metadata (RFC 8414).""" + urls = build_oauth_authorization_server_metadata_discovery_urls(auth_server_url, server_url) + headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"} + + for url in urls: + try: + response = ssrf_proxy.get(url, headers=headers) + if response.status_code == 200: + return OAuthMetadata.model_validate(response.json()) + elif response.status_code == 404: + continue # Try next URL + except (RequestError, ValidationError): + continue # Try next URL + + return None + + +def get_effective_scope( + scope_from_www_auth: str | None, + prm: ProtectedResourceMetadata | None, + asm: OAuthMetadata | None, + client_scope: str | None, +) -> str | None: + """ + Determine effective scope using priority-based selection strategy. + + Priority order: + 1. WWW-Authenticate header scope (server explicit requirement) + 2. Protected Resource Metadata scopes + 3. OAuth Authorization Server Metadata scopes + 4. Client configured scope + """ + if scope_from_www_auth: + return scope_from_www_auth + + if prm and prm.scopes_supported: + return " ".join(prm.scopes_supported) + + if asm and asm.scopes_supported: + return " ".join(asm.scopes_supported) + + return client_scope + + def _create_secure_redis_state(state_data: OAuthCallbackState) -> str: """Create a secure state parameter by storing state data in Redis and returning a random state key.""" # Generate a secure random state key @@ -121,42 +248,36 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]: return False, "" -def discover_oauth_metadata(server_url: str, protocol_version: str | None = None) -> OAuthMetadata | None: - """Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata.""" - # First check if the server supports OAuth 2.0 Resource Discovery - support_resource_discovery, oauth_discovery_url = check_support_resource_discovery(server_url) - if support_resource_discovery: - # The oauth_discovery_url is the authorization server base URL - # Try OpenID Connect discovery first (more common), then OAuth 2.0 - urls_to_try = [ - urljoin(oauth_discovery_url + "/", ".well-known/oauth-authorization-server"), - urljoin(oauth_discovery_url + "/", ".well-known/openid-configuration"), - ] - else: - urls_to_try = [urljoin(server_url, "/.well-known/oauth-authorization-server")] +def discover_oauth_metadata( + server_url: str, + resource_metadata_url: str | None = None, + scope_hint: str | None = None, + protocol_version: str | None = None, +) -> tuple[OAuthMetadata | None, ProtectedResourceMetadata | None, str | None]: + """ + Discover OAuth metadata using RFC 8414/9470 standards. - headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION} + Args: + server_url: The MCP server URL + resource_metadata_url: Protected Resource Metadata URL from WWW-Authenticate header + scope_hint: Scope hint from WWW-Authenticate header + protocol_version: MCP protocol version - for url in urls_to_try: - try: - response = ssrf_proxy.get(url, headers=headers) - if response.status_code == 404: - continue - if not response.is_success: - response.raise_for_status() - return OAuthMetadata.model_validate(response.json()) - except (RequestError, HTTPStatusError) as e: - if isinstance(e, ConnectError): - response = ssrf_proxy.get(url) - if response.status_code == 404: - continue # Try next URL - if not response.is_success: - raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata") - return OAuthMetadata.model_validate(response.json()) - # For other errors, try next URL - continue + Returns: + (oauth_metadata, protected_resource_metadata, scope_hint) + """ + # Discover Protected Resource Metadata + prm = discover_protected_resource_metadata(resource_metadata_url, server_url, protocol_version) - return None # No metadata found + # Get authorization server URL from PRM or use server URL + auth_server_url = None + if prm and prm.authorization_servers: + auth_server_url = prm.authorization_servers[0] + + # Discover OAuth Authorization Server Metadata + asm = discover_oauth_authorization_server_metadata(auth_server_url, server_url, protocol_version) + + return asm, prm, scope_hint def start_authorization( @@ -166,6 +287,7 @@ def start_authorization( redirect_url: str, provider_id: str, tenant_id: str, + scope: str | None = None, ) -> tuple[str, str]: """Begins the authorization flow with secure Redis state storage.""" response_type = "code" @@ -175,13 +297,6 @@ def start_authorization( authorization_url = metadata.authorization_endpoint if response_type not in metadata.response_types_supported: raise ValueError(f"Incompatible auth server: does not support response type {response_type}") - if ( - not metadata.code_challenge_methods_supported - or code_challenge_method not in metadata.code_challenge_methods_supported - ): - raise ValueError( - f"Incompatible auth server: does not support code challenge method {code_challenge_method}" - ) else: authorization_url = urljoin(server_url, "/authorize") @@ -210,10 +325,49 @@ def start_authorization( "state": state_key, } + # Add scope if provided + if scope: + params["scope"] = scope + authorization_url = f"{authorization_url}?{urllib.parse.urlencode(params)}" return authorization_url, code_verifier +def _parse_token_response(response: httpx.Response) -> OAuthTokens: + """ + Parse OAuth token response supporting both JSON and form-urlencoded formats. + + Per RFC 6749 Section 5.1, the standard format is JSON. + However, some legacy OAuth providers (e.g., early GitHub OAuth Apps) return + application/x-www-form-urlencoded format for backwards compatibility. + + Args: + response: The HTTP response from token endpoint + + Returns: + Parsed OAuth tokens + + Raises: + ValueError: If response cannot be parsed + """ + content_type = response.headers.get("content-type", "").lower() + + if "application/json" in content_type: + # Standard OAuth 2.0 JSON response (RFC 6749) + return OAuthTokens.model_validate(response.json()) + elif "application/x-www-form-urlencoded" in content_type: + # Legacy form-urlencoded response (non-standard but used by some providers) + token_data = dict(urllib.parse.parse_qsl(response.text)) + return OAuthTokens.model_validate(token_data) + else: + # No content-type or unknown - try JSON first, fallback to form-urlencoded + try: + return OAuthTokens.model_validate(response.json()) + except (ValidationError, json.JSONDecodeError): + token_data = dict(urllib.parse.parse_qsl(response.text)) + return OAuthTokens.model_validate(token_data) + + def exchange_authorization( server_url: str, metadata: OAuthMetadata | None, @@ -246,7 +400,7 @@ def exchange_authorization( response = ssrf_proxy.post(token_url, data=params) if not response.is_success: raise ValueError(f"Token exchange failed: HTTP {response.status_code}") - return OAuthTokens.model_validate(response.json()) + return _parse_token_response(response) def refresh_authorization( @@ -279,7 +433,7 @@ def refresh_authorization( raise MCPRefreshTokenError(e) from e if not response.is_success: raise MCPRefreshTokenError(response.text) - return OAuthTokens.model_validate(response.json()) + return _parse_token_response(response) def client_credentials_flow( @@ -322,7 +476,7 @@ def client_credentials_flow( f"Client credentials token request failed: HTTP {response.status_code}, Response: {response.text}" ) - return OAuthTokens.model_validate(response.json()) + return _parse_token_response(response) def register_client( @@ -352,6 +506,8 @@ def auth( provider: MCPProviderEntity, authorization_code: str | None = None, state_param: str | None = None, + resource_metadata_url: str | None = None, + scope_hint: str | None = None, ) -> AuthResult: """ Orchestrates the full auth flow with a server using secure Redis state storage. @@ -363,18 +519,26 @@ def auth( provider: The MCP provider entity authorization_code: Optional authorization code from OAuth callback state_param: Optional state parameter from OAuth callback + resource_metadata_url: Optional Protected Resource Metadata URL from WWW-Authenticate + scope_hint: Optional scope hint from WWW-Authenticate header Returns: AuthResult containing actions to be performed and response data """ actions: list[AuthAction] = [] server_url = provider.decrypt_server_url() - server_metadata = discover_oauth_metadata(server_url) + + # Discover OAuth metadata using RFC 8414/9470 standards + server_metadata, prm, scope_from_www_auth = discover_oauth_metadata( + server_url, resource_metadata_url, scope_hint, LATEST_PROTOCOL_VERSION + ) + client_metadata = provider.client_metadata provider_id = provider.id tenant_id = provider.tenant_id client_information = provider.retrieve_client_information() redirect_url = provider.redirect_url + credentials = provider.decrypt_credentials() # Determine grant type based on server metadata if not server_metadata: @@ -392,8 +556,8 @@ def auth( else: effective_grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS.value - # Get stored credentials - credentials = provider.decrypt_credentials() + # Determine effective scope using priority-based strategy + effective_scope = get_effective_scope(scope_from_www_auth, prm, server_metadata, credentials.get("scope")) if not client_information: if authorization_code is not None: @@ -425,12 +589,11 @@ def auth( if effective_grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS.value: # Direct token request without user interaction try: - scope = credentials.get("scope") tokens = client_credentials_flow( server_url, server_metadata, client_information, - scope, + effective_scope, ) # Return action to save tokens and grant type @@ -526,6 +689,7 @@ def auth( redirect_url, provider_id, tenant_id, + effective_scope, ) # Return action to save code verifier diff --git a/api/core/mcp/auth_client.py b/api/core/mcp/auth_client.py index 942c8d3c23..d8724b8de5 100644 --- a/api/core/mcp/auth_client.py +++ b/api/core/mcp/auth_client.py @@ -90,7 +90,13 @@ class MCPClientWithAuthRetry(MCPClient): mcp_service = MCPToolManageService(session=session) # Perform authentication using the service's auth method - mcp_service.auth_with_actions(self.provider_entity, self.authorization_code) + # Extract OAuth metadata hints from the error + mcp_service.auth_with_actions( + self.provider_entity, + self.authorization_code, + resource_metadata_url=error.resource_metadata_url, + scope_hint=error.scope_hint, + ) # Retrieve new tokens self.provider_entity = mcp_service.get_provider_entity( diff --git a/api/core/mcp/client/sse_client.py b/api/core/mcp/client/sse_client.py index 2d5e3dd263..24ca59ee45 100644 --- a/api/core/mcp/client/sse_client.py +++ b/api/core/mcp/client/sse_client.py @@ -290,7 +290,7 @@ def sse_client( except httpx.HTTPStatusError as exc: if exc.response.status_code == 401: - raise MCPAuthError() + raise MCPAuthError(response=exc.response) raise MCPConnectionError() except Exception: logger.exception("Error connecting to SSE endpoint") diff --git a/api/core/mcp/error.py b/api/core/mcp/error.py index d4fb8b7674..1128369ac5 100644 --- a/api/core/mcp/error.py +++ b/api/core/mcp/error.py @@ -1,3 +1,10 @@ +import re +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import httpx + + class MCPError(Exception): pass @@ -7,7 +14,49 @@ class MCPConnectionError(MCPError): class MCPAuthError(MCPConnectionError): - pass + def __init__( + self, + message: str | None = None, + response: "httpx.Response | None" = None, + www_authenticate_header: str | None = None, + ): + """ + MCP Authentication Error. + + Args: + message: Error message + response: HTTP response object (will extract WWW-Authenticate header if provided) + www_authenticate_header: Pre-extracted WWW-Authenticate header value + """ + super().__init__(message or "Authentication failed") + + # Extract OAuth metadata hints from WWW-Authenticate header + if response is not None: + www_authenticate_header = response.headers.get("WWW-Authenticate") + + self.resource_metadata_url: str | None = None + self.scope_hint: str | None = None + + if www_authenticate_header: + self.resource_metadata_url = self._extract_field(www_authenticate_header, "resource_metadata") + self.scope_hint = self._extract_field(www_authenticate_header, "scope") + + @staticmethod + def _extract_field(www_auth: str, field_name: str) -> str | None: + """Extract a specific field from the WWW-Authenticate header.""" + # Pattern to match field="value" or field=value + pattern = rf'{field_name}="([^"]*)"' + match = re.search(pattern, www_auth) + if match: + return match.group(1) + + # Try without quotes + pattern = rf"{field_name}=([^\s,]+)" + match = re.search(pattern, www_auth) + if match: + return match.group(1) + + return None class MCPRefreshTokenError(MCPError): diff --git a/api/core/mcp/session/base_session.py b/api/core/mcp/session/base_session.py index 3dcd166ea2..c97ae6eac7 100644 --- a/api/core/mcp/session/base_session.py +++ b/api/core/mcp/session/base_session.py @@ -149,7 +149,7 @@ class BaseSession( messages when entered. """ - _response_streams: dict[RequestId, queue.Queue[JSONRPCResponse | JSONRPCError]] + _response_streams: dict[RequestId, queue.Queue[JSONRPCResponse | JSONRPCError | HTTPStatusError]] _request_id: int _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] _receive_request_type: type[ReceiveRequestT] @@ -230,7 +230,7 @@ class BaseSession( request_id = self._request_id self._request_id = request_id + 1 - response_queue: queue.Queue[JSONRPCResponse | JSONRPCError] = queue.Queue() + response_queue: queue.Queue[JSONRPCResponse | JSONRPCError | HTTPStatusError] = queue.Queue() self._response_streams[request_id] = response_queue try: @@ -261,11 +261,17 @@ class BaseSession( message="No response received", ) ) + elif isinstance(response_or_error, HTTPStatusError): + # HTTPStatusError from streamable_client with preserved response object + if response_or_error.response.status_code == 401: + raise MCPAuthError(response=response_or_error.response) + else: + raise MCPConnectionError( + ErrorData(code=response_or_error.response.status_code, message=str(response_or_error)) + ) elif isinstance(response_or_error, JSONRPCError): if response_or_error.error.code == 401: - raise MCPAuthError( - ErrorData(code=response_or_error.error.code, message=response_or_error.error.message) - ) + raise MCPAuthError(message=response_or_error.error.message) else: raise MCPConnectionError( ErrorData(code=response_or_error.error.code, message=response_or_error.error.message) @@ -327,13 +333,17 @@ class BaseSession( if isinstance(message, HTTPStatusError): response_queue = self._response_streams.get(self._request_id - 1) if response_queue is not None: - response_queue.put( - JSONRPCError( - jsonrpc="2.0", - id=self._request_id - 1, - error=ErrorData(code=message.response.status_code, message=message.args[0]), + # For 401 errors, pass the HTTPStatusError directly to preserve response object + if message.response.status_code == 401: + response_queue.put(message) + else: + response_queue.put( + JSONRPCError( + jsonrpc="2.0", + id=self._request_id - 1, + error=ErrorData(code=message.response.status_code, message=message.args[0]), + ) ) - ) else: self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}")) elif isinstance(message, Exception): diff --git a/api/core/mcp/types.py b/api/core/mcp/types.py index fd2062d2e1..335c6a5cbc 100644 --- a/api/core/mcp/types.py +++ b/api/core/mcp/types.py @@ -23,7 +23,7 @@ for reference. not separate types in the schema. """ # Client support both version, not support 2025-06-18 yet. -LATEST_PROTOCOL_VERSION = "2025-03-26" +LATEST_PROTOCOL_VERSION = "2025-06-18" # Server support 2024-11-05 to allow claude to use. SERVER_LATEST_PROTOCOL_VERSION = "2024-11-05" DEFAULT_NEGOTIATED_VERSION = "2025-03-26" @@ -1330,3 +1330,13 @@ class OAuthMetadata(BaseModel): response_types_supported: list[str] grant_types_supported: list[str] | None = None code_challenge_methods_supported: list[str] | None = None + scopes_supported: list[str] | None = None + + +class ProtectedResourceMetadata(BaseModel): + """OAuth 2.0 Protected Resource Metadata (RFC 9470).""" + + resource: str | None = None + authorization_servers: list[str] + scopes_supported: list[str] | None = None + bearer_methods_supported: list[str] | None = None diff --git a/api/services/tools/mcp_tools_manage_service.py b/api/services/tools/mcp_tools_manage_service.py index d798e11ff1..7eedf76aed 100644 --- a/api/services/tools/mcp_tools_manage_service.py +++ b/api/services/tools/mcp_tools_manage_service.py @@ -507,7 +507,11 @@ class MCPToolManageService: return auth_result.response def auth_with_actions( - self, provider_entity: MCPProviderEntity, authorization_code: str | None = None + self, + provider_entity: MCPProviderEntity, + authorization_code: str | None = None, + resource_metadata_url: str | None = None, + scope_hint: str | None = None, ) -> dict[str, str]: """ Perform authentication and execute all resulting actions. @@ -517,11 +521,18 @@ class MCPToolManageService: Args: provider_entity: The MCP provider entity authorization_code: Optional authorization code + resource_metadata_url: Optional Protected Resource Metadata URL from WWW-Authenticate + scope_hint: Optional scope hint from WWW-Authenticate header Returns: Response dictionary from auth result """ - auth_result = auth(provider_entity, authorization_code) + auth_result = auth( + provider_entity, + authorization_code, + resource_metadata_url=resource_metadata_url, + scope_hint=scope_hint, + ) return self.execute_auth_actions(auth_result) def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> ReconnectResult: diff --git a/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py b/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py index 12a9f11205..60f37b6de0 100644 --- a/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py +++ b/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py @@ -23,11 +23,13 @@ from core.mcp.auth.auth_flow import ( ) from core.mcp.entities import AuthActionType, AuthResult from core.mcp.types import ( + LATEST_PROTOCOL_VERSION, OAuthClientInformation, OAuthClientInformationFull, OAuthClientMetadata, OAuthMetadata, OAuthTokens, + ProtectedResourceMetadata, ) @@ -154,7 +156,7 @@ class TestOAuthDiscovery: assert auth_url == "https://auth.example.com" mock_get.assert_called_once_with( "https://api.example.com/.well-known/oauth-protected-resource", - headers={"MCP-Protocol-Version": "2025-03-26", "User-Agent": "Dify"}, + headers={"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"}, ) @patch("core.helper.ssrf_proxy.get") @@ -183,59 +185,61 @@ class TestOAuthDiscovery: assert auth_url == "https://auth.example.com" mock_get.assert_called_once_with( "https://api.example.com/.well-known/oauth-protected-resource?query=1#fragment", - headers={"MCP-Protocol-Version": "2025-03-26", "User-Agent": "Dify"}, + headers={"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"}, ) - @patch("core.helper.ssrf_proxy.get") - def test_discover_oauth_metadata_with_resource_discovery(self, mock_get): + def test_discover_oauth_metadata_with_resource_discovery(self): """Test OAuth metadata discovery with resource discovery support.""" - with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check: - mock_check.return_value = (True, "https://auth.example.com") + with patch("core.mcp.auth.auth_flow.discover_protected_resource_metadata") as mock_prm: + with patch("core.mcp.auth.auth_flow.discover_oauth_authorization_server_metadata") as mock_asm: + # Mock protected resource metadata with auth server URL + mock_prm.return_value = ProtectedResourceMetadata( + resource="https://api.example.com", + authorization_servers=["https://auth.example.com"], + ) - mock_response = Mock() - mock_response.status_code = 200 - mock_response.is_success = True - mock_response.json.return_value = { - "authorization_endpoint": "https://auth.example.com/authorize", - "token_endpoint": "https://auth.example.com/token", - "response_types_supported": ["code"], - } - mock_get.return_value = mock_response + # Mock OAuth authorization server metadata + mock_asm.return_value = OAuthMetadata( + authorization_endpoint="https://auth.example.com/authorize", + token_endpoint="https://auth.example.com/token", + response_types_supported=["code"], + ) - metadata = discover_oauth_metadata("https://api.example.com") + oauth_metadata, prm, scope = discover_oauth_metadata("https://api.example.com") - assert metadata is not None - assert metadata.authorization_endpoint == "https://auth.example.com/authorize" - assert metadata.token_endpoint == "https://auth.example.com/token" - mock_get.assert_called_once_with( - "https://auth.example.com/.well-known/oauth-authorization-server", - headers={"MCP-Protocol-Version": "2025-03-26"}, - ) + assert oauth_metadata is not None + assert oauth_metadata.authorization_endpoint == "https://auth.example.com/authorize" + assert oauth_metadata.token_endpoint == "https://auth.example.com/token" + assert prm is not None + assert prm.authorization_servers == ["https://auth.example.com"] - @patch("core.helper.ssrf_proxy.get") - def test_discover_oauth_metadata_without_resource_discovery(self, mock_get): + # Verify the discovery functions were called + mock_prm.assert_called_once() + mock_asm.assert_called_once() + + def test_discover_oauth_metadata_without_resource_discovery(self): """Test OAuth metadata discovery without resource discovery.""" - with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check: - mock_check.return_value = (False, "") + with patch("core.mcp.auth.auth_flow.discover_protected_resource_metadata") as mock_prm: + with patch("core.mcp.auth.auth_flow.discover_oauth_authorization_server_metadata") as mock_asm: + # Mock no protected resource metadata + mock_prm.return_value = None - mock_response = Mock() - mock_response.status_code = 200 - mock_response.is_success = True - mock_response.json.return_value = { - "authorization_endpoint": "https://api.example.com/oauth/authorize", - "token_endpoint": "https://api.example.com/oauth/token", - "response_types_supported": ["code"], - } - mock_get.return_value = mock_response + # Mock OAuth authorization server metadata + mock_asm.return_value = OAuthMetadata( + authorization_endpoint="https://api.example.com/oauth/authorize", + token_endpoint="https://api.example.com/oauth/token", + response_types_supported=["code"], + ) - metadata = discover_oauth_metadata("https://api.example.com") + oauth_metadata, prm, scope = discover_oauth_metadata("https://api.example.com") - assert metadata is not None - assert metadata.authorization_endpoint == "https://api.example.com/oauth/authorize" - mock_get.assert_called_once_with( - "https://api.example.com/.well-known/oauth-authorization-server", - headers={"MCP-Protocol-Version": "2025-03-26"}, - ) + assert oauth_metadata is not None + assert oauth_metadata.authorization_endpoint == "https://api.example.com/oauth/authorize" + assert prm is None + + # Verify the discovery functions were called + mock_prm.assert_called_once() + mock_asm.assert_called_once() @patch("core.helper.ssrf_proxy.get") def test_discover_oauth_metadata_not_found(self, mock_get): @@ -247,9 +251,9 @@ class TestOAuthDiscovery: mock_response.status_code = 404 mock_get.return_value = mock_response - metadata = discover_oauth_metadata("https://api.example.com") + oauth_metadata, prm, scope = discover_oauth_metadata("https://api.example.com") - assert metadata is None + assert oauth_metadata is None class TestAuthorizationFlow: @@ -342,6 +346,7 @@ class TestAuthorizationFlow: """Test successful authorization code exchange.""" mock_response = Mock() mock_response.is_success = True + mock_response.headers = {"content-type": "application/json"} mock_response.json.return_value = { "access_token": "new-access-token", "token_type": "Bearer", @@ -412,6 +417,7 @@ class TestAuthorizationFlow: """Test successful token refresh.""" mock_response = Mock() mock_response.is_success = True + mock_response.headers = {"content-type": "application/json"} mock_response.json.return_value = { "access_token": "refreshed-access-token", "token_type": "Bearer", @@ -577,11 +583,15 @@ class TestAuthOrchestration: def test_auth_new_registration(self, mock_start_auth, mock_register, mock_discover, mock_provider, mock_service): """Test auth flow for new client registration.""" # Setup - mock_discover.return_value = OAuthMetadata( - authorization_endpoint="https://auth.example.com/authorize", - token_endpoint="https://auth.example.com/token", - response_types_supported=["code"], - grant_types_supported=["authorization_code"], + mock_discover.return_value = ( + OAuthMetadata( + authorization_endpoint="https://auth.example.com/authorize", + token_endpoint="https://auth.example.com/token", + response_types_supported=["code"], + grant_types_supported=["authorization_code"], + ), + None, + None, ) mock_register.return_value = OAuthClientInformationFull( client_id="new-client-id", @@ -619,11 +629,15 @@ class TestAuthOrchestration: def test_auth_exchange_code(self, mock_exchange, mock_retrieve_state, mock_discover, mock_provider, mock_service): """Test auth flow for exchanging authorization code.""" # Setup metadata discovery - mock_discover.return_value = OAuthMetadata( - authorization_endpoint="https://auth.example.com/authorize", - token_endpoint="https://auth.example.com/token", - response_types_supported=["code"], - grant_types_supported=["authorization_code"], + mock_discover.return_value = ( + OAuthMetadata( + authorization_endpoint="https://auth.example.com/authorize", + token_endpoint="https://auth.example.com/token", + response_types_supported=["code"], + grant_types_supported=["authorization_code"], + ), + None, + None, ) # Setup existing client @@ -662,11 +676,15 @@ class TestAuthOrchestration: def test_auth_exchange_code_without_state(self, mock_discover, mock_provider, mock_service): """Test auth flow fails when exchanging code without state.""" # Setup metadata discovery - mock_discover.return_value = OAuthMetadata( - authorization_endpoint="https://auth.example.com/authorize", - token_endpoint="https://auth.example.com/token", - response_types_supported=["code"], - grant_types_supported=["authorization_code"], + mock_discover.return_value = ( + OAuthMetadata( + authorization_endpoint="https://auth.example.com/authorize", + token_endpoint="https://auth.example.com/token", + response_types_supported=["code"], + grant_types_supported=["authorization_code"], + ), + None, + None, ) mock_provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="existing-client") @@ -698,11 +716,15 @@ class TestAuthOrchestration: mock_refresh.return_value = new_tokens with patch("core.mcp.auth.auth_flow.discover_oauth_metadata") as mock_discover: - mock_discover.return_value = OAuthMetadata( - authorization_endpoint="https://auth.example.com/authorize", - token_endpoint="https://auth.example.com/token", - response_types_supported=["code"], - grant_types_supported=["authorization_code"], + mock_discover.return_value = ( + OAuthMetadata( + authorization_endpoint="https://auth.example.com/authorize", + token_endpoint="https://auth.example.com/token", + response_types_supported=["code"], + grant_types_supported=["authorization_code"], + ), + None, + None, ) result = auth(mock_provider) @@ -725,11 +747,15 @@ class TestAuthOrchestration: def test_auth_registration_fails_with_code(self, mock_discover, mock_provider, mock_service): """Test auth fails when no client info exists but code is provided.""" # Setup metadata discovery - mock_discover.return_value = OAuthMetadata( - authorization_endpoint="https://auth.example.com/authorize", - token_endpoint="https://auth.example.com/token", - response_types_supported=["code"], - grant_types_supported=["authorization_code"], + mock_discover.return_value = ( + OAuthMetadata( + authorization_endpoint="https://auth.example.com/authorize", + token_endpoint="https://auth.example.com/token", + response_types_supported=["code"], + grant_types_supported=["authorization_code"], + ), + None, + None, ) mock_provider.retrieve_client_information.return_value = None diff --git a/api/tests/unit_tests/core/mcp/client/test_sse.py b/api/tests/unit_tests/core/mcp/client/test_sse.py index aadd366762..490a647025 100644 --- a/api/tests/unit_tests/core/mcp/client/test_sse.py +++ b/api/tests/unit_tests/core/mcp/client/test_sse.py @@ -139,7 +139,9 @@ def test_sse_client_error_handling(): with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory: with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sse_connect: # Mock 401 HTTP error - mock_error = httpx.HTTPStatusError("Unauthorized", request=Mock(), response=Mock(status_code=401)) + mock_response = Mock(status_code=401) + mock_response.headers = {"WWW-Authenticate": 'Bearer realm="example"'} + mock_error = httpx.HTTPStatusError("Unauthorized", request=Mock(), response=mock_response) mock_sse_connect.side_effect = mock_error with pytest.raises(MCPAuthError): @@ -150,7 +152,9 @@ def test_sse_client_error_handling(): with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory: with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sse_connect: # Mock other HTTP error - mock_error = httpx.HTTPStatusError("Server Error", request=Mock(), response=Mock(status_code=500)) + mock_response = Mock(status_code=500) + mock_response.headers = {} + mock_error = httpx.HTTPStatusError("Server Error", request=Mock(), response=mock_response) mock_sse_connect.side_effect = mock_error with pytest.raises(MCPConnectionError): diff --git a/api/tests/unit_tests/core/mcp/test_types.py b/api/tests/unit_tests/core/mcp/test_types.py index 6d8130bd13..d4fe353f0a 100644 --- a/api/tests/unit_tests/core/mcp/test_types.py +++ b/api/tests/unit_tests/core/mcp/test_types.py @@ -58,7 +58,7 @@ class TestConstants: def test_protocol_versions(self): """Test protocol version constants.""" - assert LATEST_PROTOCOL_VERSION == "2025-03-26" + assert LATEST_PROTOCOL_VERSION == "2025-06-18" assert SERVER_LATEST_PROTOCOL_VERSION == "2024-11-05" def test_error_codes(self): diff --git a/web/app/components/tools/mcp/modal.tsx b/web/app/components/tools/mcp/modal.tsx index ad528e9fb9..68f97703bf 100644 --- a/web/app/components/tools/mcp/modal.tsx +++ b/web/app/components/tools/mcp/modal.tsx @@ -24,6 +24,8 @@ import { shouldUseMcpIconForAppIcon } from '@/utils/mcp' import TabSlider from '@/app/components/base/tab-slider' import { MCPAuthMethod } from '@/app/components/tools/types' import Switch from '@/app/components/base/switch' +import AlertTriangle from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback/AlertTriangle' +import { API_PREFIX } from '@/config' export type DuplicateAppModalProps = { data?: ToolWithProvider @@ -313,6 +315,17 @@ const MCPModal = ({ /> {t('tools.mcp.modal.useDynamicClientRegistration')} + {!isDynamicRegistration && ( +
+ +
+
{t('tools.mcp.modal.redirectUrlWarning')}
+ + {`${API_PREFIX}/mcp/oauth/callback`} + +
+
+ )}
diff --git a/web/i18n/en-US/tools.ts b/web/i18n/en-US/tools.ts index 308d4b2b05..6086d9aa16 100644 --- a/web/i18n/en-US/tools.ts +++ b/web/i18n/en-US/tools.ts @@ -201,6 +201,7 @@ const translation = { timeoutPlaceholder: '30', authentication: 'Authentication', useDynamicClientRegistration: 'Use Dynamic Client Registration', + redirectUrlWarning: 'Please configure your OAuth redirect URL to:', clientID: 'Client ID', clientSecret: 'Client Secret', clientSecretPlaceholder: 'Client Secret', diff --git a/web/i18n/zh-Hans/tools.ts b/web/i18n/zh-Hans/tools.ts index cab4b22164..ad046ff198 100644 --- a/web/i18n/zh-Hans/tools.ts +++ b/web/i18n/zh-Hans/tools.ts @@ -201,6 +201,7 @@ const translation = { timeoutPlaceholder: '30', authentication: '认证', useDynamicClientRegistration: '使用动态客户端注册', + redirectUrlWarning: '请将您的 OAuth 重定向 URL 配置为:', clientID: '客户端 ID', clientSecret: '客户端密钥', clientSecretPlaceholder: '客户端密钥',