From 0ded6303c1e29b3bc01b46f7b11077cd4bf50d44 Mon Sep 17 00:00:00 2001 From: Novice Date: Mon, 27 Oct 2025 17:07:51 +0800 Subject: [PATCH] feat: implement MCP specification 2025-06-18 (#25766) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../console/workspace/tool_providers.py | 211 +++-- api/core/entities/mcp_provider.py | 328 +++++++ api/core/mcp/auth/auth_flow.py | 320 +++++-- api/core/mcp/auth/auth_provider.py | 77 -- api/core/mcp/auth_client.py | 191 ++++ api/core/mcp/auth_client_comparison.md | 0 api/core/mcp/client/sse_client.py | 51 +- api/core/mcp/client/streamable_client.py | 79 +- api/core/mcp/entities.py | 47 +- api/core/mcp/error.py | 4 + api/core/mcp/mcp_client.py | 103 +- api/core/mcp/session/base_session.py | 7 +- api/core/mcp/session/client_session.py | 2 +- api/core/mcp/types.py | 264 ++++-- api/core/tools/__base/tool.py | 13 + api/core/tools/entities/api_entities.py | 20 +- api/core/tools/mcp_tool/provider.py | 46 +- api/core/tools/mcp_tool/tool.py | 92 +- api/core/tools/tool_manager.py | 75 +- api/models/tools.py | 123 +-- .../tools/mcp_tools_manage_service.py | 888 +++++++++++++----- api/services/tools/tools_transform_service.py | 77 +- .../tools/test_mcp_tools_manage_service.py | 515 ++++++---- api/tests/unit_tests/core/mcp/__init__.py | 0 .../unit_tests/core/mcp/auth/__init__.py | 0 .../core/mcp/auth/test_auth_flow.py | 740 +++++++++++++++ .../core/mcp/test_auth_client_inheritance.py | 0 .../unit_tests/core/mcp/test_entities.py | 239 +++++ api/tests/unit_tests/core/mcp/test_error.py | 205 ++++ .../unit_tests/core/mcp/test_mcp_client.py | 382 ++++++++ api/tests/unit_tests/core/mcp/test_types.py | 492 ++++++++++ api/tests/unit_tests/core/mcp/test_utils.py | 355 +++++++ .../tools/test_mcp_tools_transform.py | 45 +- 33 files changed, 4863 insertions(+), 1128 deletions(-) create mode 100644 api/core/entities/mcp_provider.py delete mode 100644 api/core/mcp/auth/auth_provider.py create mode 100644 api/core/mcp/auth_client.py create mode 100644 api/core/mcp/auth_client_comparison.md create mode 100644 api/tests/unit_tests/core/mcp/__init__.py create mode 100644 api/tests/unit_tests/core/mcp/auth/__init__.py create mode 100644 api/tests/unit_tests/core/mcp/auth/test_auth_flow.py create mode 100644 api/tests/unit_tests/core/mcp/test_auth_client_inheritance.py create mode 100644 api/tests/unit_tests/core/mcp/test_entities.py create mode 100644 api/tests/unit_tests/core/mcp/test_error.py create mode 100644 api/tests/unit_tests/core/mcp/test_mcp_client.py create mode 100644 api/tests/unit_tests/core/mcp/test_types.py create mode 100644 api/tests/unit_tests/core/mcp/test_utils.py diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index cc50131f0a..870ad87c6c 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -6,6 +6,7 @@ from flask_restx import ( Resource, reqparse, ) +from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden from configs import dify_config @@ -15,20 +16,21 @@ from controllers.console.wraps import ( enterprise_license_required, setup_required, ) +from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration from core.mcp.auth.auth_flow import auth, handle_callback -from core.mcp.auth.auth_provider import OAuthClientProvider -from core.mcp.error import MCPAuthError, MCPError +from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError from core.mcp.mcp_client import MCPClient from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.impl.oauth import OAuthHandler from core.tools.entities.tool_entities import CredentialType +from extensions.ext_database import db from libs.helper import StrLen, alphanumeric, uuid_value from libs.login import current_account_with_tenant, login_required from models.provider_ids import ToolProviderID from services.plugin.oauth_service import OAuthProxyService from services.tools.api_tools_manage_service import ApiToolManageService from services.tools.builtin_tools_manage_service import BuiltinToolManageService -from services.tools.mcp_tools_manage_service import MCPToolManageService +from services.tools.mcp_tools_manage_service import MCPToolManageService, OAuthDataType from services.tools.tool_labels_service import ToolLabelsService from services.tools.tools_manage_service import ToolCommonService from services.tools.tools_transform_service import ToolTransformService @@ -42,7 +44,9 @@ def is_valid_url(url: str) -> bool: try: parsed = urlparse(url) return all([parsed.scheme, parsed.netloc]) and parsed.scheme in ["http", "https"] - except Exception: + except (ValueError, TypeError): + # ValueError: Invalid URL format + # TypeError: url is not a string return False @@ -886,29 +890,34 @@ class ToolProviderMCPApi(Resource): .add_argument("icon_type", type=str, required=True, nullable=False, location="json") .add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="") .add_argument("server_identifier", type=str, required=True, nullable=False, location="json") - .add_argument("timeout", type=float, required=False, nullable=False, location="json", default=30) - .add_argument("sse_read_timeout", type=float, required=False, nullable=False, location="json", default=300) + .add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={}) .add_argument("headers", type=dict, required=False, nullable=True, location="json", default={}) + .add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={}) ) args = parser.parse_args() user, tenant_id = current_account_with_tenant() - if not is_valid_url(args["server_url"]): - raise ValueError("Server URL is not valid.") - return jsonable_encoder( - MCPToolManageService.create_mcp_provider( + + # Parse and validate models + configuration = MCPConfiguration.model_validate(args["configuration"]) + authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None + + # Create provider + with Session(db.engine) as session, session.begin(): + service = MCPToolManageService(session=session) + result = service.create_provider( tenant_id=tenant_id, + user_id=user.id, server_url=args["server_url"], name=args["name"], icon=args["icon"], icon_type=args["icon_type"], icon_background=args["icon_background"], - user_id=user.id, server_identifier=args["server_identifier"], - timeout=args["timeout"], - sse_read_timeout=args["sse_read_timeout"], headers=args["headers"], + configuration=configuration, + authentication=authentication, ) - ) + return jsonable_encoder(result) @setup_required @login_required @@ -923,31 +932,43 @@ class ToolProviderMCPApi(Resource): .add_argument("icon_background", type=str, required=False, nullable=True, location="json") .add_argument("provider_id", type=str, required=True, nullable=False, location="json") .add_argument("server_identifier", type=str, required=True, nullable=False, location="json") - .add_argument("timeout", type=float, required=False, nullable=True, location="json") - .add_argument("sse_read_timeout", type=float, required=False, nullable=True, location="json") - .add_argument("headers", type=dict, required=False, nullable=True, location="json") + .add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={}) + .add_argument("headers", type=dict, required=False, nullable=True, location="json", default={}) + .add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={}) ) args = parser.parse_args() - if not is_valid_url(args["server_url"]): - if "[__HIDDEN__]" in args["server_url"]: - pass - else: - raise ValueError("Server URL is not valid.") + configuration = MCPConfiguration.model_validate(args["configuration"]) + authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None _, current_tenant_id = current_account_with_tenant() - MCPToolManageService.update_mcp_provider( - tenant_id=current_tenant_id, - provider_id=args["provider_id"], - server_url=args["server_url"], - name=args["name"], - icon=args["icon"], - icon_type=args["icon_type"], - icon_background=args["icon_background"], - server_identifier=args["server_identifier"], - timeout=args.get("timeout"), - sse_read_timeout=args.get("sse_read_timeout"), - headers=args.get("headers"), - ) - return {"result": "success"} + + # Step 1: Validate server URL change if needed (includes URL format validation and network operation) + validation_result = None + with Session(db.engine) as session: + service = MCPToolManageService(session=session) + validation_result = service.validate_server_url_change( + tenant_id=current_tenant_id, provider_id=args["provider_id"], new_server_url=args["server_url"] + ) + + # No need to check for errors here, exceptions will be raised directly + + # Step 2: Perform database update in a transaction + with Session(db.engine) as session, session.begin(): + service = MCPToolManageService(session=session) + service.update_provider( + tenant_id=current_tenant_id, + provider_id=args["provider_id"], + server_url=args["server_url"], + name=args["name"], + icon=args["icon"], + icon_type=args["icon_type"], + icon_background=args["icon_background"], + server_identifier=args["server_identifier"], + headers=args["headers"], + configuration=configuration, + authentication=authentication, + validation_result=validation_result, + ) + return {"result": "success"} @setup_required @login_required @@ -958,8 +979,11 @@ class ToolProviderMCPApi(Resource): ) args = parser.parse_args() _, current_tenant_id = current_account_with_tenant() - MCPToolManageService.delete_mcp_tool(tenant_id=current_tenant_id, provider_id=args["provider_id"]) - return {"result": "success"} + + with Session(db.engine) as session, session.begin(): + service = MCPToolManageService(session=session) + service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"]) + return {"result": "success"} @console_ns.route("/workspaces/current/tool-provider/mcp/auth") @@ -976,37 +1000,53 @@ class ToolMCPAuthApi(Resource): args = parser.parse_args() provider_id = args["provider_id"] _, tenant_id = current_account_with_tenant() - provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id) - if not provider: - raise ValueError("provider not found") - try: - with MCPClient( - provider.decrypted_server_url, - provider_id, - tenant_id, - authed=False, - authorization_code=args["authorization_code"], - for_list=True, - headers=provider.decrypted_headers, - timeout=provider.timeout, - sse_read_timeout=provider.sse_read_timeout, - ): - MCPToolManageService.update_mcp_provider_credentials( - mcp_provider=provider, - credentials=provider.decrypted_credentials, - authed=True, - ) - return {"result": "success"} - except MCPAuthError: - auth_provider = OAuthClientProvider(provider_id, tenant_id, for_list=True) - return auth(auth_provider, provider.decrypted_server_url, args["authorization_code"]) + with Session(db.engine) as session, session.begin(): + service = MCPToolManageService(session=session) + db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id) + if not db_provider: + raise ValueError("provider not found") + + # Convert to entity + provider_entity = db_provider.to_entity() + server_url = provider_entity.decrypt_server_url() + headers = provider_entity.decrypt_authentication() + + # Try to connect without active transaction + try: + # Use MCPClientWithAuthRetry to handle authentication automatically + with MCPClient( + server_url=server_url, + headers=headers, + timeout=provider_entity.timeout, + sse_read_timeout=provider_entity.sse_read_timeout, + ): + # Update credentials in new transaction + with Session(db.engine) as session, session.begin(): + service = MCPToolManageService(session=session) + service.update_provider_credentials( + provider_id=provider_id, + tenant_id=tenant_id, + credentials=provider_entity.credentials, + authed=True, + ) + return {"result": "success"} + except MCPAuthError as e: + try: + auth_result = auth(provider_entity, args.get("authorization_code")) + with Session(db.engine) as session, session.begin(): + service = MCPToolManageService(session=session) + response = service.execute_auth_actions(auth_result) + return response + except MCPRefreshTokenError as e: + with Session(db.engine) as session, session.begin(): + service = MCPToolManageService(session=session) + service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id) + raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e except MCPError as e: - MCPToolManageService.update_mcp_provider_credentials( - mcp_provider=provider, - credentials={}, - authed=False, - ) + with Session(db.engine) as session, session.begin(): + service = MCPToolManageService(session=session) + service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id) raise ValueError(f"Failed to connect to MCP server: {e}") from e @@ -1017,8 +1057,10 @@ class ToolMCPDetailApi(Resource): @account_initialization_required def get(self, provider_id): _, tenant_id = current_account_with_tenant() - provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id) - return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True)) + with Session(db.engine) as session, session.begin(): + service = MCPToolManageService(session=session) + provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id) + return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True)) @console_ns.route("/workspaces/current/tools/mcp") @@ -1029,9 +1071,12 @@ class ToolMCPListAllApi(Resource): def get(self): _, tenant_id = current_account_with_tenant() - tools = MCPToolManageService.retrieve_mcp_tools(tenant_id=tenant_id) + with Session(db.engine) as session, session.begin(): + service = MCPToolManageService(session=session) + # Skip sensitive data decryption for list view to improve performance + tools = service.list_providers(tenant_id=tenant_id, include_sensitive=False) - return [tool.to_dict() for tool in tools] + return [tool.to_dict() for tool in tools] @console_ns.route("/workspaces/current/tool-provider/mcp/update/") @@ -1041,11 +1086,13 @@ class ToolMCPUpdateApi(Resource): @account_initialization_required def get(self, provider_id): _, tenant_id = current_account_with_tenant() - tools = MCPToolManageService.list_mcp_tool_from_remote_server( - tenant_id=tenant_id, - provider_id=provider_id, - ) - return jsonable_encoder(tools) + with Session(db.engine) as session, session.begin(): + service = MCPToolManageService(session=session) + tools = service.list_provider_tools( + tenant_id=tenant_id, + provider_id=provider_id, + ) + return jsonable_encoder(tools) @console_ns.route("/mcp/oauth/callback") @@ -1059,5 +1106,15 @@ class ToolMCPCallbackApi(Resource): args = parser.parse_args() state_key = args["state"] authorization_code = args["code"] - handle_callback(state_key, authorization_code) + + # Create service instance for handle_callback + with Session(db.engine) as session, session.begin(): + mcp_service = MCPToolManageService(session=session) + # handle_callback now returns state data and tokens + state_data, tokens = handle_callback(state_key, authorization_code) + # Save tokens using the service layer + mcp_service.save_oauth_data( + state_data.provider_id, state_data.tenant_id, tokens.model_dump(), OAuthDataType.TOKENS + ) + return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") diff --git a/api/core/entities/mcp_provider.py b/api/core/entities/mcp_provider.py new file mode 100644 index 0000000000..4ac39cef02 --- /dev/null +++ b/api/core/entities/mcp_provider.py @@ -0,0 +1,328 @@ +import json +from datetime import datetime +from enum import StrEnum +from typing import TYPE_CHECKING, Any +from urllib.parse import urlparse + +from pydantic import BaseModel + +from configs import dify_config +from core.entities.provider_entities import BasicProviderConfig +from core.file import helpers as file_helpers +from core.helper import encrypter +from core.helper.provider_cache import NoOpProviderCredentialCache +from core.mcp.types import OAuthClientInformation, OAuthClientMetadata, OAuthTokens +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolProviderType +from core.tools.utils.encryption import create_provider_encrypter + +if TYPE_CHECKING: + from models.tools import MCPToolProvider + +# Constants +CLIENT_NAME = "Dify" +CLIENT_URI = "https://github.com/langgenius/dify" +DEFAULT_TOKEN_TYPE = "Bearer" +DEFAULT_EXPIRES_IN = 3600 +MASK_CHAR = "*" +MIN_UNMASK_LENGTH = 6 + + +class MCPSupportGrantType(StrEnum): + """The supported grant types for MCP""" + + AUTHORIZATION_CODE = "authorization_code" + CLIENT_CREDENTIALS = "client_credentials" + REFRESH_TOKEN = "refresh_token" + + +class MCPAuthentication(BaseModel): + client_id: str + client_secret: str | None = None + + +class MCPConfiguration(BaseModel): + timeout: float = 30 + sse_read_timeout: float = 300 + + +class MCPProviderEntity(BaseModel): + """MCP Provider domain entity for business logic operations""" + + # Basic identification + id: str + provider_id: str # server_identifier + name: str + tenant_id: str + user_id: str + + # Server connection info + server_url: str # encrypted URL + headers: dict[str, str] # encrypted headers + timeout: float + sse_read_timeout: float + + # Authentication related + authed: bool + credentials: dict[str, Any] # encrypted credentials + code_verifier: str | None = None # for OAuth + + # Tools and display info + tools: list[dict[str, Any]] # parsed tools list + icon: str | dict[str, str] # parsed icon + + # Timestamps + created_at: datetime + updated_at: datetime + + @classmethod + def from_db_model(cls, db_provider: "MCPToolProvider") -> "MCPProviderEntity": + """Create entity from database model with decryption""" + + return cls( + id=db_provider.id, + provider_id=db_provider.server_identifier, + name=db_provider.name, + tenant_id=db_provider.tenant_id, + user_id=db_provider.user_id, + server_url=db_provider.server_url, + headers=db_provider.headers, + timeout=db_provider.timeout, + sse_read_timeout=db_provider.sse_read_timeout, + authed=db_provider.authed, + credentials=db_provider.credentials, + tools=db_provider.tool_dict, + icon=db_provider.icon or "", + created_at=db_provider.created_at, + updated_at=db_provider.updated_at, + ) + + @property + def redirect_url(self) -> str: + """OAuth redirect URL""" + return dify_config.CONSOLE_API_URL + "/console/api/mcp/oauth/callback" + + @property + def client_metadata(self) -> OAuthClientMetadata: + """Metadata about this OAuth client.""" + # Get grant type from credentials + credentials = self.decrypt_credentials() + + # Try to get grant_type from different locations + grant_type = credentials.get("grant_type", MCPSupportGrantType.AUTHORIZATION_CODE) + + # For nested structure, check if client_information has grant_types + if "client_information" in credentials and isinstance(credentials["client_information"], dict): + client_info = credentials["client_information"] + # If grant_types is specified in client_information, use it to determine grant_type + if "grant_types" in client_info and isinstance(client_info["grant_types"], list): + if "client_credentials" in client_info["grant_types"]: + grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS + elif "authorization_code" in client_info["grant_types"]: + grant_type = MCPSupportGrantType.AUTHORIZATION_CODE + + # Configure based on grant type + is_client_credentials = grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS + + grant_types = ["refresh_token"] + grant_types.append("client_credentials" if is_client_credentials else "authorization_code") + + response_types = [] if is_client_credentials else ["code"] + redirect_uris = [] if is_client_credentials else [self.redirect_url] + + return OAuthClientMetadata( + redirect_uris=redirect_uris, + token_endpoint_auth_method="none", + grant_types=grant_types, + response_types=response_types, + client_name=CLIENT_NAME, + client_uri=CLIENT_URI, + ) + + @property + def provider_icon(self) -> dict[str, str] | str: + """Get provider icon, handling both dict and string formats""" + if isinstance(self.icon, dict): + return self.icon + try: + return json.loads(self.icon) + except (json.JSONDecodeError, TypeError): + # If not JSON, assume it's a file path + return file_helpers.get_signed_file_url(self.icon) + + def to_api_response(self, user_name: str | None = None, include_sensitive: bool = True) -> dict[str, Any]: + """Convert to API response format + + Args: + user_name: User name to display + include_sensitive: If False, skip expensive decryption operations (for list view optimization) + """ + response = { + "id": self.id, + "author": user_name or "Anonymous", + "name": self.name, + "icon": self.provider_icon, + "type": ToolProviderType.MCP.value, + "is_team_authorization": self.authed, + "server_url": self.masked_server_url(), + "server_identifier": self.provider_id, + "updated_at": int(self.updated_at.timestamp()), + "label": I18nObject(en_US=self.name, zh_Hans=self.name).to_dict(), + "description": I18nObject(en_US="", zh_Hans="").to_dict(), + } + + # Add configuration + response["configuration"] = { + "timeout": str(self.timeout), + "sse_read_timeout": str(self.sse_read_timeout), + } + + # Skip expensive operations when sensitive data is not needed (e.g., list view) + if not include_sensitive: + response["masked_headers"] = {} + response["is_dynamic_registration"] = True + else: + # Add masked headers + response["masked_headers"] = self.masked_headers() + + # Add authentication info if available + masked_creds = self.masked_credentials() + if masked_creds: + response["authentication"] = masked_creds + response["is_dynamic_registration"] = self.credentials.get("client_information", {}).get( + "is_dynamic_registration", True + ) + + return response + + def retrieve_client_information(self) -> OAuthClientInformation | None: + """OAuth client information if available""" + credentials = self.decrypt_credentials() + if not credentials: + return None + + # Check if we have nested client_information structure + if "client_information" not in credentials: + return None + client_info_data = credentials["client_information"] + if isinstance(client_info_data, dict): + if "encrypted_client_secret" in client_info_data: + client_info_data["client_secret"] = encrypter.decrypt_token( + self.tenant_id, client_info_data["encrypted_client_secret"] + ) + return OAuthClientInformation.model_validate(client_info_data) + return None + + def retrieve_tokens(self) -> OAuthTokens | None: + """OAuth tokens if available""" + if not self.credentials: + return None + credentials = self.decrypt_credentials() + return OAuthTokens( + access_token=credentials.get("access_token", ""), + token_type=credentials.get("token_type", DEFAULT_TOKEN_TYPE), + expires_in=int(credentials.get("expires_in", str(DEFAULT_EXPIRES_IN)) or DEFAULT_EXPIRES_IN), + refresh_token=credentials.get("refresh_token", ""), + ) + + def masked_server_url(self) -> str: + """Masked server URL for display""" + parsed = urlparse(self.decrypt_server_url()) + if parsed.path and parsed.path != "/": + masked = parsed._replace(path="/******") + return masked.geturl() + return parsed.geturl() + + def _mask_value(self, value: str) -> str: + """Mask a sensitive value for display""" + if len(value) > MIN_UNMASK_LENGTH: + return value[:2] + MASK_CHAR * (len(value) - 4) + value[-2:] + else: + return MASK_CHAR * len(value) + + def masked_headers(self) -> dict[str, str]: + """Masked headers for display""" + return {key: self._mask_value(value) for key, value in self.decrypt_headers().items()} + + def masked_credentials(self) -> dict[str, str]: + """Masked credentials for display""" + credentials = self.decrypt_credentials() + if not credentials: + return {} + + masked = {} + + if "client_information" not in credentials or not isinstance(credentials["client_information"], dict): + return {} + client_info = credentials["client_information"] + # Mask sensitive fields from nested structure + if client_info.get("client_id"): + masked["client_id"] = self._mask_value(client_info["client_id"]) + if client_info.get("encrypted_client_secret"): + masked["client_secret"] = self._mask_value( + encrypter.decrypt_token(self.tenant_id, client_info["encrypted_client_secret"]) + ) + if client_info.get("client_secret"): + masked["client_secret"] = self._mask_value(client_info["client_secret"]) + return masked + + def decrypt_server_url(self) -> str: + """Decrypt server URL""" + return encrypter.decrypt_token(self.tenant_id, self.server_url) + + def _decrypt_dict(self, data: dict[str, Any]) -> dict[str, Any]: + """Generic method to decrypt dictionary fields""" + if not data: + return {} + + # Only decrypt fields that are actually encrypted + # For nested structures, client_information is not encrypted as a whole + encrypted_fields = [] + for key, value in data.items(): + # Skip nested objects - they are not encrypted + if isinstance(value, dict): + continue + # Only process string values that might be encrypted + if isinstance(value, str) and value: + encrypted_fields.append(key) + + if not encrypted_fields: + return data + + # Create dynamic config only for encrypted fields + config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in encrypted_fields] + + encrypter_instance, _ = create_provider_encrypter( + tenant_id=self.tenant_id, + config=config, + cache=NoOpProviderCredentialCache(), + ) + + # Decrypt only the encrypted fields + decrypted_data = encrypter_instance.decrypt({k: data[k] for k in encrypted_fields}) + + # Merge decrypted data with original data (preserving non-encrypted fields) + result = data.copy() + result.update(decrypted_data) + + return result + + def decrypt_headers(self) -> dict[str, Any]: + """Decrypt headers""" + return self._decrypt_dict(self.headers) + + def decrypt_credentials(self) -> dict[str, Any]: + """Decrypt credentials""" + return self._decrypt_dict(self.credentials) + + def decrypt_authentication(self) -> dict[str, Any]: + """Decrypt authentication""" + # Option 1: if headers is provided, use it and don't need to get token + headers = self.decrypt_headers() + + # Option 2: Add OAuth token if authed and no headers provided + if not self.headers and self.authed: + token = self.retrieve_tokens() + if token: + headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}" + return headers diff --git a/api/core/mcp/auth/auth_flow.py b/api/core/mcp/auth/auth_flow.py index 7d938a8a7d..951c22f6dd 100644 --- a/api/core/mcp/auth/auth_flow.py +++ b/api/core/mcp/auth/auth_flow.py @@ -6,11 +6,15 @@ import secrets import urllib.parse from urllib.parse import urljoin, urlparse -import httpx -from pydantic import BaseModel, ValidationError +from httpx import ConnectError, HTTPStatusError, RequestError +from pydantic import ValidationError -from core.mcp.auth.auth_provider import OAuthClientProvider +from core.entities.mcp_provider import MCPProviderEntity, MCPSupportGrantType +from core.helper import ssrf_proxy +from core.mcp.entities import AuthAction, AuthActionType, AuthResult, OAuthCallbackState +from core.mcp.error import MCPRefreshTokenError from core.mcp.types import ( + LATEST_PROTOCOL_VERSION, OAuthClientInformation, OAuthClientInformationFull, OAuthClientMetadata, @@ -19,21 +23,10 @@ from core.mcp.types import ( ) from extensions.ext_redis import redis_client -LATEST_PROTOCOL_VERSION = "1.0" OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:" -class OAuthCallbackState(BaseModel): - provider_id: str - tenant_id: str - server_url: str - metadata: OAuthMetadata | None = None - client_information: OAuthClientInformation - code_verifier: str - redirect_uri: str - - def generate_pkce_challenge() -> tuple[str, str]: """Generate PKCE challenge and verifier.""" code_verifier = base64.urlsafe_b64encode(os.urandom(40)).decode("utf-8") @@ -80,8 +73,13 @@ def _retrieve_redis_state(state_key: str) -> OAuthCallbackState: raise ValueError(f"Invalid state parameter: {str(e)}") -def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackState: - """Handle the callback from the OAuth provider.""" +def handle_callback(state_key: str, authorization_code: str) -> tuple[OAuthCallbackState, OAuthTokens]: + """ + Handle the callback from the OAuth provider. + + Returns: + A tuple of (callback_state, tokens) that can be used by the caller to save data. + """ # Retrieve state data from Redis (state is automatically deleted after retrieval) full_state_data = _retrieve_redis_state(state_key) @@ -93,30 +91,32 @@ def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackSta full_state_data.code_verifier, full_state_data.redirect_uri, ) - provider = OAuthClientProvider(full_state_data.provider_id, full_state_data.tenant_id, for_list=True) - provider.save_tokens(tokens) - return full_state_data + + return full_state_data, tokens def check_support_resource_discovery(server_url: str) -> tuple[bool, str]: """Check if the server supports OAuth 2.0 Resource Discovery.""" - b_scheme, b_netloc, b_path, _, b_query, b_fragment = urlparse(server_url, "", True) - url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource{b_path}" + b_scheme, b_netloc, _, _, b_query, b_fragment = urlparse(server_url, "", True) + url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource" if b_query: url_for_resource_discovery += f"?{b_query}" if b_fragment: url_for_resource_discovery += f"#{b_fragment}" try: headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"} - response = httpx.get(url_for_resource_discovery, headers=headers) + response = ssrf_proxy.get(url_for_resource_discovery, headers=headers) if 200 <= response.status_code < 300: body = response.json() - if "authorization_server_url" in body: + # Support both singular and plural forms + if body.get("authorization_servers"): + return True, body["authorization_servers"][0] + elif body.get("authorization_server_url"): return True, body["authorization_server_url"][0] else: return False, "" return False, "" - except httpx.RequestError: + except RequestError: # Not support resource discovery, fall back to well-known OAuth metadata return False, "" @@ -126,27 +126,37 @@ def discover_oauth_metadata(server_url: str, protocol_version: str | None = None # First check if the server supports OAuth 2.0 Resource Discovery support_resource_discovery, oauth_discovery_url = check_support_resource_discovery(server_url) if support_resource_discovery: - url = oauth_discovery_url + # The oauth_discovery_url is the authorization server base URL + # Try OpenID Connect discovery first (more common), then OAuth 2.0 + urls_to_try = [ + urljoin(oauth_discovery_url + "/", ".well-known/oauth-authorization-server"), + urljoin(oauth_discovery_url + "/", ".well-known/openid-configuration"), + ] else: - url = urljoin(server_url, "/.well-known/oauth-authorization-server") + urls_to_try = [urljoin(server_url, "/.well-known/oauth-authorization-server")] - try: - headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION} - response = httpx.get(url, headers=headers) - if response.status_code == 404: - return None - if not response.is_success: - raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata") - return OAuthMetadata.model_validate(response.json()) - except httpx.RequestError as e: - if isinstance(e, httpx.ConnectError): - response = httpx.get(url) + headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION} + + for url in urls_to_try: + try: + response = ssrf_proxy.get(url, headers=headers) if response.status_code == 404: - return None + continue if not response.is_success: - raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata") + response.raise_for_status() return OAuthMetadata.model_validate(response.json()) - raise + except (RequestError, HTTPStatusError) as e: + if isinstance(e, ConnectError): + response = ssrf_proxy.get(url) + if response.status_code == 404: + continue # Try next URL + if not response.is_success: + raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata") + return OAuthMetadata.model_validate(response.json()) + # For other errors, try next URL + continue + + return None # No metadata found def start_authorization( @@ -213,7 +223,7 @@ def exchange_authorization( redirect_uri: str, ) -> OAuthTokens: """Exchanges an authorization code for an access token.""" - grant_type = "authorization_code" + grant_type = MCPSupportGrantType.AUTHORIZATION_CODE.value if metadata: token_url = metadata.token_endpoint @@ -233,7 +243,7 @@ def exchange_authorization( if client_information.client_secret: params["client_secret"] = client_information.client_secret - response = httpx.post(token_url, data=params) + response = ssrf_proxy.post(token_url, data=params) if not response.is_success: raise ValueError(f"Token exchange failed: HTTP {response.status_code}") return OAuthTokens.model_validate(response.json()) @@ -246,7 +256,7 @@ def refresh_authorization( refresh_token: str, ) -> OAuthTokens: """Exchange a refresh token for an updated access token.""" - grant_type = "refresh_token" + grant_type = MCPSupportGrantType.REFRESH_TOKEN.value if metadata: token_url = metadata.token_endpoint @@ -263,10 +273,55 @@ def refresh_authorization( if client_information.client_secret: params["client_secret"] = client_information.client_secret - - response = httpx.post(token_url, data=params) + try: + response = ssrf_proxy.post(token_url, data=params) + except ssrf_proxy.MaxRetriesExceededError as e: + raise MCPRefreshTokenError(e) from e if not response.is_success: - raise ValueError(f"Token refresh failed: HTTP {response.status_code}") + raise MCPRefreshTokenError(response.text) + return OAuthTokens.model_validate(response.json()) + + +def client_credentials_flow( + server_url: str, + metadata: OAuthMetadata | None, + client_information: OAuthClientInformation, + scope: str | None = None, +) -> OAuthTokens: + """Execute Client Credentials Flow to get access token.""" + grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS.value + + if metadata: + token_url = metadata.token_endpoint + if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported: + raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}") + else: + token_url = urljoin(server_url, "/token") + + # Support both Basic Auth and body parameters for client authentication + headers = {"Content-Type": "application/x-www-form-urlencoded"} + data = {"grant_type": grant_type} + + if scope: + data["scope"] = scope + + # If client_secret is provided, use Basic Auth (preferred method) + if client_information.client_secret: + credentials = f"{client_information.client_id}:{client_information.client_secret}" + encoded_credentials = base64.b64encode(credentials.encode()).decode() + headers["Authorization"] = f"Basic {encoded_credentials}" + else: + # Fall back to including credentials in the body + data["client_id"] = client_information.client_id + if client_information.client_secret: + data["client_secret"] = client_information.client_secret + + response = ssrf_proxy.post(token_url, headers=headers, data=data) + if not response.is_success: + raise ValueError( + f"Client credentials token request failed: HTTP {response.status_code}, Response: {response.text}" + ) + return OAuthTokens.model_validate(response.json()) @@ -283,7 +338,7 @@ def register_client( else: registration_url = urljoin(server_url, "/register") - response = httpx.post( + response = ssrf_proxy.post( registration_url, json=client_metadata.model_dump(), headers={"Content-Type": "application/json"}, @@ -294,28 +349,111 @@ def register_client( def auth( - provider: OAuthClientProvider, - server_url: str, + provider: MCPProviderEntity, authorization_code: str | None = None, state_param: str | None = None, - for_list: bool = False, -) -> dict[str, str]: - """Orchestrates the full auth flow with a server using secure Redis state storage.""" - metadata = discover_oauth_metadata(server_url) +) -> AuthResult: + """ + Orchestrates the full auth flow with a server using secure Redis state storage. + + This function performs only network operations and returns actions that need + to be performed by the caller (such as saving data to database). + + Args: + provider: The MCP provider entity + authorization_code: Optional authorization code from OAuth callback + state_param: Optional state parameter from OAuth callback + + Returns: + AuthResult containing actions to be performed and response data + """ + actions: list[AuthAction] = [] + server_url = provider.decrypt_server_url() + server_metadata = discover_oauth_metadata(server_url) + client_metadata = provider.client_metadata + provider_id = provider.id + tenant_id = provider.tenant_id + client_information = provider.retrieve_client_information() + redirect_url = provider.redirect_url + + # Determine grant type based on server metadata + if not server_metadata: + raise ValueError("Failed to discover OAuth metadata from server") + + supported_grant_types = server_metadata.grant_types_supported or [] + + # Convert to lowercase for comparison + supported_grant_types_lower = [gt.lower() for gt in supported_grant_types] + + # Determine which grant type to use + effective_grant_type = None + if MCPSupportGrantType.AUTHORIZATION_CODE.value in supported_grant_types_lower: + effective_grant_type = MCPSupportGrantType.AUTHORIZATION_CODE.value + else: + effective_grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS.value + + # Get stored credentials + credentials = provider.decrypt_credentials() - # Handle client registration if needed - client_information = provider.client_information() if not client_information: if authorization_code is not None: raise ValueError("Existing OAuth client information is required when exchanging an authorization code") + + # For client credentials flow, we don't need to register client dynamically + if effective_grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS.value: + # Client should provide client_id and client_secret directly + raise ValueError("Client credentials flow requires client_id and client_secret to be provided") + try: - full_information = register_client(server_url, metadata, provider.client_metadata) - except httpx.RequestError as e: + full_information = register_client(server_url, server_metadata, client_metadata) + except RequestError as e: raise ValueError(f"Could not register OAuth client: {e}") - provider.save_client_information(full_information) + + # Return action to save client information + actions.append( + AuthAction( + action_type=AuthActionType.SAVE_CLIENT_INFO, + data={"client_information": full_information.model_dump()}, + provider_id=provider_id, + tenant_id=tenant_id, + ) + ) + client_information = full_information - # Exchange authorization code for tokens + # Handle client credentials flow + if effective_grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS.value: + # Direct token request without user interaction + try: + scope = credentials.get("scope") + tokens = client_credentials_flow( + server_url, + server_metadata, + client_information, + scope, + ) + + # Return action to save tokens and grant type + token_data = tokens.model_dump() + token_data["grant_type"] = MCPSupportGrantType.CLIENT_CREDENTIALS.value + + actions.append( + AuthAction( + action_type=AuthActionType.SAVE_TOKENS, + data=token_data, + provider_id=provider_id, + tenant_id=tenant_id, + ) + ) + + return AuthResult(actions=actions, response={"result": "success"}) + except (RequestError, ValueError, KeyError) as e: + # RequestError: HTTP request failed + # ValueError: Invalid response data + # KeyError: Missing required fields in response + raise ValueError(f"Client credentials flow failed: {e}") + + # Exchange authorization code for tokens (Authorization Code flow) if authorization_code is not None: if not state_param: raise ValueError("State parameter is required when exchanging authorization code") @@ -335,35 +473,69 @@ def auth( tokens = exchange_authorization( server_url, - metadata, + server_metadata, client_information, authorization_code, code_verifier, redirect_uri, ) - provider.save_tokens(tokens) - return {"result": "success"} - provider_tokens = provider.tokens() + # Return action to save tokens + actions.append( + AuthAction( + action_type=AuthActionType.SAVE_TOKENS, + data=tokens.model_dump(), + provider_id=provider_id, + tenant_id=tenant_id, + ) + ) + + return AuthResult(actions=actions, response={"result": "success"}) + + provider_tokens = provider.retrieve_tokens() # Handle token refresh or new authorization if provider_tokens and provider_tokens.refresh_token: try: - new_tokens = refresh_authorization(server_url, metadata, client_information, provider_tokens.refresh_token) - provider.save_tokens(new_tokens) - return {"result": "success"} - except Exception as e: + new_tokens = refresh_authorization( + server_url, server_metadata, client_information, provider_tokens.refresh_token + ) + + # Return action to save new tokens + actions.append( + AuthAction( + action_type=AuthActionType.SAVE_TOKENS, + data=new_tokens.model_dump(), + provider_id=provider_id, + tenant_id=tenant_id, + ) + ) + + return AuthResult(actions=actions, response={"result": "success"}) + except (RequestError, ValueError, KeyError) as e: + # RequestError: HTTP request failed + # ValueError: Invalid response data + # KeyError: Missing required fields in response raise ValueError(f"Could not refresh OAuth tokens: {e}") - # Start new authorization flow + # Start new authorization flow (only for authorization code flow) authorization_url, code_verifier = start_authorization( server_url, - metadata, + server_metadata, client_information, - provider.redirect_url, - provider.mcp_provider.id, - provider.mcp_provider.tenant_id, + redirect_url, + provider_id, + tenant_id, ) - provider.save_code_verifier(code_verifier) - return {"authorization_url": authorization_url} + # Return action to save code verifier + actions.append( + AuthAction( + action_type=AuthActionType.SAVE_CODE_VERIFIER, + data={"code_verifier": code_verifier}, + provider_id=provider_id, + tenant_id=tenant_id, + ) + ) + + return AuthResult(actions=actions, response={"authorization_url": authorization_url}) diff --git a/api/core/mcp/auth/auth_provider.py b/api/core/mcp/auth/auth_provider.py deleted file mode 100644 index 3a550eb1b6..0000000000 --- a/api/core/mcp/auth/auth_provider.py +++ /dev/null @@ -1,77 +0,0 @@ -from configs import dify_config -from core.mcp.types import ( - OAuthClientInformation, - OAuthClientInformationFull, - OAuthClientMetadata, - OAuthTokens, -) -from models.tools import MCPToolProvider -from services.tools.mcp_tools_manage_service import MCPToolManageService - - -class OAuthClientProvider: - mcp_provider: MCPToolProvider - - def __init__(self, provider_id: str, tenant_id: str, for_list: bool = False): - if for_list: - self.mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id) - else: - self.mcp_provider = MCPToolManageService.get_mcp_provider_by_server_identifier(provider_id, tenant_id) - - @property - def redirect_url(self) -> str: - """The URL to redirect the user agent to after authorization.""" - return dify_config.CONSOLE_API_URL + "/console/api/mcp/oauth/callback" - - @property - def client_metadata(self) -> OAuthClientMetadata: - """Metadata about this OAuth client.""" - return OAuthClientMetadata( - redirect_uris=[self.redirect_url], - token_endpoint_auth_method="none", - grant_types=["authorization_code", "refresh_token"], - response_types=["code"], - client_name="Dify", - client_uri="https://github.com/langgenius/dify", - ) - - def client_information(self) -> OAuthClientInformation | None: - """Loads information about this OAuth client.""" - client_information = self.mcp_provider.decrypted_credentials.get("client_information", {}) - if not client_information: - return None - return OAuthClientInformation.model_validate(client_information) - - def save_client_information(self, client_information: OAuthClientInformationFull): - """Saves client information after dynamic registration.""" - MCPToolManageService.update_mcp_provider_credentials( - self.mcp_provider, - {"client_information": client_information.model_dump()}, - ) - - def tokens(self) -> OAuthTokens | None: - """Loads any existing OAuth tokens for the current session.""" - credentials = self.mcp_provider.decrypted_credentials - if not credentials: - return None - return OAuthTokens( - access_token=credentials.get("access_token", ""), - token_type=credentials.get("token_type", "Bearer"), - expires_in=int(credentials.get("expires_in", "3600") or 3600), - refresh_token=credentials.get("refresh_token", ""), - ) - - def save_tokens(self, tokens: OAuthTokens): - """Stores new OAuth tokens for the current session.""" - # update mcp provider credentials - token_dict = tokens.model_dump() - MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, token_dict, authed=True) - - def save_code_verifier(self, code_verifier: str): - """Saves a PKCE code verifier for the current session.""" - MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, {"code_verifier": code_verifier}) - - def code_verifier(self) -> str: - """Loads the PKCE code verifier for the current session.""" - # get code verifier from mcp provider credentials - return str(self.mcp_provider.decrypted_credentials.get("code_verifier", "")) diff --git a/api/core/mcp/auth_client.py b/api/core/mcp/auth_client.py new file mode 100644 index 0000000000..942c8d3c23 --- /dev/null +++ b/api/core/mcp/auth_client.py @@ -0,0 +1,191 @@ +""" +MCP Client with Authentication Retry Support + +This module provides an enhanced MCPClient that automatically handles +authentication failures and retries operations after refreshing tokens. +""" + +import logging +from collections.abc import Callable +from typing import Any + +from sqlalchemy.orm import Session + +from core.entities.mcp_provider import MCPProviderEntity +from core.mcp.error import MCPAuthError +from core.mcp.mcp_client import MCPClient +from core.mcp.types import CallToolResult, Tool +from extensions.ext_database import db + +logger = logging.getLogger(__name__) + + +class MCPClientWithAuthRetry(MCPClient): + """ + An enhanced MCPClient that provides automatic authentication retry. + + This class extends MCPClient and intercepts MCPAuthError exceptions + to refresh authentication before retrying failed operations. + + Note: This class uses lazy session creation - database sessions are only + created when authentication retry is actually needed, not on every request. + """ + + def __init__( + self, + server_url: str, + headers: dict[str, str] | None = None, + timeout: float | None = None, + sse_read_timeout: float | None = None, + provider_entity: MCPProviderEntity | None = None, + authorization_code: str | None = None, + by_server_id: bool = False, + ): + """ + Initialize the MCP client with auth retry capability. + + Args: + server_url: The MCP server URL + headers: Optional headers for requests + timeout: Request timeout + sse_read_timeout: SSE read timeout + provider_entity: Provider entity for authentication + authorization_code: Optional authorization code for initial auth + by_server_id: Whether to look up provider by server ID + """ + super().__init__(server_url, headers, timeout, sse_read_timeout) + + self.provider_entity = provider_entity + self.authorization_code = authorization_code + self.by_server_id = by_server_id + self._has_retried = False + + def _handle_auth_error(self, error: MCPAuthError) -> None: + """ + Handle authentication error by refreshing tokens. + + This method creates a short-lived database session only when authentication + retry is needed, minimizing database connection hold time. + + Args: + error: The authentication error + + Raises: + MCPAuthError: If authentication fails or max retries reached + """ + if not self.provider_entity: + raise error + if self._has_retried: + raise error + + self._has_retried = True + + try: + # Create a temporary session only for auth retry + # This session is short-lived and only exists during the auth operation + + from services.tools.mcp_tools_manage_service import MCPToolManageService + + with Session(db.engine) as session, session.begin(): + mcp_service = MCPToolManageService(session=session) + + # Perform authentication using the service's auth method + mcp_service.auth_with_actions(self.provider_entity, self.authorization_code) + + # Retrieve new tokens + self.provider_entity = mcp_service.get_provider_entity( + self.provider_entity.id, self.provider_entity.tenant_id, by_server_id=self.by_server_id + ) + + # Session is closed here, before we update headers + token = self.provider_entity.retrieve_tokens() + if not token: + raise MCPAuthError("Authentication failed - no token received") + + # Update headers with new token + self.headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}" + + # Clear authorization code after first use + self.authorization_code = None + + except MCPAuthError: + # Re-raise MCPAuthError as is + raise + except Exception as e: + # Catch all exceptions during auth retry + logger.exception("Authentication retry failed") + raise MCPAuthError(f"Authentication retry failed: {e}") from e + + def _execute_with_retry(self, func: Callable[..., Any], *args, **kwargs) -> Any: + """ + Execute a function with authentication retry logic. + + Args: + func: The function to execute + *args: Positional arguments for the function + **kwargs: Keyword arguments for the function + + Returns: + The result of the function call + + Raises: + MCPAuthError: If authentication fails after retries + Any other exceptions from the function + """ + try: + return func(*args, **kwargs) + except MCPAuthError as e: + self._handle_auth_error(e) + + # Re-initialize the connection with new headers + if self._initialized: + # Clean up existing connection + self._exit_stack.close() + self._session = None + self._initialized = False + + # Re-initialize with new headers + self._initialize() + self._initialized = True + + return func(*args, **kwargs) + finally: + # Reset retry flag after operation completes + self._has_retried = False + + def __enter__(self): + """Enter the context manager with retry support.""" + + def initialize_with_retry(): + super(MCPClientWithAuthRetry, self).__enter__() + return self + + return self._execute_with_retry(initialize_with_retry) + + def list_tools(self) -> list[Tool]: + """ + List available tools from the MCP server with auth retry. + + Returns: + List of available tools + + Raises: + MCPAuthError: If authentication fails after retries + """ + return self._execute_with_retry(super().list_tools) + + def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult: + """ + Invoke a tool on the MCP server with auth retry. + + Args: + tool_name: Name of the tool to invoke + tool_args: Arguments for the tool + + Returns: + Result of the tool invocation + + Raises: + MCPAuthError: If authentication fails after retries + """ + return self._execute_with_retry(super().invoke_tool, tool_name, tool_args) diff --git a/api/core/mcp/auth_client_comparison.md b/api/core/mcp/auth_client_comparison.md new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/mcp/client/sse_client.py b/api/core/mcp/client/sse_client.py index 6db22a09e0..2d5e3dd263 100644 --- a/api/core/mcp/client/sse_client.py +++ b/api/core/mcp/client/sse_client.py @@ -46,7 +46,7 @@ class SSETransport: url: str, headers: dict[str, Any] | None = None, timeout: float = 5.0, - sse_read_timeout: float = 5 * 60, + sse_read_timeout: float = 1 * 60, ): """Initialize the SSE transport. @@ -255,7 +255,7 @@ def sse_client( url: str, headers: dict[str, Any] | None = None, timeout: float = 5.0, - sse_read_timeout: float = 5 * 60, + sse_read_timeout: float = 1 * 60, ) -> Generator[tuple[ReadQueue, WriteQueue], None, None]: """ Client transport for SSE. @@ -276,31 +276,34 @@ def sse_client( read_queue: ReadQueue | None = None write_queue: WriteQueue | None = None - with ThreadPoolExecutor() as executor: - try: - with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client: - with ssrf_proxy_sse_connect( - url, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client - ) as event_source: - event_source.response.raise_for_status() + executor = ThreadPoolExecutor() + try: + with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client: + with ssrf_proxy_sse_connect( + url, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client + ) as event_source: + event_source.response.raise_for_status() - read_queue, write_queue = transport.connect(executor, client, event_source) + read_queue, write_queue = transport.connect(executor, client, event_source) - yield read_queue, write_queue + yield read_queue, write_queue - except httpx.HTTPStatusError as exc: - if exc.response.status_code == 401: - raise MCPAuthError() - raise MCPConnectionError() - except Exception: - logger.exception("Error connecting to SSE endpoint") - raise - finally: - # Clean up queues - if read_queue: - read_queue.put(None) - if write_queue: - write_queue.put(None) + except httpx.HTTPStatusError as exc: + if exc.response.status_code == 401: + raise MCPAuthError() + raise MCPConnectionError() + except Exception: + logger.exception("Error connecting to SSE endpoint") + raise + finally: + # Clean up queues + if read_queue: + read_queue.put(None) + if write_queue: + write_queue.put(None) + + # Shutdown executor without waiting to prevent hanging + executor.shutdown(wait=False) def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage): diff --git a/api/core/mcp/client/streamable_client.py b/api/core/mcp/client/streamable_client.py index 7eafa79837..a9c1eda624 100644 --- a/api/core/mcp/client/streamable_client.py +++ b/api/core/mcp/client/streamable_client.py @@ -434,45 +434,48 @@ def streamablehttp_client( server_to_client_queue: ServerToClientQueue = queue.Queue() # For messages FROM server TO client client_to_server_queue: ClientToServerQueue = queue.Queue() # For messages FROM client TO server - with ThreadPoolExecutor(max_workers=2) as executor: - try: - with create_ssrf_proxy_mcp_http_client( - headers=transport.request_headers, - timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout), - ) as client: - # Define callbacks that need access to thread pool - def start_get_stream(): - """Start a worker thread to handle server-initiated messages.""" - executor.submit(transport.handle_get_stream, client, server_to_client_queue) + executor = ThreadPoolExecutor(max_workers=2) + try: + with create_ssrf_proxy_mcp_http_client( + headers=transport.request_headers, + timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout), + ) as client: + # Define callbacks that need access to thread pool + def start_get_stream(): + """Start a worker thread to handle server-initiated messages.""" + executor.submit(transport.handle_get_stream, client, server_to_client_queue) - # Start the post_writer worker thread - executor.submit( - transport.post_writer, - client, - client_to_server_queue, # Queue for messages FROM client TO server - server_to_client_queue, # Queue for messages FROM server TO client - start_get_stream, - ) + # Start the post_writer worker thread + executor.submit( + transport.post_writer, + client, + client_to_server_queue, # Queue for messages FROM client TO server + server_to_client_queue, # Queue for messages FROM server TO client + start_get_stream, + ) - try: - yield ( - server_to_client_queue, # Queue for receiving messages FROM server - client_to_server_queue, # Queue for sending messages TO server - transport.get_session_id, - ) - finally: - if transport.session_id and terminate_on_close: - transport.terminate_session(client) - - # Signal threads to stop - client_to_server_queue.put(None) - finally: - # Clear any remaining items and add None sentinel to unblock any waiting threads try: - while not client_to_server_queue.empty(): - client_to_server_queue.get_nowait() - except queue.Empty: - pass + yield ( + server_to_client_queue, # Queue for receiving messages FROM server + client_to_server_queue, # Queue for sending messages TO server + transport.get_session_id, + ) + finally: + if transport.session_id and terminate_on_close: + transport.terminate_session(client) - client_to_server_queue.put(None) - server_to_client_queue.put(None) + # Signal threads to stop + client_to_server_queue.put(None) + finally: + # Clear any remaining items and add None sentinel to unblock any waiting threads + try: + while not client_to_server_queue.empty(): + client_to_server_queue.get_nowait() + except queue.Empty: + pass + + client_to_server_queue.put(None) + server_to_client_queue.put(None) + + # Shutdown executor without waiting to prevent hanging + executor.shutdown(wait=False) diff --git a/api/core/mcp/entities.py b/api/core/mcp/entities.py index 7553c10a2e..08823daab1 100644 --- a/api/core/mcp/entities.py +++ b/api/core/mcp/entities.py @@ -1,10 +1,13 @@ from dataclasses import dataclass +from enum import StrEnum from typing import Any, Generic, TypeVar -from core.mcp.session.base_session import BaseSession -from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestId, RequestParams +from pydantic import BaseModel -SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", LATEST_PROTOCOL_VERSION] +from core.mcp.session.base_session import BaseSession +from core.mcp.types import LATEST_PROTOCOL_VERSION, OAuthClientInformation, OAuthMetadata, RequestId, RequestParams + +SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION] SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any]) @@ -17,3 +20,41 @@ class RequestContext(Generic[SessionT, LifespanContextT]): meta: RequestParams.Meta | None session: SessionT lifespan_context: LifespanContextT + + +class AuthActionType(StrEnum): + """Types of actions that can be performed during auth flow.""" + + SAVE_CLIENT_INFO = "save_client_info" + SAVE_TOKENS = "save_tokens" + SAVE_CODE_VERIFIER = "save_code_verifier" + START_AUTHORIZATION = "start_authorization" + SUCCESS = "success" + + +class AuthAction(BaseModel): + """Represents an action that needs to be performed as a result of auth flow.""" + + action_type: AuthActionType + data: dict[str, Any] + provider_id: str | None = None + tenant_id: str | None = None + + +class AuthResult(BaseModel): + """Result of auth function containing actions to be performed and response data.""" + + actions: list[AuthAction] + response: dict[str, str] + + +class OAuthCallbackState(BaseModel): + """State data stored in Redis during OAuth callback flow.""" + + provider_id: str + tenant_id: str + server_url: str + metadata: OAuthMetadata | None = None + client_information: OAuthClientInformation + code_verifier: str + redirect_uri: str diff --git a/api/core/mcp/error.py b/api/core/mcp/error.py index 92ea7bde09..d4fb8b7674 100644 --- a/api/core/mcp/error.py +++ b/api/core/mcp/error.py @@ -8,3 +8,7 @@ class MCPConnectionError(MCPError): class MCPAuthError(MCPConnectionError): pass + + +class MCPRefreshTokenError(MCPError): + pass diff --git a/api/core/mcp/mcp_client.py b/api/core/mcp/mcp_client.py index 86ec2c4db9..b0e0dab9be 100644 --- a/api/core/mcp/mcp_client.py +++ b/api/core/mcp/mcp_client.py @@ -7,9 +7,9 @@ from urllib.parse import urlparse from core.mcp.client.sse_client import sse_client from core.mcp.client.streamable_client import streamablehttp_client -from core.mcp.error import MCPAuthError, MCPConnectionError +from core.mcp.error import MCPConnectionError from core.mcp.session.client_session import ClientSession -from core.mcp.types import Tool +from core.mcp.types import CallToolResult, Tool logger = logging.getLogger(__name__) @@ -18,40 +18,18 @@ class MCPClient: def __init__( self, server_url: str, - provider_id: str, - tenant_id: str, - authed: bool = True, - authorization_code: str | None = None, - for_list: bool = False, headers: dict[str, str] | None = None, timeout: float | None = None, sse_read_timeout: float | None = None, ): - # Initialize info - self.provider_id = provider_id - self.tenant_id = tenant_id - self.client_type = "streamable" self.server_url = server_url self.headers = headers or {} self.timeout = timeout self.sse_read_timeout = sse_read_timeout - # Authentication info - self.authed = authed - self.authorization_code = authorization_code - if authed: - from core.mcp.auth.auth_provider import OAuthClientProvider - - self.provider = OAuthClientProvider(self.provider_id, self.tenant_id, for_list=for_list) - self.token = self.provider.tokens() - # Initialize session and client objects self._session: ClientSession | None = None - self._streams_context: AbstractContextManager[Any] | None = None - self._session_context: ClientSession | None = None self._exit_stack = ExitStack() - - # Whether the client has been initialized self._initialized = False def __enter__(self): @@ -85,61 +63,42 @@ class MCPClient: logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.") self.connect_server(streamablehttp_client, "mcp") - def connect_server( - self, client_factory: Callable[..., AbstractContextManager[Any]], method_name: str, first_try: bool = True - ): - from core.mcp.auth.auth_flow import auth + def connect_server(self, client_factory: Callable[..., AbstractContextManager[Any]], method_name: str) -> None: + """ + Connect to the MCP server using streamable http or sse. + Default to streamable http. + Args: + client_factory: The client factory to use(streamablehttp_client or sse_client). + method_name: The method name to use(mcp or sse). + """ + streams_context = client_factory( + url=self.server_url, + headers=self.headers, + timeout=self.timeout, + sse_read_timeout=self.sse_read_timeout, + ) - try: - headers = ( - {"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"} - if self.authed and self.token - else self.headers - ) - self._streams_context = client_factory( - url=self.server_url, - headers=headers, - timeout=self.timeout, - sse_read_timeout=self.sse_read_timeout, - ) - if not self._streams_context: - raise MCPConnectionError("Failed to create connection context") + # Use exit_stack to manage context managers properly + if method_name == "mcp": + read_stream, write_stream, _ = self._exit_stack.enter_context(streams_context) + streams = (read_stream, write_stream) + else: # sse_client + streams = self._exit_stack.enter_context(streams_context) - # Use exit_stack to manage context managers properly - if method_name == "mcp": - read_stream, write_stream, _ = self._exit_stack.enter_context(self._streams_context) - streams = (read_stream, write_stream) - else: # sse_client - streams = self._exit_stack.enter_context(self._streams_context) - - self._session_context = ClientSession(*streams) - self._session = self._exit_stack.enter_context(self._session_context) - self._session.initialize() - return - - except MCPAuthError: - if not self.authed: - raise - try: - auth(self.provider, self.server_url, self.authorization_code) - except Exception as e: - raise ValueError(f"Failed to authenticate: {e}") - self.token = self.provider.tokens() - if first_try: - return self.connect_server(client_factory, method_name, first_try=False) + session_context = ClientSession(*streams) + self._session = self._exit_stack.enter_context(session_context) + self._session.initialize() def list_tools(self) -> list[Tool]: - """Connect to an MCP server running with SSE transport""" - # List available tools to verify connection - if not self._initialized or not self._session: + """List available tools from the MCP server""" + if not self._session: raise ValueError("Session not initialized.") response = self._session.list_tools() - tools = response.tools - return tools + return response.tools - def invoke_tool(self, tool_name: str, tool_args: dict): + def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult: """Call a tool""" - if not self._initialized or not self._session: + if not self._session: raise ValueError("Session not initialized.") return self._session.call_tool(tool_name, tool_args) @@ -153,6 +112,4 @@ class MCPClient: raise ValueError(f"Error during cleanup: {e}") finally: self._session = None - self._session_context = None - self._streams_context = None self._initialized = False diff --git a/api/core/mcp/session/base_session.py b/api/core/mcp/session/base_session.py index 653b3773c0..3dcd166ea2 100644 --- a/api/core/mcp/session/base_session.py +++ b/api/core/mcp/session/base_session.py @@ -201,11 +201,14 @@ class BaseSession( self._receiver_future.result(timeout=5.0) # Wait up to 5 seconds except TimeoutError: # If the receiver loop is still running after timeout, we'll force shutdown - pass + # Cancel the future to interrupt the receiver loop + self._receiver_future.cancel() # Shutdown the executor if self._executor: - self._executor.shutdown(wait=True) + # Use non-blocking shutdown to prevent hanging + # The receiver thread should have already exited due to the None message in the queue + self._executor.shutdown(wait=False) def send_request( self, diff --git a/api/core/mcp/session/client_session.py b/api/core/mcp/session/client_session.py index fa1d309134..77d35cca19 100644 --- a/api/core/mcp/session/client_session.py +++ b/api/core/mcp/session/client_session.py @@ -284,7 +284,7 @@ class ClientSession( def complete( self, - ref: types.ResourceReference | types.PromptReference, + ref: types.ResourceTemplateReference | types.PromptReference, argument: dict[str, str], ) -> types.CompleteResult: """Send a completion/complete request.""" diff --git a/api/core/mcp/types.py b/api/core/mcp/types.py index c7a046b585..fd2062d2e1 100644 --- a/api/core/mcp/types.py +++ b/api/core/mcp/types.py @@ -1,13 +1,6 @@ from collections.abc import Callable from dataclasses import dataclass -from typing import ( - Annotated, - Any, - Generic, - Literal, - TypeAlias, - TypeVar, -) +from typing import Annotated, Any, Generic, Literal, TypeAlias, TypeVar from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel from pydantic.networks import AnyUrl, UrlConstraints @@ -33,6 +26,7 @@ for reference. LATEST_PROTOCOL_VERSION = "2025-03-26" # Server support 2024-11-05 to allow claude to use. SERVER_LATEST_PROTOCOL_VERSION = "2024-11-05" +DEFAULT_NEGOTIATED_VERSION = "2025-03-26" ProgressToken = str | int Cursor = str Role = Literal["user", "assistant"] @@ -55,14 +49,22 @@ class RequestParams(BaseModel): meta: Meta | None = Field(alias="_meta", default=None) +class PaginatedRequestParams(RequestParams): + cursor: Cursor | None = None + """ + An opaque token representing the current pagination position. + If provided, the server should return results starting after this cursor. + """ + + class NotificationParams(BaseModel): class Meta(BaseModel): model_config = ConfigDict(extra="allow") meta: Meta | None = Field(alias="_meta", default=None) """ - This parameter name is reserved by MCP to allow clients and servers to attach - additional metadata to their notifications. + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. """ @@ -79,12 +81,11 @@ class Request(BaseModel, Generic[RequestParamsT, MethodT]): model_config = ConfigDict(extra="allow") -class PaginatedRequest(Request[RequestParamsT, MethodT]): - cursor: Cursor | None = None - """ - An opaque token representing the current pagination position. - If provided, the server should return results starting after this cursor. - """ +class PaginatedRequest(Request[PaginatedRequestParams | None, MethodT], Generic[MethodT]): + """Base class for paginated requests, + matching the schema's PaginatedRequest interface.""" + + params: PaginatedRequestParams | None = None class Notification(BaseModel, Generic[NotificationParamsT, MethodT]): @@ -98,13 +99,12 @@ class Notification(BaseModel, Generic[NotificationParamsT, MethodT]): class Result(BaseModel): """Base class for JSON-RPC results.""" - model_config = ConfigDict(extra="allow") - meta: dict[str, Any] | None = Field(alias="_meta", default=None) """ - This result property is reserved by the protocol to allow clients and servers to - attach additional metadata to their responses. + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. """ + model_config = ConfigDict(extra="allow") class PaginatedResult(Result): @@ -186,10 +186,26 @@ class EmptyResult(Result): """A response that indicates success but carries no data.""" -class Implementation(BaseModel): - """Describes the name and version of an MCP implementation.""" +class BaseMetadata(BaseModel): + """Base class for entities with name and optional title fields.""" name: str + """The programmatic name of the entity.""" + + title: str | None = None + """ + Intended for UI and end-user contexts — optimized to be human-readable and easily understood, + even by those unfamiliar with domain-specific terminology. + + If not provided, the name should be used for display (except for Tool, + where `annotations.title` should be given precedence over using `name`, + if present). + """ + + +class Implementation(BaseMetadata): + """Describes the name and version of an MCP implementation.""" + version: str model_config = ConfigDict(extra="allow") @@ -203,7 +219,7 @@ class RootsCapability(BaseModel): class SamplingCapability(BaseModel): - """Capability for logging operations.""" + """Capability for sampling operations.""" model_config = ConfigDict(extra="allow") @@ -252,6 +268,12 @@ class LoggingCapability(BaseModel): model_config = ConfigDict(extra="allow") +class CompletionsCapability(BaseModel): + """Capability for completions operations.""" + + model_config = ConfigDict(extra="allow") + + class ServerCapabilities(BaseModel): """Capabilities that a server may support.""" @@ -265,6 +287,8 @@ class ServerCapabilities(BaseModel): """Present if the server offers any resources to read.""" tools: ToolsCapability | None = None """Present if the server offers any tools to call.""" + completions: CompletionsCapability | None = None + """Present if the server offers autocompletion suggestions for prompts and resources.""" model_config = ConfigDict(extra="allow") @@ -284,7 +308,7 @@ class InitializeRequest(Request[InitializeRequestParams, Literal["initialize"]]) to begin initialization. """ - method: Literal["initialize"] + method: Literal["initialize"] = "initialize" params: InitializeRequestParams @@ -305,7 +329,7 @@ class InitializedNotification(Notification[NotificationParams | None, Literal["n finished. """ - method: Literal["notifications/initialized"] + method: Literal["notifications/initialized"] = "notifications/initialized" params: NotificationParams | None = None @@ -315,7 +339,7 @@ class PingRequest(Request[RequestParams | None, Literal["ping"]]): still alive. """ - method: Literal["ping"] + method: Literal["ping"] = "ping" params: RequestParams | None = None @@ -334,6 +358,11 @@ class ProgressNotificationParams(NotificationParams): """ total: float | None = None """Total number of items to process (or total progress required), if known.""" + message: str | None = None + """ + Message related to progress. This should provide relevant human readable + progress information. + """ model_config = ConfigDict(extra="allow") @@ -343,15 +372,14 @@ class ProgressNotification(Notification[ProgressNotificationParams, Literal["not long-running request. """ - method: Literal["notifications/progress"] + method: Literal["notifications/progress"] = "notifications/progress" params: ProgressNotificationParams -class ListResourcesRequest(PaginatedRequest[RequestParams | None, Literal["resources/list"]]): +class ListResourcesRequest(PaginatedRequest[Literal["resources/list"]]): """Sent from the client to request a list of resources the server has.""" - method: Literal["resources/list"] - params: RequestParams | None = None + method: Literal["resources/list"] = "resources/list" class Annotations(BaseModel): @@ -360,13 +388,11 @@ class Annotations(BaseModel): model_config = ConfigDict(extra="allow") -class Resource(BaseModel): +class Resource(BaseMetadata): """A known resource that the server is capable of reading.""" uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] """The URI of this resource.""" - name: str - """A human-readable name for this resource.""" description: str | None = None """A description of what this resource represents.""" mimeType: str | None = None @@ -379,10 +405,15 @@ class Resource(BaseModel): This can be used by Hosts to display file sizes and estimate context window usage. """ annotations: Annotations | None = None + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + """ + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. + """ model_config = ConfigDict(extra="allow") -class ResourceTemplate(BaseModel): +class ResourceTemplate(BaseMetadata): """A template description for resources available on the server.""" uriTemplate: str @@ -390,8 +421,6 @@ class ResourceTemplate(BaseModel): A URI template (according to RFC 6570) that can be used to construct resource URIs. """ - name: str - """A human-readable name for the type of resource this template refers to.""" description: str | None = None """A human-readable description of what this template is for.""" mimeType: str | None = None @@ -400,6 +429,11 @@ class ResourceTemplate(BaseModel): included if all resources matching this template have the same type. """ annotations: Annotations | None = None + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + """ + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. + """ model_config = ConfigDict(extra="allow") @@ -409,11 +443,10 @@ class ListResourcesResult(PaginatedResult): resources: list[Resource] -class ListResourceTemplatesRequest(PaginatedRequest[RequestParams | None, Literal["resources/templates/list"]]): +class ListResourceTemplatesRequest(PaginatedRequest[Literal["resources/templates/list"]]): """Sent from the client to request a list of resource templates the server has.""" - method: Literal["resources/templates/list"] - params: RequestParams | None = None + method: Literal["resources/templates/list"] = "resources/templates/list" class ListResourceTemplatesResult(PaginatedResult): @@ -436,7 +469,7 @@ class ReadResourceRequestParams(RequestParams): class ReadResourceRequest(Request[ReadResourceRequestParams, Literal["resources/read"]]): """Sent from the client to the server, to read a specific resource URI.""" - method: Literal["resources/read"] + method: Literal["resources/read"] = "resources/read" params: ReadResourceRequestParams @@ -447,6 +480,11 @@ class ResourceContents(BaseModel): """The URI of this resource.""" mimeType: str | None = None """The MIME type of this resource, if known.""" + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + """ + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. + """ model_config = ConfigDict(extra="allow") @@ -481,7 +519,7 @@ class ResourceListChangedNotification( of resources it can read from has changed. """ - method: Literal["notifications/resources/list_changed"] + method: Literal["notifications/resources/list_changed"] = "notifications/resources/list_changed" params: NotificationParams | None = None @@ -502,7 +540,7 @@ class SubscribeRequest(Request[SubscribeRequestParams, Literal["resources/subscr whenever a particular resource changes. """ - method: Literal["resources/subscribe"] + method: Literal["resources/subscribe"] = "resources/subscribe" params: SubscribeRequestParams @@ -520,7 +558,7 @@ class UnsubscribeRequest(Request[UnsubscribeRequestParams, Literal["resources/un the server. """ - method: Literal["resources/unsubscribe"] + method: Literal["resources/unsubscribe"] = "resources/unsubscribe" params: UnsubscribeRequestParams @@ -543,15 +581,14 @@ class ResourceUpdatedNotification( changed and may need to be read again. """ - method: Literal["notifications/resources/updated"] + method: Literal["notifications/resources/updated"] = "notifications/resources/updated" params: ResourceUpdatedNotificationParams -class ListPromptsRequest(PaginatedRequest[RequestParams | None, Literal["prompts/list"]]): +class ListPromptsRequest(PaginatedRequest[Literal["prompts/list"]]): """Sent from the client to request a list of prompts and prompt templates.""" - method: Literal["prompts/list"] - params: RequestParams | None = None + method: Literal["prompts/list"] = "prompts/list" class PromptArgument(BaseModel): @@ -566,15 +603,18 @@ class PromptArgument(BaseModel): model_config = ConfigDict(extra="allow") -class Prompt(BaseModel): +class Prompt(BaseMetadata): """A prompt or prompt template that the server offers.""" - name: str - """The name of the prompt or prompt template.""" description: str | None = None """An optional description of what this prompt provides.""" arguments: list[PromptArgument] | None = None """A list of arguments to use for templating the prompt.""" + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + """ + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. + """ model_config = ConfigDict(extra="allow") @@ -597,7 +637,7 @@ class GetPromptRequestParams(RequestParams): class GetPromptRequest(Request[GetPromptRequestParams, Literal["prompts/get"]]): """Used by the client to get a prompt provided by the server.""" - method: Literal["prompts/get"] + method: Literal["prompts/get"] = "prompts/get" params: GetPromptRequestParams @@ -608,6 +648,11 @@ class TextContent(BaseModel): text: str """The text content of the message.""" annotations: Annotations | None = None + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + """ + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. + """ model_config = ConfigDict(extra="allow") @@ -623,6 +668,31 @@ class ImageContent(BaseModel): image types. """ annotations: Annotations | None = None + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + """ + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. + """ + model_config = ConfigDict(extra="allow") + + +class AudioContent(BaseModel): + """Audio content for a message.""" + + type: Literal["audio"] + data: str + """The base64-encoded audio data.""" + mimeType: str + """ + The MIME type of the audio. Different providers may support different + audio types. + """ + annotations: Annotations | None = None + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + """ + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. + """ model_config = ConfigDict(extra="allow") @@ -630,7 +700,7 @@ class SamplingMessage(BaseModel): """Describes a message issued to or received from an LLM API.""" role: Role - content: TextContent | ImageContent + content: TextContent | ImageContent | AudioContent model_config = ConfigDict(extra="allow") @@ -645,14 +715,36 @@ class EmbeddedResource(BaseModel): type: Literal["resource"] resource: TextResourceContents | BlobResourceContents annotations: Annotations | None = None + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + """ + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. + """ model_config = ConfigDict(extra="allow") +class ResourceLink(Resource): + """ + A resource that the server is capable of reading, included in a prompt or tool call result. + + Note: resource links returned by tools are not guaranteed to appear in the results of `resources/list` requests. + """ + + type: Literal["resource_link"] + + +ContentBlock = TextContent | ImageContent | AudioContent | ResourceLink | EmbeddedResource +"""A content block that can be used in prompts and tool results.""" + +Content: TypeAlias = ContentBlock +# """DEPRECATED: Content is deprecated, you should use ContentBlock directly.""" + + class PromptMessage(BaseModel): """Describes a message returned as part of a prompt.""" role: Role - content: TextContent | ImageContent | EmbeddedResource + content: ContentBlock model_config = ConfigDict(extra="allow") @@ -672,15 +764,14 @@ class PromptListChangedNotification( of prompts it offers has changed. """ - method: Literal["notifications/prompts/list_changed"] + method: Literal["notifications/prompts/list_changed"] = "notifications/prompts/list_changed" params: NotificationParams | None = None -class ListToolsRequest(PaginatedRequest[RequestParams | None, Literal["tools/list"]]): +class ListToolsRequest(PaginatedRequest[Literal["tools/list"]]): """Sent from the client to request a list of tools the server has.""" - method: Literal["tools/list"] - params: RequestParams | None = None + method: Literal["tools/list"] = "tools/list" class ToolAnnotations(BaseModel): @@ -731,17 +822,25 @@ class ToolAnnotations(BaseModel): model_config = ConfigDict(extra="allow") -class Tool(BaseModel): +class Tool(BaseMetadata): """Definition for a tool the client can call.""" - name: str - """The name of the tool.""" description: str | None = None """A human-readable description of the tool.""" inputSchema: dict[str, Any] """A JSON Schema object defining the expected parameters for the tool.""" + outputSchema: dict[str, Any] | None = None + """ + An optional JSON Schema object defining the structure of the tool's output + returned in the structuredContent field of a CallToolResult. + """ annotations: ToolAnnotations | None = None """Optional additional tool information.""" + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + """ + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. + """ model_config = ConfigDict(extra="allow") @@ -762,14 +861,16 @@ class CallToolRequestParams(RequestParams): class CallToolRequest(Request[CallToolRequestParams, Literal["tools/call"]]): """Used by the client to invoke a tool provided by the server.""" - method: Literal["tools/call"] + method: Literal["tools/call"] = "tools/call" params: CallToolRequestParams class CallToolResult(Result): """The server's response to a tool call.""" - content: list[TextContent | ImageContent | EmbeddedResource] + content: list[ContentBlock] + structuredContent: dict[str, Any] | None = None + """An optional JSON object that represents the structured result of the tool call.""" isError: bool = False @@ -779,7 +880,7 @@ class ToolListChangedNotification(Notification[NotificationParams | None, Litera of tools it offers has changed. """ - method: Literal["notifications/tools/list_changed"] + method: Literal["notifications/tools/list_changed"] = "notifications/tools/list_changed" params: NotificationParams | None = None @@ -797,7 +898,7 @@ class SetLevelRequestParams(RequestParams): class SetLevelRequest(Request[SetLevelRequestParams, Literal["logging/setLevel"]]): """A request from the client to the server, to enable or adjust logging.""" - method: Literal["logging/setLevel"] + method: Literal["logging/setLevel"] = "logging/setLevel" params: SetLevelRequestParams @@ -808,7 +909,7 @@ class LoggingMessageNotificationParams(NotificationParams): """The severity of this log message.""" logger: str | None = None """An optional name of the logger issuing this message.""" - data: Any = None + data: Any """ The data to be logged, such as a string message or an object. Any JSON serializable type is allowed here. @@ -819,7 +920,7 @@ class LoggingMessageNotificationParams(NotificationParams): class LoggingMessageNotification(Notification[LoggingMessageNotificationParams, Literal["notifications/message"]]): """Notification of a log message passed from server to client.""" - method: Literal["notifications/message"] + method: Literal["notifications/message"] = "notifications/message" params: LoggingMessageNotificationParams @@ -914,7 +1015,7 @@ class CreateMessageRequestParams(RequestParams): class CreateMessageRequest(Request[CreateMessageRequestParams, Literal["sampling/createMessage"]]): """A request from the server to sample an LLM via the client.""" - method: Literal["sampling/createMessage"] + method: Literal["sampling/createMessage"] = "sampling/createMessage" params: CreateMessageRequestParams @@ -925,14 +1026,14 @@ class CreateMessageResult(Result): """The client's response to a sampling/create_message request from the server.""" role: Role - content: TextContent | ImageContent + content: TextContent | ImageContent | AudioContent model: str """The name of the model that generated the message.""" stopReason: StopReason | None = None """The reason why sampling stopped, if known.""" -class ResourceReference(BaseModel): +class ResourceTemplateReference(BaseModel): """A reference to a resource or resource template definition.""" type: Literal["ref/resource"] @@ -960,18 +1061,28 @@ class CompletionArgument(BaseModel): model_config = ConfigDict(extra="allow") +class CompletionContext(BaseModel): + """Additional, optional context for completions.""" + + arguments: dict[str, str] | None = None + """Previously-resolved variables in a URI template or prompt.""" + model_config = ConfigDict(extra="allow") + + class CompleteRequestParams(RequestParams): """Parameters for completion requests.""" - ref: ResourceReference | PromptReference + ref: ResourceTemplateReference | PromptReference argument: CompletionArgument + context: CompletionContext | None = None + """Additional, optional context for completions""" model_config = ConfigDict(extra="allow") class CompleteRequest(Request[CompleteRequestParams, Literal["completion/complete"]]): """A request from the client to the server, to ask for completion options.""" - method: Literal["completion/complete"] + method: Literal["completion/complete"] = "completion/complete" params: CompleteRequestParams @@ -1010,7 +1121,7 @@ class ListRootsRequest(Request[RequestParams | None, Literal["roots/list"]]): structure or access specific locations that the client has permission to read from. """ - method: Literal["roots/list"] + method: Literal["roots/list"] = "roots/list" params: RequestParams | None = None @@ -1029,6 +1140,11 @@ class Root(BaseModel): identifier for the root, which may be useful for display purposes or for referencing the root in other parts of the application. """ + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + """ + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. + """ model_config = ConfigDict(extra="allow") @@ -1054,7 +1170,7 @@ class RootsListChangedNotification( using the ListRootsRequest. """ - method: Literal["notifications/roots/list_changed"] + method: Literal["notifications/roots/list_changed"] = "notifications/roots/list_changed" params: NotificationParams | None = None @@ -1074,7 +1190,7 @@ class CancelledNotification(Notification[CancelledNotificationParams, Literal["n previously-issued request. """ - method: Literal["notifications/cancelled"] + method: Literal["notifications/cancelled"] = "notifications/cancelled" params: CancelledNotificationParams diff --git a/api/core/tools/__base/tool.py b/api/core/tools/__base/tool.py index 6e0462c530..82616596f8 100644 --- a/api/core/tools/__base/tool.py +++ b/api/core/tools/__base/tool.py @@ -217,3 +217,16 @@ class Tool(ABC): return ToolInvokeMessage( type=ToolInvokeMessage.MessageType.JSON, message=ToolInvokeMessage.JsonMessage(json_object=object) ) + + def create_variable_message( + self, variable_name: str, variable_value: Any, stream: bool = False + ) -> ToolInvokeMessage: + """ + create a variable message + """ + return ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.VARIABLE, + message=ToolInvokeMessage.VariableMessage( + variable_name=variable_name, variable_value=variable_value, stream=stream + ), + ) diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index de6bf01ae9..8f7d1101cb 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -4,6 +4,7 @@ from typing import Any, Literal from pydantic import BaseModel, Field, field_validator +from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.__base.tool import ToolParameter from core.tools.entities.common_entities import I18nObject @@ -44,10 +45,14 @@ class ToolProviderApiEntity(BaseModel): server_url: str | None = Field(default="", description="The server url of the tool") updated_at: int = Field(default_factory=lambda: int(datetime.now().timestamp())) server_identifier: str | None = Field(default="", description="The server identifier of the MCP tool") - timeout: float | None = Field(default=30.0, description="The timeout of the MCP tool") - sse_read_timeout: float | None = Field(default=300.0, description="The SSE read timeout of the MCP tool") + masked_headers: dict[str, str] | None = Field(default=None, description="The masked headers of the MCP tool") original_headers: dict[str, str] | None = Field(default=None, description="The original headers of the MCP tool") + authentication: MCPAuthentication | None = Field(default=None, description="The OAuth config of the MCP tool") + is_dynamic_registration: bool = Field(default=True, description="Whether the MCP tool is dynamically registered") + configuration: MCPConfiguration | None = Field( + default=None, description="The timeout and sse_read_timeout of the MCP tool" + ) @field_validator("tools", mode="before") @classmethod @@ -70,8 +75,15 @@ class ToolProviderApiEntity(BaseModel): if self.type == ToolProviderType.MCP: optional_fields.update(self.optional_field("updated_at", self.updated_at)) optional_fields.update(self.optional_field("server_identifier", self.server_identifier)) - optional_fields.update(self.optional_field("timeout", self.timeout)) - optional_fields.update(self.optional_field("sse_read_timeout", self.sse_read_timeout)) + optional_fields.update( + self.optional_field( + "configuration", self.configuration.model_dump() if self.configuration else MCPConfiguration() + ) + ) + optional_fields.update( + self.optional_field("authentication", self.authentication.model_dump() if self.authentication else None) + ) + optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration)) optional_fields.update(self.optional_field("masked_headers", self.masked_headers)) optional_fields.update(self.optional_field("original_headers", self.original_headers)) return { diff --git a/api/core/tools/mcp_tool/provider.py b/api/core/tools/mcp_tool/provider.py index f0e4dba9c3..557211c8c8 100644 --- a/api/core/tools/mcp_tool/provider.py +++ b/api/core/tools/mcp_tool/provider.py @@ -1,6 +1,6 @@ -import json from typing import Any, Self +from core.entities.mcp_provider import MCPProviderEntity from core.mcp.types import Tool as RemoteMCPTool from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_runtime import ToolRuntime @@ -52,18 +52,25 @@ class MCPToolProviderController(ToolProviderController): """ from db provider """ - tools = [] - tools_data = json.loads(db_provider.tools) - remote_mcp_tools = [RemoteMCPTool.model_validate(tool) for tool in tools_data] - user = db_provider.load_user() + # Convert to entity first + provider_entity = db_provider.to_entity() + return cls.from_entity(provider_entity) + + @classmethod + def from_entity(cls, entity: MCPProviderEntity) -> Self: + """ + create a MCPToolProviderController from a MCPProviderEntity + """ + remote_mcp_tools = [RemoteMCPTool(**tool) for tool in entity.tools] + tools = [ ToolEntity( identity=ToolIdentity( - author=user.name if user else "Anonymous", + author="Anonymous", # Tool level author is not stored name=remote_mcp_tool.name, label=I18nObject(en_US=remote_mcp_tool.name, zh_Hans=remote_mcp_tool.name), - provider=db_provider.server_identifier, - icon=db_provider.icon, + provider=entity.provider_id, + icon=entity.icon if isinstance(entity.icon, str) else "", ), parameters=ToolTransformService.convert_mcp_schema_to_parameter(remote_mcp_tool.inputSchema), description=ToolDescription( @@ -72,31 +79,32 @@ class MCPToolProviderController(ToolProviderController): ), llm=remote_mcp_tool.description or "", ), + output_schema=remote_mcp_tool.outputSchema or {}, has_runtime_parameters=len(remote_mcp_tool.inputSchema) > 0, ) for remote_mcp_tool in remote_mcp_tools ] - if not db_provider.icon: + if not entity.icon: raise ValueError("Database provider icon is required") return cls( entity=ToolProviderEntityWithPlugin( identity=ToolProviderIdentity( - author=user.name if user else "Anonymous", - name=db_provider.name, - label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name), + author="Anonymous", # Provider level author is not stored in entity + name=entity.name, + label=I18nObject(en_US=entity.name, zh_Hans=entity.name), description=I18nObject(en_US="", zh_Hans=""), - icon=db_provider.icon, + icon=entity.icon if isinstance(entity.icon, str) else "", ), plugin_id=None, credentials_schema=[], tools=tools, ), - provider_id=db_provider.server_identifier or "", - tenant_id=db_provider.tenant_id or "", - server_url=db_provider.decrypted_server_url, - headers=db_provider.decrypted_headers or {}, - timeout=db_provider.timeout, - sse_read_timeout=db_provider.sse_read_timeout, + provider_id=entity.provider_id, + tenant_id=entity.tenant_id, + server_url=entity.server_url, + headers=entity.headers, + timeout=entity.timeout, + sse_read_timeout=entity.sse_read_timeout, ) def _validate_credentials(self, user_id: str, credentials: dict[str, Any]): diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index 976d4dc942..a476859f29 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -3,12 +3,13 @@ import json from collections.abc import Generator from typing import Any -from core.mcp.error import MCPAuthError, MCPConnectionError -from core.mcp.mcp_client import MCPClient -from core.mcp.types import ImageContent, TextContent +from core.mcp.auth_client import MCPClientWithAuthRetry +from core.mcp.error import MCPConnectionError +from core.mcp.types import CallToolResult, ImageContent, TextContent from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType +from core.tools.errors import ToolInvokeError class MCPTool(Tool): @@ -44,40 +45,32 @@ class MCPTool(Tool): app_id: str | None = None, message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: - from core.tools.errors import ToolInvokeError - - try: - with MCPClient( - self.server_url, - self.provider_id, - self.tenant_id, - authed=True, - headers=self.headers, - timeout=self.timeout, - sse_read_timeout=self.sse_read_timeout, - ) as mcp_client: - tool_parameters = self._handle_none_parameter(tool_parameters) - result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters) - except MCPAuthError as e: - raise ToolInvokeError("Please auth the tool first") from e - except MCPConnectionError as e: - raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e - except Exception as e: - raise ToolInvokeError(f"Failed to invoke tool: {e}") from e - + result = self.invoke_remote_mcp_tool(tool_parameters) + # handle dify tool output for content in result.content: if isinstance(content, TextContent): yield from self._process_text_content(content) elif isinstance(content, ImageContent): yield self._process_image_content(content) + # handle MCP structured output + if self.entity.output_schema and result.structuredContent: + for k, v in result.structuredContent.items(): + yield self.create_variable_message(k, v) def _process_text_content(self, content: TextContent) -> Generator[ToolInvokeMessage, None, None]: """Process text content and yield appropriate messages.""" - try: - content_json = json.loads(content.text) - yield from self._process_json_content(content_json) - except json.JSONDecodeError: - yield self.create_text_message(content.text) + # Check if content looks like JSON before attempting to parse + text = content.text.strip() + if text and text[0] in ("{", "[") and text[-1] in ("}", "]"): + try: + content_json = json.loads(text) + yield from self._process_json_content(content_json) + return + except json.JSONDecodeError: + pass + + # If not JSON or parsing failed, treat as plain text + yield self.create_text_message(content.text) def _process_json_content(self, content_json: Any) -> Generator[ToolInvokeMessage, None, None]: """Process JSON content based on its type.""" @@ -126,3 +119,44 @@ class MCPTool(Tool): for key, value in parameter.items() if value is not None and not (isinstance(value, str) and value.strip() == "") } + + def invoke_remote_mcp_tool(self, tool_parameters: dict[str, Any]) -> CallToolResult: + headers = self.headers.copy() if self.headers else {} + tool_parameters = self._handle_none_parameter(tool_parameters) + + from sqlalchemy.orm import Session + + from extensions.ext_database import db + from services.tools.mcp_tools_manage_service import MCPToolManageService + + # Step 1: Load provider entity and credentials in a short-lived session + # This minimizes database connection hold time + with Session(db.engine, expire_on_commit=False) as session: + mcp_service = MCPToolManageService(session=session) + provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True) + + # Decrypt and prepare all credentials before closing session + server_url = provider_entity.decrypt_server_url() + headers = provider_entity.decrypt_headers() + + # Try to get existing token and add to headers + if not headers: + tokens = provider_entity.retrieve_tokens() + if tokens and tokens.access_token: + headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}" + + # Step 2: Session is now closed, perform network operations without holding database connection + # MCPClientWithAuthRetry will create a new session lazily only if auth retry is needed + try: + with MCPClientWithAuthRetry( + server_url=server_url, + headers=headers, + timeout=self.timeout, + sse_read_timeout=self.sse_read_timeout, + provider_entity=provider_entity, + ) as mcp_client: + return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters) + except MCPConnectionError as e: + raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e + except Exception as e: + raise ToolInvokeError(f"Failed to invoke tool: {e}") from e diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 5c414915f4..ff7dcc0e55 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -14,17 +14,32 @@ from sqlalchemy.orm import Session from yarl import URL import contexts +from core.helper.provider_cache import ToolProviderCredentialsCache +from core.plugin.impl.tool import PluginToolManager +from core.tools.__base.tool_provider import ToolProviderController +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.mcp_tool.provider import MCPToolProviderController +from core.tools.mcp_tool.tool import MCPTool +from core.tools.plugin_tool.provider import PluginToolProviderController +from core.tools.plugin_tool.tool import PluginTool +from core.tools.utils.uuid_utils import is_valid_uuid +from core.tools.workflow_as_tool.provider import WorkflowToolProviderController +from core.workflow.runtime.variable_pool import VariablePool +from extensions.ext_database import db +from models.provider_ids import ToolProviderID +from services.enterprise.plugin_manager_service import PluginCredentialType +from services.tools.mcp_tools_manage_service import MCPToolManageService + +if TYPE_CHECKING: + from core.workflow.nodes.tool.entities import ToolEntity + from configs import dify_config from core.agent.entities import AgentToolEntity from core.app.entities.app_invoke_entities import InvokeFrom from core.helper.module_import_helper import load_single_subclass_from_source from core.helper.position_helper import is_filtered -from core.helper.provider_cache import ToolProviderCredentialsCache from core.model_runtime.utils.encoders import jsonable_encoder -from core.plugin.impl.tool import PluginToolManager from core.tools.__base.tool import Tool -from core.tools.__base.tool_provider import ToolProviderController -from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort from core.tools.builtin_tool.tool import BuiltinTool @@ -40,21 +55,11 @@ from core.tools.entities.tool_entities import ( ToolProviderType, ) from core.tools.errors import ToolProviderNotFoundError -from core.tools.mcp_tool.provider import MCPToolProviderController -from core.tools.mcp_tool.tool import MCPTool -from core.tools.plugin_tool.provider import PluginToolProviderController -from core.tools.plugin_tool.tool import PluginTool from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.configuration import ToolParameterConfigurationManager from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter -from core.tools.utils.uuid_utils import is_valid_uuid -from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.tool import WorkflowTool -from extensions.ext_database import db -from models.provider_ids import ToolProviderID -from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider -from services.enterprise.plugin_manager_service import PluginCredentialType -from services.tools.mcp_tools_manage_service import MCPToolManageService +from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider from services.tools.tools_transform_service import ToolTransformService if TYPE_CHECKING: @@ -719,7 +724,9 @@ class ToolManager: ) result_providers[f"workflow_provider.{user_provider.name}"] = user_provider if "mcp" in filters: - mcp_providers = MCPToolManageService.retrieve_mcp_tools(tenant_id, for_list=True) + with Session(db.engine) as session: + mcp_service = MCPToolManageService(session=session) + mcp_providers = mcp_service.list_providers(tenant_id=tenant_id, for_list=True) for mcp_provider in mcp_providers: result_providers[f"mcp_provider.{mcp_provider.name}"] = mcp_provider @@ -774,17 +781,12 @@ class ToolManager: :return: the provider controller, the credentials """ - provider: MCPToolProvider | None = ( - db.session.query(MCPToolProvider) - .where( - MCPToolProvider.server_identifier == provider_id, - MCPToolProvider.tenant_id == tenant_id, - ) - .first() - ) - - if provider is None: - raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found") + with Session(db.engine) as session: + mcp_service = MCPToolManageService(session=session) + try: + provider = mcp_service.get_provider(server_identifier=provider_id, tenant_id=tenant_id) + except ValueError: + raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found") controller = MCPToolProviderController.from_db(provider) @@ -922,16 +924,15 @@ class ToolManager: @classmethod def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str] | str: try: - mcp_provider: MCPToolProvider | None = ( - db.session.query(MCPToolProvider) - .where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == provider_id) - .first() - ) - - if mcp_provider is None: - raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found") - - return mcp_provider.provider_icon + with Session(db.engine) as session: + mcp_service = MCPToolManageService(session=session) + try: + mcp_provider = mcp_service.get_provider_entity( + provider_id=provider_id, tenant_id=tenant_id, by_server_id=True + ) + return mcp_provider.provider_icon + except ValueError: + raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found") except Exception: return {"background": "#252525", "content": "\ud83d\ude01"} diff --git a/api/models/tools.py b/api/models/tools.py index 4e2976ce1b..12acc149b1 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -1,16 +1,13 @@ import json -from collections.abc import Mapping from datetime import datetime from decimal import Decimal from typing import TYPE_CHECKING, Any, cast -from urllib.parse import urlparse import sqlalchemy as sa from deprecated import deprecated from sqlalchemy import ForeignKey, String, func from sqlalchemy.orm import Mapped, mapped_column -from core.helper import encrypter from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration @@ -21,7 +18,7 @@ from .model import Account, App, Tenant from .types import StringUUID if TYPE_CHECKING: - from core.mcp.types import Tool as MCPTool + from core.entities.mcp_provider import MCPProviderEntity from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration @@ -331,126 +328,36 @@ class MCPToolProvider(TypeBase): def load_user(self) -> Account | None: return db.session.query(Account).where(Account.id == self.user_id).first() - @property - def tenant(self) -> Tenant | None: - return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() - @property def credentials(self) -> dict[str, Any]: if not self.encrypted_credentials: return {} try: - return cast(dict[str, Any], json.loads(self.encrypted_credentials)) or {} - except json.JSONDecodeError: - return {} - - @property - def mcp_tools(self) -> list["MCPTool"]: - from core.mcp.types import Tool as MCPTool - - return [MCPTool.model_validate(tool) for tool in json.loads(self.tools)] - - @property - def provider_icon(self) -> Mapping[str, str] | str: - from core.file import helpers as file_helpers - - assert self.icon - try: - return json.loads(self.icon) - except json.JSONDecodeError: - return file_helpers.get_signed_file_url(self.icon) - - @property - def decrypted_server_url(self) -> str: - return encrypter.decrypt_token(self.tenant_id, self.server_url) - - @property - def decrypted_headers(self) -> dict[str, Any]: - """Get decrypted headers for MCP server requests.""" - from core.entities.provider_entities import BasicProviderConfig - from core.helper.provider_cache import NoOpProviderCredentialCache - from core.tools.utils.encryption import create_provider_encrypter - - try: - if not self.encrypted_headers: - return {} - - headers_data = json.loads(self.encrypted_headers) - - # Create dynamic config for all headers as SECRET_INPUT - config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers_data] - - encrypter_instance, _ = create_provider_encrypter( - tenant_id=self.tenant_id, - config=config, - cache=NoOpProviderCredentialCache(), - ) - - result = encrypter_instance.decrypt(headers_data) - return result + return json.loads(self.encrypted_credentials) except Exception: return {} @property - def masked_headers(self) -> dict[str, Any]: - """Get masked headers for frontend display.""" - from core.entities.provider_entities import BasicProviderConfig - from core.helper.provider_cache import NoOpProviderCredentialCache - from core.tools.utils.encryption import create_provider_encrypter - + def headers(self) -> dict[str, Any]: + if self.encrypted_headers is None: + return {} try: - if not self.encrypted_headers: - return {} - - headers_data = json.loads(self.encrypted_headers) - - # Create dynamic config for all headers as SECRET_INPUT - config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers_data] - - encrypter_instance, _ = create_provider_encrypter( - tenant_id=self.tenant_id, - config=config, - cache=NoOpProviderCredentialCache(), - ) - - # First decrypt, then mask - decrypted_headers = encrypter_instance.decrypt(headers_data) - result = encrypter_instance.mask_tool_credentials(decrypted_headers) - return result + return json.loads(self.encrypted_headers) except Exception: return {} @property - def masked_server_url(self) -> str: - def mask_url(url: str, mask_char: str = "*") -> str: - """ - mask the url to a simple string - """ - parsed = urlparse(url) - base_url = f"{parsed.scheme}://{parsed.netloc}" + def tool_dict(self) -> list[dict[str, Any]]: + try: + return json.loads(self.tools) if self.tools else [] + except (json.JSONDecodeError, TypeError): + return [] - if parsed.path and parsed.path != "/": - return f"{base_url}/{mask_char * 6}" - else: - return base_url + def to_entity(self) -> "MCPProviderEntity": + """Convert to domain entity""" + from core.entities.mcp_provider import MCPProviderEntity - return mask_url(self.decrypted_server_url) - - @property - def decrypted_credentials(self) -> dict[str, Any]: - from core.helper.provider_cache import NoOpProviderCredentialCache - from core.tools.mcp_tool.provider import MCPToolProviderController - from core.tools.utils.encryption import create_provider_encrypter - - provider_controller = MCPToolProviderController.from_db(self) - - encrypter, _ = create_provider_encrypter( - tenant_id=self.tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - cache=NoOpProviderCredentialCache(), - ) - - return encrypter.decrypt(self.credentials) + return MCPProviderEntity.from_db_model(self) class ToolModelInvoke(TypeBase): diff --git a/api/services/tools/mcp_tools_manage_service.py b/api/services/tools/mcp_tools_manage_service.py index 92c33c1a49..e219bd4ce9 100644 --- a/api/services/tools/mcp_tools_manage_service.py +++ b/api/services/tools/mcp_tools_manage_service.py @@ -1,86 +1,118 @@ import hashlib import json +import logging from datetime import datetime +from enum import StrEnum from typing import Any +from urllib.parse import urlparse -from sqlalchemy import or_ +from pydantic import BaseModel, Field +from sqlalchemy import or_, select from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session +from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity from core.helper import encrypter from core.helper.provider_cache import NoOpProviderCredentialCache +from core.mcp.auth.auth_flow import auth +from core.mcp.auth_client import MCPClientWithAuthRetry from core.mcp.error import MCPAuthError, MCPError -from core.mcp.mcp_client import MCPClient from core.tools.entities.api_entities import ToolProviderApiEntity -from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolProviderType -from core.tools.mcp_tool.provider import MCPToolProviderController from core.tools.utils.encryption import ProviderConfigEncrypter -from extensions.ext_database import db from models.tools import MCPToolProvider from services.tools.tools_transform_service import ToolTransformService +logger = logging.getLogger(__name__) + +# Constants UNCHANGED_SERVER_URL_PLACEHOLDER = "[__HIDDEN__]" +CLIENT_NAME = "Dify" +EMPTY_TOOLS_JSON = "[]" +EMPTY_CREDENTIALS_JSON = "{}" + + +class OAuthDataType(StrEnum): + """Types of OAuth data that can be saved.""" + + TOKENS = "tokens" + CLIENT_INFO = "client_info" + CODE_VERIFIER = "code_verifier" + MIXED = "mixed" + + +class ReconnectResult(BaseModel): + """Result of reconnecting to an MCP provider""" + + authed: bool = Field(description="Whether the provider is authenticated") + tools: str = Field(description="JSON string of tool list") + encrypted_credentials: str = Field(description="JSON string of encrypted credentials") + + +class ServerUrlValidationResult(BaseModel): + """Result of server URL validation check""" + + needs_validation: bool + validation_passed: bool = False + reconnect_result: ReconnectResult | None = None + encrypted_server_url: str | None = None + server_url_hash: str | None = None + + @property + def should_update_server_url(self) -> bool: + """Check if server URL should be updated based on validation result""" + return self.needs_validation and self.validation_passed and self.reconnect_result is not None class MCPToolManageService: - """ - Service class for managing mcp tools. - """ + """Service class for managing MCP tools and providers.""" - @staticmethod - def _encrypt_headers(headers: dict[str, str], tenant_id: str) -> dict[str, str]: + def __init__(self, session: Session): + self._session = session + + # ========== Provider CRUD Operations ========== + + def get_provider( + self, *, provider_id: str | None = None, server_identifier: str | None = None, tenant_id: str + ) -> MCPToolProvider: """ - Encrypt headers using ProviderConfigEncrypter with all headers as SECRET_INPUT. + Get MCP provider by ID or server identifier. Args: - headers: Dictionary of headers to encrypt - tenant_id: Tenant ID for encryption + provider_id: Provider ID (UUID) + server_identifier: Server identifier + tenant_id: Tenant ID Returns: - Dictionary with all headers encrypted + MCPToolProvider instance + + Raises: + ValueError: If provider not found """ - if not headers: - return {} + if server_identifier: + stmt = select(MCPToolProvider).where( + MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier + ) + else: + stmt = select(MCPToolProvider).where( + MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id + ) - from core.entities.provider_entities import BasicProviderConfig - from core.helper.provider_cache import NoOpProviderCredentialCache - from core.tools.utils.encryption import create_provider_encrypter - - # Create dynamic config for all headers as SECRET_INPUT - config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers] - - encrypter_instance, _ = create_provider_encrypter( - tenant_id=tenant_id, - config=config, - cache=NoOpProviderCredentialCache(), - ) - - return encrypter_instance.encrypt(headers) - - @staticmethod - def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider: - res = ( - db.session.query(MCPToolProvider) - .where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id) - .first() - ) - if not res: + provider = self._session.scalar(stmt) + if not provider: raise ValueError("MCP tool not found") - return res + return provider - @staticmethod - def get_mcp_provider_by_server_identifier(server_identifier: str, tenant_id: str) -> MCPToolProvider: - res = ( - db.session.query(MCPToolProvider) - .where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier) - .first() - ) - if not res: - raise ValueError("MCP tool not found") - return res + def get_provider_entity(self, provider_id: str, tenant_id: str, by_server_id: bool = False) -> MCPProviderEntity: + """Get provider entity by ID or server identifier.""" + if by_server_id: + db_provider = self.get_provider(server_identifier=provider_id, tenant_id=tenant_id) + else: + db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) + return db_provider.to_entity() - @staticmethod - def create_mcp_provider( + def create_provider( + self, + *, tenant_id: str, name: str, server_url: str, @@ -89,37 +121,30 @@ class MCPToolManageService: icon_type: str, icon_background: str, server_identifier: str, - timeout: float, - sse_read_timeout: float, + configuration: MCPConfiguration, + authentication: MCPAuthentication | None = None, headers: dict[str, str] | None = None, ) -> ToolProviderApiEntity: - server_url_hash = hashlib.sha256(server_url.encode()).hexdigest() - existing_provider = ( - db.session.query(MCPToolProvider) - .where( - MCPToolProvider.tenant_id == tenant_id, - or_( - MCPToolProvider.name == name, - MCPToolProvider.server_url_hash == server_url_hash, - MCPToolProvider.server_identifier == server_identifier, - ), - ) - .first() - ) - if existing_provider: - if existing_provider.name == name: - raise ValueError(f"MCP tool {name} already exists") - if existing_provider.server_url_hash == server_url_hash: - raise ValueError(f"MCP tool {server_url} already exists") - if existing_provider.server_identifier == server_identifier: - raise ValueError(f"MCP tool {server_identifier} already exists") - encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url) - # Encrypt headers - encrypted_headers = None - if headers: - encrypted_headers_dict = MCPToolManageService._encrypt_headers(headers, tenant_id) - encrypted_headers = json.dumps(encrypted_headers_dict) + """Create a new MCP provider.""" + # Validate URL format + if not self._is_valid_url(server_url): + raise ValueError("Server URL is not valid.") + server_url_hash = hashlib.sha256(server_url.encode()).hexdigest() + + # Check for existing provider + self._check_provider_exists(tenant_id, name, server_url_hash, server_identifier) + + # Encrypt sensitive data + encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url) + encrypted_headers = self._prepare_encrypted_dict(headers, tenant_id) if headers else None + encrypted_credentials = None + if authentication is not None and authentication.client_id: + encrypted_credentials = self._build_and_encrypt_credentials( + authentication.client_id, authentication.client_secret, tenant_id + ) + + # Create provider mcp_tool = MCPToolProvider( tenant_id=tenant_id, name=name, @@ -127,93 +152,23 @@ class MCPToolManageService: server_url_hash=server_url_hash, user_id=user_id, authed=False, - tools="[]", - icon=json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon, + tools=EMPTY_TOOLS_JSON, + icon=self._prepare_icon(icon, icon_type, icon_background), server_identifier=server_identifier, - timeout=timeout, - sse_read_timeout=sse_read_timeout, + timeout=configuration.timeout, + sse_read_timeout=configuration.sse_read_timeout, encrypted_headers=encrypted_headers, - ) - db.session.add(mcp_tool) - db.session.commit() - return ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True) - - @staticmethod - def retrieve_mcp_tools(tenant_id: str, for_list: bool = False) -> list[ToolProviderApiEntity]: - mcp_providers = ( - db.session.query(MCPToolProvider) - .where(MCPToolProvider.tenant_id == tenant_id) - .order_by(MCPToolProvider.name) - .all() - ) - return [ - ToolTransformService.mcp_provider_to_user_provider(mcp_provider, for_list=for_list) - for mcp_provider in mcp_providers - ] - - @classmethod - def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str) -> ToolProviderApiEntity: - mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) - server_url = mcp_provider.decrypted_server_url - authed = mcp_provider.authed - headers = mcp_provider.decrypted_headers - timeout = mcp_provider.timeout - sse_read_timeout = mcp_provider.sse_read_timeout - - try: - with MCPClient( - server_url, - provider_id, - tenant_id, - authed=authed, - for_list=True, - headers=headers, - timeout=timeout, - sse_read_timeout=sse_read_timeout, - ) as mcp_client: - tools = mcp_client.list_tools() - except MCPAuthError: - raise ValueError("Please auth the tool first") - except MCPError as e: - raise ValueError(f"Failed to connect to MCP server: {e}") - - try: - mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) - mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools]) - mcp_provider.authed = True - mcp_provider.updated_at = datetime.now() - db.session.commit() - except Exception: - db.session.rollback() - raise - - user = mcp_provider.load_user() - if not mcp_provider.icon: - raise ValueError("MCP provider icon is required") - return ToolProviderApiEntity( - id=mcp_provider.id, - name=mcp_provider.name, - tools=ToolTransformService.mcp_tool_to_user_tool(mcp_provider, tools), - type=ToolProviderType.MCP, - icon=mcp_provider.icon, - author=user.name if user else "Anonymous", - server_url=mcp_provider.masked_server_url, - updated_at=int(mcp_provider.updated_at.timestamp()), - description=I18nObject(en_US="", zh_Hans=""), - label=I18nObject(en_US=mcp_provider.name, zh_Hans=mcp_provider.name), - plugin_unique_identifier=mcp_provider.server_identifier, + encrypted_credentials=encrypted_credentials, ) - @classmethod - def delete_mcp_tool(cls, tenant_id: str, provider_id: str): - mcp_tool = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) + self._session.add(mcp_tool) + self._session.flush() + mcp_providers = ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True) + return mcp_providers - db.session.delete(mcp_tool) - db.session.commit() - - @classmethod - def update_mcp_provider( - cls, + def update_provider( + self, + *, tenant_id: str, provider_id: str, name: str, @@ -222,129 +177,546 @@ class MCPToolManageService: icon_type: str, icon_background: str, server_identifier: str, - timeout: float | None = None, - sse_read_timeout: float | None = None, headers: dict[str, str] | None = None, - ): - mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) + configuration: MCPConfiguration, + authentication: MCPAuthentication | None = None, + validation_result: ServerUrlValidationResult | None = None, + ) -> None: + """ + Update an MCP provider. - reconnect_result = None + Args: + validation_result: Pre-validation result from validate_server_url_change. + If provided and contains reconnect_result, it will be used + instead of performing network operations. + """ + mcp_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) + + # Check for duplicate name (excluding current provider) + if name != mcp_provider.name: + stmt = select(MCPToolProvider).where( + MCPToolProvider.tenant_id == tenant_id, + MCPToolProvider.name == name, + MCPToolProvider.id != provider_id, + ) + existing_provider = self._session.scalar(stmt) + if existing_provider: + raise ValueError(f"MCP tool {name} already exists") + + # Get URL update data from validation result encrypted_server_url = None server_url_hash = None + reconnect_result = None - if UNCHANGED_SERVER_URL_PLACEHOLDER not in server_url: - encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url) - server_url_hash = hashlib.sha256(server_url.encode()).hexdigest() - - if server_url_hash != mcp_provider.server_url_hash: - reconnect_result = cls._re_connect_mcp_provider(server_url, provider_id, tenant_id) + if validation_result and validation_result.encrypted_server_url: + # Use all data from validation result + encrypted_server_url = validation_result.encrypted_server_url + server_url_hash = validation_result.server_url_hash + reconnect_result = validation_result.reconnect_result try: + # Update basic fields mcp_provider.updated_at = datetime.now() mcp_provider.name = name - mcp_provider.icon = ( - json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon - ) + mcp_provider.icon = self._prepare_icon(icon, icon_type, icon_background) mcp_provider.server_identifier = server_identifier - if encrypted_server_url is not None and server_url_hash is not None: + # Update server URL if changed + if encrypted_server_url and server_url_hash: mcp_provider.server_url = encrypted_server_url mcp_provider.server_url_hash = server_url_hash if reconnect_result: - mcp_provider.authed = reconnect_result["authed"] - mcp_provider.tools = reconnect_result["tools"] - mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"] + mcp_provider.authed = reconnect_result.authed + mcp_provider.tools = reconnect_result.tools + mcp_provider.encrypted_credentials = reconnect_result.encrypted_credentials - if timeout is not None: - mcp_provider.timeout = timeout - if sse_read_timeout is not None: - mcp_provider.sse_read_timeout = sse_read_timeout + # Update optional configuration fields + self._update_optional_fields(mcp_provider, configuration) + + # Update headers if provided if headers is not None: - # Merge masked headers from frontend with existing real values - if headers: - # existing decrypted and masked headers - existing_decrypted = mcp_provider.decrypted_headers - existing_masked = mcp_provider.masked_headers + mcp_provider.encrypted_headers = self._process_headers(headers, mcp_provider, tenant_id) - # Build final headers: if value equals masked existing, keep original decrypted value - final_headers: dict[str, str] = {} - for key, incoming_value in headers.items(): - if ( - key in existing_masked - and key in existing_decrypted - and isinstance(incoming_value, str) - and incoming_value == existing_masked.get(key) - ): - # unchanged, use original decrypted value - final_headers[key] = str(existing_decrypted[key]) - else: - final_headers[key] = incoming_value + # Update credentials if provided + if authentication and authentication.client_id: + mcp_provider.encrypted_credentials = self._process_credentials(authentication, mcp_provider, tenant_id) - encrypted_headers_dict = MCPToolManageService._encrypt_headers(final_headers, tenant_id) - mcp_provider.encrypted_headers = json.dumps(encrypted_headers_dict) - else: - # Explicitly clear headers if empty dict passed - mcp_provider.encrypted_headers = None - db.session.commit() + # Flush changes to database + self._session.flush() except IntegrityError as e: - db.session.rollback() - error_msg = str(e.orig) - if "unique_mcp_provider_name" in error_msg: - raise ValueError(f"MCP tool {name} already exists") - if "unique_mcp_provider_server_url" in error_msg: - raise ValueError(f"MCP tool {server_url} already exists") - if "unique_mcp_provider_server_identifier" in error_msg: - raise ValueError(f"MCP tool {server_identifier} already exists") - raise - except Exception: - db.session.rollback() - raise + self._handle_integrity_error(e, name, server_url, server_identifier) - @classmethod - def update_mcp_provider_credentials( - cls, mcp_provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False - ): - provider_controller = MCPToolProviderController.from_db(mcp_provider) + def delete_provider(self, *, tenant_id: str, provider_id: str) -> None: + """Delete an MCP provider.""" + mcp_tool = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) + self._session.delete(mcp_tool) + + def list_providers( + self, *, tenant_id: str, for_list: bool = False, include_sensitive: bool = True + ) -> list[ToolProviderApiEntity]: + """List all MCP providers for a tenant. + + Args: + tenant_id: Tenant ID + for_list: If True, return provider ID; if False, return server identifier + include_sensitive: If False, skip expensive decryption operations (default: True for backward compatibility) + """ + from models.account import Account + + stmt = select(MCPToolProvider).where(MCPToolProvider.tenant_id == tenant_id).order_by(MCPToolProvider.name) + mcp_providers = self._session.scalars(stmt).all() + + if not mcp_providers: + return [] + + # Batch query all users to avoid N+1 problem + user_ids = {provider.user_id for provider in mcp_providers} + users = self._session.query(Account).where(Account.id.in_(user_ids)).all() + user_name_map = {user.id: user.name for user in users} + + return [ + ToolTransformService.mcp_provider_to_user_provider( + provider, + for_list=for_list, + user_name=user_name_map.get(provider.user_id), + include_sensitive=include_sensitive, + ) + for provider in mcp_providers + ] + + # ========== Tool Operations ========== + + def list_provider_tools(self, *, tenant_id: str, provider_id: str) -> ToolProviderApiEntity: + """List tools from remote MCP server.""" + # Load provider and convert to entity + db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) + provider_entity = db_provider.to_entity() + + # Verify authentication + if not provider_entity.authed: + raise ValueError("Please auth the tool first") + + # Prepare headers with auth token + headers = self._prepare_auth_headers(provider_entity) + + # Retrieve tools from remote server + server_url = provider_entity.decrypt_server_url() + try: + tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity) + except MCPError as e: + raise ValueError(f"Failed to connect to MCP server: {e}") + + # Update database with retrieved tools + db_provider.tools = json.dumps([tool.model_dump() for tool in tools]) + db_provider.authed = True + db_provider.updated_at = datetime.now() + self._session.flush() + + # Build API response + return self._build_tool_provider_response(db_provider, provider_entity, tools) + + # ========== OAuth and Credentials Operations ========== + + def update_provider_credentials( + self, *, provider_id: str, tenant_id: str, credentials: dict[str, Any], authed: bool | None = None + ) -> None: + """ + Update provider credentials with encryption. + + Args: + provider_id: Provider ID + tenant_id: Tenant ID + credentials: Credentials to save + authed: Whether provider is authenticated (None means keep current state) + """ + from core.tools.mcp_tool.provider import MCPToolProviderController + + # Get provider from current session + provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) + + # Encrypt new credentials + provider_controller = MCPToolProviderController.from_db(provider) tool_configuration = ProviderConfigEncrypter( - tenant_id=mcp_provider.tenant_id, + tenant_id=provider.tenant_id, config=list(provider_controller.get_credentials_schema()), provider_config_cache=NoOpProviderCredentialCache(), ) - credentials = tool_configuration.encrypt(credentials) - mcp_provider.updated_at = datetime.now() - mcp_provider.encrypted_credentials = json.dumps({**mcp_provider.credentials, **credentials}) - mcp_provider.authed = authed - if not authed: - mcp_provider.tools = "[]" - db.session.commit() + encrypted_credentials = tool_configuration.encrypt(credentials) - @classmethod - def _re_connect_mcp_provider(cls, server_url: str, provider_id: str, tenant_id: str): - # Get the existing provider to access headers and timeout settings - mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) - headers = mcp_provider.decrypted_headers - timeout = mcp_provider.timeout - sse_read_timeout = mcp_provider.sse_read_timeout + # Update provider + provider.updated_at = datetime.now() + provider.encrypted_credentials = json.dumps({**provider.credentials, **encrypted_credentials}) + + if authed is not None: + provider.authed = authed + if not authed: + provider.tools = EMPTY_TOOLS_JSON + + # Flush changes to database + self._session.flush() + + def save_oauth_data( + self, provider_id: str, tenant_id: str, data: dict[str, Any], data_type: OAuthDataType = OAuthDataType.MIXED + ) -> None: + """ + Save OAuth-related data (tokens, client info, code verifier). + + Args: + provider_id: Provider ID + tenant_id: Tenant ID + data: Data to save (tokens, client info, or code verifier) + data_type: Type of OAuth data to save + """ + # Determine if this makes the provider authenticated + authed = ( + data_type == OAuthDataType.TOKENS or (data_type == OAuthDataType.MIXED and "access_token" in data) or None + ) + + # update_provider_credentials will validate provider existence + self.update_provider_credentials(provider_id=provider_id, tenant_id=tenant_id, credentials=data, authed=authed) + + def clear_provider_credentials(self, *, provider_id: str, tenant_id: str) -> None: + """ + Clear all credentials for a provider. + + Args: + provider_id: Provider ID + tenant_id: Tenant ID + """ + # Get provider from current session + provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) + + provider.tools = EMPTY_TOOLS_JSON + provider.encrypted_credentials = EMPTY_CREDENTIALS_JSON + provider.updated_at = datetime.now() + provider.authed = False + + # ========== Private Helper Methods ========== + + def _check_provider_exists(self, tenant_id: str, name: str, server_url_hash: str, server_identifier: str) -> None: + """Check if provider with same attributes already exists.""" + stmt = select(MCPToolProvider).where( + MCPToolProvider.tenant_id == tenant_id, + or_( + MCPToolProvider.name == name, + MCPToolProvider.server_url_hash == server_url_hash, + MCPToolProvider.server_identifier == server_identifier, + ), + ) + existing_provider = self._session.scalar(stmt) + + if existing_provider: + if existing_provider.name == name: + raise ValueError(f"MCP tool {name} already exists") + if existing_provider.server_url_hash == server_url_hash: + raise ValueError("MCP tool with this server URL already exists") + if existing_provider.server_identifier == server_identifier: + raise ValueError(f"MCP tool {server_identifier} already exists") + + def _prepare_icon(self, icon: str, icon_type: str, icon_background: str) -> str: + """Prepare icon data for storage.""" + if icon_type == "emoji": + return json.dumps({"content": icon, "background": icon_background}) + return icon + + def _encrypt_dict_fields(self, data: dict[str, Any], secret_fields: list[str], tenant_id: str) -> dict[str, str]: + """Encrypt specified fields in a dictionary. + + Args: + data: Dictionary containing data to encrypt + secret_fields: List of field names to encrypt + tenant_id: Tenant ID for encryption + + Returns: + JSON string of encrypted data + """ + from core.entities.provider_entities import BasicProviderConfig + from core.tools.utils.encryption import create_provider_encrypter + + # Create config for secret fields + config = [ + BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=field) for field in secret_fields + ] + + encrypter_instance, _ = create_provider_encrypter( + tenant_id=tenant_id, + config=config, + cache=NoOpProviderCredentialCache(), + ) + + encrypted_data = encrypter_instance.encrypt(data) + return encrypted_data + + def _prepare_encrypted_dict(self, headers: dict[str, str], tenant_id: str) -> str: + """Encrypt headers and prepare for storage.""" + # All headers are treated as secret + return json.dumps(self._encrypt_dict_fields(headers, list(headers.keys()), tenant_id)) + + def _prepare_auth_headers(self, provider_entity: MCPProviderEntity) -> dict[str, str]: + """Prepare headers with OAuth token if available.""" + headers = provider_entity.decrypt_headers() + tokens = provider_entity.retrieve_tokens() + if tokens: + headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}" + return headers + + def _retrieve_remote_mcp_tools( + self, + server_url: str, + headers: dict[str, str], + provider_entity: MCPProviderEntity, + ): + """Retrieve tools from remote MCP server.""" + with MCPClientWithAuthRetry( + server_url=server_url, + headers=headers, + timeout=provider_entity.timeout, + sse_read_timeout=provider_entity.sse_read_timeout, + provider_entity=provider_entity, + ) as mcp_client: + return mcp_client.list_tools() + + def execute_auth_actions(self, auth_result: Any) -> dict[str, str]: + """ + Execute the actions returned by the auth function. + + This method processes the AuthResult and performs the necessary database operations. + + Args: + auth_result: The result from the auth function + + Returns: + The response from the auth result + """ + from core.mcp.entities import AuthAction, AuthActionType + + action: AuthAction + for action in auth_result.actions: + if action.provider_id is None or action.tenant_id is None: + continue + + if action.action_type == AuthActionType.SAVE_CLIENT_INFO: + self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.CLIENT_INFO) + elif action.action_type == AuthActionType.SAVE_TOKENS: + self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.TOKENS) + elif action.action_type == AuthActionType.SAVE_CODE_VERIFIER: + self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.CODE_VERIFIER) + + return auth_result.response + + def auth_with_actions( + self, provider_entity: MCPProviderEntity, authorization_code: str | None = None + ) -> dict[str, str]: + """ + Perform authentication and execute all resulting actions. + + This method is used by MCPClientWithAuthRetry for automatic re-authentication. + + Args: + provider_entity: The MCP provider entity + authorization_code: Optional authorization code + + Returns: + Response dictionary from auth result + """ + auth_result = auth(provider_entity, authorization_code) + return self.execute_auth_actions(auth_result) + + def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> ReconnectResult: + """Attempt to reconnect to MCP provider with new server URL.""" + provider_entity = provider.to_entity() + headers = provider_entity.headers try: - with MCPClient( - server_url, - provider_id, - tenant_id, - authed=False, - for_list=True, - headers=headers, - timeout=timeout, - sse_read_timeout=sse_read_timeout, - ) as mcp_client: - tools = mcp_client.list_tools() - return { - "authed": True, - "tools": json.dumps([tool.model_dump() for tool in tools]), - "encrypted_credentials": "{}", - } + tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity) + return ReconnectResult( + authed=True, + tools=json.dumps([tool.model_dump() for tool in tools]), + encrypted_credentials=EMPTY_CREDENTIALS_JSON, + ) except MCPAuthError: - return {"authed": False, "tools": "[]", "encrypted_credentials": "{}"} + return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON) except MCPError as e: raise ValueError(f"Failed to re-connect MCP server: {e}") from e + + def validate_server_url_change( + self, *, tenant_id: str, provider_id: str, new_server_url: str + ) -> ServerUrlValidationResult: + """ + Validate server URL change by attempting to connect to the new server. + This method should be called BEFORE update_provider to perform network operations + outside of the database transaction. + + Returns: + ServerUrlValidationResult: Validation result with connection status and tools if successful + """ + # Handle hidden/unchanged URL + if UNCHANGED_SERVER_URL_PLACEHOLDER in new_server_url: + return ServerUrlValidationResult(needs_validation=False) + + # Validate URL format + if not self._is_valid_url(new_server_url): + raise ValueError("Server URL is not valid.") + + # Always encrypt and hash the URL + encrypted_server_url = encrypter.encrypt_token(tenant_id, new_server_url) + new_server_url_hash = hashlib.sha256(new_server_url.encode()).hexdigest() + + # Get current provider + provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) + + # Check if URL is actually different + if new_server_url_hash == provider.server_url_hash: + # URL hasn't changed, but still return the encrypted data + return ServerUrlValidationResult( + needs_validation=False, encrypted_server_url=encrypted_server_url, server_url_hash=new_server_url_hash + ) + + # Perform validation by attempting to connect + reconnect_result = self._reconnect_provider(server_url=new_server_url, provider=provider) + return ServerUrlValidationResult( + needs_validation=True, + validation_passed=True, + reconnect_result=reconnect_result, + encrypted_server_url=encrypted_server_url, + server_url_hash=new_server_url_hash, + ) + + def _build_tool_provider_response( + self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list + ) -> ToolProviderApiEntity: + """Build API response for tool provider.""" + user = db_provider.load_user() + response = provider_entity.to_api_response( + user_name=user.name if user else None, + ) + response["tools"] = ToolTransformService.mcp_tool_to_user_tool(db_provider, tools) + response["plugin_unique_identifier"] = provider_entity.provider_id + return ToolProviderApiEntity(**response) + + def _handle_integrity_error( + self, error: IntegrityError, name: str, server_url: str, server_identifier: str + ) -> None: + """Handle database integrity errors with user-friendly messages.""" + error_msg = str(error.orig) + if "unique_mcp_provider_name" in error_msg: + raise ValueError(f"MCP tool {name} already exists") + if "unique_mcp_provider_server_url" in error_msg: + raise ValueError(f"MCP tool {server_url} already exists") + if "unique_mcp_provider_server_identifier" in error_msg: + raise ValueError(f"MCP tool {server_identifier} already exists") + raise + + def _is_valid_url(self, url: str) -> bool: + """Validate URL format.""" + if not url: + return False + try: + parsed = urlparse(url) + return all([parsed.scheme, parsed.netloc]) and parsed.scheme in ["http", "https"] + except (ValueError, TypeError): + return False + + def _update_optional_fields(self, mcp_provider: MCPToolProvider, configuration: MCPConfiguration) -> None: + """Update optional configuration fields using setattr for cleaner code.""" + field_mapping = {"timeout": configuration.timeout, "sse_read_timeout": configuration.sse_read_timeout} + + for field, value in field_mapping.items(): + if value is not None: + setattr(mcp_provider, field, value) + + def _process_headers(self, headers: dict[str, str], mcp_provider: MCPToolProvider, tenant_id: str) -> str | None: + """Process headers update, handling empty dict to clear headers.""" + if not headers: + return None + + # Merge with existing headers to preserve masked values + final_headers = self._merge_headers_with_masked(incoming_headers=headers, mcp_provider=mcp_provider) + return self._prepare_encrypted_dict(final_headers, tenant_id) + + def _process_credentials( + self, authentication: MCPAuthentication, mcp_provider: MCPToolProvider, tenant_id: str + ) -> str: + """Process credentials update, handling masked values.""" + # Merge with existing credentials + final_client_id, final_client_secret = self._merge_credentials_with_masked( + authentication.client_id, authentication.client_secret, mcp_provider + ) + + # Build and encrypt + return self._build_and_encrypt_credentials(final_client_id, final_client_secret, tenant_id) + + def _merge_headers_with_masked( + self, incoming_headers: dict[str, str], mcp_provider: MCPToolProvider + ) -> dict[str, str]: + """Merge incoming headers with existing ones, preserving unchanged masked values. + + Args: + incoming_headers: Headers from frontend (may contain masked values) + mcp_provider: The MCP provider instance + + Returns: + Final headers dict with proper values (original for unchanged masked, new for changed) + """ + mcp_provider_entity = mcp_provider.to_entity() + existing_decrypted = mcp_provider_entity.decrypt_headers() + existing_masked = mcp_provider_entity.masked_headers() + + return { + key: (str(existing_decrypted[key]) if key in existing_masked and value == existing_masked[key] else value) + for key, value in incoming_headers.items() + if key in existing_decrypted or value != existing_masked.get(key) + } + + def _merge_credentials_with_masked( + self, + client_id: str, + client_secret: str | None, + mcp_provider: MCPToolProvider, + ) -> tuple[ + str, + str | None, + ]: + """Merge incoming credentials with existing ones, preserving unchanged masked values. + + Args: + client_id: Client ID from frontend (may be masked) + client_secret: Client secret from frontend (may be masked) + mcp_provider: The MCP provider instance + + Returns: + Tuple of (final_client_id, final_client_secret) + """ + mcp_provider_entity = mcp_provider.to_entity() + existing_decrypted = mcp_provider_entity.decrypt_credentials() + existing_masked = mcp_provider_entity.masked_credentials() + + # Check if client_id is masked and unchanged + final_client_id = client_id + if existing_masked.get("client_id") and client_id == existing_masked["client_id"]: + # Use existing decrypted value + final_client_id = existing_decrypted.get("client_id", client_id) + + # Check if client_secret is masked and unchanged + final_client_secret = client_secret + if existing_masked.get("client_secret") and client_secret == existing_masked["client_secret"]: + # Use existing decrypted value + final_client_secret = existing_decrypted.get("client_secret", client_secret) + + return final_client_id, final_client_secret + + def _build_and_encrypt_credentials(self, client_id: str, client_secret: str | None, tenant_id: str) -> str: + """Build credentials and encrypt sensitive fields.""" + # Create a flat structure with all credential data + credentials_data = { + "client_id": client_id, + "client_name": CLIENT_NAME, + "is_dynamic_registration": False, + } + secret_fields = [] + if client_secret is not None: + credentials_data["encrypted_client_secret"] = client_secret + secret_fields = ["encrypted_client_secret"] + client_info = self._encrypt_dict_fields(credentials_data, secret_fields, tenant_id) + return json.dumps({"client_information": client_info}) diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index b7850ea150..6e95513318 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -3,9 +3,11 @@ import logging from collections.abc import Mapping from typing import Any, Union +from pydantic import ValidationError from yarl import URL from configs import dify_config +from core.entities.mcp_provider import MCPConfiguration from core.helper.provider_cache import ToolProviderCredentialsCache from core.mcp.types import Tool as MCPTool from core.plugin.entities.plugin_daemon import PluginDatasourceProviderEntity @@ -232,40 +234,57 @@ class ToolTransformService: ) @staticmethod - def mcp_provider_to_user_provider(db_provider: MCPToolProvider, for_list: bool = False) -> ToolProviderApiEntity: - user = db_provider.load_user() - return ToolProviderApiEntity( - id=db_provider.server_identifier if not for_list else db_provider.id, - author=user.name if user else "Anonymous", - name=db_provider.name, - icon=db_provider.provider_icon, - type=ToolProviderType.MCP, - is_team_authorization=db_provider.authed, - server_url=db_provider.masked_server_url, - tools=ToolTransformService.mcp_tool_to_user_tool( - db_provider, [MCPTool.model_validate(tool) for tool in json.loads(db_provider.tools)] - ), - updated_at=int(db_provider.updated_at.timestamp()), - label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name), - description=I18nObject(en_US="", zh_Hans=""), - server_identifier=db_provider.server_identifier, - timeout=db_provider.timeout, - sse_read_timeout=db_provider.sse_read_timeout, - masked_headers=db_provider.masked_headers, - original_headers=db_provider.decrypted_headers, - ) + def mcp_provider_to_user_provider( + db_provider: MCPToolProvider, + for_list: bool = False, + user_name: str | None = None, + include_sensitive: bool = True, + ) -> ToolProviderApiEntity: + # Use provided user_name to avoid N+1 query, fallback to load_user() if not provided + if user_name is None: + user = db_provider.load_user() + user_name = user.name if user else None + + # Convert to entity and use its API response method + provider_entity = db_provider.to_entity() + + response = provider_entity.to_api_response(user_name=user_name, include_sensitive=include_sensitive) + try: + mcp_tools = [MCPTool(**tool) for tool in json.loads(db_provider.tools)] + except (ValidationError, json.JSONDecodeError): + mcp_tools = [] + # Add additional fields specific to the transform + response["id"] = db_provider.server_identifier if not for_list else db_provider.id + response["tools"] = ToolTransformService.mcp_tool_to_user_tool(db_provider, mcp_tools, user_name=user_name) + response["server_identifier"] = db_provider.server_identifier + + # Convert configuration dict to MCPConfiguration object + if "configuration" in response and isinstance(response["configuration"], dict): + response["configuration"] = MCPConfiguration( + timeout=float(response["configuration"]["timeout"]), + sse_read_timeout=float(response["configuration"]["sse_read_timeout"]), + ) + + return ToolProviderApiEntity(**response) @staticmethod - def mcp_tool_to_user_tool(mcp_provider: MCPToolProvider, tools: list[MCPTool]) -> list[ToolApiEntity]: - user = mcp_provider.load_user() + def mcp_tool_to_user_tool( + mcp_provider: MCPToolProvider, tools: list[MCPTool], user_name: str | None = None + ) -> list[ToolApiEntity]: + # Use provided user_name to avoid N+1 query, fallback to load_user() if not provided + if user_name is None: + user = mcp_provider.load_user() + user_name = user.name if user else "Anonymous" + return [ ToolApiEntity( - author=user.name if user else "Anonymous", + author=user_name or "Anonymous", name=tool.name, label=I18nObject(en_US=tool.name, zh_Hans=tool.name), description=I18nObject(en_US=tool.description or "", zh_Hans=tool.description or ""), parameters=ToolTransformService.convert_mcp_schema_to_parameter(tool.inputSchema), labels=[], + output_schema=tool.outputSchema or {}, ) for tool in tools ] @@ -412,7 +431,7 @@ class ToolTransformService: ) @staticmethod - def convert_mcp_schema_to_parameter(schema: dict) -> list["ToolParameter"]: + def convert_mcp_schema_to_parameter(schema: dict[str, Any]) -> list["ToolParameter"]: """ Convert MCP JSON schema to tool parameters @@ -421,7 +440,7 @@ class ToolTransformService: """ def create_parameter( - name: str, description: str, param_type: str, required: bool, input_schema: dict | None = None + name: str, description: str, param_type: str, required: bool, input_schema: dict[str, Any] | None = None ) -> ToolParameter: """Create a ToolParameter instance with given attributes""" input_schema_dict: dict[str, Any] = {"input_schema": input_schema} if input_schema else {} @@ -436,7 +455,9 @@ class ToolTransformService: **input_schema_dict, ) - def process_properties(props: dict, required: list, prefix: str = "") -> list[ToolParameter]: + def process_properties( + props: dict[str, dict[str, Any]], required: list[str], prefix: str = "" + ) -> list[ToolParameter]: """Process properties recursively""" TYPE_MAPPING = {"integer": "number", "float": "number"} COMPLEX_TYPES = ["array", "object"] diff --git a/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py index 71d55c3ade..8c190762cf 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py @@ -20,12 +20,21 @@ class TestMCPToolManageService: patch("services.tools.mcp_tools_manage_service.ToolTransformService") as mock_tool_transform_service, ): # Setup default mock returns + from core.tools.entities.api_entities import ToolProviderApiEntity + from core.tools.entities.common_entities import I18nObject + mock_encrypter.encrypt_token.return_value = "encrypted_server_url" - mock_tool_transform_service.mcp_provider_to_user_provider.return_value = { - "id": "test_id", - "name": "test_name", - "type": ToolProviderType.MCP, - } + mock_tool_transform_service.mcp_provider_to_user_provider.return_value = ToolProviderApiEntity( + id="test_id", + author="test_author", + name="test_name", + type=ToolProviderType.MCP, + description=I18nObject(en_US="Test Description", zh_Hans="测试描述"), + icon={"type": "emoji", "content": "🤖"}, + label=I18nObject(en_US="Test Label", zh_Hans="测试标签"), + labels=[], + tools=[], + ) yield { "encrypter": mock_encrypter, @@ -104,9 +113,9 @@ class TestMCPToolManageService: mcp_provider = MCPToolProvider( tenant_id=tenant_id, name=fake.company(), - server_identifier=fake.uuid4(), + server_identifier=str(fake.uuid4()), server_url="encrypted_server_url", - server_url_hash=fake.sha256(), + server_url_hash=str(fake.sha256()), user_id=user_id, authed=False, tools="[]", @@ -144,7 +153,10 @@ class TestMCPToolManageService: ) # Act: Execute the method under test - result = MCPToolManageService.get_mcp_provider_by_provider_id(mcp_provider.id, tenant.id) + from extensions.ext_database import db + + service = MCPToolManageService(db.session()) + result = service.get_provider(provider_id=mcp_provider.id, tenant_id=tenant.id) # Assert: Verify the expected outcomes assert result is not None @@ -154,8 +166,6 @@ class TestMCPToolManageService: assert result.user_id == account.id # Verify database state - from extensions.ext_database import db - db.session.refresh(result) assert result.id is not None assert result.server_identifier == mcp_provider.server_identifier @@ -177,11 +187,14 @@ class TestMCPToolManageService: db_session_with_containers, mock_external_service_dependencies ) - non_existent_id = fake.uuid4() + non_existent_id = str(fake.uuid4()) # Act & Assert: Verify proper error handling + from extensions.ext_database import db + + service = MCPToolManageService(db.session()) with pytest.raises(ValueError, match="MCP tool not found"): - MCPToolManageService.get_mcp_provider_by_provider_id(non_existent_id, tenant.id) + service.get_provider(provider_id=non_existent_id, tenant_id=tenant.id) def test_get_mcp_provider_by_provider_id_tenant_isolation( self, db_session_with_containers, mock_external_service_dependencies @@ -210,8 +223,11 @@ class TestMCPToolManageService: ) # Act & Assert: Verify tenant isolation + from extensions.ext_database import db + + service = MCPToolManageService(db.session()) with pytest.raises(ValueError, match="MCP tool not found"): - MCPToolManageService.get_mcp_provider_by_provider_id(mcp_provider1.id, tenant2.id) + service.get_provider(provider_id=mcp_provider1.id, tenant_id=tenant2.id) def test_get_mcp_provider_by_server_identifier_success( self, db_session_with_containers, mock_external_service_dependencies @@ -235,7 +251,10 @@ class TestMCPToolManageService: ) # Act: Execute the method under test - result = MCPToolManageService.get_mcp_provider_by_server_identifier(mcp_provider.server_identifier, tenant.id) + from extensions.ext_database import db + + service = MCPToolManageService(db.session()) + result = service.get_provider(server_identifier=mcp_provider.server_identifier, tenant_id=tenant.id) # Assert: Verify the expected outcomes assert result is not None @@ -245,8 +264,6 @@ class TestMCPToolManageService: assert result.user_id == account.id # Verify database state - from extensions.ext_database import db - db.session.refresh(result) assert result.id is not None assert result.name == mcp_provider.name @@ -268,11 +285,14 @@ class TestMCPToolManageService: db_session_with_containers, mock_external_service_dependencies ) - non_existent_identifier = fake.uuid4() + non_existent_identifier = str(fake.uuid4()) # Act & Assert: Verify proper error handling + from extensions.ext_database import db + + service = MCPToolManageService(db.session()) with pytest.raises(ValueError, match="MCP tool not found"): - MCPToolManageService.get_mcp_provider_by_server_identifier(non_existent_identifier, tenant.id) + service.get_provider(server_identifier=non_existent_identifier, tenant_id=tenant.id) def test_get_mcp_provider_by_server_identifier_tenant_isolation( self, db_session_with_containers, mock_external_service_dependencies @@ -301,8 +321,11 @@ class TestMCPToolManageService: ) # Act & Assert: Verify tenant isolation + from extensions.ext_database import db + + service = MCPToolManageService(db.session()) with pytest.raises(ValueError, match="MCP tool not found"): - MCPToolManageService.get_mcp_provider_by_server_identifier(mcp_provider1.server_identifier, tenant2.id) + service.get_provider(server_identifier=mcp_provider1.server_identifier, tenant_id=tenant2.id) def test_create_mcp_provider_success(self, db_session_with_containers, mock_external_service_dependencies): """ @@ -322,15 +345,30 @@ class TestMCPToolManageService: ) # Setup mocks for provider creation + from core.tools.entities.api_entities import ToolProviderApiEntity + from core.tools.entities.common_entities import I18nObject + mock_external_service_dependencies["encrypter"].encrypt_token.return_value = "encrypted_server_url" - mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.return_value = { - "id": "new_provider_id", - "name": "Test MCP Provider", - "type": ToolProviderType.MCP, - } + mock_external_service_dependencies[ + "tool_transform_service" + ].mcp_provider_to_user_provider.return_value = ToolProviderApiEntity( + id="new_provider_id", + author=account.name, + name="Test MCP Provider", + type=ToolProviderType.MCP, + description=I18nObject(en_US="Test MCP Provider Description", zh_Hans="测试MCP提供者描述"), + icon={"type": "emoji", "content": "🤖"}, + label=I18nObject(en_US="Test MCP Provider", zh_Hans="测试MCP提供者"), + labels=[], + tools=[], + ) # Act: Execute the method under test - result = MCPToolManageService.create_mcp_provider( + from core.entities.mcp_provider import MCPConfiguration + from extensions.ext_database import db + + service = MCPToolManageService(db.session()) + result = service.create_provider( tenant_id=tenant.id, name="Test MCP Provider", server_url="https://example.com/mcp", @@ -339,14 +377,16 @@ class TestMCPToolManageService: icon_type="emoji", icon_background="#FF6B6B", server_identifier="test_identifier_123", - timeout=30.0, - sse_read_timeout=300.0, + configuration=MCPConfiguration( + timeout=30.0, + sse_read_timeout=300.0, + ), ) # Assert: Verify the expected outcomes assert result is not None - assert result["name"] == "Test MCP Provider" - assert result["type"] == ToolProviderType.MCP + assert result.name == "Test MCP Provider" + assert result.type == ToolProviderType.MCP # Verify database state from extensions.ext_database import db @@ -386,7 +426,11 @@ class TestMCPToolManageService: ) # Create first provider - MCPToolManageService.create_mcp_provider( + from core.entities.mcp_provider import MCPConfiguration + from extensions.ext_database import db + + service = MCPToolManageService(db.session()) + service.create_provider( tenant_id=tenant.id, name="Test MCP Provider", server_url="https://example1.com/mcp", @@ -395,13 +439,15 @@ class TestMCPToolManageService: icon_type="emoji", icon_background="#FF6B6B", server_identifier="test_identifier_1", - timeout=30.0, - sse_read_timeout=300.0, + configuration=MCPConfiguration( + timeout=30.0, + sse_read_timeout=300.0, + ), ) # Act & Assert: Verify proper error handling for duplicate name with pytest.raises(ValueError, match="MCP tool Test MCP Provider already exists"): - MCPToolManageService.create_mcp_provider( + service.create_provider( tenant_id=tenant.id, name="Test MCP Provider", # Duplicate name server_url="https://example2.com/mcp", @@ -410,8 +456,10 @@ class TestMCPToolManageService: icon_type="emoji", icon_background="#4ECDC4", server_identifier="test_identifier_2", - timeout=45.0, - sse_read_timeout=400.0, + configuration=MCPConfiguration( + timeout=45.0, + sse_read_timeout=400.0, + ), ) def test_create_mcp_provider_duplicate_server_url( @@ -432,7 +480,11 @@ class TestMCPToolManageService: ) # Create first provider - MCPToolManageService.create_mcp_provider( + from core.entities.mcp_provider import MCPConfiguration + from extensions.ext_database import db + + service = MCPToolManageService(db.session()) + service.create_provider( tenant_id=tenant.id, name="Test MCP Provider 1", server_url="https://example.com/mcp", @@ -441,13 +493,15 @@ class TestMCPToolManageService: icon_type="emoji", icon_background="#FF6B6B", server_identifier="test_identifier_1", - timeout=30.0, - sse_read_timeout=300.0, + configuration=MCPConfiguration( + timeout=30.0, + sse_read_timeout=300.0, + ), ) # Act & Assert: Verify proper error handling for duplicate server URL - with pytest.raises(ValueError, match="MCP tool https://example.com/mcp already exists"): - MCPToolManageService.create_mcp_provider( + with pytest.raises(ValueError, match="MCP tool with this server URL already exists"): + service.create_provider( tenant_id=tenant.id, name="Test MCP Provider 2", server_url="https://example.com/mcp", # Duplicate URL @@ -456,8 +510,10 @@ class TestMCPToolManageService: icon_type="emoji", icon_background="#4ECDC4", server_identifier="test_identifier_2", - timeout=45.0, - sse_read_timeout=400.0, + configuration=MCPConfiguration( + timeout=45.0, + sse_read_timeout=400.0, + ), ) def test_create_mcp_provider_duplicate_server_identifier( @@ -478,7 +534,11 @@ class TestMCPToolManageService: ) # Create first provider - MCPToolManageService.create_mcp_provider( + from core.entities.mcp_provider import MCPConfiguration + from extensions.ext_database import db + + service = MCPToolManageService(db.session()) + service.create_provider( tenant_id=tenant.id, name="Test MCP Provider 1", server_url="https://example1.com/mcp", @@ -487,13 +547,15 @@ class TestMCPToolManageService: icon_type="emoji", icon_background="#FF6B6B", server_identifier="test_identifier_123", - timeout=30.0, - sse_read_timeout=300.0, + configuration=MCPConfiguration( + timeout=30.0, + sse_read_timeout=300.0, + ), ) # Act & Assert: Verify proper error handling for duplicate server identifier with pytest.raises(ValueError, match="MCP tool test_identifier_123 already exists"): - MCPToolManageService.create_mcp_provider( + service.create_provider( tenant_id=tenant.id, name="Test MCP Provider 2", server_url="https://example2.com/mcp", @@ -502,8 +564,10 @@ class TestMCPToolManageService: icon_type="emoji", icon_background="#4ECDC4", server_identifier="test_identifier_123", # Duplicate identifier - timeout=45.0, - sse_read_timeout=400.0, + configuration=MCPConfiguration( + timeout=45.0, + sse_read_timeout=400.0, + ), ) def test_retrieve_mcp_tools_success(self, db_session_with_containers, mock_external_service_dependencies): @@ -543,23 +607,59 @@ class TestMCPToolManageService: db.session.commit() # Setup mock for transformation service + from core.tools.entities.api_entities import ToolProviderApiEntity + from core.tools.entities.common_entities import I18nObject + mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.side_effect = [ - {"id": provider1.id, "name": provider1.name, "type": ToolProviderType.MCP}, - {"id": provider2.id, "name": provider2.name, "type": ToolProviderType.MCP}, - {"id": provider3.id, "name": provider3.name, "type": ToolProviderType.MCP}, + ToolProviderApiEntity( + id=provider1.id, + author=account.name, + name=provider1.name, + type=ToolProviderType.MCP, + description=I18nObject(en_US="Alpha Provider Description", zh_Hans="Alpha提供者描述"), + icon={"type": "emoji", "content": "🅰️"}, + label=I18nObject(en_US=provider1.name, zh_Hans=provider1.name), + labels=[], + tools=[], + ), + ToolProviderApiEntity( + id=provider2.id, + author=account.name, + name=provider2.name, + type=ToolProviderType.MCP, + description=I18nObject(en_US="Beta Provider Description", zh_Hans="Beta提供者描述"), + icon={"type": "emoji", "content": "🅱️"}, + label=I18nObject(en_US=provider2.name, zh_Hans=provider2.name), + labels=[], + tools=[], + ), + ToolProviderApiEntity( + id=provider3.id, + author=account.name, + name=provider3.name, + type=ToolProviderType.MCP, + description=I18nObject(en_US="Gamma Provider Description", zh_Hans="Gamma提供者描述"), + icon={"type": "emoji", "content": "Γ"}, + label=I18nObject(en_US=provider3.name, zh_Hans=provider3.name), + labels=[], + tools=[], + ), ] # Act: Execute the method under test - result = MCPToolManageService.retrieve_mcp_tools(tenant.id, for_list=True) + from extensions.ext_database import db + + service = MCPToolManageService(db.session()) + result = service.list_providers(tenant_id=tenant.id, for_list=True) # Assert: Verify the expected outcomes assert result is not None assert len(result) == 3 # Verify correct ordering by name - assert result[0]["name"] == "Alpha Provider" - assert result[1]["name"] == "Beta Provider" - assert result[2]["name"] == "Gamma Provider" + assert result[0].name == "Alpha Provider" + assert result[1].name == "Beta Provider" + assert result[2].name == "Gamma Provider" # Verify mock interactions assert ( @@ -584,7 +684,10 @@ class TestMCPToolManageService: # No MCP providers created for this tenant # Act: Execute the method under test - result = MCPToolManageService.retrieve_mcp_tools(tenant.id, for_list=False) + from extensions.ext_database import db + + service = MCPToolManageService(db.session()) + result = service.list_providers(tenant_id=tenant.id, for_list=False) # Assert: Verify the expected outcomes assert result is not None @@ -624,20 +727,46 @@ class TestMCPToolManageService: ) # Setup mock for transformation service + from core.tools.entities.api_entities import ToolProviderApiEntity + from core.tools.entities.common_entities import I18nObject + mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.side_effect = [ - {"id": provider1.id, "name": provider1.name, "type": ToolProviderType.MCP}, - {"id": provider2.id, "name": provider2.name, "type": ToolProviderType.MCP}, + ToolProviderApiEntity( + id=provider1.id, + author=account1.name, + name=provider1.name, + type=ToolProviderType.MCP, + description=I18nObject(en_US="Provider 1 Description", zh_Hans="提供者1描述"), + icon={"type": "emoji", "content": "1️⃣"}, + label=I18nObject(en_US=provider1.name, zh_Hans=provider1.name), + labels=[], + tools=[], + ), + ToolProviderApiEntity( + id=provider2.id, + author=account2.name, + name=provider2.name, + type=ToolProviderType.MCP, + description=I18nObject(en_US="Provider 2 Description", zh_Hans="提供者2描述"), + icon={"type": "emoji", "content": "2️⃣"}, + label=I18nObject(en_US=provider2.name, zh_Hans=provider2.name), + labels=[], + tools=[], + ), ] # Act: Execute the method under test for both tenants - result1 = MCPToolManageService.retrieve_mcp_tools(tenant1.id, for_list=True) - result2 = MCPToolManageService.retrieve_mcp_tools(tenant2.id, for_list=True) + from extensions.ext_database import db + + service = MCPToolManageService(db.session()) + result1 = service.list_providers(tenant_id=tenant1.id, for_list=True) + result2 = service.list_providers(tenant_id=tenant2.id, for_list=True) # Assert: Verify tenant isolation assert len(result1) == 1 assert len(result2) == 1 - assert result1[0]["id"] == provider1.id - assert result2[0]["id"] == provider2.id + assert result1[0].id == provider1.id + assert result2[0].id == provider2.id def test_list_mcp_tool_from_remote_server_success( self, db_session_with_containers, mock_external_service_dependencies @@ -661,17 +790,20 @@ class TestMCPToolManageService: mcp_provider = self._create_test_mcp_provider( db_session_with_containers, mock_external_service_dependencies, tenant.id, account.id ) - mcp_provider.server_url = "encrypted_server_url" - mcp_provider.authed = False + # Use a valid base64 encoded string to avoid decryption errors + import base64 + + mcp_provider.server_url = base64.b64encode(b"encrypted_server_url").decode() + mcp_provider.authed = True # Provider must be authenticated to list tools mcp_provider.tools = "[]" from extensions.ext_database import db db.session.commit() - # Mock the decrypted_server_url property to avoid encryption issues - with patch("models.tools.encrypter") as mock_encrypter: - mock_encrypter.decrypt_token.return_value = "https://example.com/mcp" + # Mock the decryption process at the rsa level to avoid key file issues + with patch("libs.rsa.decrypt") as mock_decrypt: + mock_decrypt.return_value = "https://example.com/mcp" # Mock MCPClient and its context manager mock_tools = [ @@ -683,13 +815,16 @@ class TestMCPToolManageService: )(), ] - with patch("services.tools.mcp_tools_manage_service.MCPClient") as mock_mcp_client: + with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client: # Setup mock client mock_client_instance = mock_mcp_client.return_value.__enter__.return_value mock_client_instance.list_tools.return_value = mock_tools # Act: Execute the method under test - result = MCPToolManageService.list_mcp_tool_from_remote_server(tenant.id, mcp_provider.id) + from extensions.ext_database import db + + service = MCPToolManageService(db.session()) + result = service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id) # Assert: Verify the expected outcomes assert result is not None @@ -705,16 +840,8 @@ class TestMCPToolManageService: assert mcp_provider.updated_at is not None # Verify mock interactions - mock_mcp_client.assert_called_once_with( - "https://example.com/mcp", - mcp_provider.id, - tenant.id, - authed=False, - for_list=True, - headers={}, - timeout=30.0, - sse_read_timeout=300.0, - ) + # MCPClientWithAuthRetry is called with different parameters + mock_mcp_client.assert_called_once() def test_list_mcp_tool_from_remote_server_auth_error( self, db_session_with_containers, mock_external_service_dependencies @@ -737,7 +864,10 @@ class TestMCPToolManageService: mcp_provider = self._create_test_mcp_provider( db_session_with_containers, mock_external_service_dependencies, tenant.id, account.id ) - mcp_provider.server_url = "encrypted_server_url" + # Use a valid base64 encoded string to avoid decryption errors + import base64 + + mcp_provider.server_url = base64.b64encode(b"encrypted_server_url").decode() mcp_provider.authed = False mcp_provider.tools = "[]" @@ -745,20 +875,23 @@ class TestMCPToolManageService: db.session.commit() - # Mock the decrypted_server_url property to avoid encryption issues - with patch("models.tools.encrypter") as mock_encrypter: - mock_encrypter.decrypt_token.return_value = "https://example.com/mcp" + # Mock the decryption process at the rsa level to avoid key file issues + with patch("libs.rsa.decrypt") as mock_decrypt: + mock_decrypt.return_value = "https://example.com/mcp" # Mock MCPClient to raise authentication error - with patch("services.tools.mcp_tools_manage_service.MCPClient") as mock_mcp_client: + with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client: from core.mcp.error import MCPAuthError mock_client_instance = mock_mcp_client.return_value.__enter__.return_value mock_client_instance.list_tools.side_effect = MCPAuthError("Authentication required") # Act & Assert: Verify proper error handling + from extensions.ext_database import db + + service = MCPToolManageService(db.session()) with pytest.raises(ValueError, match="Please auth the tool first"): - MCPToolManageService.list_mcp_tool_from_remote_server(tenant.id, mcp_provider.id) + service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id) # Verify database state was not changed db.session.refresh(mcp_provider) @@ -786,32 +919,38 @@ class TestMCPToolManageService: mcp_provider = self._create_test_mcp_provider( db_session_with_containers, mock_external_service_dependencies, tenant.id, account.id ) - mcp_provider.server_url = "encrypted_server_url" - mcp_provider.authed = False + # Use a valid base64 encoded string to avoid decryption errors + import base64 + + mcp_provider.server_url = base64.b64encode(b"encrypted_server_url").decode() + mcp_provider.authed = True # Provider must be authenticated to test connection errors mcp_provider.tools = "[]" from extensions.ext_database import db db.session.commit() - # Mock the decrypted_server_url property to avoid encryption issues - with patch("models.tools.encrypter") as mock_encrypter: - mock_encrypter.decrypt_token.return_value = "https://example.com/mcp" + # Mock the decryption process at the rsa level to avoid key file issues + with patch("libs.rsa.decrypt") as mock_decrypt: + mock_decrypt.return_value = "https://example.com/mcp" # Mock MCPClient to raise connection error - with patch("services.tools.mcp_tools_manage_service.MCPClient") as mock_mcp_client: + with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client: from core.mcp.error import MCPError mock_client_instance = mock_mcp_client.return_value.__enter__.return_value mock_client_instance.list_tools.side_effect = MCPError("Connection failed") # Act & Assert: Verify proper error handling + from extensions.ext_database import db + + service = MCPToolManageService(db.session()) with pytest.raises(ValueError, match="Failed to connect to MCP server: Connection failed"): - MCPToolManageService.list_mcp_tool_from_remote_server(tenant.id, mcp_provider.id) + service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id) # Verify database state was not changed db.session.refresh(mcp_provider) - assert mcp_provider.authed is False + assert mcp_provider.authed is True # Provider remains authenticated assert mcp_provider.tools == "[]" def test_delete_mcp_tool_success(self, db_session_with_containers, mock_external_service_dependencies): @@ -840,7 +979,8 @@ class TestMCPToolManageService: assert db.session.query(MCPToolProvider).filter_by(id=mcp_provider.id).first() is not None # Act: Execute the method under test - MCPToolManageService.delete_mcp_tool(tenant.id, mcp_provider.id) + service = MCPToolManageService(db.session()) + service.delete_provider(tenant_id=tenant.id, provider_id=mcp_provider.id) # Assert: Verify the expected outcomes # Provider should be deleted from database @@ -862,11 +1002,14 @@ class TestMCPToolManageService: db_session_with_containers, mock_external_service_dependencies ) - non_existent_id = fake.uuid4() + non_existent_id = str(fake.uuid4()) # Act & Assert: Verify proper error handling + from extensions.ext_database import db + + service = MCPToolManageService(db.session()) with pytest.raises(ValueError, match="MCP tool not found"): - MCPToolManageService.delete_mcp_tool(tenant.id, non_existent_id) + service.delete_provider(tenant_id=tenant.id, provider_id=non_existent_id) def test_delete_mcp_tool_tenant_isolation(self, db_session_with_containers, mock_external_service_dependencies): """ @@ -893,8 +1036,11 @@ class TestMCPToolManageService: ) # Act & Assert: Verify tenant isolation + from extensions.ext_database import db + + service = MCPToolManageService(db.session()) with pytest.raises(ValueError, match="MCP tool not found"): - MCPToolManageService.delete_mcp_tool(tenant2.id, mcp_provider1.id) + service.delete_provider(tenant_id=tenant2.id, provider_id=mcp_provider1.id) # Verify provider still exists in tenant1 from extensions.ext_database import db @@ -929,7 +1075,10 @@ class TestMCPToolManageService: db.session.commit() # Act: Execute the method under test - MCPToolManageService.update_mcp_provider( + from core.entities.mcp_provider import MCPConfiguration + + service = MCPToolManageService(db.session()) + service.update_provider( tenant_id=tenant.id, provider_id=mcp_provider.id, name="Updated MCP Provider", @@ -938,8 +1087,10 @@ class TestMCPToolManageService: icon_type="emoji", icon_background="#4ECDC4", server_identifier="updated_identifier_123", - timeout=45.0, - sse_read_timeout=400.0, + configuration=MCPConfiguration( + timeout=45.0, + sse_read_timeout=400.0, + ), ) # Assert: Verify the expected outcomes @@ -953,70 +1104,10 @@ class TestMCPToolManageService: # Verify icon was updated import json - icon_data = json.loads(mcp_provider.icon) + icon_data = json.loads(mcp_provider.icon or "{}") assert icon_data["content"] == "🚀" assert icon_data["background"] == "#4ECDC4" - def test_update_mcp_provider_with_server_url_change( - self, db_session_with_containers, mock_external_service_dependencies - ): - """ - Test successful update of MCP provider with server URL change. - - This test verifies: - - Proper handling of server URL changes - - Correct reconnection logic - - Database state updates - - External service integration - """ - # Arrange: Create test data - fake = Faker() - account, tenant = self._create_test_account_and_tenant( - db_session_with_containers, mock_external_service_dependencies - ) - - # Create MCP provider - mcp_provider = self._create_test_mcp_provider( - db_session_with_containers, mock_external_service_dependencies, tenant.id, account.id - ) - - from extensions.ext_database import db - - db.session.commit() - - # Mock the reconnection method - with patch.object(MCPToolManageService, "_re_connect_mcp_provider") as mock_reconnect: - mock_reconnect.return_value = { - "authed": True, - "tools": '[{"name": "test_tool"}]', - "encrypted_credentials": "{}", - } - - # Act: Execute the method under test - MCPToolManageService.update_mcp_provider( - tenant_id=tenant.id, - provider_id=mcp_provider.id, - name="Updated MCP Provider", - server_url="https://new-example.com/mcp", - icon="🚀", - icon_type="emoji", - icon_background="#4ECDC4", - server_identifier="updated_identifier_123", - timeout=45.0, - sse_read_timeout=400.0, - ) - - # Assert: Verify the expected outcomes - db.session.refresh(mcp_provider) - assert mcp_provider.name == "Updated MCP Provider" - assert mcp_provider.server_identifier == "updated_identifier_123" - assert mcp_provider.timeout == 45.0 - assert mcp_provider.sse_read_timeout == 400.0 - assert mcp_provider.updated_at is not None - - # Verify reconnection was called - mock_reconnect.assert_called_once_with("https://new-example.com/mcp", mcp_provider.id, tenant.id) - def test_update_mcp_provider_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies): """ Test error handling when updating MCP provider with duplicate name. @@ -1048,8 +1139,12 @@ class TestMCPToolManageService: db.session.commit() # Act & Assert: Verify proper error handling for duplicate name + from core.entities.mcp_provider import MCPConfiguration + from extensions.ext_database import db + + service = MCPToolManageService(db.session()) with pytest.raises(ValueError, match="MCP tool First Provider already exists"): - MCPToolManageService.update_mcp_provider( + service.update_provider( tenant_id=tenant.id, provider_id=provider2.id, name="First Provider", # Duplicate name @@ -1058,8 +1153,10 @@ class TestMCPToolManageService: icon_type="emoji", icon_background="#4ECDC4", server_identifier="unique_identifier", - timeout=45.0, - sse_read_timeout=400.0, + configuration=MCPConfiguration( + timeout=45.0, + sse_read_timeout=400.0, + ), ) def test_update_mcp_provider_credentials_success( @@ -1094,19 +1191,25 @@ class TestMCPToolManageService: # Mock the provider controller and encryption with ( - patch("services.tools.mcp_tools_manage_service.MCPToolProviderController") as mock_controller, - patch("services.tools.mcp_tools_manage_service.ProviderConfigEncrypter") as mock_encrypter, + patch("core.tools.mcp_tool.provider.MCPToolProviderController") as mock_controller, + patch("core.tools.utils.encryption.ProviderConfigEncrypter") as mock_encrypter, ): # Setup mocks - mock_controller_instance = mock_controller._from_db.return_value + mock_controller_instance = mock_controller.from_db.return_value mock_controller_instance.get_credentials_schema.return_value = [] mock_encrypter_instance = mock_encrypter.return_value mock_encrypter_instance.encrypt.return_value = {"new_key": "encrypted_value"} # Act: Execute the method under test - MCPToolManageService.update_mcp_provider_credentials( - mcp_provider=mcp_provider, credentials={"new_key": "new_value"}, authed=True + from extensions.ext_database import db + + service = MCPToolManageService(db.session()) + service.update_provider_credentials( + provider_id=mcp_provider.id, + tenant_id=tenant.id, + credentials={"new_key": "new_value"}, + authed=True, ) # Assert: Verify the expected outcomes @@ -1117,7 +1220,7 @@ class TestMCPToolManageService: # Verify credentials were encrypted and merged import json - credentials = json.loads(mcp_provider.encrypted_credentials) + credentials = json.loads(mcp_provider.encrypted_credentials or "{}") assert "existing_key" in credentials assert "new_key" in credentials @@ -1152,19 +1255,25 @@ class TestMCPToolManageService: # Mock the provider controller and encryption with ( - patch("services.tools.mcp_tools_manage_service.MCPToolProviderController") as mock_controller, - patch("services.tools.mcp_tools_manage_service.ProviderConfigEncrypter") as mock_encrypter, + patch("core.tools.mcp_tool.provider.MCPToolProviderController") as mock_controller, + patch("core.tools.utils.encryption.ProviderConfigEncrypter") as mock_encrypter, ): # Setup mocks - mock_controller_instance = mock_controller._from_db.return_value + mock_controller_instance = mock_controller.from_db.return_value mock_controller_instance.get_credentials_schema.return_value = [] mock_encrypter_instance = mock_encrypter.return_value mock_encrypter_instance.encrypt.return_value = {"new_key": "encrypted_value"} # Act: Execute the method under test - MCPToolManageService.update_mcp_provider_credentials( - mcp_provider=mcp_provider, credentials={"new_key": "new_value"}, authed=False + from extensions.ext_database import db + + service = MCPToolManageService(db.session()) + service.update_provider_credentials( + provider_id=mcp_provider.id, + tenant_id=tenant.id, + credentials={"new_key": "new_value"}, + authed=False, ) # Assert: Verify the expected outcomes @@ -1199,41 +1308,37 @@ class TestMCPToolManageService: type("MockTool", (), {"model_dump": lambda self: {"name": "test_tool_2", "description": "Test tool 2"}})(), ] - with patch("services.tools.mcp_tools_manage_service.MCPClient") as mock_mcp_client: + with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client: # Setup mock client mock_client_instance = mock_mcp_client.return_value.__enter__.return_value mock_client_instance.list_tools.return_value = mock_tools # Act: Execute the method under test - result = MCPToolManageService._re_connect_mcp_provider( - "https://example.com/mcp", mcp_provider.id, tenant.id + from extensions.ext_database import db + + service = MCPToolManageService(db.session()) + result = service._reconnect_provider( + server_url="https://example.com/mcp", + provider=mcp_provider, ) # Assert: Verify the expected outcomes assert result is not None - assert result["authed"] is True - assert result["tools"] is not None - assert result["encrypted_credentials"] == "{}" + assert result.authed is True + assert result.tools is not None + assert result.encrypted_credentials == "{}" # Verify tools were properly serialized import json - tools_data = json.loads(result["tools"]) + tools_data = json.loads(result.tools) assert len(tools_data) == 2 assert tools_data[0]["name"] == "test_tool_1" assert tools_data[1]["name"] == "test_tool_2" # Verify mock interactions - mock_mcp_client.assert_called_once_with( - "https://example.com/mcp", - mcp_provider.id, - tenant.id, - authed=False, - for_list=True, - headers={}, - timeout=30.0, - sse_read_timeout=300.0, - ) + provider_entity = mcp_provider.to_entity() + mock_mcp_client.assert_called_once() def test_re_connect_mcp_provider_auth_error(self, db_session_with_containers, mock_external_service_dependencies): """ @@ -1256,22 +1361,26 @@ class TestMCPToolManageService: ) # Mock MCPClient to raise authentication error - with patch("services.tools.mcp_tools_manage_service.MCPClient") as mock_mcp_client: + with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client: from core.mcp.error import MCPAuthError mock_client_instance = mock_mcp_client.return_value.__enter__.return_value mock_client_instance.list_tools.side_effect = MCPAuthError("Authentication required") # Act: Execute the method under test - result = MCPToolManageService._re_connect_mcp_provider( - "https://example.com/mcp", mcp_provider.id, tenant.id + from extensions.ext_database import db + + service = MCPToolManageService(db.session()) + result = service._reconnect_provider( + server_url="https://example.com/mcp", + provider=mcp_provider, ) # Assert: Verify the expected outcomes assert result is not None - assert result["authed"] is False - assert result["tools"] == "[]" - assert result["encrypted_credentials"] == "{}" + assert result.authed is False + assert result.tools == "[]" + assert result.encrypted_credentials == "{}" def test_re_connect_mcp_provider_connection_error( self, db_session_with_containers, mock_external_service_dependencies @@ -1295,12 +1404,18 @@ class TestMCPToolManageService: ) # Mock MCPClient to raise connection error - with patch("services.tools.mcp_tools_manage_service.MCPClient") as mock_mcp_client: + with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client: from core.mcp.error import MCPError mock_client_instance = mock_mcp_client.return_value.__enter__.return_value mock_client_instance.list_tools.side_effect = MCPError("Connection failed") # Act & Assert: Verify proper error handling + from extensions.ext_database import db + + service = MCPToolManageService(db.session()) with pytest.raises(ValueError, match="Failed to re-connect MCP server: Connection failed"): - MCPToolManageService._re_connect_mcp_provider("https://example.com/mcp", mcp_provider.id, tenant.id) + service._reconnect_provider( + server_url="https://example.com/mcp", + provider=mcp_provider, + ) diff --git a/api/tests/unit_tests/core/mcp/__init__.py b/api/tests/unit_tests/core/mcp/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/mcp/auth/__init__.py b/api/tests/unit_tests/core/mcp/auth/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py b/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py new file mode 100644 index 0000000000..12a9f11205 --- /dev/null +++ b/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py @@ -0,0 +1,740 @@ +"""Unit tests for MCP OAuth authentication flow.""" + +from unittest.mock import Mock, patch + +import pytest + +from core.entities.mcp_provider import MCPProviderEntity +from core.mcp.auth.auth_flow import ( + OAUTH_STATE_EXPIRY_SECONDS, + OAUTH_STATE_REDIS_KEY_PREFIX, + OAuthCallbackState, + _create_secure_redis_state, + _retrieve_redis_state, + auth, + check_support_resource_discovery, + discover_oauth_metadata, + exchange_authorization, + generate_pkce_challenge, + handle_callback, + refresh_authorization, + register_client, + start_authorization, +) +from core.mcp.entities import AuthActionType, AuthResult +from core.mcp.types import ( + OAuthClientInformation, + OAuthClientInformationFull, + OAuthClientMetadata, + OAuthMetadata, + OAuthTokens, +) + + +class TestPKCEGeneration: + """Test PKCE challenge generation.""" + + def test_generate_pkce_challenge(self): + """Test PKCE challenge and verifier generation.""" + code_verifier, code_challenge = generate_pkce_challenge() + + # Verify format - should be URL-safe base64 without padding + assert "=" not in code_verifier + assert "+" not in code_verifier + assert "/" not in code_verifier + assert "=" not in code_challenge + assert "+" not in code_challenge + assert "/" not in code_challenge + + # Verify length + assert len(code_verifier) > 40 # Should be around 54 characters + assert len(code_challenge) > 40 # Should be around 43 characters + + def test_generate_pkce_challenge_uniqueness(self): + """Test that PKCE generation produces unique values.""" + results = set() + for _ in range(10): + code_verifier, code_challenge = generate_pkce_challenge() + results.add((code_verifier, code_challenge)) + + # All should be unique + assert len(results) == 10 + + +class TestRedisStateManagement: + """Test Redis state management functions.""" + + @patch("core.mcp.auth.auth_flow.redis_client") + def test_create_secure_redis_state(self, mock_redis): + """Test creating secure Redis state.""" + state_data = OAuthCallbackState( + provider_id="test-provider", + tenant_id="test-tenant", + server_url="https://example.com", + metadata=None, + client_information=OAuthClientInformation(client_id="test-client"), + code_verifier="test-verifier", + redirect_uri="https://redirect.example.com", + ) + + state_key = _create_secure_redis_state(state_data) + + # Verify state key format + assert len(state_key) > 20 # Should be a secure random token + + # Verify Redis call + mock_redis.setex.assert_called_once() + call_args = mock_redis.setex.call_args + assert call_args[0][0].startswith(OAUTH_STATE_REDIS_KEY_PREFIX) + assert call_args[0][1] == OAUTH_STATE_EXPIRY_SECONDS + assert state_data.model_dump_json() in call_args[0][2] + + @patch("core.mcp.auth.auth_flow.redis_client") + def test_retrieve_redis_state_success(self, mock_redis): + """Test retrieving state from Redis.""" + state_data = OAuthCallbackState( + provider_id="test-provider", + tenant_id="test-tenant", + server_url="https://example.com", + metadata=None, + client_information=OAuthClientInformation(client_id="test-client"), + code_verifier="test-verifier", + redirect_uri="https://redirect.example.com", + ) + mock_redis.get.return_value = state_data.model_dump_json() + + result = _retrieve_redis_state("test-state-key") + + # Verify result + assert result.provider_id == "test-provider" + assert result.tenant_id == "test-tenant" + assert result.server_url == "https://example.com" + + # Verify Redis calls + mock_redis.get.assert_called_once_with(f"{OAUTH_STATE_REDIS_KEY_PREFIX}test-state-key") + mock_redis.delete.assert_called_once_with(f"{OAUTH_STATE_REDIS_KEY_PREFIX}test-state-key") + + @patch("core.mcp.auth.auth_flow.redis_client") + def test_retrieve_redis_state_not_found(self, mock_redis): + """Test retrieving non-existent state from Redis.""" + mock_redis.get.return_value = None + + with pytest.raises(ValueError) as exc_info: + _retrieve_redis_state("nonexistent-key") + + assert "State parameter has expired or does not exist" in str(exc_info.value) + + @patch("core.mcp.auth.auth_flow.redis_client") + def test_retrieve_redis_state_invalid_json(self, mock_redis): + """Test retrieving invalid JSON state from Redis.""" + mock_redis.get.return_value = '{"invalid": json}' + + with pytest.raises(ValueError) as exc_info: + _retrieve_redis_state("test-key") + + assert "Invalid state parameter" in str(exc_info.value) + # State should still be deleted + mock_redis.delete.assert_called_once() + + +class TestOAuthDiscovery: + """Test OAuth discovery functions.""" + + @patch("core.helper.ssrf_proxy.get") + def test_check_support_resource_discovery_success(self, mock_get): + """Test successful resource discovery check.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"authorization_server_url": ["https://auth.example.com"]} + mock_get.return_value = mock_response + + supported, auth_url = check_support_resource_discovery("https://api.example.com/endpoint") + + assert supported is True + assert auth_url == "https://auth.example.com" + mock_get.assert_called_once_with( + "https://api.example.com/.well-known/oauth-protected-resource", + headers={"MCP-Protocol-Version": "2025-03-26", "User-Agent": "Dify"}, + ) + + @patch("core.helper.ssrf_proxy.get") + def test_check_support_resource_discovery_not_supported(self, mock_get): + """Test resource discovery not supported.""" + mock_response = Mock() + mock_response.status_code = 404 + mock_get.return_value = mock_response + + supported, auth_url = check_support_resource_discovery("https://api.example.com") + + assert supported is False + assert auth_url == "" + + @patch("core.helper.ssrf_proxy.get") + def test_check_support_resource_discovery_with_query_fragment(self, mock_get): + """Test resource discovery with query and fragment.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"authorization_server_url": ["https://auth.example.com"]} + mock_get.return_value = mock_response + + supported, auth_url = check_support_resource_discovery("https://api.example.com/path?query=1#fragment") + + assert supported is True + assert auth_url == "https://auth.example.com" + mock_get.assert_called_once_with( + "https://api.example.com/.well-known/oauth-protected-resource?query=1#fragment", + headers={"MCP-Protocol-Version": "2025-03-26", "User-Agent": "Dify"}, + ) + + @patch("core.helper.ssrf_proxy.get") + def test_discover_oauth_metadata_with_resource_discovery(self, mock_get): + """Test OAuth metadata discovery with resource discovery support.""" + with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check: + mock_check.return_value = (True, "https://auth.example.com") + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.is_success = True + mock_response.json.return_value = { + "authorization_endpoint": "https://auth.example.com/authorize", + "token_endpoint": "https://auth.example.com/token", + "response_types_supported": ["code"], + } + mock_get.return_value = mock_response + + metadata = discover_oauth_metadata("https://api.example.com") + + assert metadata is not None + assert metadata.authorization_endpoint == "https://auth.example.com/authorize" + assert metadata.token_endpoint == "https://auth.example.com/token" + mock_get.assert_called_once_with( + "https://auth.example.com/.well-known/oauth-authorization-server", + headers={"MCP-Protocol-Version": "2025-03-26"}, + ) + + @patch("core.helper.ssrf_proxy.get") + def test_discover_oauth_metadata_without_resource_discovery(self, mock_get): + """Test OAuth metadata discovery without resource discovery.""" + with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check: + mock_check.return_value = (False, "") + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.is_success = True + mock_response.json.return_value = { + "authorization_endpoint": "https://api.example.com/oauth/authorize", + "token_endpoint": "https://api.example.com/oauth/token", + "response_types_supported": ["code"], + } + mock_get.return_value = mock_response + + metadata = discover_oauth_metadata("https://api.example.com") + + assert metadata is not None + assert metadata.authorization_endpoint == "https://api.example.com/oauth/authorize" + mock_get.assert_called_once_with( + "https://api.example.com/.well-known/oauth-authorization-server", + headers={"MCP-Protocol-Version": "2025-03-26"}, + ) + + @patch("core.helper.ssrf_proxy.get") + def test_discover_oauth_metadata_not_found(self, mock_get): + """Test OAuth metadata discovery when not found.""" + with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check: + mock_check.return_value = (False, "") + + mock_response = Mock() + mock_response.status_code = 404 + mock_get.return_value = mock_response + + metadata = discover_oauth_metadata("https://api.example.com") + + assert metadata is None + + +class TestAuthorizationFlow: + """Test authorization flow functions.""" + + @patch("core.mcp.auth.auth_flow._create_secure_redis_state") + def test_start_authorization_with_metadata(self, mock_create_state): + """Test starting authorization with metadata.""" + mock_create_state.return_value = "secure-state-key" + + metadata = OAuthMetadata( + authorization_endpoint="https://auth.example.com/authorize", + token_endpoint="https://auth.example.com/token", + response_types_supported=["code"], + code_challenge_methods_supported=["S256"], + ) + client_info = OAuthClientInformation(client_id="test-client-id") + + auth_url, code_verifier = start_authorization( + "https://api.example.com", + metadata, + client_info, + "https://redirect.example.com", + "provider-id", + "tenant-id", + ) + + # Verify URL format + assert auth_url.startswith("https://auth.example.com/authorize?") + assert "response_type=code" in auth_url + assert "client_id=test-client-id" in auth_url + assert "code_challenge=" in auth_url + assert "code_challenge_method=S256" in auth_url + assert "redirect_uri=https%3A%2F%2Fredirect.example.com" in auth_url + assert "state=secure-state-key" in auth_url + + # Verify code verifier + assert len(code_verifier) > 40 + + # Verify state was stored + mock_create_state.assert_called_once() + state_data = mock_create_state.call_args[0][0] + assert state_data.provider_id == "provider-id" + assert state_data.tenant_id == "tenant-id" + assert state_data.code_verifier == code_verifier + + def test_start_authorization_without_metadata(self): + """Test starting authorization without metadata.""" + with patch("core.mcp.auth.auth_flow._create_secure_redis_state") as mock_create_state: + mock_create_state.return_value = "secure-state-key" + + client_info = OAuthClientInformation(client_id="test-client-id") + + auth_url, code_verifier = start_authorization( + "https://api.example.com", + None, + client_info, + "https://redirect.example.com", + "provider-id", + "tenant-id", + ) + + # Should use default authorization endpoint + assert auth_url.startswith("https://api.example.com/authorize?") + + def test_start_authorization_invalid_metadata(self): + """Test starting authorization with invalid metadata.""" + metadata = OAuthMetadata( + authorization_endpoint="https://auth.example.com/authorize", + token_endpoint="https://auth.example.com/token", + response_types_supported=["token"], # No "code" support + code_challenge_methods_supported=["plain"], # No "S256" support + ) + client_info = OAuthClientInformation(client_id="test-client-id") + + with pytest.raises(ValueError) as exc_info: + start_authorization( + "https://api.example.com", + metadata, + client_info, + "https://redirect.example.com", + "provider-id", + "tenant-id", + ) + + assert "does not support response type code" in str(exc_info.value) + + @patch("core.helper.ssrf_proxy.post") + def test_exchange_authorization_success(self, mock_post): + """Test successful authorization code exchange.""" + mock_response = Mock() + mock_response.is_success = True + mock_response.json.return_value = { + "access_token": "new-access-token", + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": "new-refresh-token", + } + mock_post.return_value = mock_response + + metadata = OAuthMetadata( + authorization_endpoint="https://auth.example.com/authorize", + token_endpoint="https://auth.example.com/token", + response_types_supported=["code"], + grant_types_supported=["authorization_code"], + ) + client_info = OAuthClientInformation(client_id="test-client-id", client_secret="test-secret") + + tokens = exchange_authorization( + "https://api.example.com", + metadata, + client_info, + "auth-code-123", + "code-verifier-xyz", + "https://redirect.example.com", + ) + + assert tokens.access_token == "new-access-token" + assert tokens.token_type == "Bearer" + assert tokens.expires_in == 3600 + assert tokens.refresh_token == "new-refresh-token" + + # Verify request + mock_post.assert_called_once_with( + "https://auth.example.com/token", + data={ + "grant_type": "authorization_code", + "client_id": "test-client-id", + "client_secret": "test-secret", + "code": "auth-code-123", + "code_verifier": "code-verifier-xyz", + "redirect_uri": "https://redirect.example.com", + }, + ) + + @patch("core.helper.ssrf_proxy.post") + def test_exchange_authorization_failure(self, mock_post): + """Test failed authorization code exchange.""" + mock_response = Mock() + mock_response.is_success = False + mock_response.status_code = 400 + mock_post.return_value = mock_response + + client_info = OAuthClientInformation(client_id="test-client-id") + + with pytest.raises(ValueError) as exc_info: + exchange_authorization( + "https://api.example.com", + None, + client_info, + "invalid-code", + "code-verifier", + "https://redirect.example.com", + ) + + assert "Token exchange failed: HTTP 400" in str(exc_info.value) + + @patch("core.helper.ssrf_proxy.post") + def test_refresh_authorization_success(self, mock_post): + """Test successful token refresh.""" + mock_response = Mock() + mock_response.is_success = True + mock_response.json.return_value = { + "access_token": "refreshed-access-token", + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": "new-refresh-token", + } + mock_post.return_value = mock_response + + metadata = OAuthMetadata( + authorization_endpoint="https://auth.example.com/authorize", + token_endpoint="https://auth.example.com/token", + response_types_supported=["code"], + grant_types_supported=["refresh_token"], + ) + client_info = OAuthClientInformation(client_id="test-client-id") + + tokens = refresh_authorization("https://api.example.com", metadata, client_info, "old-refresh-token") + + assert tokens.access_token == "refreshed-access-token" + assert tokens.refresh_token == "new-refresh-token" + + # Verify request + mock_post.assert_called_once_with( + "https://auth.example.com/token", + data={ + "grant_type": "refresh_token", + "client_id": "test-client-id", + "refresh_token": "old-refresh-token", + }, + ) + + @patch("core.helper.ssrf_proxy.post") + def test_register_client_success(self, mock_post): + """Test successful client registration.""" + mock_response = Mock() + mock_response.is_success = True + mock_response.json.return_value = { + "client_id": "new-client-id", + "client_secret": "new-client-secret", + "client_name": "Dify", + "redirect_uris": ["https://redirect.example.com"], + } + mock_post.return_value = mock_response + + metadata = OAuthMetadata( + authorization_endpoint="https://auth.example.com/authorize", + token_endpoint="https://auth.example.com/token", + registration_endpoint="https://auth.example.com/register", + response_types_supported=["code"], + ) + client_metadata = OAuthClientMetadata( + client_name="Dify", + redirect_uris=["https://redirect.example.com"], + grant_types=["authorization_code"], + response_types=["code"], + ) + + client_info = register_client("https://api.example.com", metadata, client_metadata) + + assert isinstance(client_info, OAuthClientInformationFull) + assert client_info.client_id == "new-client-id" + assert client_info.client_secret == "new-client-secret" + + # Verify request + mock_post.assert_called_once_with( + "https://auth.example.com/register", + json=client_metadata.model_dump(), + headers={"Content-Type": "application/json"}, + ) + + def test_register_client_no_endpoint(self): + """Test client registration when no endpoint available.""" + metadata = OAuthMetadata( + authorization_endpoint="https://auth.example.com/authorize", + token_endpoint="https://auth.example.com/token", + registration_endpoint=None, + response_types_supported=["code"], + ) + client_metadata = OAuthClientMetadata(client_name="Dify", redirect_uris=["https://redirect.example.com"]) + + with pytest.raises(ValueError) as exc_info: + register_client("https://api.example.com", metadata, client_metadata) + + assert "does not support dynamic client registration" in str(exc_info.value) + + +class TestCallbackHandling: + """Test OAuth callback handling.""" + + @patch("core.mcp.auth.auth_flow._retrieve_redis_state") + @patch("core.mcp.auth.auth_flow.exchange_authorization") + def test_handle_callback_success(self, mock_exchange, mock_retrieve_state): + """Test successful callback handling.""" + # Setup state + state_data = OAuthCallbackState( + provider_id="test-provider", + tenant_id="test-tenant", + server_url="https://api.example.com", + metadata=None, + client_information=OAuthClientInformation(client_id="test-client"), + code_verifier="test-verifier", + redirect_uri="https://redirect.example.com", + ) + mock_retrieve_state.return_value = state_data + + # Setup token exchange + tokens = OAuthTokens( + access_token="new-token", + token_type="Bearer", + expires_in=3600, + ) + mock_exchange.return_value = tokens + + # Setup service + mock_service = Mock() + + state_result, tokens_result = handle_callback("state-key", "auth-code") + + assert state_result == state_data + assert tokens_result == tokens + + # Verify calls + mock_retrieve_state.assert_called_once_with("state-key") + mock_exchange.assert_called_once_with( + "https://api.example.com", + None, + state_data.client_information, + "auth-code", + "test-verifier", + "https://redirect.example.com", + ) + # Note: handle_callback no longer saves tokens directly, it just returns them + # The caller (e.g., controller) is responsible for saving via execute_auth_actions + + +class TestAuthOrchestration: + """Test the main auth orchestration function.""" + + @pytest.fixture + def mock_provider(self): + """Create a mock provider entity.""" + provider = Mock(spec=MCPProviderEntity) + provider.id = "provider-id" + provider.tenant_id = "tenant-id" + provider.decrypt_server_url.return_value = "https://api.example.com" + provider.client_metadata = OAuthClientMetadata( + client_name="Dify", + redirect_uris=["https://redirect.example.com"], + ) + provider.redirect_url = "https://redirect.example.com" + provider.retrieve_client_information.return_value = None + provider.retrieve_tokens.return_value = None + return provider + + @pytest.fixture + def mock_service(self): + """Create a mock MCP service.""" + return Mock() + + @patch("core.mcp.auth.auth_flow.discover_oauth_metadata") + @patch("core.mcp.auth.auth_flow.register_client") + @patch("core.mcp.auth.auth_flow.start_authorization") + def test_auth_new_registration(self, mock_start_auth, mock_register, mock_discover, mock_provider, mock_service): + """Test auth flow for new client registration.""" + # Setup + mock_discover.return_value = OAuthMetadata( + authorization_endpoint="https://auth.example.com/authorize", + token_endpoint="https://auth.example.com/token", + response_types_supported=["code"], + grant_types_supported=["authorization_code"], + ) + mock_register.return_value = OAuthClientInformationFull( + client_id="new-client-id", + client_name="Dify", + redirect_uris=["https://redirect.example.com"], + ) + mock_start_auth.return_value = ("https://auth.example.com/authorize?...", "code-verifier") + + result = auth(mock_provider) + + # auth() now returns AuthResult + assert isinstance(result, AuthResult) + assert result.response == {"authorization_url": "https://auth.example.com/authorize?..."} + + # Verify that the result contains the correct actions + assert len(result.actions) == 2 + # Check for SAVE_CLIENT_INFO action + client_info_action = next(a for a in result.actions if a.action_type == AuthActionType.SAVE_CLIENT_INFO) + assert client_info_action.data == {"client_information": mock_register.return_value.model_dump()} + assert client_info_action.provider_id == "provider-id" + assert client_info_action.tenant_id == "tenant-id" + + # Check for SAVE_CODE_VERIFIER action + verifier_action = next(a for a in result.actions if a.action_type == AuthActionType.SAVE_CODE_VERIFIER) + assert verifier_action.data == {"code_verifier": "code-verifier"} + assert verifier_action.provider_id == "provider-id" + assert verifier_action.tenant_id == "tenant-id" + + # Verify calls + mock_register.assert_called_once() + + @patch("core.mcp.auth.auth_flow.discover_oauth_metadata") + @patch("core.mcp.auth.auth_flow._retrieve_redis_state") + @patch("core.mcp.auth.auth_flow.exchange_authorization") + def test_auth_exchange_code(self, mock_exchange, mock_retrieve_state, mock_discover, mock_provider, mock_service): + """Test auth flow for exchanging authorization code.""" + # Setup metadata discovery + mock_discover.return_value = OAuthMetadata( + authorization_endpoint="https://auth.example.com/authorize", + token_endpoint="https://auth.example.com/token", + response_types_supported=["code"], + grant_types_supported=["authorization_code"], + ) + + # Setup existing client + mock_provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="existing-client") + + # Setup state retrieval + state_data = OAuthCallbackState( + provider_id="provider-id", + tenant_id="tenant-id", + server_url="https://api.example.com", + metadata=None, + client_information=OAuthClientInformation(client_id="existing-client"), + code_verifier="test-verifier", + redirect_uri="https://redirect.example.com", + ) + mock_retrieve_state.return_value = state_data + + # Setup token exchange + tokens = OAuthTokens(access_token="new-token", token_type="Bearer", expires_in=3600) + mock_exchange.return_value = tokens + + result = auth(mock_provider, authorization_code="auth-code", state_param="state-key") + + # auth() now returns AuthResult, not a dict + assert isinstance(result, AuthResult) + assert result.response == {"result": "success"} + + # Verify that the result contains the correct action + assert len(result.actions) == 1 + assert result.actions[0].action_type == AuthActionType.SAVE_TOKENS + assert result.actions[0].data == tokens.model_dump() + assert result.actions[0].provider_id == "provider-id" + assert result.actions[0].tenant_id == "tenant-id" + + @patch("core.mcp.auth.auth_flow.discover_oauth_metadata") + def test_auth_exchange_code_without_state(self, mock_discover, mock_provider, mock_service): + """Test auth flow fails when exchanging code without state.""" + # Setup metadata discovery + mock_discover.return_value = OAuthMetadata( + authorization_endpoint="https://auth.example.com/authorize", + token_endpoint="https://auth.example.com/token", + response_types_supported=["code"], + grant_types_supported=["authorization_code"], + ) + + mock_provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="existing-client") + + with pytest.raises(ValueError) as exc_info: + auth(mock_provider, authorization_code="auth-code") + + assert "State parameter is required" in str(exc_info.value) + + @patch("core.mcp.auth.auth_flow.refresh_authorization") + def test_auth_refresh_token(self, mock_refresh, mock_provider, mock_service): + """Test auth flow for refreshing tokens.""" + # Setup existing client and tokens + mock_provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="existing-client") + mock_provider.retrieve_tokens.return_value = OAuthTokens( + access_token="old-token", + token_type="Bearer", + expires_in=0, + refresh_token="refresh-token", + ) + + # Setup refresh + new_tokens = OAuthTokens( + access_token="refreshed-token", + token_type="Bearer", + expires_in=3600, + refresh_token="new-refresh-token", + ) + mock_refresh.return_value = new_tokens + + with patch("core.mcp.auth.auth_flow.discover_oauth_metadata") as mock_discover: + mock_discover.return_value = OAuthMetadata( + authorization_endpoint="https://auth.example.com/authorize", + token_endpoint="https://auth.example.com/token", + response_types_supported=["code"], + grant_types_supported=["authorization_code"], + ) + + result = auth(mock_provider) + + # auth() now returns AuthResult + assert isinstance(result, AuthResult) + assert result.response == {"result": "success"} + + # Verify that the result contains the correct action + assert len(result.actions) == 1 + assert result.actions[0].action_type == AuthActionType.SAVE_TOKENS + assert result.actions[0].data == new_tokens.model_dump() + assert result.actions[0].provider_id == "provider-id" + assert result.actions[0].tenant_id == "tenant-id" + + # Verify refresh was called + mock_refresh.assert_called_once() + + @patch("core.mcp.auth.auth_flow.discover_oauth_metadata") + def test_auth_registration_fails_with_code(self, mock_discover, mock_provider, mock_service): + """Test auth fails when no client info exists but code is provided.""" + # Setup metadata discovery + mock_discover.return_value = OAuthMetadata( + authorization_endpoint="https://auth.example.com/authorize", + token_endpoint="https://auth.example.com/token", + response_types_supported=["code"], + grant_types_supported=["authorization_code"], + ) + + mock_provider.retrieve_client_information.return_value = None + + with pytest.raises(ValueError) as exc_info: + auth(mock_provider, authorization_code="auth-code") + + assert "Existing OAuth client information is required" in str(exc_info.value) diff --git a/api/tests/unit_tests/core/mcp/test_auth_client_inheritance.py b/api/tests/unit_tests/core/mcp/test_auth_client_inheritance.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/mcp/test_entities.py b/api/tests/unit_tests/core/mcp/test_entities.py new file mode 100644 index 0000000000..3fede55916 --- /dev/null +++ b/api/tests/unit_tests/core/mcp/test_entities.py @@ -0,0 +1,239 @@ +"""Unit tests for MCP entities module.""" + +from unittest.mock import Mock + +from core.mcp.entities import ( + SUPPORTED_PROTOCOL_VERSIONS, + LifespanContextT, + RequestContext, + SessionT, +) +from core.mcp.session.base_session import BaseSession +from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestParams + + +class TestProtocolVersions: + """Test protocol version constants.""" + + def test_supported_protocol_versions(self): + """Test supported protocol versions list.""" + assert isinstance(SUPPORTED_PROTOCOL_VERSIONS, list) + assert len(SUPPORTED_PROTOCOL_VERSIONS) >= 3 + assert "2024-11-05" in SUPPORTED_PROTOCOL_VERSIONS + assert "2025-03-26" in SUPPORTED_PROTOCOL_VERSIONS + assert LATEST_PROTOCOL_VERSION in SUPPORTED_PROTOCOL_VERSIONS + + def test_latest_protocol_version_is_supported(self): + """Test that latest protocol version is in supported versions.""" + assert LATEST_PROTOCOL_VERSION in SUPPORTED_PROTOCOL_VERSIONS + + +class TestRequestContext: + """Test RequestContext dataclass.""" + + def test_request_context_creation(self): + """Test creating a RequestContext instance.""" + mock_session = Mock(spec=BaseSession) + mock_lifespan = {"key": "value"} + mock_meta = RequestParams.Meta(progressToken="test-token") + + context = RequestContext( + request_id="test-request-123", + meta=mock_meta, + session=mock_session, + lifespan_context=mock_lifespan, + ) + + assert context.request_id == "test-request-123" + assert context.meta == mock_meta + assert context.session == mock_session + assert context.lifespan_context == mock_lifespan + + def test_request_context_with_none_meta(self): + """Test creating RequestContext with None meta.""" + mock_session = Mock(spec=BaseSession) + + context = RequestContext( + request_id=42, # Can be int or string + meta=None, + session=mock_session, + lifespan_context=None, + ) + + assert context.request_id == 42 + assert context.meta is None + assert context.session == mock_session + assert context.lifespan_context is None + + def test_request_context_attributes(self): + """Test RequestContext attributes are accessible.""" + mock_session = Mock(spec=BaseSession) + + context = RequestContext( + request_id="test-123", + meta=None, + session=mock_session, + lifespan_context=None, + ) + + # Verify attributes are accessible + assert hasattr(context, "request_id") + assert hasattr(context, "meta") + assert hasattr(context, "session") + assert hasattr(context, "lifespan_context") + + # Verify values + assert context.request_id == "test-123" + assert context.meta is None + assert context.session == mock_session + assert context.lifespan_context is None + + def test_request_context_generic_typing(self): + """Test RequestContext with different generic types.""" + # Create a mock session with specific type + mock_session = Mock(spec=BaseSession) + + # Create context with string lifespan context + context_str = RequestContext[BaseSession, str]( + request_id="test-1", + meta=None, + session=mock_session, + lifespan_context="string-context", + ) + assert isinstance(context_str.lifespan_context, str) + + # Create context with dict lifespan context + context_dict = RequestContext[BaseSession, dict]( + request_id="test-2", + meta=None, + session=mock_session, + lifespan_context={"key": "value"}, + ) + assert isinstance(context_dict.lifespan_context, dict) + + # Create context with custom object lifespan context + class CustomLifespan: + def __init__(self, data): + self.data = data + + custom_lifespan = CustomLifespan("test-data") + context_custom = RequestContext[BaseSession, CustomLifespan]( + request_id="test-3", + meta=None, + session=mock_session, + lifespan_context=custom_lifespan, + ) + assert isinstance(context_custom.lifespan_context, CustomLifespan) + assert context_custom.lifespan_context.data == "test-data" + + def test_request_context_with_progress_meta(self): + """Test RequestContext with progress metadata.""" + mock_session = Mock(spec=BaseSession) + progress_meta = RequestParams.Meta(progressToken="progress-123") + + context = RequestContext( + request_id="req-456", + meta=progress_meta, + session=mock_session, + lifespan_context=None, + ) + + assert context.meta is not None + assert context.meta.progressToken == "progress-123" + + def test_request_context_equality(self): + """Test RequestContext equality comparison.""" + mock_session1 = Mock(spec=BaseSession) + mock_session2 = Mock(spec=BaseSession) + + context1 = RequestContext( + request_id="test-123", + meta=None, + session=mock_session1, + lifespan_context="context", + ) + + context2 = RequestContext( + request_id="test-123", + meta=None, + session=mock_session1, + lifespan_context="context", + ) + + context3 = RequestContext( + request_id="test-456", + meta=None, + session=mock_session1, + lifespan_context="context", + ) + + # Same values should be equal + assert context1 == context2 + + # Different request_id should not be equal + assert context1 != context3 + + # Different session should not be equal + context4 = RequestContext( + request_id="test-123", + meta=None, + session=mock_session2, + lifespan_context="context", + ) + assert context1 != context4 + + def test_request_context_repr(self): + """Test RequestContext string representation.""" + mock_session = Mock(spec=BaseSession) + mock_session.__repr__ = Mock(return_value="") + + context = RequestContext( + request_id="test-123", + meta=None, + session=mock_session, + lifespan_context={"data": "test"}, + ) + + repr_str = repr(context) + assert "RequestContext" in repr_str + assert "test-123" in repr_str + assert "MockSession" in repr_str + + +class TestTypeVariables: + """Test type variables defined in the module.""" + + def test_session_type_var(self): + """Test SessionT type variable.""" + + # Create a custom session class + class CustomSession(BaseSession): + pass + + # Use in generic context + def process_session(session: SessionT) -> SessionT: + return session + + mock_session = Mock(spec=CustomSession) + result = process_session(mock_session) + assert result == mock_session + + def test_lifespan_context_type_var(self): + """Test LifespanContextT type variable.""" + + # Use in generic context + def process_lifespan(context: LifespanContextT) -> LifespanContextT: + return context + + # Test with different types + str_context = "string-context" + assert process_lifespan(str_context) == str_context + + dict_context = {"key": "value"} + assert process_lifespan(dict_context) == dict_context + + class CustomContext: + pass + + custom_context = CustomContext() + assert process_lifespan(custom_context) == custom_context diff --git a/api/tests/unit_tests/core/mcp/test_error.py b/api/tests/unit_tests/core/mcp/test_error.py new file mode 100644 index 0000000000..3a95fae673 --- /dev/null +++ b/api/tests/unit_tests/core/mcp/test_error.py @@ -0,0 +1,205 @@ +"""Unit tests for MCP error classes.""" + +import pytest + +from core.mcp.error import MCPAuthError, MCPConnectionError, MCPError + + +class TestMCPError: + """Test MCPError base exception class.""" + + def test_mcp_error_creation(self): + """Test creating MCPError instance.""" + error = MCPError("Test error message") + assert str(error) == "Test error message" + assert isinstance(error, Exception) + + def test_mcp_error_inheritance(self): + """Test MCPError inherits from Exception.""" + error = MCPError() + assert isinstance(error, Exception) + assert type(error).__name__ == "MCPError" + + def test_mcp_error_with_empty_message(self): + """Test MCPError with empty message.""" + error = MCPError() + assert str(error) == "" + + def test_mcp_error_raise(self): + """Test raising MCPError.""" + with pytest.raises(MCPError) as exc_info: + raise MCPError("Something went wrong") + + assert str(exc_info.value) == "Something went wrong" + + +class TestMCPConnectionError: + """Test MCPConnectionError exception class.""" + + def test_mcp_connection_error_creation(self): + """Test creating MCPConnectionError instance.""" + error = MCPConnectionError("Connection failed") + assert str(error) == "Connection failed" + assert isinstance(error, MCPError) + assert isinstance(error, Exception) + + def test_mcp_connection_error_inheritance(self): + """Test MCPConnectionError inheritance chain.""" + error = MCPConnectionError() + assert isinstance(error, MCPConnectionError) + assert isinstance(error, MCPError) + assert isinstance(error, Exception) + + def test_mcp_connection_error_raise(self): + """Test raising MCPConnectionError.""" + with pytest.raises(MCPConnectionError) as exc_info: + raise MCPConnectionError("Unable to connect to server") + + assert str(exc_info.value) == "Unable to connect to server" + + def test_mcp_connection_error_catch_as_mcp_error(self): + """Test catching MCPConnectionError as MCPError.""" + with pytest.raises(MCPError) as exc_info: + raise MCPConnectionError("Connection issue") + + assert isinstance(exc_info.value, MCPConnectionError) + assert str(exc_info.value) == "Connection issue" + + +class TestMCPAuthError: + """Test MCPAuthError exception class.""" + + def test_mcp_auth_error_creation(self): + """Test creating MCPAuthError instance.""" + error = MCPAuthError("Authentication failed") + assert str(error) == "Authentication failed" + assert isinstance(error, MCPConnectionError) + assert isinstance(error, MCPError) + assert isinstance(error, Exception) + + def test_mcp_auth_error_inheritance(self): + """Test MCPAuthError inheritance chain.""" + error = MCPAuthError() + assert isinstance(error, MCPAuthError) + assert isinstance(error, MCPConnectionError) + assert isinstance(error, MCPError) + assert isinstance(error, Exception) + + def test_mcp_auth_error_raise(self): + """Test raising MCPAuthError.""" + with pytest.raises(MCPAuthError) as exc_info: + raise MCPAuthError("Invalid credentials") + + assert str(exc_info.value) == "Invalid credentials" + + def test_mcp_auth_error_catch_hierarchy(self): + """Test catching MCPAuthError at different levels.""" + # Catch as MCPAuthError + with pytest.raises(MCPAuthError) as exc_info: + raise MCPAuthError("Auth specific error") + assert str(exc_info.value) == "Auth specific error" + + # Catch as MCPConnectionError + with pytest.raises(MCPConnectionError) as exc_info: + raise MCPAuthError("Auth connection error") + assert isinstance(exc_info.value, MCPAuthError) + assert str(exc_info.value) == "Auth connection error" + + # Catch as MCPError + with pytest.raises(MCPError) as exc_info: + raise MCPAuthError("Auth base error") + assert isinstance(exc_info.value, MCPAuthError) + assert str(exc_info.value) == "Auth base error" + + +class TestErrorHierarchy: + """Test the complete error hierarchy.""" + + def test_exception_hierarchy(self): + """Test the complete exception hierarchy.""" + # Create instances + base_error = MCPError("base") + connection_error = MCPConnectionError("connection") + auth_error = MCPAuthError("auth") + + # Test type relationships + assert not isinstance(base_error, MCPConnectionError) + assert not isinstance(base_error, MCPAuthError) + + assert isinstance(connection_error, MCPError) + assert not isinstance(connection_error, MCPAuthError) + + assert isinstance(auth_error, MCPError) + assert isinstance(auth_error, MCPConnectionError) + + def test_error_handling_patterns(self): + """Test common error handling patterns.""" + + def raise_auth_error(): + raise MCPAuthError("401 Unauthorized") + + def raise_connection_error(): + raise MCPConnectionError("Connection timeout") + + def raise_base_error(): + raise MCPError("Generic error") + + # Pattern 1: Catch specific errors first + errors_caught = [] + + for error_func in [raise_auth_error, raise_connection_error, raise_base_error]: + try: + error_func() + except MCPAuthError: + errors_caught.append("auth") + except MCPConnectionError: + errors_caught.append("connection") + except MCPError: + errors_caught.append("base") + + assert errors_caught == ["auth", "connection", "base"] + + # Pattern 2: Catch all as base error + for error_func in [raise_auth_error, raise_connection_error, raise_base_error]: + with pytest.raises(MCPError) as exc_info: + error_func() + assert isinstance(exc_info.value, MCPError) + + def test_error_with_cause(self): + """Test errors with cause (chained exceptions).""" + original_error = ValueError("Original error") + + def raise_chained_error(): + try: + raise original_error + except ValueError as e: + raise MCPConnectionError("Connection failed") from e + + with pytest.raises(MCPConnectionError) as exc_info: + raise_chained_error() + + assert str(exc_info.value) == "Connection failed" + assert exc_info.value.__cause__ == original_error + + def test_error_comparison(self): + """Test error instance comparison.""" + error1 = MCPError("Test message") + error2 = MCPError("Test message") + error3 = MCPError("Different message") + + # Errors are not equal even with same message (different instances) + assert error1 != error2 + assert error1 != error3 + + # But they have the same type + assert type(error1) == type(error2) == type(error3) + + def test_error_representation(self): + """Test error string representation.""" + base_error = MCPError("Base error message") + connection_error = MCPConnectionError("Connection error message") + auth_error = MCPAuthError("Auth error message") + + assert repr(base_error) == "MCPError('Base error message')" + assert repr(connection_error) == "MCPConnectionError('Connection error message')" + assert repr(auth_error) == "MCPAuthError('Auth error message')" diff --git a/api/tests/unit_tests/core/mcp/test_mcp_client.py b/api/tests/unit_tests/core/mcp/test_mcp_client.py new file mode 100644 index 0000000000..c0420d3371 --- /dev/null +++ b/api/tests/unit_tests/core/mcp/test_mcp_client.py @@ -0,0 +1,382 @@ +"""Unit tests for MCP client.""" + +from contextlib import ExitStack +from types import TracebackType +from unittest.mock import Mock, patch + +import pytest + +from core.mcp.error import MCPConnectionError +from core.mcp.mcp_client import MCPClient +from core.mcp.types import CallToolResult, ListToolsResult, TextContent, Tool, ToolAnnotations + + +class TestMCPClient: + """Test suite for MCPClient.""" + + def test_init(self): + """Test client initialization.""" + client = MCPClient( + server_url="http://test.example.com/mcp", + headers={"Authorization": "Bearer test"}, + timeout=30.0, + sse_read_timeout=60.0, + ) + + assert client.server_url == "http://test.example.com/mcp" + assert client.headers == {"Authorization": "Bearer test"} + assert client.timeout == 30.0 + assert client.sse_read_timeout == 60.0 + assert client._session is None + assert isinstance(client._exit_stack, ExitStack) + assert client._initialized is False + + def test_init_defaults(self): + """Test client initialization with defaults.""" + client = MCPClient(server_url="http://test.example.com") + + assert client.server_url == "http://test.example.com" + assert client.headers == {} + assert client.timeout is None + assert client.sse_read_timeout is None + + @patch("core.mcp.mcp_client.streamablehttp_client") + @patch("core.mcp.mcp_client.ClientSession") + def test_initialize_with_mcp_url(self, mock_client_session, mock_streamable_client): + """Test initialization with MCP URL.""" + # Setup mocks + mock_read_stream = Mock() + mock_write_stream = Mock() + mock_client_context = Mock() + mock_streamable_client.return_value.__enter__.return_value = ( + mock_read_stream, + mock_write_stream, + mock_client_context, + ) + + mock_session = Mock() + mock_client_session.return_value.__enter__.return_value = mock_session + + client = MCPClient(server_url="http://test.example.com/mcp") + client._initialize() + + # Verify streamable client was called + mock_streamable_client.assert_called_once_with( + url="http://test.example.com/mcp", + headers={}, + timeout=None, + sse_read_timeout=None, + ) + + # Verify session was created + mock_client_session.assert_called_once_with(mock_read_stream, mock_write_stream) + mock_session.initialize.assert_called_once() + assert client._session == mock_session + + @patch("core.mcp.mcp_client.sse_client") + @patch("core.mcp.mcp_client.ClientSession") + def test_initialize_with_sse_url(self, mock_client_session, mock_sse_client): + """Test initialization with SSE URL.""" + # Setup mocks + mock_read_stream = Mock() + mock_write_stream = Mock() + mock_sse_client.return_value.__enter__.return_value = (mock_read_stream, mock_write_stream) + + mock_session = Mock() + mock_client_session.return_value.__enter__.return_value = mock_session + + client = MCPClient(server_url="http://test.example.com/sse") + client._initialize() + + # Verify SSE client was called + mock_sse_client.assert_called_once_with( + url="http://test.example.com/sse", + headers={}, + timeout=None, + sse_read_timeout=None, + ) + + # Verify session was created + mock_client_session.assert_called_once_with(mock_read_stream, mock_write_stream) + mock_session.initialize.assert_called_once() + assert client._session == mock_session + + @patch("core.mcp.mcp_client.sse_client") + @patch("core.mcp.mcp_client.streamablehttp_client") + @patch("core.mcp.mcp_client.ClientSession") + def test_initialize_with_unknown_method_fallback_to_sse( + self, mock_client_session, mock_streamable_client, mock_sse_client + ): + """Test initialization with unknown method falls back to SSE.""" + # Setup mocks + mock_read_stream = Mock() + mock_write_stream = Mock() + mock_sse_client.return_value.__enter__.return_value = (mock_read_stream, mock_write_stream) + + mock_session = Mock() + mock_client_session.return_value.__enter__.return_value = mock_session + + client = MCPClient(server_url="http://test.example.com/unknown") + client._initialize() + + # Verify SSE client was tried + mock_sse_client.assert_called_once() + mock_streamable_client.assert_not_called() + + # Verify session was created + assert client._session == mock_session + + @patch("core.mcp.mcp_client.sse_client") + @patch("core.mcp.mcp_client.streamablehttp_client") + @patch("core.mcp.mcp_client.ClientSession") + def test_initialize_fallback_from_sse_to_mcp(self, mock_client_session, mock_streamable_client, mock_sse_client): + """Test initialization falls back from SSE to MCP on connection error.""" + # Setup SSE to fail + mock_sse_client.side_effect = MCPConnectionError("SSE connection failed") + + # Setup MCP to succeed + mock_read_stream = Mock() + mock_write_stream = Mock() + mock_client_context = Mock() + mock_streamable_client.return_value.__enter__.return_value = ( + mock_read_stream, + mock_write_stream, + mock_client_context, + ) + + mock_session = Mock() + mock_client_session.return_value.__enter__.return_value = mock_session + + client = MCPClient(server_url="http://test.example.com/unknown") + client._initialize() + + # Verify both were tried + mock_sse_client.assert_called_once() + mock_streamable_client.assert_called_once() + + # Verify session was created with MCP + assert client._session == mock_session + + @patch("core.mcp.mcp_client.streamablehttp_client") + @patch("core.mcp.mcp_client.ClientSession") + def test_connect_server_mcp(self, mock_client_session, mock_streamable_client): + """Test connect_server with MCP method.""" + # Setup mocks + mock_read_stream = Mock() + mock_write_stream = Mock() + mock_client_context = Mock() + mock_streamable_client.return_value.__enter__.return_value = ( + mock_read_stream, + mock_write_stream, + mock_client_context, + ) + + mock_session = Mock() + mock_client_session.return_value.__enter__.return_value = mock_session + + client = MCPClient(server_url="http://test.example.com") + client.connect_server(mock_streamable_client, "mcp") + + # Verify correct streams were passed + mock_client_session.assert_called_once_with(mock_read_stream, mock_write_stream) + mock_session.initialize.assert_called_once() + + @patch("core.mcp.mcp_client.sse_client") + @patch("core.mcp.mcp_client.ClientSession") + def test_connect_server_sse(self, mock_client_session, mock_sse_client): + """Test connect_server with SSE method.""" + # Setup mocks + mock_read_stream = Mock() + mock_write_stream = Mock() + mock_sse_client.return_value.__enter__.return_value = (mock_read_stream, mock_write_stream) + + mock_session = Mock() + mock_client_session.return_value.__enter__.return_value = mock_session + + client = MCPClient(server_url="http://test.example.com") + client.connect_server(mock_sse_client, "sse") + + # Verify correct streams were passed + mock_client_session.assert_called_once_with(mock_read_stream, mock_write_stream) + mock_session.initialize.assert_called_once() + + def test_context_manager_enter(self): + """Test context manager enter.""" + client = MCPClient(server_url="http://test.example.com") + + with patch.object(client, "_initialize") as mock_initialize: + result = client.__enter__() + + assert result == client + assert client._initialized is True + mock_initialize.assert_called_once() + + def test_context_manager_exit(self): + """Test context manager exit.""" + client = MCPClient(server_url="http://test.example.com") + + with patch.object(client, "cleanup") as mock_cleanup: + exc_type: type[BaseException] | None = None + exc_val: BaseException | None = None + exc_tb: TracebackType | None = None + client.__exit__(exc_type, exc_val, exc_tb) + + mock_cleanup.assert_called_once() + + def test_list_tools_not_initialized(self): + """Test list_tools when session not initialized.""" + client = MCPClient(server_url="http://test.example.com") + + with pytest.raises(ValueError) as exc_info: + client.list_tools() + + assert "Session not initialized" in str(exc_info.value) + + def test_list_tools_success(self): + """Test successful list_tools call.""" + client = MCPClient(server_url="http://test.example.com") + + # Setup mock session + mock_session = Mock() + expected_tools = [ + Tool( + name="test-tool", + description="A test tool", + inputSchema={"type": "object", "properties": {}}, + annotations=ToolAnnotations(title="Test Tool"), + ) + ] + mock_session.list_tools.return_value = ListToolsResult(tools=expected_tools) + client._session = mock_session + + result = client.list_tools() + + assert result == expected_tools + mock_session.list_tools.assert_called_once() + + def test_invoke_tool_not_initialized(self): + """Test invoke_tool when session not initialized.""" + client = MCPClient(server_url="http://test.example.com") + + with pytest.raises(ValueError) as exc_info: + client.invoke_tool("test-tool", {"arg": "value"}) + + assert "Session not initialized" in str(exc_info.value) + + def test_invoke_tool_success(self): + """Test successful invoke_tool call.""" + client = MCPClient(server_url="http://test.example.com") + + # Setup mock session + mock_session = Mock() + expected_result = CallToolResult( + content=[TextContent(type="text", text="Tool executed successfully")], + isError=False, + ) + mock_session.call_tool.return_value = expected_result + client._session = mock_session + + result = client.invoke_tool("test-tool", {"arg": "value"}) + + assert result == expected_result + mock_session.call_tool.assert_called_once_with("test-tool", {"arg": "value"}) + + def test_cleanup(self): + """Test cleanup method.""" + client = MCPClient(server_url="http://test.example.com") + mock_exit_stack = Mock(spec=ExitStack) + client._exit_stack = mock_exit_stack + client._session = Mock() + client._initialized = True + + client.cleanup() + + mock_exit_stack.close.assert_called_once() + assert client._session is None + assert client._initialized is False + + def test_cleanup_with_error(self): + """Test cleanup method with error.""" + client = MCPClient(server_url="http://test.example.com") + mock_exit_stack = Mock(spec=ExitStack) + mock_exit_stack.close.side_effect = Exception("Cleanup error") + client._exit_stack = mock_exit_stack + client._session = Mock() + client._initialized = True + + with pytest.raises(ValueError) as exc_info: + client.cleanup() + + assert "Error during cleanup: Cleanup error" in str(exc_info.value) + assert client._session is None + assert client._initialized is False + + @patch("core.mcp.mcp_client.streamablehttp_client") + @patch("core.mcp.mcp_client.ClientSession") + def test_full_context_manager_flow(self, mock_client_session, mock_streamable_client): + """Test full context manager flow.""" + # Setup mocks + mock_read_stream = Mock() + mock_write_stream = Mock() + mock_client_context = Mock() + mock_streamable_client.return_value.__enter__.return_value = ( + mock_read_stream, + mock_write_stream, + mock_client_context, + ) + + mock_session = Mock() + mock_client_session.return_value.__enter__.return_value = mock_session + + expected_tools = [Tool(name="test-tool", description="Test", inputSchema={})] + mock_session.list_tools.return_value = ListToolsResult(tools=expected_tools) + + with MCPClient(server_url="http://test.example.com/mcp") as client: + assert client._initialized is True + assert client._session == mock_session + + # Test tool operations + tools = client.list_tools() + assert tools == expected_tools + + # After exit, should be cleaned up + assert client._initialized is False + assert client._session is None + + def test_headers_passed_to_clients(self): + """Test that headers are properly passed to underlying clients.""" + custom_headers = { + "Authorization": "Bearer test-token", + "X-Custom-Header": "test-value", + } + + with patch("core.mcp.mcp_client.streamablehttp_client") as mock_streamable_client: + with patch("core.mcp.mcp_client.ClientSession") as mock_client_session: + # Setup mocks + mock_read_stream = Mock() + mock_write_stream = Mock() + mock_client_context = Mock() + mock_streamable_client.return_value.__enter__.return_value = ( + mock_read_stream, + mock_write_stream, + mock_client_context, + ) + + mock_session = Mock() + mock_client_session.return_value.__enter__.return_value = mock_session + + client = MCPClient( + server_url="http://test.example.com/mcp", + headers=custom_headers, + timeout=30.0, + sse_read_timeout=60.0, + ) + client._initialize() + + # Verify headers were passed + mock_streamable_client.assert_called_once_with( + url="http://test.example.com/mcp", + headers=custom_headers, + timeout=30.0, + sse_read_timeout=60.0, + ) diff --git a/api/tests/unit_tests/core/mcp/test_types.py b/api/tests/unit_tests/core/mcp/test_types.py new file mode 100644 index 0000000000..6d8130bd13 --- /dev/null +++ b/api/tests/unit_tests/core/mcp/test_types.py @@ -0,0 +1,492 @@ +"""Unit tests for MCP types module.""" + +import pytest +from pydantic import ValidationError + +from core.mcp.types import ( + INTERNAL_ERROR, + INVALID_PARAMS, + INVALID_REQUEST, + LATEST_PROTOCOL_VERSION, + METHOD_NOT_FOUND, + PARSE_ERROR, + SERVER_LATEST_PROTOCOL_VERSION, + Annotations, + CallToolRequest, + CallToolRequestParams, + CallToolResult, + ClientCapabilities, + CompleteRequest, + CompleteRequestParams, + CompleteResult, + Completion, + CompletionArgument, + CompletionContext, + ErrorData, + ImageContent, + Implementation, + InitializeRequest, + InitializeRequestParams, + InitializeResult, + JSONRPCError, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + ListToolsRequest, + ListToolsResult, + OAuthClientInformation, + OAuthClientMetadata, + OAuthMetadata, + OAuthTokens, + PingRequest, + ProgressNotification, + ProgressNotificationParams, + PromptReference, + RequestParams, + ResourceTemplateReference, + Result, + ServerCapabilities, + TextContent, + Tool, + ToolAnnotations, +) + + +class TestConstants: + """Test module constants.""" + + def test_protocol_versions(self): + """Test protocol version constants.""" + assert LATEST_PROTOCOL_VERSION == "2025-03-26" + assert SERVER_LATEST_PROTOCOL_VERSION == "2024-11-05" + + def test_error_codes(self): + """Test JSON-RPC error code constants.""" + assert PARSE_ERROR == -32700 + assert INVALID_REQUEST == -32600 + assert METHOD_NOT_FOUND == -32601 + assert INVALID_PARAMS == -32602 + assert INTERNAL_ERROR == -32603 + + +class TestRequestParams: + """Test RequestParams and related classes.""" + + def test_request_params_basic(self): + """Test basic RequestParams creation.""" + params = RequestParams() + assert params.meta is None + + def test_request_params_with_meta(self): + """Test RequestParams with meta.""" + meta = RequestParams.Meta(progressToken="test-token") + params = RequestParams(_meta=meta) + assert params.meta is not None + assert params.meta.progressToken == "test-token" + + def test_request_params_meta_extra_fields(self): + """Test RequestParams.Meta allows extra fields.""" + meta = RequestParams.Meta(progressToken="token", customField="value") + assert meta.progressToken == "token" + assert meta.customField == "value" # type: ignore + + def test_request_params_serialization(self): + """Test RequestParams serialization with _meta alias.""" + meta = RequestParams.Meta(progressToken="test") + params = RequestParams(_meta=meta) + + # Model dump should use the alias + dumped = params.model_dump(by_alias=True) + assert "_meta" in dumped + assert dumped["_meta"] is not None + assert dumped["_meta"]["progressToken"] == "test" + + +class TestJSONRPCMessages: + """Test JSON-RPC message types.""" + + def test_jsonrpc_request(self): + """Test JSONRPCRequest creation and validation.""" + request = JSONRPCRequest(jsonrpc="2.0", id="test-123", method="test_method", params={"key": "value"}) + + assert request.jsonrpc == "2.0" + assert request.id == "test-123" + assert request.method == "test_method" + assert request.params == {"key": "value"} + + def test_jsonrpc_request_numeric_id(self): + """Test JSONRPCRequest with numeric ID.""" + request = JSONRPCRequest(jsonrpc="2.0", id=123, method="test", params=None) + assert request.id == 123 + + def test_jsonrpc_notification(self): + """Test JSONRPCNotification creation.""" + notification = JSONRPCNotification(jsonrpc="2.0", method="notification_method", params={"data": "test"}) + + assert notification.jsonrpc == "2.0" + assert notification.method == "notification_method" + assert not hasattr(notification, "id") # Notifications don't have ID + + def test_jsonrpc_response(self): + """Test JSONRPCResponse creation.""" + response = JSONRPCResponse(jsonrpc="2.0", id="req-123", result={"success": True}) + + assert response.jsonrpc == "2.0" + assert response.id == "req-123" + assert response.result == {"success": True} + + def test_jsonrpc_error(self): + """Test JSONRPCError creation.""" + error_data = ErrorData(code=INVALID_PARAMS, message="Invalid parameters", data={"field": "missing"}) + + error = JSONRPCError(jsonrpc="2.0", id="req-123", error=error_data) + + assert error.jsonrpc == "2.0" + assert error.id == "req-123" + assert error.error.code == INVALID_PARAMS + assert error.error.message == "Invalid parameters" + assert error.error.data == {"field": "missing"} + + def test_jsonrpc_message_parsing(self): + """Test JSONRPCMessage parsing different message types.""" + # Parse request + request_json = '{"jsonrpc": "2.0", "id": 1, "method": "test", "params": null}' + msg = JSONRPCMessage.model_validate_json(request_json) + assert isinstance(msg.root, JSONRPCRequest) + + # Parse response + response_json = '{"jsonrpc": "2.0", "id": 1, "result": {"data": "test"}}' + msg = JSONRPCMessage.model_validate_json(response_json) + assert isinstance(msg.root, JSONRPCResponse) + + # Parse error + error_json = '{"jsonrpc": "2.0", "id": 1, "error": {"code": -32600, "message": "Invalid Request"}}' + msg = JSONRPCMessage.model_validate_json(error_json) + assert isinstance(msg.root, JSONRPCError) + + +class TestCapabilities: + """Test capability classes.""" + + def test_client_capabilities(self): + """Test ClientCapabilities creation.""" + caps = ClientCapabilities( + experimental={"feature": {"enabled": True}}, + sampling={"model_config": {"extra": "allow"}}, + roots={"listChanged": True}, + ) + + assert caps.experimental == {"feature": {"enabled": True}} + assert caps.sampling is not None + assert caps.roots.listChanged is True # type: ignore + + def test_server_capabilities(self): + """Test ServerCapabilities creation.""" + caps = ServerCapabilities( + tools={"listChanged": True}, + resources={"subscribe": True, "listChanged": False}, + prompts={"listChanged": True}, + logging={}, + completions={}, + ) + + assert caps.tools.listChanged is True # type: ignore + assert caps.resources.subscribe is True # type: ignore + assert caps.resources.listChanged is False # type: ignore + + +class TestInitialization: + """Test initialization request/response types.""" + + def test_initialize_request(self): + """Test InitializeRequest creation.""" + client_info = Implementation(name="test-client", version="1.0.0") + capabilities = ClientCapabilities() + + params = InitializeRequestParams( + protocolVersion=LATEST_PROTOCOL_VERSION, capabilities=capabilities, clientInfo=client_info + ) + + request = InitializeRequest(params=params) + + assert request.method == "initialize" + assert request.params.protocolVersion == LATEST_PROTOCOL_VERSION + assert request.params.clientInfo.name == "test-client" + + def test_initialize_result(self): + """Test InitializeResult creation.""" + server_info = Implementation(name="test-server", version="1.0.0") + capabilities = ServerCapabilities() + + result = InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=capabilities, + serverInfo=server_info, + instructions="Welcome to test server", + ) + + assert result.protocolVersion == LATEST_PROTOCOL_VERSION + assert result.serverInfo.name == "test-server" + assert result.instructions == "Welcome to test server" + + +class TestTools: + """Test tool-related types.""" + + def test_tool_creation(self): + """Test Tool creation with all fields.""" + tool = Tool( + name="test_tool", + title="Test Tool", + description="A tool for testing", + inputSchema={"type": "object", "properties": {"input": {"type": "string"}}, "required": ["input"]}, + outputSchema={"type": "object", "properties": {"result": {"type": "string"}}}, + annotations=ToolAnnotations( + title="Test Tool", readOnlyHint=False, destructiveHint=False, idempotentHint=True + ), + ) + + assert tool.name == "test_tool" + assert tool.title == "Test Tool" + assert tool.description == "A tool for testing" + assert tool.inputSchema["properties"]["input"]["type"] == "string" + assert tool.annotations.idempotentHint is True + + def test_call_tool_request(self): + """Test CallToolRequest creation.""" + params = CallToolRequestParams(name="test_tool", arguments={"input": "test value"}) + + request = CallToolRequest(params=params) + + assert request.method == "tools/call" + assert request.params.name == "test_tool" + assert request.params.arguments == {"input": "test value"} + + def test_call_tool_result(self): + """Test CallToolResult creation.""" + result = CallToolResult( + content=[TextContent(type="text", text="Tool executed successfully")], + structuredContent={"status": "success", "data": "test"}, + isError=False, + ) + + assert len(result.content) == 1 + assert result.content[0].text == "Tool executed successfully" # type: ignore + assert result.structuredContent == {"status": "success", "data": "test"} + assert result.isError is False + + def test_list_tools_request(self): + """Test ListToolsRequest creation.""" + request = ListToolsRequest() + assert request.method == "tools/list" + + def test_list_tools_result(self): + """Test ListToolsResult creation.""" + tool1 = Tool(name="tool1", inputSchema={}) + tool2 = Tool(name="tool2", inputSchema={}) + + result = ListToolsResult(tools=[tool1, tool2]) + + assert len(result.tools) == 2 + assert result.tools[0].name == "tool1" + assert result.tools[1].name == "tool2" + + +class TestContent: + """Test content types.""" + + def test_text_content(self): + """Test TextContent creation.""" + annotations = Annotations(audience=["user"], priority=0.8) + content = TextContent(type="text", text="Hello, world!", annotations=annotations) + + assert content.type == "text" + assert content.text == "Hello, world!" + assert content.annotations is not None + assert content.annotations.priority == 0.8 + + def test_image_content(self): + """Test ImageContent creation.""" + content = ImageContent(type="image", data="base64encodeddata", mimeType="image/png") + + assert content.type == "image" + assert content.data == "base64encodeddata" + assert content.mimeType == "image/png" + + +class TestOAuth: + """Test OAuth-related types.""" + + def test_oauth_client_metadata(self): + """Test OAuthClientMetadata creation.""" + metadata = OAuthClientMetadata( + client_name="Test Client", + redirect_uris=["https://example.com/callback"], + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + token_endpoint_auth_method="none", + client_uri="https://example.com", + scope="read write", + ) + + assert metadata.client_name == "Test Client" + assert len(metadata.redirect_uris) == 1 + assert "authorization_code" in metadata.grant_types + + def test_oauth_client_information(self): + """Test OAuthClientInformation creation.""" + info = OAuthClientInformation(client_id="test-client-id", client_secret="test-secret") + + assert info.client_id == "test-client-id" + assert info.client_secret == "test-secret" + + def test_oauth_client_information_without_secret(self): + """Test OAuthClientInformation without secret.""" + info = OAuthClientInformation(client_id="public-client") + + assert info.client_id == "public-client" + assert info.client_secret is None + + def test_oauth_tokens(self): + """Test OAuthTokens creation.""" + tokens = OAuthTokens( + access_token="access-token-123", + token_type="Bearer", + expires_in=3600, + refresh_token="refresh-token-456", + scope="read write", + ) + + assert tokens.access_token == "access-token-123" + assert tokens.token_type == "Bearer" + assert tokens.expires_in == 3600 + assert tokens.refresh_token == "refresh-token-456" + assert tokens.scope == "read write" + + def test_oauth_metadata(self): + """Test OAuthMetadata creation.""" + metadata = OAuthMetadata( + authorization_endpoint="https://auth.example.com/authorize", + token_endpoint="https://auth.example.com/token", + registration_endpoint="https://auth.example.com/register", + response_types_supported=["code", "token"], + grant_types_supported=["authorization_code", "refresh_token"], + code_challenge_methods_supported=["plain", "S256"], + ) + + assert metadata.authorization_endpoint == "https://auth.example.com/authorize" + assert "code" in metadata.response_types_supported + assert "S256" in metadata.code_challenge_methods_supported + + +class TestNotifications: + """Test notification types.""" + + def test_progress_notification(self): + """Test ProgressNotification creation.""" + params = ProgressNotificationParams( + progressToken="progress-123", progress=50.0, total=100.0, message="Processing... 50%" + ) + + notification = ProgressNotification(params=params) + + assert notification.method == "notifications/progress" + assert notification.params.progressToken == "progress-123" + assert notification.params.progress == 50.0 + assert notification.params.total == 100.0 + assert notification.params.message == "Processing... 50%" + + def test_ping_request(self): + """Test PingRequest creation.""" + request = PingRequest() + assert request.method == "ping" + assert request.params is None + + +class TestCompletion: + """Test completion-related types.""" + + def test_completion_context(self): + """Test CompletionContext creation.""" + context = CompletionContext(arguments={"template_var": "value"}) + assert context.arguments == {"template_var": "value"} + + def test_resource_template_reference(self): + """Test ResourceTemplateReference creation.""" + ref = ResourceTemplateReference(type="ref/resource", uri="file:///path/to/{filename}") + assert ref.type == "ref/resource" + assert ref.uri == "file:///path/to/{filename}" + + def test_prompt_reference(self): + """Test PromptReference creation.""" + ref = PromptReference(type="ref/prompt", name="test_prompt") + assert ref.type == "ref/prompt" + assert ref.name == "test_prompt" + + def test_complete_request(self): + """Test CompleteRequest creation.""" + ref = PromptReference(type="ref/prompt", name="test_prompt") + arg = CompletionArgument(name="arg1", value="val") + + params = CompleteRequestParams(ref=ref, argument=arg, context=CompletionContext(arguments={"key": "value"})) + + request = CompleteRequest(params=params) + + assert request.method == "completion/complete" + assert request.params.ref.name == "test_prompt" # type: ignore + assert request.params.argument.name == "arg1" + + def test_complete_result(self): + """Test CompleteResult creation.""" + completion = Completion(values=["option1", "option2", "option3"], total=10, hasMore=True) + + result = CompleteResult(completion=completion) + + assert len(result.completion.values) == 3 + assert result.completion.total == 10 + assert result.completion.hasMore is True + + +class TestValidation: + """Test validation of various types.""" + + def test_invalid_jsonrpc_version(self): + """Test invalid JSON-RPC version validation.""" + with pytest.raises(ValidationError): + JSONRPCRequest( + jsonrpc="1.0", # Invalid version + id=1, + method="test", + ) + + def test_tool_annotations_validation(self): + """Test ToolAnnotations with invalid values.""" + # Valid annotations + annotations = ToolAnnotations( + title="Test", readOnlyHint=True, destructiveHint=False, idempotentHint=True, openWorldHint=False + ) + assert annotations.title == "Test" + + def test_extra_fields_allowed(self): + """Test that extra fields are allowed in models.""" + # Most models should allow extra fields + tool = Tool( + name="test", + inputSchema={}, + customField="allowed", # type: ignore + ) + assert tool.customField == "allowed" # type: ignore + + def test_result_meta_alias(self): + """Test Result model with _meta alias.""" + # Create with the field name (not alias) + result = Result(_meta={"key": "value"}) + + # Verify the field is set correctly + assert result.meta == {"key": "value"} + + # Dump with alias + dumped = result.model_dump(by_alias=True) + assert "_meta" in dumped + assert dumped["_meta"] == {"key": "value"} diff --git a/api/tests/unit_tests/core/mcp/test_utils.py b/api/tests/unit_tests/core/mcp/test_utils.py new file mode 100644 index 0000000000..ca41d5f4c1 --- /dev/null +++ b/api/tests/unit_tests/core/mcp/test_utils.py @@ -0,0 +1,355 @@ +"""Unit tests for MCP utils module.""" + +import json +from collections.abc import Generator +from unittest.mock import MagicMock, Mock, patch + +import httpx +import httpx_sse +import pytest + +from core.mcp.utils import ( + STATUS_FORCELIST, + create_mcp_error_response, + create_ssrf_proxy_mcp_http_client, + ssrf_proxy_sse_connect, +) + + +class TestConstants: + """Test module constants.""" + + def test_status_forcelist(self): + """Test STATUS_FORCELIST contains expected HTTP status codes.""" + assert STATUS_FORCELIST == [429, 500, 502, 503, 504] + assert 429 in STATUS_FORCELIST # Too Many Requests + assert 500 in STATUS_FORCELIST # Internal Server Error + assert 502 in STATUS_FORCELIST # Bad Gateway + assert 503 in STATUS_FORCELIST # Service Unavailable + assert 504 in STATUS_FORCELIST # Gateway Timeout + + +class TestCreateSSRFProxyMCPHTTPClient: + """Test create_ssrf_proxy_mcp_http_client function.""" + + @patch("core.mcp.utils.dify_config") + def test_create_client_with_all_url_proxy(self, mock_config): + """Test client creation with SSRF_PROXY_ALL_URL configured.""" + mock_config.SSRF_PROXY_ALL_URL = "http://proxy.example.com:8080" + mock_config.HTTP_REQUEST_NODE_SSL_VERIFY = True + + client = create_ssrf_proxy_mcp_http_client( + headers={"Authorization": "Bearer token"}, timeout=httpx.Timeout(30.0) + ) + + assert isinstance(client, httpx.Client) + assert client.headers["Authorization"] == "Bearer token" + assert client.timeout.connect == 30.0 + assert client.follow_redirects is True + + # Clean up + client.close() + + @patch("core.mcp.utils.dify_config") + def test_create_client_with_http_https_proxies(self, mock_config): + """Test client creation with separate HTTP/HTTPS proxies.""" + mock_config.SSRF_PROXY_ALL_URL = None + mock_config.SSRF_PROXY_HTTP_URL = "http://http-proxy.example.com:8080" + mock_config.SSRF_PROXY_HTTPS_URL = "http://https-proxy.example.com:8443" + mock_config.HTTP_REQUEST_NODE_SSL_VERIFY = False + + client = create_ssrf_proxy_mcp_http_client() + + assert isinstance(client, httpx.Client) + assert client.follow_redirects is True + + # Clean up + client.close() + + @patch("core.mcp.utils.dify_config") + def test_create_client_without_proxy(self, mock_config): + """Test client creation without proxy configuration.""" + mock_config.SSRF_PROXY_ALL_URL = None + mock_config.SSRF_PROXY_HTTP_URL = None + mock_config.SSRF_PROXY_HTTPS_URL = None + mock_config.HTTP_REQUEST_NODE_SSL_VERIFY = True + + headers = {"X-Custom-Header": "value"} + timeout = httpx.Timeout(timeout=30.0, connect=5.0, read=10.0, write=30.0) + + client = create_ssrf_proxy_mcp_http_client(headers=headers, timeout=timeout) + + assert isinstance(client, httpx.Client) + assert client.headers["X-Custom-Header"] == "value" + assert client.timeout.connect == 5.0 + assert client.timeout.read == 10.0 + assert client.follow_redirects is True + + # Clean up + client.close() + + @patch("core.mcp.utils.dify_config") + def test_create_client_default_params(self, mock_config): + """Test client creation with default parameters.""" + mock_config.SSRF_PROXY_ALL_URL = None + mock_config.SSRF_PROXY_HTTP_URL = None + mock_config.SSRF_PROXY_HTTPS_URL = None + mock_config.HTTP_REQUEST_NODE_SSL_VERIFY = True + + client = create_ssrf_proxy_mcp_http_client() + + assert isinstance(client, httpx.Client) + # httpx.Client adds default headers, so we just check it's a Headers object + assert isinstance(client.headers, httpx.Headers) + # When no timeout is provided, httpx uses its default timeout + assert client.timeout is not None + + # Clean up + client.close() + + +class TestSSRFProxySSEConnect: + """Test ssrf_proxy_sse_connect function.""" + + @patch("core.mcp.utils.connect_sse") + @patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client") + def test_sse_connect_with_provided_client(self, mock_create_client, mock_connect_sse): + """Test SSE connection with pre-configured client.""" + # Setup mocks + mock_client = Mock(spec=httpx.Client) + mock_event_source = Mock(spec=httpx_sse.EventSource) + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_event_source + mock_connect_sse.return_value = mock_context + + # Call with provided client + result = ssrf_proxy_sse_connect( + "http://example.com/sse", client=mock_client, method="POST", headers={"Authorization": "Bearer token"} + ) + + # Verify client creation was not called + mock_create_client.assert_not_called() + + # Verify connect_sse was called correctly + mock_connect_sse.assert_called_once_with( + mock_client, "POST", "http://example.com/sse", headers={"Authorization": "Bearer token"} + ) + + # Verify result + assert result == mock_context + + @patch("core.mcp.utils.connect_sse") + @patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client") + @patch("core.mcp.utils.dify_config") + def test_sse_connect_without_client(self, mock_config, mock_create_client, mock_connect_sse): + """Test SSE connection without pre-configured client.""" + # Setup config + mock_config.SSRF_DEFAULT_TIME_OUT = 30.0 + mock_config.SSRF_DEFAULT_CONNECT_TIME_OUT = 10.0 + mock_config.SSRF_DEFAULT_READ_TIME_OUT = 60.0 + mock_config.SSRF_DEFAULT_WRITE_TIME_OUT = 30.0 + + # Setup mocks + mock_client = Mock(spec=httpx.Client) + mock_create_client.return_value = mock_client + + mock_event_source = Mock(spec=httpx_sse.EventSource) + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_event_source + mock_connect_sse.return_value = mock_context + + # Call without client + result = ssrf_proxy_sse_connect("http://example.com/sse", headers={"X-Custom": "value"}) + + # Verify client was created + mock_create_client.assert_called_once() + call_args = mock_create_client.call_args + assert call_args[1]["headers"] == {"X-Custom": "value"} + + timeout = call_args[1]["timeout"] + # httpx.Timeout object has these attributes + assert isinstance(timeout, httpx.Timeout) + assert timeout.connect == 10.0 + assert timeout.read == 60.0 + assert timeout.write == 30.0 + + # Verify connect_sse was called + mock_connect_sse.assert_called_once_with( + mock_client, + "GET", # Default method + "http://example.com/sse", + ) + + # Verify result + assert result == mock_context + + @patch("core.mcp.utils.connect_sse") + @patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client") + def test_sse_connect_with_custom_timeout(self, mock_create_client, mock_connect_sse): + """Test SSE connection with custom timeout.""" + # Setup mocks + mock_client = Mock(spec=httpx.Client) + mock_create_client.return_value = mock_client + + mock_event_source = Mock(spec=httpx_sse.EventSource) + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_event_source + mock_connect_sse.return_value = mock_context + + custom_timeout = httpx.Timeout(timeout=60.0, read=120.0) + + # Call with custom timeout + result = ssrf_proxy_sse_connect("http://example.com/sse", timeout=custom_timeout) + + # Verify client was created with custom timeout + mock_create_client.assert_called_once() + call_args = mock_create_client.call_args + assert call_args[1]["timeout"] == custom_timeout + + # Verify result + assert result == mock_context + + @patch("core.mcp.utils.connect_sse") + @patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client") + def test_sse_connect_error_cleanup(self, mock_create_client, mock_connect_sse): + """Test SSE connection cleans up client on error.""" + # Setup mocks + mock_client = Mock(spec=httpx.Client) + mock_create_client.return_value = mock_client + + # Make connect_sse raise an exception + mock_connect_sse.side_effect = httpx.ConnectError("Connection failed") + + # Call should raise the exception + with pytest.raises(httpx.ConnectError): + ssrf_proxy_sse_connect("http://example.com/sse") + + # Verify client was cleaned up + mock_client.close.assert_called_once() + + @patch("core.mcp.utils.connect_sse") + def test_sse_connect_error_no_cleanup_with_provided_client(self, mock_connect_sse): + """Test SSE connection doesn't clean up provided client on error.""" + # Setup mocks + mock_client = Mock(spec=httpx.Client) + + # Make connect_sse raise an exception + mock_connect_sse.side_effect = httpx.ConnectError("Connection failed") + + # Call should raise the exception + with pytest.raises(httpx.ConnectError): + ssrf_proxy_sse_connect("http://example.com/sse", client=mock_client) + + # Verify client was NOT cleaned up (because it was provided) + mock_client.close.assert_not_called() + + +class TestCreateMCPErrorResponse: + """Test create_mcp_error_response function.""" + + def test_create_error_response_basic(self): + """Test creating basic error response.""" + generator = create_mcp_error_response(request_id="req-123", code=-32600, message="Invalid Request") + + # Generator should yield bytes + assert isinstance(generator, Generator) + + # Get the response + response_bytes = next(generator) + assert isinstance(response_bytes, bytes) + + # Parse the response + response_str = response_bytes.decode("utf-8") + response_json = json.loads(response_str) + + assert response_json["jsonrpc"] == "2.0" + assert response_json["id"] == "req-123" + assert response_json["error"]["code"] == -32600 + assert response_json["error"]["message"] == "Invalid Request" + assert response_json["error"]["data"] is None + + # Generator should be exhausted + with pytest.raises(StopIteration): + next(generator) + + def test_create_error_response_with_data(self): + """Test creating error response with additional data.""" + error_data = {"field": "username", "reason": "required"} + + generator = create_mcp_error_response( + request_id=456, # Numeric ID + code=-32602, + message="Invalid params", + data=error_data, + ) + + response_bytes = next(generator) + response_json = json.loads(response_bytes.decode("utf-8")) + + assert response_json["id"] == 456 + assert response_json["error"]["code"] == -32602 + assert response_json["error"]["message"] == "Invalid params" + assert response_json["error"]["data"] == error_data + + def test_create_error_response_without_request_id(self): + """Test creating error response without request ID.""" + generator = create_mcp_error_response(request_id=None, code=-32700, message="Parse error") + + response_bytes = next(generator) + response_json = json.loads(response_bytes.decode("utf-8")) + + # Should default to ID 1 + assert response_json["id"] == 1 + assert response_json["error"]["code"] == -32700 + assert response_json["error"]["message"] == "Parse error" + + def test_create_error_response_with_complex_data(self): + """Test creating error response with complex error data.""" + complex_data = { + "errors": [{"field": "name", "message": "Too short"}, {"field": "email", "message": "Invalid format"}], + "timestamp": "2024-01-01T00:00:00Z", + } + + generator = create_mcp_error_response( + request_id="complex-req", code=-32602, message="Validation failed", data=complex_data + ) + + response_bytes = next(generator) + response_json = json.loads(response_bytes.decode("utf-8")) + + assert response_json["error"]["data"] == complex_data + assert len(response_json["error"]["data"]["errors"]) == 2 + + def test_create_error_response_encoding(self): + """Test error response with non-ASCII characters.""" + generator = create_mcp_error_response( + request_id="unicode-req", + code=-32603, + message="内部错误", # Chinese characters + data={"details": "エラー詳細"}, # Japanese characters + ) + + response_bytes = next(generator) + + # Should be valid UTF-8 + response_str = response_bytes.decode("utf-8") + response_json = json.loads(response_str) + + assert response_json["error"]["message"] == "内部错误" + assert response_json["error"]["data"]["details"] == "エラー詳細" + + def test_create_error_response_yields_once(self): + """Test that error response generator yields exactly once.""" + generator = create_mcp_error_response(request_id="test", code=-32600, message="Test") + + # First yield should work + first_yield = next(generator) + assert isinstance(first_yield, bytes) + + # Second yield should raise StopIteration + with pytest.raises(StopIteration): + next(generator) + + # Subsequent calls should also raise + with pytest.raises(StopIteration): + next(generator) diff --git a/api/tests/unit_tests/services/tools/test_mcp_tools_transform.py b/api/tests/unit_tests/services/tools/test_mcp_tools_transform.py index fb0139932b..7511fd6f0c 100644 --- a/api/tests/unit_tests/services/tools/test_mcp_tools_transform.py +++ b/api/tests/unit_tests/services/tools/test_mcp_tools_transform.py @@ -180,6 +180,25 @@ class TestMCPToolTransform: # Set tools data with null description mock_provider_full.tools = '[{"name": "tool1", "description": null, "inputSchema": {}}]' + # Mock the to_entity and to_api_response methods + mock_entity = Mock() + mock_entity.to_api_response.return_value = { + "name": "Test MCP Provider", + "type": ToolProviderType.MCP, + "is_team_authorization": True, + "server_url": "https://*****.com/mcp", + "provider_icon": "icon.png", + "masked_headers": {"Authorization": "Bearer *****"}, + "updated_at": 1234567890, + "labels": [], + "author": "Test User", + "description": I18nObject(en_US="Test MCP Provider Description", zh_Hans="Test MCP Provider Description"), + "icon": "icon.png", + "label": I18nObject(en_US="Test MCP Provider", zh_Hans="Test MCP Provider"), + "masked_credentials": {}, + } + mock_provider_full.to_entity.return_value = mock_entity + # Call the method with for_list=True result = ToolTransformService.mcp_provider_to_user_provider(mock_provider_full, for_list=True) @@ -198,6 +217,27 @@ class TestMCPToolTransform: # Set tools data with description mock_provider_full.tools = '[{"name": "tool1", "description": "Tool description", "inputSchema": {}}]' + # Mock the to_entity and to_api_response methods + mock_entity = Mock() + mock_entity.to_api_response.return_value = { + "name": "Test MCP Provider", + "type": ToolProviderType.MCP, + "is_team_authorization": True, + "server_url": "https://*****.com/mcp", + "provider_icon": "icon.png", + "masked_headers": {"Authorization": "Bearer *****"}, + "updated_at": 1234567890, + "labels": [], + "configuration": {"timeout": "30", "sse_read_timeout": "300"}, + "original_headers": {"Authorization": "Bearer secret-token"}, + "author": "Test User", + "description": I18nObject(en_US="Test MCP Provider Description", zh_Hans="Test MCP Provider Description"), + "icon": "icon.png", + "label": I18nObject(en_US="Test MCP Provider", zh_Hans="Test MCP Provider"), + "masked_credentials": {}, + } + mock_provider_full.to_entity.return_value = mock_entity + # Call the method with for_list=False result = ToolTransformService.mcp_provider_to_user_provider(mock_provider_full, for_list=False) @@ -205,8 +245,9 @@ class TestMCPToolTransform: assert isinstance(result, ToolProviderApiEntity) assert result.id == "server-identifier-456" # Should use server_identifier when for_list=False assert result.server_identifier == "server-identifier-456" - assert result.timeout == 30 - assert result.sse_read_timeout == 300 + assert result.configuration is not None + assert result.configuration.timeout == 30 + assert result.configuration.sse_read_timeout == 300 assert result.original_headers == {"Authorization": "Bearer secret-token"} assert len(result.tools) == 1 assert result.tools[0].description.en_US == "Tool description"