From 5c6a2af4481c8c9c7fd0c1985401a7b9b5e80c4b Mon Sep 17 00:00:00 2001 From: Novice Date: Tue, 14 Oct 2025 20:36:13 +0800 Subject: [PATCH] chore: fix review issues --- .../console/workspace/tool_providers.py | 81 ++++++------- api/core/entities/mcp_provider.py | 107 +++++++++--------- api/core/mcp/auth/auth_flow.py | 82 ++++++++------ api/core/mcp/auth_client.py | 4 + api/core/tools/entities/api_entities.py | 15 ++- api/core/tools/mcp_tool/provider.py | 7 +- api/core/tools/mcp_tool/tool.py | 103 ++++++++++------- .../tools/mcp_tools_manage_service.py | 98 +++++++--------- api/services/tools/tools_transform_service.py | 8 ++ .../core/mcp/auth/test_auth_flow.py | 40 ++++--- .../tools/test_mcp_tools_transform.py | 8 +- 11 files changed, 296 insertions(+), 257 deletions(-) diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 5d91b37702..93176ccf16 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -16,7 +16,7 @@ from controllers.console.wraps import ( enterprise_license_required, setup_required, ) -from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPSupportGrantType +from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration from core.mcp.auth.auth_flow import auth, handle_callback from core.mcp.error import MCPAuthError, MCPError from core.mcp.mcp_client import MCPClient @@ -44,7 +44,9 @@ def is_valid_url(url: str) -> bool: try: parsed = urlparse(url) return all([parsed.scheme, parsed.netloc]) and parsed.scheme in ["http", "https"] - except Exception: + except (ValueError, TypeError): + # ValueError: Invalid URL format + # TypeError: url is not a string return False @@ -886,7 +888,7 @@ class ToolProviderMCPApi(Resource): authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None # Create provider - with Session(db.engine) as session: + with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) result = service.create_provider( tenant_id=tenant_id, @@ -897,14 +899,10 @@ class ToolProviderMCPApi(Resource): icon_type=args["icon_type"], icon_background=args["icon_background"], server_identifier=args["server_identifier"], - timeout=configuration.timeout, - sse_read_timeout=configuration.sse_read_timeout, headers=args["headers"], - client_id=authentication.client_id if authentication else None, - client_secret=authentication.client_secret if authentication else None, - grant_type=authentication.grant_type if authentication else MCPSupportGrantType.AUTHORIZATION_CODE, + configuration=configuration, + authentication=authentication, ) - session.commit() return jsonable_encoder(result) @setup_required @@ -932,7 +930,7 @@ class ToolProviderMCPApi(Resource): authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None _, current_tenant_id = current_account_with_tenant() - with Session(db.engine) as session: + with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) service.update_provider( tenant_id=current_tenant_id, @@ -943,14 +941,10 @@ class ToolProviderMCPApi(Resource): icon_type=args["icon_type"], icon_background=args["icon_background"], server_identifier=args["server_identifier"], - timeout=configuration.timeout, - sse_read_timeout=configuration.sse_read_timeout, headers=args["headers"], - client_id=authentication.client_id if authentication else None, - client_secret=authentication.client_secret if authentication else None, - grant_type=authentication.grant_type if authentication else MCPSupportGrantType.AUTHORIZATION_CODE, + configuration=configuration, + authentication=authentication, ) - session.commit() return {"result": "success"} @setup_required @@ -962,10 +956,9 @@ class ToolProviderMCPApi(Resource): args = parser.parse_args() _, current_tenant_id = current_account_with_tenant() - with Session(db.engine) as session: + with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) - service.delete_provider(tenant_id=current_tenant_id.current_tenant_id, provider_id=args["provider_id"]) - session.commit() + service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"]) return {"result": "success"} @@ -983,23 +976,18 @@ class ToolMCPAuthApi(Resource): _, tenant_id = current_account_with_tenant() with Session(db.engine) as session: - service = MCPToolManageService(session=session) - db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id) - if not db_provider: - raise ValueError("provider not found") + with session.begin(): + service = MCPToolManageService(session=session) + db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id) + if not db_provider: + raise ValueError("provider not found") - # Convert to entity - provider_entity = db_provider.to_entity() - server_url = provider_entity.decrypt_server_url() + # Convert to entity + provider_entity = db_provider.to_entity() + server_url = provider_entity.decrypt_server_url() + headers = provider_entity.decrypt_authentication() - # Option 1: if headers is provided, use it and don't need to get token - headers = provider_entity.decrypt_headers() - - # Option 2: Add OAuth token if authed and no headers provided - if not provider_entity.headers and provider_entity.authed: - token = provider_entity.retrieve_tokens() - if token: - headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}" + # Try to connect without active transaction try: # Use MCPClientWithAuthRetry to handle authentication automatically with MCPClient( @@ -1008,18 +996,20 @@ class ToolMCPAuthApi(Resource): timeout=provider_entity.timeout, sse_read_timeout=provider_entity.sse_read_timeout, ): - service.update_provider_credentials( - provider=db_provider, - credentials=provider_entity.credentials, - authed=True, - ) - session.commit() + # Create new transaction for update + with session.begin(): + service.update_provider_credentials( + provider=db_provider, + credentials=provider_entity.credentials, + authed=True, + ) return {"result": "success"} except MCPAuthError as e: + service = MCPToolManageService(session=session) return auth(provider_entity, service, args.get("authorization_code")) except MCPError as e: - service.clear_provider_credentials(provider=db_provider) - session.commit() + with session.begin(): + service.clear_provider_credentials(provider=db_provider) raise ValueError(f"Failed to connect to MCP server: {e}") from e @@ -1044,7 +1034,7 @@ class ToolMCPListAllApi(Resource): def get(self): _, tenant_id = current_account_with_tenant() - with Session(db.engine) as session: + with Session(db.engine, expire_on_commit=False) as session: service = MCPToolManageService(session=session) tools = service.list_providers(tenant_id=tenant_id) @@ -1058,7 +1048,7 @@ class ToolMCPUpdateApi(Resource): @account_initialization_required def get(self, provider_id): _, tenant_id = current_account_with_tenant() - with Session(db.engine) as session: + with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) tools = service.list_provider_tools( tenant_id=tenant_id, @@ -1078,9 +1068,8 @@ class ToolMCPCallbackApi(Resource): authorization_code = args["code"] # Create service instance for handle_callback - with Session(db.engine) as session: + with Session(db.engine) as session, session.begin(): mcp_service = MCPToolManageService(session=session) handle_callback(state_key, authorization_code, mcp_service) - session.commit() return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") diff --git a/api/core/entities/mcp_provider.py b/api/core/entities/mcp_provider.py index ef95ed9b4c..4295aa91f1 100644 --- a/api/core/entities/mcp_provider.py +++ b/api/core/entities/mcp_provider.py @@ -4,7 +4,7 @@ from enum import StrEnum from typing import TYPE_CHECKING, Any from urllib.parse import urlparse -from pydantic import BaseModel, Field +from pydantic import BaseModel from configs import dify_config from core.entities.provider_entities import BasicProviderConfig @@ -20,7 +20,6 @@ if TYPE_CHECKING: from models.tools import MCPToolProvider # Constants -DEFAULT_GRANT_TYPE = "authorization_code" CLIENT_NAME = "Dify" CLIENT_URI = "https://github.com/langgenius/dify" DEFAULT_TOKEN_TYPE = "Bearer" @@ -34,12 +33,12 @@ class MCPSupportGrantType(StrEnum): AUTHORIZATION_CODE = "authorization_code" CLIENT_CREDENTIALS = "client_credentials" + REFRESH_TOKEN = "refresh_token" class MCPAuthentication(BaseModel): client_id: str client_secret: str | None = None - grant_type: MCPSupportGrantType = Field(default=MCPSupportGrantType.AUTHORIZATION_CODE) class MCPConfiguration(BaseModel): @@ -110,7 +109,7 @@ class MCPProviderEntity(BaseModel): credentials = self.decrypt_credentials() # Try to get grant_type from different locations - grant_type = credentials.get("grant_type", DEFAULT_GRANT_TYPE) + grant_type = credentials.get("grant_type", MCPSupportGrantType.AUTHORIZATION_CODE) # For nested structure, check if client_information has grant_types if "client_information" in credentials and isinstance(credentials["client_information"], dict): @@ -118,12 +117,12 @@ class MCPProviderEntity(BaseModel): # If grant_types is specified in client_information, use it to determine grant_type if "grant_types" in client_info and isinstance(client_info["grant_types"], list): if "client_credentials" in client_info["grant_types"]: - grant_type = "client_credentials" + grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS elif "authorization_code" in client_info["grant_types"]: - grant_type = "authorization_code" + grant_type = MCPSupportGrantType.AUTHORIZATION_CODE # Configure based on grant type - is_client_credentials = grant_type == "client_credentials" + is_client_credentials = grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS grant_types = ["refresh_token"] grant_types.append("client_credentials" if is_client_credentials else "authorization_code") @@ -212,10 +211,7 @@ class MCPProviderEntity(BaseModel): json_fields = ["redirect_uris", "grant_types", "response_types"] for field in json_fields: if field in credentials: - try: - client_info[field] = json.loads(credentials[field]) - except: - client_info[field] = [] + client_info[field] = json.loads(credentials[field]) if "scope" in credentials: client_info["scope"] = credentials["scope"] @@ -237,10 +233,10 @@ class MCPProviderEntity(BaseModel): def masked_server_url(self) -> str: """Masked server URL for display""" parsed = urlparse(self.decrypt_server_url()) - base_url = f"{parsed.scheme}://{parsed.netloc}" if parsed.path and parsed.path != "/": - return f"{base_url}/******" - return base_url + masked = parsed._replace(path="/******") + return masked.geturl() + return parsed.geturl() def _mask_value(self, value: str) -> str: """Mask a sensitive value for display""" @@ -289,46 +285,41 @@ class MCPProviderEntity(BaseModel): def _decrypt_dict(self, data: dict[str, Any]) -> dict[str, Any]: """Generic method to decrypt dictionary fields""" - try: - if not data: - return {} - - # Only decrypt fields that are actually encrypted - # For nested structures, client_information is not encrypted as a whole - encrypted_fields = [] - for key, value in data.items(): - # Skip nested objects - they are not encrypted - if isinstance(value, dict): - continue - # Only process string values that might be encrypted - if isinstance(value, str) and value: - encrypted_fields.append(key) - - if not encrypted_fields: - return data - - # Create dynamic config only for encrypted fields - config = [ - BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in encrypted_fields - ] - - encrypter_instance, _ = create_provider_encrypter( - tenant_id=self.tenant_id, - config=config, - cache=NoOpProviderCredentialCache(), - ) - - # Decrypt only the encrypted fields - decrypted_data = encrypter_instance.decrypt({k: data[k] for k in encrypted_fields}) - - # Merge decrypted data with original data (preserving non-encrypted fields) - result = data.copy() - result.update(decrypted_data) - - return result - except Exception: + if not data: return {} + # Only decrypt fields that are actually encrypted + # For nested structures, client_information is not encrypted as a whole + encrypted_fields = [] + for key, value in data.items(): + # Skip nested objects - they are not encrypted + if isinstance(value, dict): + continue + # Only process string values that might be encrypted + if isinstance(value, str) and value: + encrypted_fields.append(key) + + if not encrypted_fields: + return data + + # Create dynamic config only for encrypted fields + config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in encrypted_fields] + + encrypter_instance, _ = create_provider_encrypter( + tenant_id=self.tenant_id, + config=config, + cache=NoOpProviderCredentialCache(), + ) + + # Decrypt only the encrypted fields + decrypted_data = encrypter_instance.decrypt({k: data[k] for k in encrypted_fields}) + + # Merge decrypted data with original data (preserving non-encrypted fields) + result = data.copy() + result.update(decrypted_data) + + return result + def decrypt_headers(self) -> dict[str, Any]: """Decrypt headers""" return self._decrypt_dict(self.headers) @@ -336,3 +327,15 @@ class MCPProviderEntity(BaseModel): def decrypt_credentials(self) -> dict[str, Any]: """Decrypt credentials""" return self._decrypt_dict(self.credentials) + + def decrypt_authentication(self) -> dict[str, Any]: + """Decrypt authentication""" + # Option 1: if headers is provided, use it and don't need to get token + headers = self.decrypt_headers() + + # Option 2: Add OAuth token if authed and no headers provided + if not self.headers and self.authed: + token = self.retrieve_tokens() + if token: + headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}" + return headers diff --git a/api/core/mcp/auth/auth_flow.py b/api/core/mcp/auth/auth_flow.py index 138be598c8..4ebf97c7f2 100644 --- a/api/core/mcp/auth/auth_flow.py +++ b/api/core/mcp/auth/auth_flow.py @@ -7,10 +7,11 @@ import urllib.parse from typing import TYPE_CHECKING from urllib.parse import urljoin, urlparse -import httpx +from httpx import ConnectError, HTTPStatusError, RequestError from pydantic import BaseModel, ValidationError -from core.entities.mcp_provider import MCPProviderEntity +from core.entities.mcp_provider import MCPProviderEntity, MCPSupportGrantType +from core.helper import ssrf_proxy from core.mcp.types import ( LATEST_PROTOCOL_VERSION, OAuthClientInformation, @@ -106,15 +107,15 @@ def handle_callback(state_key: str, authorization_code: str, mcp_service: "MCPTo def check_support_resource_discovery(server_url: str) -> tuple[bool, str]: """Check if the server supports OAuth 2.0 Resource Discovery.""" - b_scheme, b_netloc, b_path, _, b_query, b_fragment = urlparse(server_url, "", True) - url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource{b_path}" + b_scheme, b_netloc, _, _, b_query, b_fragment = urlparse(server_url, "", True) + url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource" if b_query: url_for_resource_discovery += f"?{b_query}" if b_fragment: url_for_resource_discovery += f"#{b_fragment}" try: headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"} - response = httpx.get(url_for_resource_discovery, headers=headers) + response = ssrf_proxy.get(url_for_resource_discovery, headers=headers) if 200 <= response.status_code < 300: body = response.json() # Support both singular and plural forms @@ -125,7 +126,7 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]: else: return False, "" return False, "" - except httpx.RequestError: + except RequestError: # Not support resource discovery, fall back to well-known OAuth metadata return False, "" @@ -138,8 +139,8 @@ def discover_oauth_metadata(server_url: str, protocol_version: str | None = None # The oauth_discovery_url is the authorization server base URL # Try OpenID Connect discovery first (more common), then OAuth 2.0 urls_to_try = [ - urljoin(oauth_discovery_url + "/", ".well-known/openid-configuration"), urljoin(oauth_discovery_url + "/", ".well-known/oauth-authorization-server"), + urljoin(oauth_discovery_url + "/", ".well-known/openid-configuration"), ] else: urls_to_try = [urljoin(server_url, "/.well-known/oauth-authorization-server")] @@ -148,15 +149,15 @@ def discover_oauth_metadata(server_url: str, protocol_version: str | None = None for url in urls_to_try: try: - response = httpx.get(url, headers=headers) + response = ssrf_proxy.get(url, headers=headers) if response.status_code == 404: - continue # Try next URL + continue if not response.is_success: - raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata") + response.raise_for_status() return OAuthMetadata.model_validate(response.json()) - except httpx.RequestError as e: - if isinstance(e, httpx.ConnectError): - response = httpx.get(url) + except (RequestError, HTTPStatusError) as e: + if isinstance(e, ConnectError): + response = ssrf_proxy.get(url) if response.status_code == 404: continue # Try next URL if not response.is_success: @@ -232,7 +233,7 @@ def exchange_authorization( redirect_uri: str, ) -> OAuthTokens: """Exchanges an authorization code for an access token.""" - grant_type = "authorization_code" + grant_type = MCPSupportGrantType.AUTHORIZATION_CODE.value if metadata: token_url = metadata.token_endpoint @@ -252,7 +253,7 @@ def exchange_authorization( if client_information.client_secret: params["client_secret"] = client_information.client_secret - response = httpx.post(token_url, data=params) + response = ssrf_proxy.post(token_url, data=params) if not response.is_success: raise ValueError(f"Token exchange failed: HTTP {response.status_code}") return OAuthTokens.model_validate(response.json()) @@ -265,7 +266,7 @@ def refresh_authorization( refresh_token: str, ) -> OAuthTokens: """Exchange a refresh token for an updated access token.""" - grant_type = "refresh_token" + grant_type = MCPSupportGrantType.REFRESH_TOKEN.value if metadata: token_url = metadata.token_endpoint @@ -283,7 +284,7 @@ def refresh_authorization( if client_information.client_secret: params["client_secret"] = client_information.client_secret - response = httpx.post(token_url, data=params) + response = ssrf_proxy.post(token_url, data=params) if not response.is_success: raise ValueError(f"Token refresh failed: HTTP {response.status_code}") return OAuthTokens.model_validate(response.json()) @@ -296,7 +297,7 @@ def client_credentials_flow( scope: str | None = None, ) -> OAuthTokens: """Execute Client Credentials Flow to get access token.""" - grant_type = "client_credentials" + grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS.value if metadata: token_url = metadata.token_endpoint @@ -323,7 +324,7 @@ def client_credentials_flow( if client_information.client_secret: data["client_secret"] = client_information.client_secret - response = httpx.post(token_url, headers=headers, data=data) + response = ssrf_proxy.post(token_url, headers=headers, data=data) if not response.is_success: raise ValueError( f"Client credentials token request failed: HTTP {response.status_code}, Response: {response.text}" @@ -345,7 +346,7 @@ def register_client( else: registration_url = urljoin(server_url, "/register") - response = httpx.post( + response = ssrf_proxy.post( registration_url, json=client_metadata.model_dump(), headers={"Content-Type": "application/json"}, @@ -360,7 +361,6 @@ def auth( mcp_service: "MCPToolManageService", authorization_code: str | None = None, state_param: str | None = None, - grant_type: str = "authorization_code", ) -> dict[str, str]: """Orchestrates the full auth flow with a server using secure Redis state storage.""" server_url = provider.decrypt_server_url() @@ -371,25 +371,37 @@ def auth( client_information = provider.retrieve_client_information() redirect_url = provider.redirect_url - # Check if we should use client credentials flow - credentials = provider.decrypt_credentials() - stored_grant_type = credentials.get("grant_type", "authorization_code") + # Determine grant type based on server metadata + if not server_metadata: + raise ValueError("Failed to discover OAuth metadata from server") - # Use stored grant type if available, otherwise use parameter - effective_grant_type = stored_grant_type or grant_type + supported_grant_types = server_metadata.grant_types_supported or [] + + # Convert to lowercase for comparison + supported_grant_types_lower = [gt.lower() for gt in supported_grant_types] + + # Determine which grant type to use + effective_grant_type = None + if MCPSupportGrantType.AUTHORIZATION_CODE.value in supported_grant_types_lower: + effective_grant_type = MCPSupportGrantType.AUTHORIZATION_CODE.value + else: + effective_grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS.value + + # Get stored credentials + credentials = provider.decrypt_credentials() if not client_information: if authorization_code is not None: raise ValueError("Existing OAuth client information is required when exchanging an authorization code") # For client credentials flow, we don't need to register client dynamically - if effective_grant_type == "client_credentials": + if effective_grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS.value: # Client should provide client_id and client_secret directly raise ValueError("Client credentials flow requires client_id and client_secret to be provided") try: full_information = register_client(server_url, server_metadata, client_metadata) - except httpx.RequestError as e: + except RequestError as e: raise ValueError(f"Could not register OAuth client: {e}") # Save client information using service layer @@ -400,7 +412,7 @@ def auth( client_information = full_information # Handle client credentials flow - if effective_grant_type == "client_credentials": + if effective_grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS.value: # Direct token request without user interaction try: scope = credentials.get("scope") @@ -413,11 +425,14 @@ def auth( # Save tokens and grant type token_data = tokens.model_dump() - token_data["grant_type"] = "client_credentials" + token_data["grant_type"] = MCPSupportGrantType.CLIENT_CREDENTIALS.value mcp_service.save_oauth_data(provider_id, tenant_id, token_data, "tokens") return {"result": "success"} - except Exception as e: + except (RequestError, ValueError, KeyError) as e: + # RequestError: HTTP request failed + # ValueError: Invalid response data + # KeyError: Missing required fields in response raise ValueError(f"Client credentials flow failed: {e}") # Exchange authorization code for tokens (Authorization Code flow) @@ -465,7 +480,10 @@ def auth( mcp_service.save_oauth_data(provider_id, tenant_id, new_tokens.model_dump(), "tokens") return {"result": "success"} - except Exception as e: + except (RequestError, ValueError, KeyError) as e: + # RequestError: HTTP request failed + # ValueError: Invalid response data + # KeyError: Missing required fields in response raise ValueError(f"Could not refresh OAuth tokens: {e}") # Start new authorization flow (only for authorization code flow) diff --git a/api/core/mcp/auth_client.py b/api/core/mcp/auth_client.py index 839090fde9..95f552f5db 100644 --- a/api/core/mcp/auth_client.py +++ b/api/core/mcp/auth_client.py @@ -99,7 +99,11 @@ class MCPClientWithAuthRetry(MCPClient): # Clear authorization code after first use self.authorization_code = None + except MCPAuthError: + # Re-raise MCPAuthError as is + raise except Exception as e: + # Catch all exceptions during auth retry logger.exception("Authentication retry failed") raise MCPAuthError(f"Authentication retry failed: {e}") from e diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index c3cb19a312..8f7d1101cb 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -4,6 +4,7 @@ from typing import Any, Literal from pydantic import BaseModel, Field, field_validator +from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.__base.tool import ToolParameter from core.tools.entities.common_entities import I18nObject @@ -47,9 +48,9 @@ class ToolProviderApiEntity(BaseModel): masked_headers: dict[str, str] | None = Field(default=None, description="The masked headers of the MCP tool") original_headers: dict[str, str] | None = Field(default=None, description="The original headers of the MCP tool") - authentication: dict[str, str] | None = Field(default=None, description="The OAuth config of the MCP tool") + authentication: MCPAuthentication | None = Field(default=None, description="The OAuth config of the MCP tool") is_dynamic_registration: bool = Field(default=True, description="Whether the MCP tool is dynamically registered") - configuration: dict[str, str] | None = Field( + configuration: MCPConfiguration | None = Field( default=None, description="The timeout and sse_read_timeout of the MCP tool" ) @@ -74,8 +75,14 @@ class ToolProviderApiEntity(BaseModel): if self.type == ToolProviderType.MCP: optional_fields.update(self.optional_field("updated_at", self.updated_at)) optional_fields.update(self.optional_field("server_identifier", self.server_identifier)) - optional_fields.update(self.optional_field("configuration", self.configuration)) - optional_fields.update(self.optional_field("authentication", self.authentication)) + optional_fields.update( + self.optional_field( + "configuration", self.configuration.model_dump() if self.configuration else MCPConfiguration() + ) + ) + optional_fields.update( + self.optional_field("authentication", self.authentication.model_dump() if self.authentication else None) + ) optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration)) optional_fields.update(self.optional_field("masked_headers", self.masked_headers)) optional_fields.update(self.optional_field("original_headers", self.original_headers)) diff --git a/api/core/tools/mcp_tool/provider.py b/api/core/tools/mcp_tool/provider.py index 3161a45635..3404f5c3b4 100644 --- a/api/core/tools/mcp_tool/provider.py +++ b/api/core/tools/mcp_tool/provider.py @@ -61,10 +61,7 @@ class MCPToolProviderController(ToolProviderController): """ create a MCPToolProviderController from a MCPProviderEntity """ - try: - remote_mcp_tools = [RemoteMCPTool(**tool) for tool in entity.tools] - except Exception: - remote_mcp_tools = [] + remote_mcp_tools = [RemoteMCPTool(**tool) for tool in entity.tools] tools = [ ToolEntity( @@ -87,7 +84,7 @@ class MCPToolProviderController(ToolProviderController): ) for remote_mcp_tool in remote_mcp_tools ] - if not db_provider.icon: + if not entity.icon: raise ValueError("Database provider icon is required") return cls( entity=ToolProviderEntityWithPlugin( diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index 12749c9b89..290077ecd8 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -60,11 +60,18 @@ class MCPTool(Tool): def _process_text_content(self, content: TextContent) -> Generator[ToolInvokeMessage, None, None]: """Process text content and yield appropriate messages.""" - try: - content_json = json.loads(content.text) - yield from self._process_json_content(content_json) - except json.JSONDecodeError: - yield self.create_text_message(content.text) + # Check if content looks like JSON before attempting to parse + text = content.text.strip() + if text and text[0] in ("{", "[") and text[-1] in ("}", "]"): + try: + content_json = json.loads(text) + yield from self._process_json_content(content_json) + return + except json.JSONDecodeError: + pass + + # If not JSON or parsing failed, treat as plain text + yield self.create_text_message(content.text) def _process_json_content(self, content_json: Any) -> Generator[ToolInvokeMessage, None, None]: """Process JSON content based on its type.""" @@ -119,10 +126,6 @@ class MCPTool(Tool): tool_parameters = self._handle_none_parameter(tool_parameters) # Get provider entity to access tokens - from typing import TYPE_CHECKING - - if TYPE_CHECKING: - pass # Get MCP service from invoke parameters or create new one provider_entity = None @@ -131,18 +134,7 @@ class MCPTool(Tool): # Check if mcp_service is passed in tool_parameters if "_mcp_service" in tool_parameters: mcp_service = tool_parameters.pop("_mcp_service") - else: - # Fallback to creating service with database session - from sqlalchemy.orm import Session - - from extensions.ext_database import db - from services.tools.mcp_tools_manage_service import MCPToolManageService - - with Session(db.engine) as session: - mcp_service = MCPToolManageService(session=session) - - if mcp_service: - try: + if mcp_service: provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True) headers = provider_entity.decrypt_headers() # Try to get existing token and add to headers @@ -150,23 +142,54 @@ class MCPTool(Tool): tokens = provider_entity.retrieve_tokens() if tokens and tokens.access_token: headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}" - except Exception: - # If provider retrieval or token fails, continue without auth - pass - # Use MCPClientWithAuthRetry to handle authentication automatically - try: - with MCPClientWithAuthRetry( - server_url=provider_entity.decrypt_server_url() if provider_entity else self.server_url, - headers=headers, - timeout=self.timeout, - sse_read_timeout=self.sse_read_timeout, - provider_entity=provider_entity, - auth_callback=auth if mcp_service else None, - mcp_service=mcp_service, - ) as mcp_client: - return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters) - except MCPConnectionError as e: - raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e - except Exception as e: - raise ToolInvokeError(f"Failed to invoke tool: {e}") from e + # Use MCPClientWithAuthRetry to handle authentication automatically + try: + with MCPClientWithAuthRetry( + server_url=provider_entity.decrypt_server_url() if provider_entity else self.server_url, + headers=headers, + timeout=self.timeout, + sse_read_timeout=self.sse_read_timeout, + provider_entity=provider_entity, + auth_callback=auth if mcp_service else None, + mcp_service=mcp_service, + ) as mcp_client: + return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters) + except MCPConnectionError as e: + raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e + except (ValueError, TypeError, KeyError) as e: + # Catch specific exceptions that might occur during tool invocation + raise ToolInvokeError(f"Failed to invoke tool: {e}") from e + else: + # Fallback to creating service with database session + from sqlalchemy.orm import Session + + from extensions.ext_database import db + from services.tools.mcp_tools_manage_service import MCPToolManageService + + with Session(db.engine, expire_on_commit=False) as session: + mcp_service = MCPToolManageService(session=session) + provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True) + headers = provider_entity.decrypt_headers() + # Try to get existing token and add to headers + if not headers: + tokens = provider_entity.retrieve_tokens() + if tokens and tokens.access_token: + headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}" + + # Use MCPClientWithAuthRetry to handle authentication automatically + try: + with MCPClientWithAuthRetry( + server_url=provider_entity.decrypt_server_url() if provider_entity else self.server_url, + headers=headers, + timeout=self.timeout, + sse_read_timeout=self.sse_read_timeout, + provider_entity=provider_entity, + auth_callback=auth if mcp_service else None, + mcp_service=mcp_service, + ) as mcp_client: + return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters) + except MCPConnectionError as e: + raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e + except Exception as e: + raise ToolInvokeError(f"Failed to invoke tool: {e}") from e diff --git a/api/services/tools/mcp_tools_manage_service.py b/api/services/tools/mcp_tools_manage_service.py index 0529b84fdb..57a2cd49f9 100644 --- a/api/services/tools/mcp_tools_manage_service.py +++ b/api/services/tools/mcp_tools_manage_service.py @@ -9,8 +9,7 @@ from sqlalchemy import or_, select from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session -from configs import dify_config -from core.entities.mcp_provider import MCPProviderEntity +from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity from core.helper import encrypter from core.helper.provider_cache import NoOpProviderCredentialCache from core.mcp.auth_client import MCPClientWithAuthRetry @@ -24,7 +23,6 @@ logger = logging.getLogger(__name__) # Constants UNCHANGED_SERVER_URL_PLACEHOLDER = "[__HIDDEN__]" -DEFAULT_GRANT_TYPE = "authorization_code" CLIENT_NAME = "Dify" EMPTY_TOOLS_JSON = "[]" EMPTY_CREDENTIALS_JSON = "{}" @@ -88,12 +86,9 @@ class MCPToolManageService: icon_type: str, icon_background: str, server_identifier: str, - timeout: float, - sse_read_timeout: float, + configuration: MCPConfiguration, + authentication: MCPAuthentication | None = None, headers: dict[str, str] | None = None, - client_id: str | None = None, - client_secret: str | None = None, - grant_type: str = DEFAULT_GRANT_TYPE, ) -> ToolProviderApiEntity: """Create a new MCP provider.""" server_url_hash = hashlib.sha256(server_url.encode()).hexdigest() @@ -104,9 +99,11 @@ class MCPToolManageService: # Encrypt sensitive data encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url) encrypted_headers = self._prepare_encrypted_dict(headers, tenant_id) if headers else None - if client_id and client_secret: + if authentication is not None: # Build the full credentials structure with encrypted client_id and client_secret - encrypted_credentials = self._build_and_encrypt_credentials(client_id, client_secret, grant_type, tenant_id) + encrypted_credentials = self._build_and_encrypt_credentials( + authentication.client_id, authentication.client_secret, tenant_id + ) else: encrypted_credentials = None # Create provider @@ -120,16 +117,16 @@ class MCPToolManageService: tools=EMPTY_TOOLS_JSON, icon=self._prepare_icon(icon, icon_type, icon_background), server_identifier=server_identifier, - timeout=timeout, - sse_read_timeout=sse_read_timeout, + timeout=configuration.timeout, + sse_read_timeout=configuration.sse_read_timeout, encrypted_headers=encrypted_headers, encrypted_credentials=encrypted_credentials, ) self._session.add(mcp_tool) - self._session.commit() - - return ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True) + self._session.flush() + mcp_providers = ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True) + return mcp_providers def update_provider( self, @@ -142,12 +139,9 @@ class MCPToolManageService: icon_type: str, icon_background: str, server_identifier: str, - timeout: float | None = None, - sse_read_timeout: float | None = None, headers: dict[str, str] | None = None, - client_id: str | None = None, - client_secret: str | None = None, - grant_type: str | None = None, + configuration: MCPConfiguration, + authentication: MCPAuthentication | None = None, ) -> None: """Update an MCP provider.""" mcp_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) @@ -185,10 +179,10 @@ class MCPToolManageService: mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"] # Update optional fields - if timeout is not None: - mcp_provider.timeout = timeout - if sse_read_timeout is not None: - mcp_provider.sse_read_timeout = sse_read_timeout + if configuration.timeout is not None: + mcp_provider.timeout = configuration.timeout + if configuration.sse_read_timeout is not None: + mcp_provider.sse_read_timeout = configuration.sse_read_timeout if headers is not None: if headers: # Build headers preserving unchanged masked values @@ -200,20 +194,18 @@ class MCPToolManageService: mcp_provider.encrypted_headers = None # Update credentials if provided - if client_id is not None and client_secret is not None: + if authentication is not None: # Merge with existing credentials to handle masked values ( final_client_id, final_client_secret, - final_grant_type, - ) = self._merge_credentials_with_masked(client_id, client_secret, grant_type, mcp_provider) - - # Use default grant_type if none found - final_grant_type = final_grant_type or DEFAULT_GRANT_TYPE + ) = self._merge_credentials_with_masked( + authentication.client_id, authentication.client_secret, mcp_provider + ) # Build and encrypt new credentials encrypted_credentials = self._build_and_encrypt_credentials( - final_client_id, final_client_secret, final_grant_type, tenant_id + final_client_id, final_client_secret, tenant_id ) mcp_provider.encrypted_credentials = encrypted_credentials @@ -221,7 +213,11 @@ class MCPToolManageService: except IntegrityError as e: self._session.rollback() self._handle_integrity_error(e, name, server_url, server_identifier) - except Exception: + except (ValueError, AttributeError, TypeError) as e: + # Catch specific exceptions that might occur during update + # ValueError: invalid data provided + # AttributeError: missing required attributes + # TypeError: type conversion errors self._session.rollback() raise @@ -271,7 +267,7 @@ class MCPToolManageService: db_provider.tools = json.dumps([tool.model_dump() for tool in tools]) db_provider.authed = True db_provider.updated_at = datetime.now() - self._session.commit() + self._session.flush() # Build API response return self._build_tool_provider_response(db_provider, provider_entity, tools) @@ -309,7 +305,7 @@ class MCPToolManageService: if not authed: provider.tools = EMPTY_TOOLS_JSON - self._session.commit() + self._session.flush() def save_oauth_data(self, provider_id: str, tenant_id: str, data: dict[str, Any], data_type: str = "mixed") -> None: """ @@ -495,20 +491,21 @@ class MCPToolManageService: def _merge_credentials_with_masked( self, client_id: str, - client_secret: str, - grant_type: str | None, + client_secret: str | None, mcp_provider: MCPToolProvider, - ) -> tuple[str, str, str | None]: + ) -> tuple[ + str, + str | None, + ]: """Merge incoming credentials with existing ones, preserving unchanged masked values. Args: client_id: Client ID from frontend (may be masked) client_secret: Client secret from frontend (may be masked) - grant_type: Grant type from frontend mcp_provider: The MCP provider instance Returns: - Tuple of (final_client_id, final_client_secret, grant_type) + Tuple of (final_client_id, final_client_secret) """ mcp_provider_entity = mcp_provider.to_entity() existing_decrypted = mcp_provider_entity.decrypt_credentials() @@ -526,35 +523,18 @@ class MCPToolManageService: # Use existing decrypted value final_client_secret = existing_decrypted.get("client_secret", client_secret) - final_grant_type = grant_type if grant_type is not None else existing_decrypted.get("grant_type") + return final_client_id, final_client_secret - return final_client_id, final_client_secret, final_grant_type - - def _build_and_encrypt_credentials( - self, client_id: str, client_secret: str, grant_type: str, tenant_id: str - ) -> str: + def _build_and_encrypt_credentials(self, client_id: str, client_secret: str | None, tenant_id: str) -> str: """Build credentials and encrypt sensitive fields.""" # Create a flat structure with all credential data credentials_data = { "client_id": client_id, "client_secret": client_secret, - "grant_type": grant_type, "client_name": CLIENT_NAME, "is_dynamic_registration": False, } - # Add grant types and response types based on grant_type - if grant_type == "client_credentials": - credentials_data["grant_types"] = json.dumps(["client_credentials"]) - credentials_data["response_types"] = json.dumps([]) - credentials_data["redirect_uris"] = json.dumps([]) - else: - credentials_data["grant_types"] = json.dumps(["authorization_code", "refresh_token"]) - credentials_data["response_types"] = json.dumps(["code"]) - credentials_data["redirect_uris"] = json.dumps( - [f"{dify_config.CONSOLE_API_URL}/console/api/mcp/oauth/callback"] - ) - # Only client_id and client_secret need encryption - secret_fields = ["client_id", "client_secret"] + secret_fields = ["client_id", "client_secret"] if client_secret else ["client_id"] return self._encrypt_dict_fields(credentials_data, secret_fields, tenant_id) diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index e68ecfe74d..22f63a7aa4 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -6,6 +6,7 @@ from typing import Any, Union from yarl import URL from configs import dify_config +from core.entities.mcp_provider import MCPConfiguration from core.helper.provider_cache import ToolProviderCredentialsCache from core.mcp.types import Tool as MCPTool from core.plugin.entities.plugin_daemon import PluginDatasourceProviderEntity @@ -246,6 +247,13 @@ class ToolTransformService: ) response["server_identifier"] = db_provider.server_identifier + # Convert configuration dict to MCPConfiguration object + if "configuration" in response and isinstance(response["configuration"], dict): + response["configuration"] = MCPConfiguration( + timeout=float(response["configuration"]["timeout"]), + sse_read_timeout=float(response["configuration"]["sse_read_timeout"]), + ) + return ToolProviderApiEntity(**response) @staticmethod diff --git a/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py b/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py index 986dbefdf6..26b5d1f7ce 100644 --- a/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py +++ b/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py @@ -139,7 +139,7 @@ class TestRedisStateManagement: class TestOAuthDiscovery: """Test OAuth discovery functions.""" - @patch("httpx.get") + @patch("core.helper.ssrf_proxy.get") def test_check_support_resource_discovery_success(self, mock_get): """Test successful resource discovery check.""" mock_response = Mock() @@ -152,11 +152,11 @@ class TestOAuthDiscovery: assert supported is True assert auth_url == "https://auth.example.com" mock_get.assert_called_once_with( - "https://api.example.com/.well-known/oauth-protected-resource/endpoint", + "https://api.example.com/.well-known/oauth-protected-resource", headers={"MCP-Protocol-Version": "2025-03-26", "User-Agent": "Dify"}, ) - @patch("httpx.get") + @patch("core.helper.ssrf_proxy.get") def test_check_support_resource_discovery_not_supported(self, mock_get): """Test resource discovery not supported.""" mock_response = Mock() @@ -168,7 +168,7 @@ class TestOAuthDiscovery: assert supported is False assert auth_url == "" - @patch("httpx.get") + @patch("core.helper.ssrf_proxy.get") def test_check_support_resource_discovery_with_query_fragment(self, mock_get): """Test resource discovery with query and fragment.""" mock_response = Mock() @@ -181,11 +181,11 @@ class TestOAuthDiscovery: assert supported is True assert auth_url == "https://auth.example.com" mock_get.assert_called_once_with( - "https://api.example.com/.well-known/oauth-protected-resource/path?query=1#fragment", + "https://api.example.com/.well-known/oauth-protected-resource?query=1#fragment", headers={"MCP-Protocol-Version": "2025-03-26", "User-Agent": "Dify"}, ) - @patch("httpx.get") + @patch("core.helper.ssrf_proxy.get") def test_discover_oauth_metadata_with_resource_discovery(self, mock_get): """Test OAuth metadata discovery with resource discovery support.""" with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check: @@ -207,11 +207,11 @@ class TestOAuthDiscovery: assert metadata.authorization_endpoint == "https://auth.example.com/authorize" assert metadata.token_endpoint == "https://auth.example.com/token" mock_get.assert_called_once_with( - "https://auth.example.com/.well-known/openid-configuration", + "https://auth.example.com/.well-known/oauth-authorization-server", headers={"MCP-Protocol-Version": "2025-03-26"}, ) - @patch("httpx.get") + @patch("core.helper.ssrf_proxy.get") def test_discover_oauth_metadata_without_resource_discovery(self, mock_get): """Test OAuth metadata discovery without resource discovery.""" with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check: @@ -236,7 +236,7 @@ class TestOAuthDiscovery: headers={"MCP-Protocol-Version": "2025-03-26"}, ) - @patch("httpx.get") + @patch("core.helper.ssrf_proxy.get") def test_discover_oauth_metadata_not_found(self, mock_get): """Test OAuth metadata discovery when not found.""" with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check: @@ -336,7 +336,7 @@ class TestAuthorizationFlow: assert "does not support response type code" in str(exc_info.value) - @patch("httpx.post") + @patch("core.helper.ssrf_proxy.post") def test_exchange_authorization_success(self, mock_post): """Test successful authorization code exchange.""" mock_response = Mock() @@ -384,7 +384,7 @@ class TestAuthorizationFlow: }, ) - @patch("httpx.post") + @patch("core.helper.ssrf_proxy.post") def test_exchange_authorization_failure(self, mock_post): """Test failed authorization code exchange.""" mock_response = Mock() @@ -406,7 +406,7 @@ class TestAuthorizationFlow: assert "Token exchange failed: HTTP 400" in str(exc_info.value) - @patch("httpx.post") + @patch("core.helper.ssrf_proxy.post") def test_refresh_authorization_success(self, mock_post): """Test successful token refresh.""" mock_response = Mock() @@ -442,7 +442,7 @@ class TestAuthorizationFlow: }, ) - @patch("httpx.post") + @patch("core.helper.ssrf_proxy.post") def test_register_client_success(self, mock_post): """Test successful client registration.""" mock_response = Mock() @@ -576,7 +576,12 @@ class TestAuthOrchestration: def test_auth_new_registration(self, mock_start_auth, mock_register, mock_discover, mock_provider, mock_service): """Test auth flow for new client registration.""" # Setup - mock_discover.return_value = None + mock_discover.return_value = OAuthMetadata( + authorization_endpoint="https://auth.example.com/authorize", + token_endpoint="https://auth.example.com/token", + response_types_supported=["code"], + grant_types_supported=["authorization_code"], + ) mock_register.return_value = OAuthClientInformationFull( client_id="new-client-id", client_name="Dify", @@ -679,7 +684,12 @@ class TestAuthOrchestration: mock_refresh.return_value = new_tokens with patch("core.mcp.auth.auth_flow.discover_oauth_metadata") as mock_discover: - mock_discover.return_value = None + mock_discover.return_value = OAuthMetadata( + authorization_endpoint="https://auth.example.com/authorize", + token_endpoint="https://auth.example.com/token", + response_types_supported=["code"], + grant_types_supported=["authorization_code"], + ) result = auth(mock_provider, mock_service) diff --git a/api/tests/unit_tests/services/tools/test_mcp_tools_transform.py b/api/tests/unit_tests/services/tools/test_mcp_tools_transform.py index 8505b3cbec..7511fd6f0c 100644 --- a/api/tests/unit_tests/services/tools/test_mcp_tools_transform.py +++ b/api/tests/unit_tests/services/tools/test_mcp_tools_transform.py @@ -228,8 +228,7 @@ class TestMCPToolTransform: "masked_headers": {"Authorization": "Bearer *****"}, "updated_at": 1234567890, "labels": [], - "timeout": 30, - "sse_read_timeout": 300, + "configuration": {"timeout": "30", "sse_read_timeout": "300"}, "original_headers": {"Authorization": "Bearer secret-token"}, "author": "Test User", "description": I18nObject(en_US="Test MCP Provider Description", zh_Hans="Test MCP Provider Description"), @@ -246,8 +245,9 @@ class TestMCPToolTransform: assert isinstance(result, ToolProviderApiEntity) assert result.id == "server-identifier-456" # Should use server_identifier when for_list=False assert result.server_identifier == "server-identifier-456" - assert result.timeout == 30 - assert result.sse_read_timeout == 300 + assert result.configuration is not None + assert result.configuration.timeout == 30 + assert result.configuration.sse_read_timeout == 300 assert result.original_headers == {"Authorization": "Bearer secret-token"} assert len(result.tools) == 1 assert result.tools[0].description.en_US == "Tool description"