diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 54d64e7085..a8d4f0f5de 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -30,7 +30,7 @@ from models.provider_ids import ToolProviderID from services.plugin.oauth_service import OAuthProxyService from services.tools.api_tools_manage_service import ApiToolManageService from services.tools.builtin_tools_manage_service import BuiltinToolManageService -from services.tools.mcp_tools_manage_service import MCPToolManageService +from services.tools.mcp_tools_manage_service import MCPToolManageService, OAuthDataType from services.tools.tool_labels_service import ToolLabelsService from services.tools.tools_manage_service import ToolCommonService from services.tools.tools_transform_service import ToolTransformService @@ -897,10 +897,6 @@ class ToolProviderMCPApi(Resource): args = parser.parse_args() user, tenant_id = current_account_with_tenant() - # Validate server URL - if not is_valid_url(args["server_url"]): - raise ValueError("Server URL is not valid.") - # Parse and validate models configuration = MCPConfiguration.model_validate(args["configuration"]) authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None @@ -941,15 +937,21 @@ class ToolProviderMCPApi(Resource): .add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={}) ) args = parser.parse_args() - if not is_valid_url(args["server_url"]): - if "[__HIDDEN__]" in args["server_url"]: - pass - else: - raise ValueError("Server URL is not valid.") configuration = MCPConfiguration.model_validate(args["configuration"]) authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None _, current_tenant_id = current_account_with_tenant() + # Step 1: Validate server URL change if needed (includes URL format validation and network operation) + validation_result = None + with Session(db.engine) as session: + service = MCPToolManageService(session=session) + validation_result = service.validate_server_url_change( + tenant_id=current_tenant_id, provider_id=args["provider_id"], new_server_url=args["server_url"] + ) + + # No need to check for errors here, exceptions will be raised directly + + # Step 2: Perform database update in a transaction with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) service.update_provider( @@ -964,6 +966,7 @@ class ToolProviderMCPApi(Resource): headers=args["headers"], configuration=configuration, authentication=authentication, + validation_result=validation_result, ) return {"result": "success"} @@ -998,47 +1001,49 @@ class ToolMCPAuthApi(Resource): provider_id = args["provider_id"] _, tenant_id = current_account_with_tenant() - with Session(db.engine) as session: - with session.begin(): - service = MCPToolManageService(session=session) - db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id) - if not db_provider: - raise ValueError("provider not found") + with Session(db.engine) as session, session.begin(): + service = MCPToolManageService(session=session) + db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id) + if not db_provider: + raise ValueError("provider not found") - # Convert to entity - provider_entity = db_provider.to_entity() - server_url = provider_entity.decrypt_server_url() - headers = provider_entity.decrypt_authentication() + # Convert to entity + provider_entity = db_provider.to_entity() + server_url = provider_entity.decrypt_server_url() + headers = provider_entity.decrypt_authentication() - # Try to connect without active transaction + # Try to connect without active transaction + try: + # Use MCPClientWithAuthRetry to handle authentication automatically + with MCPClient( + server_url=server_url, + headers=headers, + timeout=provider_entity.timeout, + sse_read_timeout=provider_entity.sse_read_timeout, + ): + # Create new transaction for update + with session.begin(): + service.update_provider_credentials( + provider=db_provider, + credentials=provider_entity.credentials, + authed=True, + ) + return {"result": "success"} + except MCPAuthError as e: + service = MCPToolManageService(session=session) try: - # Use MCPClientWithAuthRetry to handle authentication automatically - with MCPClient( - server_url=server_url, - headers=headers, - timeout=provider_entity.timeout, - sse_read_timeout=provider_entity.sse_read_timeout, - ): - # Create new transaction for update - with session.begin(): - service.update_provider_credentials( - provider=db_provider, - credentials=provider_entity.credentials, - authed=True, - ) - return {"result": "success"} - except MCPAuthError as e: - service = MCPToolManageService(session=session) - try: - return auth(provider_entity, service, args.get("authorization_code")) - except MCPRefreshTokenError as e: - with session.begin(): - service.clear_provider_credentials(provider=db_provider) - raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e - except MCPError as e: + auth_result = auth(provider_entity, args.get("authorization_code")) + with session.begin(): + response = service.execute_auth_actions(auth_result) + return response + except MCPRefreshTokenError as e: with session.begin(): service.clear_provider_credentials(provider=db_provider) - raise ValueError(f"Failed to connect to MCP server: {e}") from e + raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e + except MCPError as e: + with session.begin(): + service.clear_provider_credentials(provider=db_provider) + raise ValueError(f"Failed to connect to MCP server: {e}") from e @console_ns.route("/workspaces/current/tool-provider/mcp/tools/") @@ -1048,7 +1053,7 @@ class ToolMCPDetailApi(Resource): @account_initialization_required def get(self, provider_id): _, tenant_id = current_account_with_tenant() - with Session(db.engine) as session: + with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id) return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True)) @@ -1062,7 +1067,7 @@ class ToolMCPListAllApi(Resource): def get(self): _, tenant_id = current_account_with_tenant() - with Session(db.engine, expire_on_commit=False) as session: + with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) tools = service.list_providers(tenant_id=tenant_id) @@ -1100,6 +1105,11 @@ class ToolMCPCallbackApi(Resource): # Create service instance for handle_callback with Session(db.engine) as session, session.begin(): mcp_service = MCPToolManageService(session=session) - handle_callback(state_key, authorization_code, mcp_service) + # handle_callback now returns state data and tokens + state_data, tokens = handle_callback(state_key, authorization_code) + # Save tokens using the service layer + mcp_service.save_oauth_data( + state_data.provider_id, state_data.tenant_id, tokens.model_dump(), OAuthDataType.TOKENS + ) return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") diff --git a/api/core/mcp/auth/auth_flow.py b/api/core/mcp/auth/auth_flow.py index a1fcd6e033..951c22f6dd 100644 --- a/api/core/mcp/auth/auth_flow.py +++ b/api/core/mcp/auth/auth_flow.py @@ -4,14 +4,14 @@ import json import os import secrets import urllib.parse -from typing import TYPE_CHECKING from urllib.parse import urljoin, urlparse from httpx import ConnectError, HTTPStatusError, RequestError -from pydantic import BaseModel, ValidationError +from pydantic import ValidationError from core.entities.mcp_provider import MCPProviderEntity, MCPSupportGrantType from core.helper import ssrf_proxy +from core.mcp.entities import AuthAction, AuthActionType, AuthResult, OAuthCallbackState from core.mcp.error import MCPRefreshTokenError from core.mcp.types import ( LATEST_PROTOCOL_VERSION, @@ -23,23 +23,10 @@ from core.mcp.types import ( ) from extensions.ext_redis import redis_client -if TYPE_CHECKING: - from services.tools.mcp_tools_manage_service import MCPToolManageService - OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:" -class OAuthCallbackState(BaseModel): - provider_id: str - tenant_id: str - server_url: str - metadata: OAuthMetadata | None = None - client_information: OAuthClientInformation - code_verifier: str - redirect_uri: str - - def generate_pkce_challenge() -> tuple[str, str]: """Generate PKCE challenge and verifier.""" code_verifier = base64.urlsafe_b64encode(os.urandom(40)).decode("utf-8") @@ -86,8 +73,13 @@ def _retrieve_redis_state(state_key: str) -> OAuthCallbackState: raise ValueError(f"Invalid state parameter: {str(e)}") -def handle_callback(state_key: str, authorization_code: str, mcp_service: "MCPToolManageService") -> OAuthCallbackState: - """Handle the callback from the OAuth provider.""" +def handle_callback(state_key: str, authorization_code: str) -> tuple[OAuthCallbackState, OAuthTokens]: + """ + Handle the callback from the OAuth provider. + + Returns: + A tuple of (callback_state, tokens) that can be used by the caller to save data. + """ # Retrieve state data from Redis (state is automatically deleted after retrieval) full_state_data = _retrieve_redis_state(state_key) @@ -100,10 +92,7 @@ def handle_callback(state_key: str, authorization_code: str, mcp_service: "MCPTo full_state_data.redirect_uri, ) - # Save tokens using the service layer - mcp_service.save_oauth_data(full_state_data.provider_id, full_state_data.tenant_id, tokens.model_dump(), "tokens") - - return full_state_data + return full_state_data, tokens def check_support_resource_discovery(server_url: str) -> tuple[bool, str]: @@ -361,11 +350,24 @@ def register_client( def auth( provider: MCPProviderEntity, - mcp_service: "MCPToolManageService", authorization_code: str | None = None, state_param: str | None = None, -) -> dict[str, str]: - """Orchestrates the full auth flow with a server using secure Redis state storage.""" +) -> AuthResult: + """ + Orchestrates the full auth flow with a server using secure Redis state storage. + + This function performs only network operations and returns actions that need + to be performed by the caller (such as saving data to database). + + Args: + provider: The MCP provider entity + authorization_code: Optional authorization code from OAuth callback + state_param: Optional state parameter from OAuth callback + + Returns: + AuthResult containing actions to be performed and response data + """ + actions: list[AuthAction] = [] server_url = provider.decrypt_server_url() server_metadata = discover_oauth_metadata(server_url) client_metadata = provider.client_metadata @@ -407,9 +409,14 @@ def auth( except RequestError as e: raise ValueError(f"Could not register OAuth client: {e}") - # Save client information using service layer - mcp_service.save_oauth_data( - provider_id, tenant_id, {"client_information": full_information.model_dump()}, "client_info" + # Return action to save client information + actions.append( + AuthAction( + action_type=AuthActionType.SAVE_CLIENT_INFO, + data={"client_information": full_information.model_dump()}, + provider_id=provider_id, + tenant_id=tenant_id, + ) ) client_information = full_information @@ -426,12 +433,20 @@ def auth( scope, ) - # Save tokens and grant type + # Return action to save tokens and grant type token_data = tokens.model_dump() token_data["grant_type"] = MCPSupportGrantType.CLIENT_CREDENTIALS.value - mcp_service.save_oauth_data(provider_id, tenant_id, token_data, "tokens") - return {"result": "success"} + actions.append( + AuthAction( + action_type=AuthActionType.SAVE_TOKENS, + data=token_data, + provider_id=provider_id, + tenant_id=tenant_id, + ) + ) + + return AuthResult(actions=actions, response={"result": "success"}) except (RequestError, ValueError, KeyError) as e: # RequestError: HTTP request failed # ValueError: Invalid response data @@ -465,10 +480,17 @@ def auth( redirect_uri, ) - # Save tokens using service layer - mcp_service.save_oauth_data(provider_id, tenant_id, tokens.model_dump(), "tokens") + # Return action to save tokens + actions.append( + AuthAction( + action_type=AuthActionType.SAVE_TOKENS, + data=tokens.model_dump(), + provider_id=provider_id, + tenant_id=tenant_id, + ) + ) - return {"result": "success"} + return AuthResult(actions=actions, response={"result": "success"}) provider_tokens = provider.retrieve_tokens() @@ -479,10 +501,17 @@ def auth( server_url, server_metadata, client_information, provider_tokens.refresh_token ) - # Save new tokens using service layer - mcp_service.save_oauth_data(provider_id, tenant_id, new_tokens.model_dump(), "tokens") + # Return action to save new tokens + actions.append( + AuthAction( + action_type=AuthActionType.SAVE_TOKENS, + data=new_tokens.model_dump(), + provider_id=provider_id, + tenant_id=tenant_id, + ) + ) - return {"result": "success"} + return AuthResult(actions=actions, response={"result": "success"}) except (RequestError, ValueError, KeyError) as e: # RequestError: HTTP request failed # ValueError: Invalid response data @@ -499,7 +528,14 @@ def auth( tenant_id, ) - # Save code verifier using service layer - mcp_service.save_oauth_data(provider_id, tenant_id, {"code_verifier": code_verifier}, "code_verifier") + # Return action to save code verifier + actions.append( + AuthAction( + action_type=AuthActionType.SAVE_CODE_VERIFIER, + data={"code_verifier": code_verifier}, + provider_id=provider_id, + tenant_id=tenant_id, + ) + ) - return {"authorization_url": authorization_url} + return AuthResult(actions=actions, response={"authorization_url": authorization_url}) diff --git a/api/core/mcp/auth_client.py b/api/core/mcp/auth_client.py index 95f552f5db..942c8d3c23 100644 --- a/api/core/mcp/auth_client.py +++ b/api/core/mcp/auth_client.py @@ -7,15 +7,15 @@ authentication failures and retries operations after refreshing tokens. import logging from collections.abc import Callable -from typing import TYPE_CHECKING, Any, Optional +from typing import Any + +from sqlalchemy.orm import Session from core.entities.mcp_provider import MCPProviderEntity from core.mcp.error import MCPAuthError from core.mcp.mcp_client import MCPClient from core.mcp.types import CallToolResult, Tool - -if TYPE_CHECKING: - from services.tools.mcp_tools_manage_service import MCPToolManageService +from extensions.ext_database import db logger = logging.getLogger(__name__) @@ -26,6 +26,9 @@ class MCPClientWithAuthRetry(MCPClient): This class extends MCPClient and intercepts MCPAuthError exceptions to refresh authentication before retrying failed operations. + + Note: This class uses lazy session creation - database sessions are only + created when authentication retry is actually needed, not on every request. """ def __init__( @@ -35,11 +38,8 @@ class MCPClientWithAuthRetry(MCPClient): timeout: float | None = None, sse_read_timeout: float | None = None, provider_entity: MCPProviderEntity | None = None, - auth_callback: Callable[[MCPProviderEntity, "MCPToolManageService", Optional[str]], dict[str, str]] - | None = None, authorization_code: str | None = None, by_server_id: bool = False, - mcp_service: Optional["MCPToolManageService"] = None, ): """ Initialize the MCP client with auth retry capability. @@ -50,31 +50,30 @@ class MCPClientWithAuthRetry(MCPClient): timeout: Request timeout sse_read_timeout: SSE read timeout provider_entity: Provider entity for authentication - auth_callback: Authentication callback function authorization_code: Optional authorization code for initial auth by_server_id: Whether to look up provider by server ID - mcp_service: MCP service instance """ super().__init__(server_url, headers, timeout, sse_read_timeout) self.provider_entity = provider_entity - self.auth_callback = auth_callback self.authorization_code = authorization_code self.by_server_id = by_server_id - self.mcp_service = mcp_service self._has_retried = False def _handle_auth_error(self, error: MCPAuthError) -> None: """ Handle authentication error by refreshing tokens. + This method creates a short-lived database session only when authentication + retry is needed, minimizing database connection hold time. + Args: error: The authentication error Raises: MCPAuthError: If authentication fails or max retries reached """ - if not self.provider_entity or not self.auth_callback or not self.mcp_service: + if not self.provider_entity: raise error if self._has_retried: raise error @@ -82,13 +81,23 @@ class MCPClientWithAuthRetry(MCPClient): self._has_retried = True try: - # Perform authentication - self.auth_callback(self.provider_entity, self.mcp_service, self.authorization_code) + # Create a temporary session only for auth retry + # This session is short-lived and only exists during the auth operation - # Retrieve new tokens - self.provider_entity = self.mcp_service.get_provider_entity( - self.provider_entity.id, self.provider_entity.tenant_id, by_server_id=self.by_server_id - ) + from services.tools.mcp_tools_manage_service import MCPToolManageService + + with Session(db.engine) as session, session.begin(): + mcp_service = MCPToolManageService(session=session) + + # Perform authentication using the service's auth method + mcp_service.auth_with_actions(self.provider_entity, self.authorization_code) + + # Retrieve new tokens + self.provider_entity = mcp_service.get_provider_entity( + self.provider_entity.id, self.provider_entity.tenant_id, by_server_id=self.by_server_id + ) + + # Session is closed here, before we update headers token = self.provider_entity.retrieve_tokens() if not token: raise MCPAuthError("Authentication failed - no token received") diff --git a/api/core/mcp/entities.py b/api/core/mcp/entities.py index 9e414ab2b3..08823daab1 100644 --- a/api/core/mcp/entities.py +++ b/api/core/mcp/entities.py @@ -1,8 +1,11 @@ from dataclasses import dataclass +from enum import StrEnum from typing import Any, Generic, TypeVar +from pydantic import BaseModel + from core.mcp.session.base_session import BaseSession -from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestId, RequestParams +from core.mcp.types import LATEST_PROTOCOL_VERSION, OAuthClientInformation, OAuthMetadata, RequestId, RequestParams SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION] @@ -17,3 +20,41 @@ class RequestContext(Generic[SessionT, LifespanContextT]): meta: RequestParams.Meta | None session: SessionT lifespan_context: LifespanContextT + + +class AuthActionType(StrEnum): + """Types of actions that can be performed during auth flow.""" + + SAVE_CLIENT_INFO = "save_client_info" + SAVE_TOKENS = "save_tokens" + SAVE_CODE_VERIFIER = "save_code_verifier" + START_AUTHORIZATION = "start_authorization" + SUCCESS = "success" + + +class AuthAction(BaseModel): + """Represents an action that needs to be performed as a result of auth flow.""" + + action_type: AuthActionType + data: dict[str, Any] + provider_id: str | None = None + tenant_id: str | None = None + + +class AuthResult(BaseModel): + """Result of auth function containing actions to be performed and response data.""" + + actions: list[AuthAction] + response: dict[str, str] + + +class OAuthCallbackState(BaseModel): + """State data stored in Redis during OAuth callback flow.""" + + provider_id: str + tenant_id: str + server_url: str + metadata: OAuthMetadata | None = None + client_information: OAuthClientInformation + code_verifier: str + redirect_uri: str diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index 290077ecd8..a476859f29 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -3,7 +3,6 @@ import json from collections.abc import Generator from typing import Any -from core.mcp.auth.auth_flow import auth from core.mcp.auth_client import MCPClientWithAuthRetry from core.mcp.error import MCPConnectionError from core.mcp.types import CallToolResult, ImageContent, TextContent @@ -125,71 +124,39 @@ class MCPTool(Tool): headers = self.headers.copy() if self.headers else {} tool_parameters = self._handle_none_parameter(tool_parameters) - # Get provider entity to access tokens + from sqlalchemy.orm import Session - # Get MCP service from invoke parameters or create new one - provider_entity = None - mcp_service = None + from extensions.ext_database import db + from services.tools.mcp_tools_manage_service import MCPToolManageService - # Check if mcp_service is passed in tool_parameters - if "_mcp_service" in tool_parameters: - mcp_service = tool_parameters.pop("_mcp_service") - if mcp_service: - provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True) - headers = provider_entity.decrypt_headers() - # Try to get existing token and add to headers - if not headers: - tokens = provider_entity.retrieve_tokens() - if tokens and tokens.access_token: - headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}" + # Step 1: Load provider entity and credentials in a short-lived session + # This minimizes database connection hold time + with Session(db.engine, expire_on_commit=False) as session: + mcp_service = MCPToolManageService(session=session) + provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True) - # Use MCPClientWithAuthRetry to handle authentication automatically - try: - with MCPClientWithAuthRetry( - server_url=provider_entity.decrypt_server_url() if provider_entity else self.server_url, - headers=headers, - timeout=self.timeout, - sse_read_timeout=self.sse_read_timeout, - provider_entity=provider_entity, - auth_callback=auth if mcp_service else None, - mcp_service=mcp_service, - ) as mcp_client: - return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters) - except MCPConnectionError as e: - raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e - except (ValueError, TypeError, KeyError) as e: - # Catch specific exceptions that might occur during tool invocation - raise ToolInvokeError(f"Failed to invoke tool: {e}") from e - else: - # Fallback to creating service with database session - from sqlalchemy.orm import Session + # Decrypt and prepare all credentials before closing session + server_url = provider_entity.decrypt_server_url() + headers = provider_entity.decrypt_headers() - from extensions.ext_database import db - from services.tools.mcp_tools_manage_service import MCPToolManageService + # Try to get existing token and add to headers + if not headers: + tokens = provider_entity.retrieve_tokens() + if tokens and tokens.access_token: + headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}" - with Session(db.engine, expire_on_commit=False) as session: - mcp_service = MCPToolManageService(session=session) - provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True) - headers = provider_entity.decrypt_headers() - # Try to get existing token and add to headers - if not headers: - tokens = provider_entity.retrieve_tokens() - if tokens and tokens.access_token: - headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}" - - # Use MCPClientWithAuthRetry to handle authentication automatically - try: - with MCPClientWithAuthRetry( - server_url=provider_entity.decrypt_server_url() if provider_entity else self.server_url, - headers=headers, - timeout=self.timeout, - sse_read_timeout=self.sse_read_timeout, - provider_entity=provider_entity, - auth_callback=auth if mcp_service else None, - mcp_service=mcp_service, - ) as mcp_client: - return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters) - except MCPConnectionError as e: - raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e - except Exception as e: - raise ToolInvokeError(f"Failed to invoke tool: {e}") from e + # Step 2: Session is now closed, perform network operations without holding database connection + # MCPClientWithAuthRetry will create a new session lazily only if auth retry is needed + try: + with MCPClientWithAuthRetry( + server_url=server_url, + headers=headers, + timeout=self.timeout, + sse_read_timeout=self.sse_read_timeout, + provider_entity=provider_entity, + ) as mcp_client: + return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters) + except MCPConnectionError as e: + raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e + except Exception as e: + raise ToolInvokeError(f"Failed to invoke tool: {e}") from e diff --git a/api/services/tools/mcp_tools_manage_service.py b/api/services/tools/mcp_tools_manage_service.py index 18f4c9250e..b24483b9c6 100644 --- a/api/services/tools/mcp_tools_manage_service.py +++ b/api/services/tools/mcp_tools_manage_service.py @@ -1,10 +1,12 @@ import hashlib import json import logging -from collections.abc import Callable from datetime import datetime +from enum import StrEnum from typing import Any +from urllib.parse import urlparse +from pydantic import BaseModel, Field from sqlalchemy import or_, select from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session @@ -12,6 +14,7 @@ from sqlalchemy.orm import Session from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity from core.helper import encrypter from core.helper.provider_cache import NoOpProviderCredentialCache +from core.mcp.auth.auth_flow import auth from core.mcp.auth_client import MCPClientWithAuthRetry from core.mcp.error import MCPAuthError, MCPError from core.tools.entities.api_entities import ToolProviderApiEntity @@ -28,6 +31,38 @@ EMPTY_TOOLS_JSON = "[]" EMPTY_CREDENTIALS_JSON = "{}" +class OAuthDataType(StrEnum): + """Types of OAuth data that can be saved.""" + + TOKENS = "tokens" + CLIENT_INFO = "client_info" + CODE_VERIFIER = "code_verifier" + MIXED = "mixed" + + +class ReconnectResult(BaseModel): + """Result of reconnecting to an MCP provider""" + + authed: bool = Field(description="Whether the provider is authenticated") + tools: str = Field(description="JSON string of tool list") + encrypted_credentials: str = Field(description="JSON string of encrypted credentials") + + +class ServerUrlValidationResult(BaseModel): + """Result of server URL validation check""" + + needs_validation: bool + validation_passed: bool = False + reconnect_result: ReconnectResult | None = None + encrypted_server_url: str | None = None + server_url_hash: str | None = None + + @property + def should_update_server_url(self) -> bool: + """Check if server URL should be updated based on validation result""" + return self.needs_validation and self.validation_passed and self.reconnect_result is not None + + class MCPToolManageService: """Service class for managing MCP tools and providers.""" @@ -91,6 +126,10 @@ class MCPToolManageService: headers: dict[str, str] | None = None, ) -> ToolProviderApiEntity: """Create a new MCP provider.""" + # Validate URL format + if not self._is_valid_url(server_url): + raise ValueError("Server URL is not valid.") + server_url_hash = hashlib.sha256(server_url.encode()).hexdigest() # Check for existing provider @@ -99,13 +138,12 @@ class MCPToolManageService: # Encrypt sensitive data encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url) encrypted_headers = self._prepare_encrypted_dict(headers, tenant_id) if headers else None - if authentication is not None and authentication.client_id and authentication.client_secret: - # Build the full credentials structure with encrypted client_id and client_secret + encrypted_credentials = None + if authentication is not None and authentication.client_id: encrypted_credentials = self._build_and_encrypt_credentials( authentication.client_id, authentication.client_secret, tenant_id ) - else: - encrypted_credentials = None + # Create provider mcp_tool = MCPToolProvider( tenant_id=tenant_id, @@ -142,24 +180,39 @@ class MCPToolManageService: headers: dict[str, str] | None = None, configuration: MCPConfiguration, authentication: MCPAuthentication | None = None, + validation_result: ServerUrlValidationResult | None = None, ) -> None: - """Update an MCP provider.""" + """ + Update an MCP provider. + + Args: + validation_result: Pre-validation result from validate_server_url_change. + If provided and contains reconnect_result, it will be used + instead of performing network operations. + """ mcp_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) - reconnect_result = None + # Check for duplicate name (excluding current provider) + if name != mcp_provider.name: + stmt = select(MCPToolProvider).where( + MCPToolProvider.tenant_id == tenant_id, + MCPToolProvider.name == name, + MCPToolProvider.id != provider_id, + ) + existing_provider = self._session.scalar(stmt) + if existing_provider: + raise ValueError(f"MCP tool {name} already exists") + + # Get URL update data from validation result encrypted_server_url = None server_url_hash = None + reconnect_result = None - # Handle server URL update - if UNCHANGED_SERVER_URL_PLACEHOLDER not in server_url: - encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url) - server_url_hash = hashlib.sha256(server_url.encode()).hexdigest() - - if server_url_hash != mcp_provider.server_url_hash: - reconnect_result = self._reconnect_provider( - server_url=server_url, - provider=mcp_provider, - ) + if validation_result and validation_result.encrypted_server_url: + # Use all data from validation result + encrypted_server_url = validation_result.encrypted_server_url + server_url_hash = validation_result.server_url_hash + reconnect_result = validation_result.reconnect_result try: # Update basic fields @@ -169,63 +222,35 @@ class MCPToolManageService: mcp_provider.server_identifier = server_identifier # Update server URL if changed - if encrypted_server_url is not None and server_url_hash is not None: + if encrypted_server_url and server_url_hash: mcp_provider.server_url = encrypted_server_url mcp_provider.server_url_hash = server_url_hash if reconnect_result: - mcp_provider.authed = reconnect_result["authed"] - mcp_provider.tools = reconnect_result["tools"] - mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"] + mcp_provider.authed = reconnect_result.authed + mcp_provider.tools = reconnect_result.tools + mcp_provider.encrypted_credentials = reconnect_result.encrypted_credentials - # Update optional fields - if configuration.timeout is not None: - mcp_provider.timeout = configuration.timeout - if configuration.sse_read_timeout is not None: - mcp_provider.sse_read_timeout = configuration.sse_read_timeout + # Update optional configuration fields + self._update_optional_fields(mcp_provider, configuration) + + # Update headers if provided if headers is not None: - if headers: - # Build headers preserving unchanged masked values - final_headers = self._merge_headers_with_masked(incoming_headers=headers, mcp_provider=mcp_provider) - encrypted_headers_dict = self._prepare_encrypted_dict(final_headers, tenant_id) - mcp_provider.encrypted_headers = encrypted_headers_dict - else: - # Clear headers if empty dict passed - mcp_provider.encrypted_headers = None + mcp_provider.encrypted_headers = self._process_headers(headers, mcp_provider, tenant_id) # Update credentials if provided - if authentication is not None and authentication.client_id and authentication.client_secret: - # Merge with existing credentials to handle masked values - ( - final_client_id, - final_client_secret, - ) = self._merge_credentials_with_masked( - authentication.client_id, authentication.client_secret, mcp_provider - ) + if authentication and authentication.client_id: + mcp_provider.encrypted_credentials = self._process_credentials(authentication, mcp_provider, tenant_id) - # Build and encrypt new credentials - encrypted_credentials = self._build_and_encrypt_credentials( - final_client_id, final_client_secret, tenant_id - ) - mcp_provider.encrypted_credentials = encrypted_credentials - - self._session.commit() + # Flush changes to database + self._session.flush() except IntegrityError as e: - self._session.rollback() self._handle_integrity_error(e, name, server_url, server_identifier) - except (ValueError, AttributeError, TypeError) as e: - # Catch specific exceptions that might occur during update - # ValueError: invalid data provided - # AttributeError: missing required attributes - # TypeError: type conversion errors - self._session.rollback() - raise def delete_provider(self, *, tenant_id: str, provider_id: str) -> None: """Delete an MCP provider.""" mcp_tool = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) self._session.delete(mcp_tool) - self._session.commit() def list_providers(self, *, tenant_id: str, for_list: bool = False) -> list[ToolProviderApiEntity]: """List all MCP providers for a tenant.""" @@ -241,8 +266,6 @@ class MCPToolManageService: def list_provider_tools(self, *, tenant_id: str, provider_id: str) -> ToolProviderApiEntity: """List tools from remote MCP server.""" - from core.mcp.auth.auth_flow import auth - # Load provider and convert to entity db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) provider_entity = db_provider.to_entity() @@ -257,9 +280,7 @@ class MCPToolManageService: # Retrieve tools from remote server server_url = provider_entity.decrypt_server_url() try: - tools = self._retrieve_remote_mcp_tools( - server_url, headers, provider_entity, lambda p, s, c: auth(p, self, c) - ) + tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity) except MCPError as e: raise ValueError(f"Failed to connect to MCP server: {e}") @@ -305,9 +326,12 @@ class MCPToolManageService: if not authed: provider.tools = EMPTY_TOOLS_JSON - self._session.commit() + # Flush changes to database + self._session.flush() - def save_oauth_data(self, provider_id: str, tenant_id: str, data: dict[str, Any], data_type: str = "mixed") -> None: + def save_oauth_data( + self, provider_id: str, tenant_id: str, data: dict[str, Any], data_type: OAuthDataType = OAuthDataType.MIXED + ) -> None: """ Save OAuth-related data (tokens, client info, code verifier). @@ -315,12 +339,14 @@ class MCPToolManageService: provider_id: Provider ID tenant_id: Tenant ID data: Data to save (tokens, client info, or code verifier) - data_type: Type of data ('tokens', 'client_info', 'code_verifier', 'mixed') + data_type: Type of OAuth data to save """ db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) # Determine if this makes the provider authenticated - authed = data_type == "tokens" or (data_type == "mixed" and "access_token" in data) or None + authed = ( + data_type == OAuthDataType.TOKENS or (data_type == OAuthDataType.MIXED and "access_token" in data) or None + ) self.update_provider_credentials(provider=db_provider, credentials=data, authed=authed) @@ -330,7 +356,6 @@ class MCPToolManageService: provider.encrypted_credentials = EMPTY_CREDENTIALS_JSON provider.updated_at = datetime.now() provider.authed = False - self._session.commit() # ========== Private Helper Methods ========== @@ -406,41 +431,123 @@ class MCPToolManageService: server_url: str, headers: dict[str, str], provider_entity: MCPProviderEntity, - auth_callback: Callable[[MCPProviderEntity, "MCPToolManageService", str | None], dict[str, str]], ): """Retrieve tools from remote MCP server.""" with MCPClientWithAuthRetry( - server_url, + server_url=server_url, headers=headers, timeout=provider_entity.timeout, sse_read_timeout=provider_entity.sse_read_timeout, provider_entity=provider_entity, - auth_callback=auth_callback, - mcp_service=self, ) as mcp_client: return mcp_client.list_tools() - def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> dict[str, Any]: - """Attempt to reconnect to MCP provider with new server URL.""" - from core.mcp.auth.auth_flow import auth + def execute_auth_actions(self, auth_result: Any) -> dict[str, str]: + """ + Execute the actions returned by the auth function. + This method processes the AuthResult and performs the necessary database operations. + + Args: + auth_result: The result from the auth function + + Returns: + The response from the auth result + """ + from core.mcp.entities import AuthAction, AuthActionType + + action: AuthAction + for action in auth_result.actions: + if action.provider_id is None or action.tenant_id is None: + continue + + if action.action_type == AuthActionType.SAVE_CLIENT_INFO: + self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.CLIENT_INFO) + elif action.action_type == AuthActionType.SAVE_TOKENS: + self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.TOKENS) + elif action.action_type == AuthActionType.SAVE_CODE_VERIFIER: + self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.CODE_VERIFIER) + + return auth_result.response + + def auth_with_actions( + self, provider_entity: MCPProviderEntity, authorization_code: str | None = None + ) -> dict[str, str]: + """ + Perform authentication and execute all resulting actions. + + This method is used by MCPClientWithAuthRetry for automatic re-authentication. + + Args: + provider_entity: The MCP provider entity + authorization_code: Optional authorization code + + Returns: + Response dictionary from auth result + """ + auth_result = auth(provider_entity, authorization_code) + return self.execute_auth_actions(auth_result) + + def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> ReconnectResult: + """Attempt to reconnect to MCP provider with new server URL.""" provider_entity = provider.to_entity() headers = provider_entity.headers try: - tools = self._retrieve_remote_mcp_tools( - server_url, headers, provider_entity, lambda p, s, c: auth(p, self, c) + tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity) + return ReconnectResult( + authed=True, + tools=json.dumps([tool.model_dump() for tool in tools]), + encrypted_credentials=EMPTY_CREDENTIALS_JSON, ) - return { - "authed": True, - "tools": json.dumps([tool.model_dump() for tool in tools]), - "encrypted_credentials": EMPTY_CREDENTIALS_JSON, - } except MCPAuthError: - return {"authed": False, "tools": EMPTY_TOOLS_JSON, "encrypted_credentials": EMPTY_CREDENTIALS_JSON} + return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON) except MCPError as e: raise ValueError(f"Failed to re-connect MCP server: {e}") from e + def validate_server_url_change( + self, *, tenant_id: str, provider_id: str, new_server_url: str + ) -> ServerUrlValidationResult: + """ + Validate server URL change by attempting to connect to the new server. + This method should be called BEFORE update_provider to perform network operations + outside of the database transaction. + + Returns: + ServerUrlValidationResult: Validation result with connection status and tools if successful + """ + # Handle hidden/unchanged URL + if UNCHANGED_SERVER_URL_PLACEHOLDER in new_server_url: + return ServerUrlValidationResult(needs_validation=False) + + # Validate URL format + if not self._is_valid_url(new_server_url): + raise ValueError("Server URL is not valid.") + + # Always encrypt and hash the URL + encrypted_server_url = encrypter.encrypt_token(tenant_id, new_server_url) + new_server_url_hash = hashlib.sha256(new_server_url.encode()).hexdigest() + + # Get current provider + provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) + + # Check if URL is actually different + if new_server_url_hash == provider.server_url_hash: + # URL hasn't changed, but still return the encrypted data + return ServerUrlValidationResult( + needs_validation=False, encrypted_server_url=encrypted_server_url, server_url_hash=new_server_url_hash + ) + + # Perform validation by attempting to connect + reconnect_result = self._reconnect_provider(server_url=new_server_url, provider=provider) + return ServerUrlValidationResult( + needs_validation=True, + validation_passed=True, + reconnect_result=reconnect_result, + encrypted_server_url=encrypted_server_url, + server_url_hash=new_server_url_hash, + ) + def _build_tool_provider_response( self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list ) -> ToolProviderApiEntity: @@ -466,6 +573,45 @@ class MCPToolManageService: raise ValueError(f"MCP tool {server_identifier} already exists") raise + def _is_valid_url(self, url: str) -> bool: + """Validate URL format.""" + if not url: + return False + try: + parsed = urlparse(url) + return all([parsed.scheme, parsed.netloc]) and parsed.scheme in ["http", "https"] + except (ValueError, TypeError): + return False + + def _update_optional_fields(self, mcp_provider: MCPToolProvider, configuration: MCPConfiguration) -> None: + """Update optional configuration fields using setattr for cleaner code.""" + field_mapping = {"timeout": configuration.timeout, "sse_read_timeout": configuration.sse_read_timeout} + + for field, value in field_mapping.items(): + if value is not None: + setattr(mcp_provider, field, value) + + def _process_headers(self, headers: dict[str, str], mcp_provider: MCPToolProvider, tenant_id: str) -> str | None: + """Process headers update, handling empty dict to clear headers.""" + if not headers: + return None + + # Merge with existing headers to preserve masked values + final_headers = self._merge_headers_with_masked(incoming_headers=headers, mcp_provider=mcp_provider) + return self._prepare_encrypted_dict(final_headers, tenant_id) + + def _process_credentials( + self, authentication: MCPAuthentication, mcp_provider: MCPToolProvider, tenant_id: str + ) -> str: + """Process credentials update, handling masked values.""" + # Merge with existing credentials + final_client_id, final_client_secret = self._merge_credentials_with_masked( + authentication.client_id, authentication.client_secret, mcp_provider + ) + + # Build and encrypt + return self._build_and_encrypt_credentials(final_client_id, final_client_secret, tenant_id) + def _merge_headers_with_masked( self, incoming_headers: dict[str, str], mcp_provider: MCPToolProvider ) -> dict[str, str]: @@ -530,12 +676,12 @@ class MCPToolManageService: # Create a flat structure with all credential data credentials_data = { "client_id": client_id, - "encrypted_client_secret": client_secret, "client_name": CLIENT_NAME, "is_dynamic_registration": False, } - - # Only client_id and client_secret need encryption - secret_fields = ["encrypted_client_secret"] if client_secret else [] + secret_fields = [] + if client_secret is not None: + credentials_data["encrypted_client_secret"] = encrypter.encrypt_token(tenant_id, client_secret) + secret_fields = ["encrypted_client_secret"] client_info = self._encrypt_dict_fields(credentials_data, secret_fields, tenant_id) return json.dumps({"client_information": client_info}) diff --git a/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py index 6565179f7a..3c77d0c0da 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py @@ -1108,75 +1108,6 @@ class TestMCPToolManageService: assert icon_data["content"] == "🚀" assert icon_data["background"] == "#4ECDC4" - def test_update_mcp_provider_with_server_url_change( - self, db_session_with_containers, mock_external_service_dependencies - ): - """ - Test successful update of MCP provider with server URL change. - - This test verifies: - - Proper handling of server URL changes - - Correct reconnection logic - - Database state updates - - External service integration - """ - # Arrange: Create test data - fake = Faker() - account, tenant = self._create_test_account_and_tenant( - db_session_with_containers, mock_external_service_dependencies - ) - - # Create MCP provider - mcp_provider = self._create_test_mcp_provider( - db_session_with_containers, mock_external_service_dependencies, tenant.id, account.id - ) - - from extensions.ext_database import db - - db.session.commit() - - # Mock the reconnection method - with patch.object(MCPToolManageService, "_reconnect_provider") as mock_reconnect: - mock_reconnect.return_value = { - "authed": True, - "tools": '[{"name": "test_tool"}]', - "encrypted_credentials": "{}", - } - - # Act: Execute the method under test - from core.entities.mcp_provider import MCPConfiguration - from extensions.ext_database import db - - service = MCPToolManageService(db.session()) - service.update_provider( - tenant_id=tenant.id, - provider_id=mcp_provider.id, - name="Updated MCP Provider", - server_url="https://new-example.com/mcp", - icon="🚀", - icon_type="emoji", - icon_background="#4ECDC4", - server_identifier="updated_identifier_123", - configuration=MCPConfiguration( - timeout=45.0, - sse_read_timeout=400.0, - ), - ) - - # Assert: Verify the expected outcomes - db.session.refresh(mcp_provider) - assert mcp_provider.name == "Updated MCP Provider" - assert mcp_provider.server_identifier == "updated_identifier_123" - assert mcp_provider.timeout == 45.0 - assert mcp_provider.sse_read_timeout == 400.0 - assert mcp_provider.updated_at is not None - - # Verify reconnection was called - mock_reconnect.assert_called_once_with( - server_url="https://new-example.com/mcp", - provider=mcp_provider, - ) - def test_update_mcp_provider_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies): """ Test error handling when updating MCP provider with duplicate name. @@ -1387,14 +1318,14 @@ class TestMCPToolManageService: # Assert: Verify the expected outcomes assert result is not None - assert result["authed"] is True - assert result["tools"] is not None - assert result["encrypted_credentials"] == "{}" + assert result.authed is True + assert result.tools is not None + assert result.encrypted_credentials == "{}" # Verify tools were properly serialized import json - tools_data = json.loads(result["tools"]) + tools_data = json.loads(result.tools) assert len(tools_data) == 2 assert tools_data[0]["name"] == "test_tool_1" assert tools_data[1]["name"] == "test_tool_2" @@ -1441,9 +1372,9 @@ class TestMCPToolManageService: # Assert: Verify the expected outcomes assert result is not None - assert result["authed"] is False - assert result["tools"] == "[]" - assert result["encrypted_credentials"] == "{}" + assert result.authed is False + assert result.tools == "[]" + assert result.encrypted_credentials == "{}" def test_re_connect_mcp_provider_connection_error( self, db_session_with_containers, mock_external_service_dependencies diff --git a/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py b/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py index 26b5d1f7ce..12a9f11205 100644 --- a/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py +++ b/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py @@ -21,6 +21,7 @@ from core.mcp.auth.auth_flow import ( register_client, start_authorization, ) +from core.mcp.entities import AuthActionType, AuthResult from core.mcp.types import ( OAuthClientInformation, OAuthClientInformationFull, @@ -527,9 +528,10 @@ class TestCallbackHandling: # Setup service mock_service = Mock() - result = handle_callback("state-key", "auth-code", mock_service) + state_result, tokens_result = handle_callback("state-key", "auth-code") - assert result == state_data + assert state_result == state_data + assert tokens_result == tokens # Verify calls mock_retrieve_state.assert_called_once_with("state-key") @@ -541,9 +543,8 @@ class TestCallbackHandling: "test-verifier", "https://redirect.example.com", ) - mock_service.save_oauth_data.assert_called_once_with( - "test-provider", "test-tenant", tokens.model_dump(), "tokens" - ) + # Note: handle_callback no longer saves tokens directly, it just returns them + # The caller (e.g., controller) is responsible for saving via execute_auth_actions class TestAuthOrchestration: @@ -589,21 +590,28 @@ class TestAuthOrchestration: ) mock_start_auth.return_value = ("https://auth.example.com/authorize?...", "code-verifier") - result = auth(mock_provider, mock_service) + result = auth(mock_provider) - assert result == {"authorization_url": "https://auth.example.com/authorize?..."} + # auth() now returns AuthResult + assert isinstance(result, AuthResult) + assert result.response == {"authorization_url": "https://auth.example.com/authorize?..."} + + # Verify that the result contains the correct actions + assert len(result.actions) == 2 + # Check for SAVE_CLIENT_INFO action + client_info_action = next(a for a in result.actions if a.action_type == AuthActionType.SAVE_CLIENT_INFO) + assert client_info_action.data == {"client_information": mock_register.return_value.model_dump()} + assert client_info_action.provider_id == "provider-id" + assert client_info_action.tenant_id == "tenant-id" + + # Check for SAVE_CODE_VERIFIER action + verifier_action = next(a for a in result.actions if a.action_type == AuthActionType.SAVE_CODE_VERIFIER) + assert verifier_action.data == {"code_verifier": "code-verifier"} + assert verifier_action.provider_id == "provider-id" + assert verifier_action.tenant_id == "tenant-id" # Verify calls mock_register.assert_called_once() - mock_service.save_oauth_data.assert_any_call( - "provider-id", - "tenant-id", - {"client_information": mock_register.return_value.model_dump()}, - "client_info", - ) - mock_service.save_oauth_data.assert_any_call( - "provider-id", "tenant-id", {"code_verifier": "code-verifier"}, "code_verifier" - ) @patch("core.mcp.auth.auth_flow.discover_oauth_metadata") @patch("core.mcp.auth.auth_flow._retrieve_redis_state") @@ -637,12 +645,18 @@ class TestAuthOrchestration: tokens = OAuthTokens(access_token="new-token", token_type="Bearer", expires_in=3600) mock_exchange.return_value = tokens - result = auth(mock_provider, mock_service, authorization_code="auth-code", state_param="state-key") + result = auth(mock_provider, authorization_code="auth-code", state_param="state-key") - assert result == {"result": "success"} + # auth() now returns AuthResult, not a dict + assert isinstance(result, AuthResult) + assert result.response == {"result": "success"} - # Verify token save - mock_service.save_oauth_data.assert_called_with("provider-id", "tenant-id", tokens.model_dump(), "tokens") + # Verify that the result contains the correct action + assert len(result.actions) == 1 + assert result.actions[0].action_type == AuthActionType.SAVE_TOKENS + assert result.actions[0].data == tokens.model_dump() + assert result.actions[0].provider_id == "provider-id" + assert result.actions[0].tenant_id == "tenant-id" @patch("core.mcp.auth.auth_flow.discover_oauth_metadata") def test_auth_exchange_code_without_state(self, mock_discover, mock_provider, mock_service): @@ -658,7 +672,7 @@ class TestAuthOrchestration: mock_provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="existing-client") with pytest.raises(ValueError) as exc_info: - auth(mock_provider, mock_service, authorization_code="auth-code") + auth(mock_provider, authorization_code="auth-code") assert "State parameter is required" in str(exc_info.value) @@ -691,15 +705,21 @@ class TestAuthOrchestration: grant_types_supported=["authorization_code"], ) - result = auth(mock_provider, mock_service) + result = auth(mock_provider) - assert result == {"result": "success"} + # auth() now returns AuthResult + assert isinstance(result, AuthResult) + assert result.response == {"result": "success"} + + # Verify that the result contains the correct action + assert len(result.actions) == 1 + assert result.actions[0].action_type == AuthActionType.SAVE_TOKENS + assert result.actions[0].data == new_tokens.model_dump() + assert result.actions[0].provider_id == "provider-id" + assert result.actions[0].tenant_id == "tenant-id" # Verify refresh was called mock_refresh.assert_called_once() - mock_service.save_oauth_data.assert_called_with( - "provider-id", "tenant-id", new_tokens.model_dump(), "tokens" - ) @patch("core.mcp.auth.auth_flow.discover_oauth_metadata") def test_auth_registration_fails_with_code(self, mock_discover, mock_provider, mock_service): @@ -715,6 +735,6 @@ class TestAuthOrchestration: mock_provider.retrieve_client_information.return_value = None with pytest.raises(ValueError) as exc_info: - auth(mock_provider, mock_service, authorization_code="auth-code") + auth(mock_provider, authorization_code="auth-code") assert "Existing OAuth client information is required" in str(exc_info.value) diff --git a/api/tests/unit_tests/core/mcp/test_auth_client.py b/api/tests/unit_tests/core/mcp/test_auth_client.py deleted file mode 100644 index 7b06c9df4d..0000000000 --- a/api/tests/unit_tests/core/mcp/test_auth_client.py +++ /dev/null @@ -1,420 +0,0 @@ -"""Unit tests for MCP auth client with retry logic.""" - -from types import TracebackType -from unittest.mock import Mock, patch - -import pytest - -from core.entities.mcp_provider import MCPProviderEntity -from core.mcp.auth_client import MCPClientWithAuthRetry -from core.mcp.error import MCPAuthError -from core.mcp.mcp_client import MCPClient -from core.mcp.types import CallToolResult, TextContent, Tool, ToolAnnotations - - -class TestMCPClientWithAuthRetry: - """Test suite for MCPClientWithAuthRetry.""" - - @pytest.fixture - def mock_provider_entity(self): - """Create a mock provider entity.""" - provider = Mock(spec=MCPProviderEntity) - provider.id = "test-provider-id" - provider.tenant_id = "test-tenant-id" - provider.retrieve_tokens.return_value = Mock( - access_token="test-token", token_type="Bearer", expires_in=3600, refresh_token=None - ) - return provider - - @pytest.fixture - def mock_mcp_service(self): - """Create a mock MCP service.""" - service = Mock() - service.get_provider_entity.return_value = Mock( - retrieve_tokens=lambda: Mock( - access_token="new-test-token", token_type="Bearer", expires_in=3600, refresh_token=None - ) - ) - return service - - @pytest.fixture - def auth_callback(self): - """Create a mock auth callback.""" - return Mock() - - def test_init(self, mock_provider_entity, mock_mcp_service, auth_callback): - """Test client initialization.""" - client = MCPClientWithAuthRetry( - server_url="http://test.example.com", - headers={"Authorization": "Bearer test"}, - timeout=30.0, - sse_read_timeout=60.0, - provider_entity=mock_provider_entity, - auth_callback=auth_callback, - authorization_code="test-auth-code", - by_server_id=True, - mcp_service=mock_mcp_service, - ) - - assert client.server_url == "http://test.example.com" - assert client.headers == {"Authorization": "Bearer test"} - assert client.timeout == 30.0 - assert client.sse_read_timeout == 60.0 - assert client.provider_entity == mock_provider_entity - assert client.auth_callback == auth_callback - assert client.authorization_code == "test-auth-code" - assert client.by_server_id is True - assert client.mcp_service == mock_mcp_service - assert client._has_retried is False - # In inheritance design, we don't have _client attribute - assert hasattr(client, "_session") # Inherited from MCPClient - - def test_inheritance_structure(self): - """Test that MCPClientWithAuthRetry properly inherits from MCPClient.""" - client = MCPClientWithAuthRetry( - server_url="http://test.example.com", - headers={"Authorization": "Bearer test"}, - ) - - # Verify inheritance - assert isinstance(client, MCPClient) - - # Verify inherited attributes are accessible - assert hasattr(client, "server_url") - assert hasattr(client, "headers") - assert hasattr(client, "_session") - assert hasattr(client, "_exit_stack") - assert hasattr(client, "_initialized") - - def test_handle_auth_error_no_retry_components(self): - """Test auth error handling when retry components are missing.""" - client = MCPClientWithAuthRetry(server_url="http://test.example.com") - error = MCPAuthError("Auth failed") - - with pytest.raises(MCPAuthError) as exc_info: - client._handle_auth_error(error) - - assert exc_info.value == error - - def test_handle_auth_error_already_retried(self, mock_provider_entity, mock_mcp_service, auth_callback): - """Test auth error handling when already retried.""" - client = MCPClientWithAuthRetry( - server_url="http://test.example.com", - provider_entity=mock_provider_entity, - auth_callback=auth_callback, - mcp_service=mock_mcp_service, - ) - client._has_retried = True - error = MCPAuthError("Auth failed") - - with pytest.raises(MCPAuthError) as exc_info: - client._handle_auth_error(error) - - assert exc_info.value == error - auth_callback.assert_not_called() - - def test_handle_auth_error_successful_refresh(self, mock_provider_entity, mock_mcp_service, auth_callback): - """Test successful auth refresh on error.""" - client = MCPClientWithAuthRetry( - server_url="http://test.example.com", - provider_entity=mock_provider_entity, - auth_callback=auth_callback, - authorization_code="test-code", - by_server_id=True, - mcp_service=mock_mcp_service, - ) - - # Configure mocks - new_provider = Mock(spec=MCPProviderEntity) - new_provider.id = "test-provider-id" - new_provider.tenant_id = "test-tenant-id" - new_provider.retrieve_tokens.return_value = Mock( - access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None - ) - mock_mcp_service.get_provider_entity.return_value = new_provider - - error = MCPAuthError("Auth failed") - client._handle_auth_error(error) - - # Verify auth flow - auth_callback.assert_called_once_with(mock_provider_entity, mock_mcp_service, "test-code") - mock_mcp_service.get_provider_entity.assert_called_once_with( - "test-provider-id", "test-tenant-id", by_server_id=True - ) - assert client.headers["Authorization"] == "Bearer new-token" - assert client.authorization_code is None # Should be cleared after use - assert client._has_retried is True - - def test_handle_auth_error_refresh_fails(self, mock_provider_entity, mock_mcp_service, auth_callback): - """Test auth refresh failure.""" - client = MCPClientWithAuthRetry( - server_url="http://test.example.com", - provider_entity=mock_provider_entity, - auth_callback=auth_callback, - mcp_service=mock_mcp_service, - ) - - auth_callback.side_effect = Exception("Auth callback failed") - - error = MCPAuthError("Original auth failed") - with pytest.raises(MCPAuthError) as exc_info: - client._handle_auth_error(error) - - assert "Authentication retry failed" in str(exc_info.value) - - def test_handle_auth_error_no_token_received(self, mock_provider_entity, mock_mcp_service, auth_callback): - """Test auth refresh when no token is received.""" - client = MCPClientWithAuthRetry( - server_url="http://test.example.com", - provider_entity=mock_provider_entity, - auth_callback=auth_callback, - mcp_service=mock_mcp_service, - ) - - # Configure mock to return no token - new_provider = Mock(spec=MCPProviderEntity) - new_provider.retrieve_tokens.return_value = None - mock_mcp_service.get_provider_entity.return_value = new_provider - - error = MCPAuthError("Auth failed") - with pytest.raises(MCPAuthError) as exc_info: - client._handle_auth_error(error) - - assert "no token received" in str(exc_info.value) - - def test_execute_with_retry_success(self): - """Test successful execution without retry.""" - client = MCPClientWithAuthRetry(server_url="http://test.example.com") - - mock_func = Mock(return_value="success") - result = client._execute_with_retry(mock_func, "arg1", kwarg1="value1") - - assert result == "success" - mock_func.assert_called_once_with("arg1", kwarg1="value1") - assert client._has_retried is False - - def test_execute_with_retry_auth_error_then_success(self, mock_provider_entity, mock_mcp_service, auth_callback): - """Test execution with auth error followed by successful retry.""" - client = MCPClientWithAuthRetry( - server_url="http://test.example.com", - provider_entity=mock_provider_entity, - auth_callback=auth_callback, - mcp_service=mock_mcp_service, - ) - - # Configure new provider with token - new_provider = Mock(spec=MCPProviderEntity) - new_provider.retrieve_tokens.return_value = Mock( - access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None - ) - mock_mcp_service.get_provider_entity.return_value = new_provider - - # Mock function that fails first, then succeeds - mock_func = Mock(side_effect=[MCPAuthError("Auth failed"), "success"]) - - # Mock the exit stack and session cleanup - with ( - patch.object(client, "_exit_stack") as mock_exit_stack, - patch.object(client, "_session") as mock_session, - patch.object(client, "_initialize") as mock_initialize, - ): - client._initialized = True - result = client._execute_with_retry(mock_func, "arg1", kwarg1="value1") - - assert result == "success" - assert mock_func.call_count == 2 - mock_func.assert_called_with("arg1", kwarg1="value1") - auth_callback.assert_called_once() - mock_exit_stack.close.assert_called_once() - mock_initialize.assert_called_once() - assert client._has_retried is False # Reset after completion - - def test_execute_with_retry_non_auth_error(self): - """Test execution with non-auth error (no retry).""" - client = MCPClientWithAuthRetry(server_url="http://test.example.com") - - mock_func = Mock(side_effect=ValueError("Some other error")) - - with pytest.raises(ValueError) as exc_info: - client._execute_with_retry(mock_func) - - assert str(exc_info.value) == "Some other error" - mock_func.assert_called_once() - - def test_context_manager_enter(self): - """Test context manager enter.""" - client = MCPClientWithAuthRetry(server_url="http://test.example.com") - - with patch.object(client, "_initialize") as mock_initialize: - result = client.__enter__() - - assert result == client - assert client._initialized is True - mock_initialize.assert_called_once() - - def test_context_manager_enter_with_auth_error(self, mock_provider_entity, mock_mcp_service, auth_callback): - """Test context manager enter with auth error and retry.""" - # Configure new provider with token - new_provider = Mock(spec=MCPProviderEntity) - new_provider.retrieve_tokens.return_value = Mock( - access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None - ) - mock_mcp_service.get_provider_entity.return_value = new_provider - - client = MCPClientWithAuthRetry( - server_url="http://test.example.com", - provider_entity=mock_provider_entity, - auth_callback=auth_callback, - mcp_service=mock_mcp_service, - ) - - # Mock parent class __enter__ to raise auth error first, then succeed - with patch.object(MCPClient, "__enter__") as mock_parent_enter: - mock_parent_enter.side_effect = [MCPAuthError("Auth failed"), client] - - result = client.__enter__() - - assert result == client - assert mock_parent_enter.call_count == 2 - auth_callback.assert_called_once() - - def test_context_manager_exit(self): - """Test context manager exit.""" - client = MCPClientWithAuthRetry(server_url="http://test.example.com") - - with patch.object(client, "cleanup") as mock_cleanup: - exc_type: type[BaseException] | None = None - exc_val: BaseException | None = None - exc_tb: TracebackType | None = None - client.__exit__(exc_type, exc_val, exc_tb) - - mock_cleanup.assert_called_once() - - def test_list_tools_not_initialized(self): - """Test list_tools when client not initialized.""" - client = MCPClientWithAuthRetry(server_url="http://test.example.com") - - with pytest.raises(ValueError) as exc_info: - client.list_tools() - - assert "Session not initialized" in str(exc_info.value) - - def test_list_tools_success(self): - """Test successful list_tools call.""" - client = MCPClientWithAuthRetry(server_url="http://test.example.com") - - expected_tools = [ - Tool( - name="test-tool", - description="A test tool", - inputSchema={"type": "object", "properties": {}}, - annotations=ToolAnnotations(title="Test Tool"), - ) - ] - - # Mock the parent class list_tools method - with patch.object(MCPClient, "list_tools", return_value=expected_tools): - result = client.list_tools() - assert result == expected_tools - - def test_list_tools_with_auth_retry(self, mock_provider_entity, mock_mcp_service, auth_callback): - """Test list_tools with auth retry.""" - client = MCPClientWithAuthRetry( - server_url="http://test.example.com", - provider_entity=mock_provider_entity, - auth_callback=auth_callback, - mcp_service=mock_mcp_service, - ) - - # Configure new provider with token - new_provider = Mock(spec=MCPProviderEntity) - new_provider.retrieve_tokens.return_value = Mock( - access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None - ) - mock_mcp_service.get_provider_entity.return_value = new_provider - - expected_tools = [Tool(name="test-tool", description="A test tool", inputSchema={})] - - # Mock parent class list_tools to raise auth error first, then succeed - with patch.object(MCPClient, "list_tools") as mock_list_tools: - mock_list_tools.side_effect = [MCPAuthError("Auth failed"), expected_tools] - - result = client.list_tools() - - assert result == expected_tools - assert mock_list_tools.call_count == 2 - auth_callback.assert_called_once() - - def test_invoke_tool_not_initialized(self): - """Test invoke_tool when client not initialized.""" - client = MCPClientWithAuthRetry(server_url="http://test.example.com") - - with pytest.raises(ValueError) as exc_info: - client.invoke_tool("test-tool", {"arg": "value"}) - - assert "Session not initialized" in str(exc_info.value) - - def test_invoke_tool_success(self): - """Test successful invoke_tool call.""" - client = MCPClientWithAuthRetry(server_url="http://test.example.com") - - expected_result = CallToolResult( - content=[TextContent(type="text", text="Tool executed successfully")], isError=False - ) - - # Mock the parent class invoke_tool method - with patch.object(MCPClient, "invoke_tool", return_value=expected_result) as mock_invoke: - result = client.invoke_tool("test-tool", {"arg": "value"}) - - assert result == expected_result - mock_invoke.assert_called_once_with("test-tool", {"arg": "value"}) - - def test_invoke_tool_with_auth_retry(self, mock_provider_entity, mock_mcp_service, auth_callback): - """Test invoke_tool with auth retry.""" - client = MCPClientWithAuthRetry( - server_url="http://test.example.com", - provider_entity=mock_provider_entity, - auth_callback=auth_callback, - mcp_service=mock_mcp_service, - ) - - # Configure new provider with token - new_provider = Mock(spec=MCPProviderEntity) - new_provider.retrieve_tokens.return_value = Mock( - access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None - ) - mock_mcp_service.get_provider_entity.return_value = new_provider - - expected_result = CallToolResult(content=[TextContent(type="text", text="Success")], isError=False) - - # Mock parent class invoke_tool to raise auth error first, then succeed - with patch.object(MCPClient, "invoke_tool") as mock_invoke_tool: - mock_invoke_tool.side_effect = [MCPAuthError("Auth failed"), expected_result] - - result = client.invoke_tool("test-tool", {"arg": "value"}) - - assert result == expected_result - assert mock_invoke_tool.call_count == 2 - mock_invoke_tool.assert_called_with("test-tool", {"arg": "value"}) - auth_callback.assert_called_once() - - def test_cleanup(self): - """Test cleanup method.""" - client = MCPClientWithAuthRetry(server_url="http://test.example.com") - - # Mock the parent class cleanup method - with patch.object(MCPClient, "cleanup") as mock_cleanup: - client.cleanup() - mock_cleanup.assert_called_once() - - def test_cleanup_no_client(self): - """Test cleanup when no client exists.""" - client = MCPClientWithAuthRetry(server_url="http://test.example.com") - - # Should not raise - client.cleanup() - - # Since MCPClientWithAuthRetry inherits from MCPClient, - # it doesn't have a _client attribute. The test should just - # verify that cleanup can be called without error. - assert not hasattr(client, "_client")