refactor(mcp): clean the auth code

This commit is contained in:
Novice 2025-10-23 17:00:02 +08:00
parent 8cf4a0d3ad
commit ffd3a461f6
No known key found for this signature in database
GPG Key ID: EE3F68E3105DAAAB
9 changed files with 521 additions and 781 deletions

View File

@ -30,7 +30,7 @@ from models.provider_ids import ToolProviderID
from services.plugin.oauth_service import OAuthProxyService
from services.tools.api_tools_manage_service import ApiToolManageService
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from services.tools.mcp_tools_manage_service import MCPToolManageService
from services.tools.mcp_tools_manage_service import MCPToolManageService, OAuthDataType
from services.tools.tool_labels_service import ToolLabelsService
from services.tools.tools_manage_service import ToolCommonService
from services.tools.tools_transform_service import ToolTransformService
@ -897,10 +897,6 @@ class ToolProviderMCPApi(Resource):
args = parser.parse_args()
user, tenant_id = current_account_with_tenant()
# Validate server URL
if not is_valid_url(args["server_url"]):
raise ValueError("Server URL is not valid.")
# Parse and validate models
configuration = MCPConfiguration.model_validate(args["configuration"])
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
@ -941,15 +937,21 @@ class ToolProviderMCPApi(Resource):
.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
)
args = parser.parse_args()
if not is_valid_url(args["server_url"]):
if "[__HIDDEN__]" in args["server_url"]:
pass
else:
raise ValueError("Server URL is not valid.")
configuration = MCPConfiguration.model_validate(args["configuration"])
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
_, current_tenant_id = current_account_with_tenant()
# Step 1: Validate server URL change if needed (includes URL format validation and network operation)
validation_result = None
with Session(db.engine) as session:
service = MCPToolManageService(session=session)
validation_result = service.validate_server_url_change(
tenant_id=current_tenant_id, provider_id=args["provider_id"], new_server_url=args["server_url"]
)
# No need to check for errors here, exceptions will be raised directly
# Step 2: Perform database update in a transaction
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.update_provider(
@ -964,6 +966,7 @@ class ToolProviderMCPApi(Resource):
headers=args["headers"],
configuration=configuration,
authentication=authentication,
validation_result=validation_result,
)
return {"result": "success"}
@ -998,47 +1001,49 @@ class ToolMCPAuthApi(Resource):
provider_id = args["provider_id"]
_, tenant_id = current_account_with_tenant()
with Session(db.engine) as session:
with session.begin():
service = MCPToolManageService(session=session)
db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
if not db_provider:
raise ValueError("provider not found")
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
if not db_provider:
raise ValueError("provider not found")
# Convert to entity
provider_entity = db_provider.to_entity()
server_url = provider_entity.decrypt_server_url()
headers = provider_entity.decrypt_authentication()
# Convert to entity
provider_entity = db_provider.to_entity()
server_url = provider_entity.decrypt_server_url()
headers = provider_entity.decrypt_authentication()
# Try to connect without active transaction
# Try to connect without active transaction
try:
# Use MCPClientWithAuthRetry to handle authentication automatically
with MCPClient(
server_url=server_url,
headers=headers,
timeout=provider_entity.timeout,
sse_read_timeout=provider_entity.sse_read_timeout,
):
# Create new transaction for update
with session.begin():
service.update_provider_credentials(
provider=db_provider,
credentials=provider_entity.credentials,
authed=True,
)
return {"result": "success"}
except MCPAuthError as e:
service = MCPToolManageService(session=session)
try:
# Use MCPClientWithAuthRetry to handle authentication automatically
with MCPClient(
server_url=server_url,
headers=headers,
timeout=provider_entity.timeout,
sse_read_timeout=provider_entity.sse_read_timeout,
):
# Create new transaction for update
with session.begin():
service.update_provider_credentials(
provider=db_provider,
credentials=provider_entity.credentials,
authed=True,
)
return {"result": "success"}
except MCPAuthError as e:
service = MCPToolManageService(session=session)
try:
return auth(provider_entity, service, args.get("authorization_code"))
except MCPRefreshTokenError as e:
with session.begin():
service.clear_provider_credentials(provider=db_provider)
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
except MCPError as e:
auth_result = auth(provider_entity, args.get("authorization_code"))
with session.begin():
response = service.execute_auth_actions(auth_result)
return response
except MCPRefreshTokenError as e:
with session.begin():
service.clear_provider_credentials(provider=db_provider)
raise ValueError(f"Failed to connect to MCP server: {e}") from e
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
except MCPError as e:
with session.begin():
service.clear_provider_credentials(provider=db_provider)
raise ValueError(f"Failed to connect to MCP server: {e}") from e
@console_ns.route("/workspaces/current/tool-provider/mcp/tools/<path:provider_id>")
@ -1048,7 +1053,7 @@ class ToolMCPDetailApi(Resource):
@account_initialization_required
def get(self, provider_id):
_, tenant_id = current_account_with_tenant()
with Session(db.engine) as session:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
@ -1062,7 +1067,7 @@ class ToolMCPListAllApi(Resource):
def get(self):
_, tenant_id = current_account_with_tenant()
with Session(db.engine, expire_on_commit=False) as session:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
tools = service.list_providers(tenant_id=tenant_id)
@ -1100,6 +1105,11 @@ class ToolMCPCallbackApi(Resource):
# Create service instance for handle_callback
with Session(db.engine) as session, session.begin():
mcp_service = MCPToolManageService(session=session)
handle_callback(state_key, authorization_code, mcp_service)
# handle_callback now returns state data and tokens
state_data, tokens = handle_callback(state_key, authorization_code)
# Save tokens using the service layer
mcp_service.save_oauth_data(
state_data.provider_id, state_data.tenant_id, tokens.model_dump(), OAuthDataType.TOKENS
)
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")

View File

@ -4,14 +4,14 @@ import json
import os
import secrets
import urllib.parse
from typing import TYPE_CHECKING
from urllib.parse import urljoin, urlparse
from httpx import ConnectError, HTTPStatusError, RequestError
from pydantic import BaseModel, ValidationError
from pydantic import ValidationError
from core.entities.mcp_provider import MCPProviderEntity, MCPSupportGrantType
from core.helper import ssrf_proxy
from core.mcp.entities import AuthAction, AuthActionType, AuthResult, OAuthCallbackState
from core.mcp.error import MCPRefreshTokenError
from core.mcp.types import (
LATEST_PROTOCOL_VERSION,
@ -23,23 +23,10 @@ from core.mcp.types import (
)
from extensions.ext_redis import redis_client
if TYPE_CHECKING:
from services.tools.mcp_tools_manage_service import MCPToolManageService
OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry
OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
class OAuthCallbackState(BaseModel):
provider_id: str
tenant_id: str
server_url: str
metadata: OAuthMetadata | None = None
client_information: OAuthClientInformation
code_verifier: str
redirect_uri: str
def generate_pkce_challenge() -> tuple[str, str]:
"""Generate PKCE challenge and verifier."""
code_verifier = base64.urlsafe_b64encode(os.urandom(40)).decode("utf-8")
@ -86,8 +73,13 @@ def _retrieve_redis_state(state_key: str) -> OAuthCallbackState:
raise ValueError(f"Invalid state parameter: {str(e)}")
def handle_callback(state_key: str, authorization_code: str, mcp_service: "MCPToolManageService") -> OAuthCallbackState:
"""Handle the callback from the OAuth provider."""
def handle_callback(state_key: str, authorization_code: str) -> tuple[OAuthCallbackState, OAuthTokens]:
"""
Handle the callback from the OAuth provider.
Returns:
A tuple of (callback_state, tokens) that can be used by the caller to save data.
"""
# Retrieve state data from Redis (state is automatically deleted after retrieval)
full_state_data = _retrieve_redis_state(state_key)
@ -100,10 +92,7 @@ def handle_callback(state_key: str, authorization_code: str, mcp_service: "MCPTo
full_state_data.redirect_uri,
)
# Save tokens using the service layer
mcp_service.save_oauth_data(full_state_data.provider_id, full_state_data.tenant_id, tokens.model_dump(), "tokens")
return full_state_data
return full_state_data, tokens
def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
@ -361,11 +350,24 @@ def register_client(
def auth(
provider: MCPProviderEntity,
mcp_service: "MCPToolManageService",
authorization_code: str | None = None,
state_param: str | None = None,
) -> dict[str, str]:
"""Orchestrates the full auth flow with a server using secure Redis state storage."""
) -> AuthResult:
"""
Orchestrates the full auth flow with a server using secure Redis state storage.
This function performs only network operations and returns actions that need
to be performed by the caller (such as saving data to database).
Args:
provider: The MCP provider entity
authorization_code: Optional authorization code from OAuth callback
state_param: Optional state parameter from OAuth callback
Returns:
AuthResult containing actions to be performed and response data
"""
actions: list[AuthAction] = []
server_url = provider.decrypt_server_url()
server_metadata = discover_oauth_metadata(server_url)
client_metadata = provider.client_metadata
@ -407,9 +409,14 @@ def auth(
except RequestError as e:
raise ValueError(f"Could not register OAuth client: {e}")
# Save client information using service layer
mcp_service.save_oauth_data(
provider_id, tenant_id, {"client_information": full_information.model_dump()}, "client_info"
# Return action to save client information
actions.append(
AuthAction(
action_type=AuthActionType.SAVE_CLIENT_INFO,
data={"client_information": full_information.model_dump()},
provider_id=provider_id,
tenant_id=tenant_id,
)
)
client_information = full_information
@ -426,12 +433,20 @@ def auth(
scope,
)
# Save tokens and grant type
# Return action to save tokens and grant type
token_data = tokens.model_dump()
token_data["grant_type"] = MCPSupportGrantType.CLIENT_CREDENTIALS.value
mcp_service.save_oauth_data(provider_id, tenant_id, token_data, "tokens")
return {"result": "success"}
actions.append(
AuthAction(
action_type=AuthActionType.SAVE_TOKENS,
data=token_data,
provider_id=provider_id,
tenant_id=tenant_id,
)
)
return AuthResult(actions=actions, response={"result": "success"})
except (RequestError, ValueError, KeyError) as e:
# RequestError: HTTP request failed
# ValueError: Invalid response data
@ -465,10 +480,17 @@ def auth(
redirect_uri,
)
# Save tokens using service layer
mcp_service.save_oauth_data(provider_id, tenant_id, tokens.model_dump(), "tokens")
# Return action to save tokens
actions.append(
AuthAction(
action_type=AuthActionType.SAVE_TOKENS,
data=tokens.model_dump(),
provider_id=provider_id,
tenant_id=tenant_id,
)
)
return {"result": "success"}
return AuthResult(actions=actions, response={"result": "success"})
provider_tokens = provider.retrieve_tokens()
@ -479,10 +501,17 @@ def auth(
server_url, server_metadata, client_information, provider_tokens.refresh_token
)
# Save new tokens using service layer
mcp_service.save_oauth_data(provider_id, tenant_id, new_tokens.model_dump(), "tokens")
# Return action to save new tokens
actions.append(
AuthAction(
action_type=AuthActionType.SAVE_TOKENS,
data=new_tokens.model_dump(),
provider_id=provider_id,
tenant_id=tenant_id,
)
)
return {"result": "success"}
return AuthResult(actions=actions, response={"result": "success"})
except (RequestError, ValueError, KeyError) as e:
# RequestError: HTTP request failed
# ValueError: Invalid response data
@ -499,7 +528,14 @@ def auth(
tenant_id,
)
# Save code verifier using service layer
mcp_service.save_oauth_data(provider_id, tenant_id, {"code_verifier": code_verifier}, "code_verifier")
# Return action to save code verifier
actions.append(
AuthAction(
action_type=AuthActionType.SAVE_CODE_VERIFIER,
data={"code_verifier": code_verifier},
provider_id=provider_id,
tenant_id=tenant_id,
)
)
return {"authorization_url": authorization_url}
return AuthResult(actions=actions, response={"authorization_url": authorization_url})

View File

@ -7,15 +7,15 @@ authentication failures and retries operations after refreshing tokens.
import logging
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Optional
from typing import Any
from sqlalchemy.orm import Session
from core.entities.mcp_provider import MCPProviderEntity
from core.mcp.error import MCPAuthError
from core.mcp.mcp_client import MCPClient
from core.mcp.types import CallToolResult, Tool
if TYPE_CHECKING:
from services.tools.mcp_tools_manage_service import MCPToolManageService
from extensions.ext_database import db
logger = logging.getLogger(__name__)
@ -26,6 +26,9 @@ class MCPClientWithAuthRetry(MCPClient):
This class extends MCPClient and intercepts MCPAuthError exceptions
to refresh authentication before retrying failed operations.
Note: This class uses lazy session creation - database sessions are only
created when authentication retry is actually needed, not on every request.
"""
def __init__(
@ -35,11 +38,8 @@ class MCPClientWithAuthRetry(MCPClient):
timeout: float | None = None,
sse_read_timeout: float | None = None,
provider_entity: MCPProviderEntity | None = None,
auth_callback: Callable[[MCPProviderEntity, "MCPToolManageService", Optional[str]], dict[str, str]]
| None = None,
authorization_code: str | None = None,
by_server_id: bool = False,
mcp_service: Optional["MCPToolManageService"] = None,
):
"""
Initialize the MCP client with auth retry capability.
@ -50,31 +50,30 @@ class MCPClientWithAuthRetry(MCPClient):
timeout: Request timeout
sse_read_timeout: SSE read timeout
provider_entity: Provider entity for authentication
auth_callback: Authentication callback function
authorization_code: Optional authorization code for initial auth
by_server_id: Whether to look up provider by server ID
mcp_service: MCP service instance
"""
super().__init__(server_url, headers, timeout, sse_read_timeout)
self.provider_entity = provider_entity
self.auth_callback = auth_callback
self.authorization_code = authorization_code
self.by_server_id = by_server_id
self.mcp_service = mcp_service
self._has_retried = False
def _handle_auth_error(self, error: MCPAuthError) -> None:
"""
Handle authentication error by refreshing tokens.
This method creates a short-lived database session only when authentication
retry is needed, minimizing database connection hold time.
Args:
error: The authentication error
Raises:
MCPAuthError: If authentication fails or max retries reached
"""
if not self.provider_entity or not self.auth_callback or not self.mcp_service:
if not self.provider_entity:
raise error
if self._has_retried:
raise error
@ -82,13 +81,23 @@ class MCPClientWithAuthRetry(MCPClient):
self._has_retried = True
try:
# Perform authentication
self.auth_callback(self.provider_entity, self.mcp_service, self.authorization_code)
# Create a temporary session only for auth retry
# This session is short-lived and only exists during the auth operation
# Retrieve new tokens
self.provider_entity = self.mcp_service.get_provider_entity(
self.provider_entity.id, self.provider_entity.tenant_id, by_server_id=self.by_server_id
)
from services.tools.mcp_tools_manage_service import MCPToolManageService
with Session(db.engine) as session, session.begin():
mcp_service = MCPToolManageService(session=session)
# Perform authentication using the service's auth method
mcp_service.auth_with_actions(self.provider_entity, self.authorization_code)
# Retrieve new tokens
self.provider_entity = mcp_service.get_provider_entity(
self.provider_entity.id, self.provider_entity.tenant_id, by_server_id=self.by_server_id
)
# Session is closed here, before we update headers
token = self.provider_entity.retrieve_tokens()
if not token:
raise MCPAuthError("Authentication failed - no token received")

View File

@ -1,8 +1,11 @@
from dataclasses import dataclass
from enum import StrEnum
from typing import Any, Generic, TypeVar
from pydantic import BaseModel
from core.mcp.session.base_session import BaseSession
from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestId, RequestParams
from core.mcp.types import LATEST_PROTOCOL_VERSION, OAuthClientInformation, OAuthMetadata, RequestId, RequestParams
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION]
@ -17,3 +20,41 @@ class RequestContext(Generic[SessionT, LifespanContextT]):
meta: RequestParams.Meta | None
session: SessionT
lifespan_context: LifespanContextT
class AuthActionType(StrEnum):
"""Types of actions that can be performed during auth flow."""
SAVE_CLIENT_INFO = "save_client_info"
SAVE_TOKENS = "save_tokens"
SAVE_CODE_VERIFIER = "save_code_verifier"
START_AUTHORIZATION = "start_authorization"
SUCCESS = "success"
class AuthAction(BaseModel):
"""Represents an action that needs to be performed as a result of auth flow."""
action_type: AuthActionType
data: dict[str, Any]
provider_id: str | None = None
tenant_id: str | None = None
class AuthResult(BaseModel):
"""Result of auth function containing actions to be performed and response data."""
actions: list[AuthAction]
response: dict[str, str]
class OAuthCallbackState(BaseModel):
"""State data stored in Redis during OAuth callback flow."""
provider_id: str
tenant_id: str
server_url: str
metadata: OAuthMetadata | None = None
client_information: OAuthClientInformation
code_verifier: str
redirect_uri: str

View File

@ -3,7 +3,6 @@ import json
from collections.abc import Generator
from typing import Any
from core.mcp.auth.auth_flow import auth
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPConnectionError
from core.mcp.types import CallToolResult, ImageContent, TextContent
@ -125,71 +124,39 @@ class MCPTool(Tool):
headers = self.headers.copy() if self.headers else {}
tool_parameters = self._handle_none_parameter(tool_parameters)
# Get provider entity to access tokens
from sqlalchemy.orm import Session
# Get MCP service from invoke parameters or create new one
provider_entity = None
mcp_service = None
from extensions.ext_database import db
from services.tools.mcp_tools_manage_service import MCPToolManageService
# Check if mcp_service is passed in tool_parameters
if "_mcp_service" in tool_parameters:
mcp_service = tool_parameters.pop("_mcp_service")
if mcp_service:
provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True)
headers = provider_entity.decrypt_headers()
# Try to get existing token and add to headers
if not headers:
tokens = provider_entity.retrieve_tokens()
if tokens and tokens.access_token:
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
# Step 1: Load provider entity and credentials in a short-lived session
# This minimizes database connection hold time
with Session(db.engine, expire_on_commit=False) as session:
mcp_service = MCPToolManageService(session=session)
provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True)
# Use MCPClientWithAuthRetry to handle authentication automatically
try:
with MCPClientWithAuthRetry(
server_url=provider_entity.decrypt_server_url() if provider_entity else self.server_url,
headers=headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
provider_entity=provider_entity,
auth_callback=auth if mcp_service else None,
mcp_service=mcp_service,
) as mcp_client:
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
except MCPConnectionError as e:
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
except (ValueError, TypeError, KeyError) as e:
# Catch specific exceptions that might occur during tool invocation
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
else:
# Fallback to creating service with database session
from sqlalchemy.orm import Session
# Decrypt and prepare all credentials before closing session
server_url = provider_entity.decrypt_server_url()
headers = provider_entity.decrypt_headers()
from extensions.ext_database import db
from services.tools.mcp_tools_manage_service import MCPToolManageService
# Try to get existing token and add to headers
if not headers:
tokens = provider_entity.retrieve_tokens()
if tokens and tokens.access_token:
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
with Session(db.engine, expire_on_commit=False) as session:
mcp_service = MCPToolManageService(session=session)
provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True)
headers = provider_entity.decrypt_headers()
# Try to get existing token and add to headers
if not headers:
tokens = provider_entity.retrieve_tokens()
if tokens and tokens.access_token:
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
# Use MCPClientWithAuthRetry to handle authentication automatically
try:
with MCPClientWithAuthRetry(
server_url=provider_entity.decrypt_server_url() if provider_entity else self.server_url,
headers=headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
provider_entity=provider_entity,
auth_callback=auth if mcp_service else None,
mcp_service=mcp_service,
) as mcp_client:
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
except MCPConnectionError as e:
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
except Exception as e:
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
# Step 2: Session is now closed, perform network operations without holding database connection
# MCPClientWithAuthRetry will create a new session lazily only if auth retry is needed
try:
with MCPClientWithAuthRetry(
server_url=server_url,
headers=headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
provider_entity=provider_entity,
) as mcp_client:
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
except MCPConnectionError as e:
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
except Exception as e:
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e

View File

@ -1,10 +1,12 @@
import hashlib
import json
import logging
from collections.abc import Callable
from datetime import datetime
from enum import StrEnum
from typing import Any
from urllib.parse import urlparse
from pydantic import BaseModel, Field
from sqlalchemy import or_, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
@ -12,6 +14,7 @@ from sqlalchemy.orm import Session
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity
from core.helper import encrypter
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.mcp.auth.auth_flow import auth
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPAuthError, MCPError
from core.tools.entities.api_entities import ToolProviderApiEntity
@ -28,6 +31,38 @@ EMPTY_TOOLS_JSON = "[]"
EMPTY_CREDENTIALS_JSON = "{}"
class OAuthDataType(StrEnum):
"""Types of OAuth data that can be saved."""
TOKENS = "tokens"
CLIENT_INFO = "client_info"
CODE_VERIFIER = "code_verifier"
MIXED = "mixed"
class ReconnectResult(BaseModel):
"""Result of reconnecting to an MCP provider"""
authed: bool = Field(description="Whether the provider is authenticated")
tools: str = Field(description="JSON string of tool list")
encrypted_credentials: str = Field(description="JSON string of encrypted credentials")
class ServerUrlValidationResult(BaseModel):
"""Result of server URL validation check"""
needs_validation: bool
validation_passed: bool = False
reconnect_result: ReconnectResult | None = None
encrypted_server_url: str | None = None
server_url_hash: str | None = None
@property
def should_update_server_url(self) -> bool:
"""Check if server URL should be updated based on validation result"""
return self.needs_validation and self.validation_passed and self.reconnect_result is not None
class MCPToolManageService:
"""Service class for managing MCP tools and providers."""
@ -91,6 +126,10 @@ class MCPToolManageService:
headers: dict[str, str] | None = None,
) -> ToolProviderApiEntity:
"""Create a new MCP provider."""
# Validate URL format
if not self._is_valid_url(server_url):
raise ValueError("Server URL is not valid.")
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
# Check for existing provider
@ -99,13 +138,12 @@ class MCPToolManageService:
# Encrypt sensitive data
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
encrypted_headers = self._prepare_encrypted_dict(headers, tenant_id) if headers else None
if authentication is not None and authentication.client_id and authentication.client_secret:
# Build the full credentials structure with encrypted client_id and client_secret
encrypted_credentials = None
if authentication is not None and authentication.client_id:
encrypted_credentials = self._build_and_encrypt_credentials(
authentication.client_id, authentication.client_secret, tenant_id
)
else:
encrypted_credentials = None
# Create provider
mcp_tool = MCPToolProvider(
tenant_id=tenant_id,
@ -142,24 +180,39 @@ class MCPToolManageService:
headers: dict[str, str] | None = None,
configuration: MCPConfiguration,
authentication: MCPAuthentication | None = None,
validation_result: ServerUrlValidationResult | None = None,
) -> None:
"""Update an MCP provider."""
"""
Update an MCP provider.
Args:
validation_result: Pre-validation result from validate_server_url_change.
If provided and contains reconnect_result, it will be used
instead of performing network operations.
"""
mcp_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
reconnect_result = None
# Check for duplicate name (excluding current provider)
if name != mcp_provider.name:
stmt = select(MCPToolProvider).where(
MCPToolProvider.tenant_id == tenant_id,
MCPToolProvider.name == name,
MCPToolProvider.id != provider_id,
)
existing_provider = self._session.scalar(stmt)
if existing_provider:
raise ValueError(f"MCP tool {name} already exists")
# Get URL update data from validation result
encrypted_server_url = None
server_url_hash = None
reconnect_result = None
# Handle server URL update
if UNCHANGED_SERVER_URL_PLACEHOLDER not in server_url:
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
if server_url_hash != mcp_provider.server_url_hash:
reconnect_result = self._reconnect_provider(
server_url=server_url,
provider=mcp_provider,
)
if validation_result and validation_result.encrypted_server_url:
# Use all data from validation result
encrypted_server_url = validation_result.encrypted_server_url
server_url_hash = validation_result.server_url_hash
reconnect_result = validation_result.reconnect_result
try:
# Update basic fields
@ -169,63 +222,35 @@ class MCPToolManageService:
mcp_provider.server_identifier = server_identifier
# Update server URL if changed
if encrypted_server_url is not None and server_url_hash is not None:
if encrypted_server_url and server_url_hash:
mcp_provider.server_url = encrypted_server_url
mcp_provider.server_url_hash = server_url_hash
if reconnect_result:
mcp_provider.authed = reconnect_result["authed"]
mcp_provider.tools = reconnect_result["tools"]
mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"]
mcp_provider.authed = reconnect_result.authed
mcp_provider.tools = reconnect_result.tools
mcp_provider.encrypted_credentials = reconnect_result.encrypted_credentials
# Update optional fields
if configuration.timeout is not None:
mcp_provider.timeout = configuration.timeout
if configuration.sse_read_timeout is not None:
mcp_provider.sse_read_timeout = configuration.sse_read_timeout
# Update optional configuration fields
self._update_optional_fields(mcp_provider, configuration)
# Update headers if provided
if headers is not None:
if headers:
# Build headers preserving unchanged masked values
final_headers = self._merge_headers_with_masked(incoming_headers=headers, mcp_provider=mcp_provider)
encrypted_headers_dict = self._prepare_encrypted_dict(final_headers, tenant_id)
mcp_provider.encrypted_headers = encrypted_headers_dict
else:
# Clear headers if empty dict passed
mcp_provider.encrypted_headers = None
mcp_provider.encrypted_headers = self._process_headers(headers, mcp_provider, tenant_id)
# Update credentials if provided
if authentication is not None and authentication.client_id and authentication.client_secret:
# Merge with existing credentials to handle masked values
(
final_client_id,
final_client_secret,
) = self._merge_credentials_with_masked(
authentication.client_id, authentication.client_secret, mcp_provider
)
if authentication and authentication.client_id:
mcp_provider.encrypted_credentials = self._process_credentials(authentication, mcp_provider, tenant_id)
# Build and encrypt new credentials
encrypted_credentials = self._build_and_encrypt_credentials(
final_client_id, final_client_secret, tenant_id
)
mcp_provider.encrypted_credentials = encrypted_credentials
self._session.commit()
# Flush changes to database
self._session.flush()
except IntegrityError as e:
self._session.rollback()
self._handle_integrity_error(e, name, server_url, server_identifier)
except (ValueError, AttributeError, TypeError) as e:
# Catch specific exceptions that might occur during update
# ValueError: invalid data provided
# AttributeError: missing required attributes
# TypeError: type conversion errors
self._session.rollback()
raise
def delete_provider(self, *, tenant_id: str, provider_id: str) -> None:
"""Delete an MCP provider."""
mcp_tool = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
self._session.delete(mcp_tool)
self._session.commit()
def list_providers(self, *, tenant_id: str, for_list: bool = False) -> list[ToolProviderApiEntity]:
"""List all MCP providers for a tenant."""
@ -241,8 +266,6 @@ class MCPToolManageService:
def list_provider_tools(self, *, tenant_id: str, provider_id: str) -> ToolProviderApiEntity:
"""List tools from remote MCP server."""
from core.mcp.auth.auth_flow import auth
# Load provider and convert to entity
db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
provider_entity = db_provider.to_entity()
@ -257,9 +280,7 @@ class MCPToolManageService:
# Retrieve tools from remote server
server_url = provider_entity.decrypt_server_url()
try:
tools = self._retrieve_remote_mcp_tools(
server_url, headers, provider_entity, lambda p, s, c: auth(p, self, c)
)
tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity)
except MCPError as e:
raise ValueError(f"Failed to connect to MCP server: {e}")
@ -305,9 +326,12 @@ class MCPToolManageService:
if not authed:
provider.tools = EMPTY_TOOLS_JSON
self._session.commit()
# Flush changes to database
self._session.flush()
def save_oauth_data(self, provider_id: str, tenant_id: str, data: dict[str, Any], data_type: str = "mixed") -> None:
def save_oauth_data(
self, provider_id: str, tenant_id: str, data: dict[str, Any], data_type: OAuthDataType = OAuthDataType.MIXED
) -> None:
"""
Save OAuth-related data (tokens, client info, code verifier).
@ -315,12 +339,14 @@ class MCPToolManageService:
provider_id: Provider ID
tenant_id: Tenant ID
data: Data to save (tokens, client info, or code verifier)
data_type: Type of data ('tokens', 'client_info', 'code_verifier', 'mixed')
data_type: Type of OAuth data to save
"""
db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
# Determine if this makes the provider authenticated
authed = data_type == "tokens" or (data_type == "mixed" and "access_token" in data) or None
authed = (
data_type == OAuthDataType.TOKENS or (data_type == OAuthDataType.MIXED and "access_token" in data) or None
)
self.update_provider_credentials(provider=db_provider, credentials=data, authed=authed)
@ -330,7 +356,6 @@ class MCPToolManageService:
provider.encrypted_credentials = EMPTY_CREDENTIALS_JSON
provider.updated_at = datetime.now()
provider.authed = False
self._session.commit()
# ========== Private Helper Methods ==========
@ -406,41 +431,123 @@ class MCPToolManageService:
server_url: str,
headers: dict[str, str],
provider_entity: MCPProviderEntity,
auth_callback: Callable[[MCPProviderEntity, "MCPToolManageService", str | None], dict[str, str]],
):
"""Retrieve tools from remote MCP server."""
with MCPClientWithAuthRetry(
server_url,
server_url=server_url,
headers=headers,
timeout=provider_entity.timeout,
sse_read_timeout=provider_entity.sse_read_timeout,
provider_entity=provider_entity,
auth_callback=auth_callback,
mcp_service=self,
) as mcp_client:
return mcp_client.list_tools()
def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> dict[str, Any]:
"""Attempt to reconnect to MCP provider with new server URL."""
from core.mcp.auth.auth_flow import auth
def execute_auth_actions(self, auth_result: Any) -> dict[str, str]:
"""
Execute the actions returned by the auth function.
This method processes the AuthResult and performs the necessary database operations.
Args:
auth_result: The result from the auth function
Returns:
The response from the auth result
"""
from core.mcp.entities import AuthAction, AuthActionType
action: AuthAction
for action in auth_result.actions:
if action.provider_id is None or action.tenant_id is None:
continue
if action.action_type == AuthActionType.SAVE_CLIENT_INFO:
self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.CLIENT_INFO)
elif action.action_type == AuthActionType.SAVE_TOKENS:
self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.TOKENS)
elif action.action_type == AuthActionType.SAVE_CODE_VERIFIER:
self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.CODE_VERIFIER)
return auth_result.response
def auth_with_actions(
self, provider_entity: MCPProviderEntity, authorization_code: str | None = None
) -> dict[str, str]:
"""
Perform authentication and execute all resulting actions.
This method is used by MCPClientWithAuthRetry for automatic re-authentication.
Args:
provider_entity: The MCP provider entity
authorization_code: Optional authorization code
Returns:
Response dictionary from auth result
"""
auth_result = auth(provider_entity, authorization_code)
return self.execute_auth_actions(auth_result)
def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> ReconnectResult:
"""Attempt to reconnect to MCP provider with new server URL."""
provider_entity = provider.to_entity()
headers = provider_entity.headers
try:
tools = self._retrieve_remote_mcp_tools(
server_url, headers, provider_entity, lambda p, s, c: auth(p, self, c)
tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity)
return ReconnectResult(
authed=True,
tools=json.dumps([tool.model_dump() for tool in tools]),
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
)
return {
"authed": True,
"tools": json.dumps([tool.model_dump() for tool in tools]),
"encrypted_credentials": EMPTY_CREDENTIALS_JSON,
}
except MCPAuthError:
return {"authed": False, "tools": EMPTY_TOOLS_JSON, "encrypted_credentials": EMPTY_CREDENTIALS_JSON}
return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON)
except MCPError as e:
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
def validate_server_url_change(
self, *, tenant_id: str, provider_id: str, new_server_url: str
) -> ServerUrlValidationResult:
"""
Validate server URL change by attempting to connect to the new server.
This method should be called BEFORE update_provider to perform network operations
outside of the database transaction.
Returns:
ServerUrlValidationResult: Validation result with connection status and tools if successful
"""
# Handle hidden/unchanged URL
if UNCHANGED_SERVER_URL_PLACEHOLDER in new_server_url:
return ServerUrlValidationResult(needs_validation=False)
# Validate URL format
if not self._is_valid_url(new_server_url):
raise ValueError("Server URL is not valid.")
# Always encrypt and hash the URL
encrypted_server_url = encrypter.encrypt_token(tenant_id, new_server_url)
new_server_url_hash = hashlib.sha256(new_server_url.encode()).hexdigest()
# Get current provider
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
# Check if URL is actually different
if new_server_url_hash == provider.server_url_hash:
# URL hasn't changed, but still return the encrypted data
return ServerUrlValidationResult(
needs_validation=False, encrypted_server_url=encrypted_server_url, server_url_hash=new_server_url_hash
)
# Perform validation by attempting to connect
reconnect_result = self._reconnect_provider(server_url=new_server_url, provider=provider)
return ServerUrlValidationResult(
needs_validation=True,
validation_passed=True,
reconnect_result=reconnect_result,
encrypted_server_url=encrypted_server_url,
server_url_hash=new_server_url_hash,
)
def _build_tool_provider_response(
self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list
) -> ToolProviderApiEntity:
@ -466,6 +573,45 @@ class MCPToolManageService:
raise ValueError(f"MCP tool {server_identifier} already exists")
raise
def _is_valid_url(self, url: str) -> bool:
"""Validate URL format."""
if not url:
return False
try:
parsed = urlparse(url)
return all([parsed.scheme, parsed.netloc]) and parsed.scheme in ["http", "https"]
except (ValueError, TypeError):
return False
def _update_optional_fields(self, mcp_provider: MCPToolProvider, configuration: MCPConfiguration) -> None:
"""Update optional configuration fields using setattr for cleaner code."""
field_mapping = {"timeout": configuration.timeout, "sse_read_timeout": configuration.sse_read_timeout}
for field, value in field_mapping.items():
if value is not None:
setattr(mcp_provider, field, value)
def _process_headers(self, headers: dict[str, str], mcp_provider: MCPToolProvider, tenant_id: str) -> str | None:
"""Process headers update, handling empty dict to clear headers."""
if not headers:
return None
# Merge with existing headers to preserve masked values
final_headers = self._merge_headers_with_masked(incoming_headers=headers, mcp_provider=mcp_provider)
return self._prepare_encrypted_dict(final_headers, tenant_id)
def _process_credentials(
self, authentication: MCPAuthentication, mcp_provider: MCPToolProvider, tenant_id: str
) -> str:
"""Process credentials update, handling masked values."""
# Merge with existing credentials
final_client_id, final_client_secret = self._merge_credentials_with_masked(
authentication.client_id, authentication.client_secret, mcp_provider
)
# Build and encrypt
return self._build_and_encrypt_credentials(final_client_id, final_client_secret, tenant_id)
def _merge_headers_with_masked(
self, incoming_headers: dict[str, str], mcp_provider: MCPToolProvider
) -> dict[str, str]:
@ -530,12 +676,12 @@ class MCPToolManageService:
# Create a flat structure with all credential data
credentials_data = {
"client_id": client_id,
"encrypted_client_secret": client_secret,
"client_name": CLIENT_NAME,
"is_dynamic_registration": False,
}
# Only client_id and client_secret need encryption
secret_fields = ["encrypted_client_secret"] if client_secret else []
secret_fields = []
if client_secret is not None:
credentials_data["encrypted_client_secret"] = encrypter.encrypt_token(tenant_id, client_secret)
secret_fields = ["encrypted_client_secret"]
client_info = self._encrypt_dict_fields(credentials_data, secret_fields, tenant_id)
return json.dumps({"client_information": client_info})

View File

@ -1108,75 +1108,6 @@ class TestMCPToolManageService:
assert icon_data["content"] == "🚀"
assert icon_data["background"] == "#4ECDC4"
def test_update_mcp_provider_with_server_url_change(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful update of MCP provider with server URL change.
This test verifies:
- Proper handling of server URL changes
- Correct reconnection logic
- Database state updates
- External service integration
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Create MCP provider
mcp_provider = self._create_test_mcp_provider(
db_session_with_containers, mock_external_service_dependencies, tenant.id, account.id
)
from extensions.ext_database import db
db.session.commit()
# Mock the reconnection method
with patch.object(MCPToolManageService, "_reconnect_provider") as mock_reconnect:
mock_reconnect.return_value = {
"authed": True,
"tools": '[{"name": "test_tool"}]',
"encrypted_credentials": "{}",
}
# Act: Execute the method under test
from core.entities.mcp_provider import MCPConfiguration
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service.update_provider(
tenant_id=tenant.id,
provider_id=mcp_provider.id,
name="Updated MCP Provider",
server_url="https://new-example.com/mcp",
icon="🚀",
icon_type="emoji",
icon_background="#4ECDC4",
server_identifier="updated_identifier_123",
configuration=MCPConfiguration(
timeout=45.0,
sse_read_timeout=400.0,
),
)
# Assert: Verify the expected outcomes
db.session.refresh(mcp_provider)
assert mcp_provider.name == "Updated MCP Provider"
assert mcp_provider.server_identifier == "updated_identifier_123"
assert mcp_provider.timeout == 45.0
assert mcp_provider.sse_read_timeout == 400.0
assert mcp_provider.updated_at is not None
# Verify reconnection was called
mock_reconnect.assert_called_once_with(
server_url="https://new-example.com/mcp",
provider=mcp_provider,
)
def test_update_mcp_provider_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test error handling when updating MCP provider with duplicate name.
@ -1387,14 +1318,14 @@ class TestMCPToolManageService:
# Assert: Verify the expected outcomes
assert result is not None
assert result["authed"] is True
assert result["tools"] is not None
assert result["encrypted_credentials"] == "{}"
assert result.authed is True
assert result.tools is not None
assert result.encrypted_credentials == "{}"
# Verify tools were properly serialized
import json
tools_data = json.loads(result["tools"])
tools_data = json.loads(result.tools)
assert len(tools_data) == 2
assert tools_data[0]["name"] == "test_tool_1"
assert tools_data[1]["name"] == "test_tool_2"
@ -1441,9 +1372,9 @@ class TestMCPToolManageService:
# Assert: Verify the expected outcomes
assert result is not None
assert result["authed"] is False
assert result["tools"] == "[]"
assert result["encrypted_credentials"] == "{}"
assert result.authed is False
assert result.tools == "[]"
assert result.encrypted_credentials == "{}"
def test_re_connect_mcp_provider_connection_error(
self, db_session_with_containers, mock_external_service_dependencies

View File

@ -21,6 +21,7 @@ from core.mcp.auth.auth_flow import (
register_client,
start_authorization,
)
from core.mcp.entities import AuthActionType, AuthResult
from core.mcp.types import (
OAuthClientInformation,
OAuthClientInformationFull,
@ -527,9 +528,10 @@ class TestCallbackHandling:
# Setup service
mock_service = Mock()
result = handle_callback("state-key", "auth-code", mock_service)
state_result, tokens_result = handle_callback("state-key", "auth-code")
assert result == state_data
assert state_result == state_data
assert tokens_result == tokens
# Verify calls
mock_retrieve_state.assert_called_once_with("state-key")
@ -541,9 +543,8 @@ class TestCallbackHandling:
"test-verifier",
"https://redirect.example.com",
)
mock_service.save_oauth_data.assert_called_once_with(
"test-provider", "test-tenant", tokens.model_dump(), "tokens"
)
# Note: handle_callback no longer saves tokens directly, it just returns them
# The caller (e.g., controller) is responsible for saving via execute_auth_actions
class TestAuthOrchestration:
@ -589,21 +590,28 @@ class TestAuthOrchestration:
)
mock_start_auth.return_value = ("https://auth.example.com/authorize?...", "code-verifier")
result = auth(mock_provider, mock_service)
result = auth(mock_provider)
assert result == {"authorization_url": "https://auth.example.com/authorize?..."}
# auth() now returns AuthResult
assert isinstance(result, AuthResult)
assert result.response == {"authorization_url": "https://auth.example.com/authorize?..."}
# Verify that the result contains the correct actions
assert len(result.actions) == 2
# Check for SAVE_CLIENT_INFO action
client_info_action = next(a for a in result.actions if a.action_type == AuthActionType.SAVE_CLIENT_INFO)
assert client_info_action.data == {"client_information": mock_register.return_value.model_dump()}
assert client_info_action.provider_id == "provider-id"
assert client_info_action.tenant_id == "tenant-id"
# Check for SAVE_CODE_VERIFIER action
verifier_action = next(a for a in result.actions if a.action_type == AuthActionType.SAVE_CODE_VERIFIER)
assert verifier_action.data == {"code_verifier": "code-verifier"}
assert verifier_action.provider_id == "provider-id"
assert verifier_action.tenant_id == "tenant-id"
# Verify calls
mock_register.assert_called_once()
mock_service.save_oauth_data.assert_any_call(
"provider-id",
"tenant-id",
{"client_information": mock_register.return_value.model_dump()},
"client_info",
)
mock_service.save_oauth_data.assert_any_call(
"provider-id", "tenant-id", {"code_verifier": "code-verifier"}, "code_verifier"
)
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
@patch("core.mcp.auth.auth_flow._retrieve_redis_state")
@ -637,12 +645,18 @@ class TestAuthOrchestration:
tokens = OAuthTokens(access_token="new-token", token_type="Bearer", expires_in=3600)
mock_exchange.return_value = tokens
result = auth(mock_provider, mock_service, authorization_code="auth-code", state_param="state-key")
result = auth(mock_provider, authorization_code="auth-code", state_param="state-key")
assert result == {"result": "success"}
# auth() now returns AuthResult, not a dict
assert isinstance(result, AuthResult)
assert result.response == {"result": "success"}
# Verify token save
mock_service.save_oauth_data.assert_called_with("provider-id", "tenant-id", tokens.model_dump(), "tokens")
# Verify that the result contains the correct action
assert len(result.actions) == 1
assert result.actions[0].action_type == AuthActionType.SAVE_TOKENS
assert result.actions[0].data == tokens.model_dump()
assert result.actions[0].provider_id == "provider-id"
assert result.actions[0].tenant_id == "tenant-id"
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
def test_auth_exchange_code_without_state(self, mock_discover, mock_provider, mock_service):
@ -658,7 +672,7 @@ class TestAuthOrchestration:
mock_provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="existing-client")
with pytest.raises(ValueError) as exc_info:
auth(mock_provider, mock_service, authorization_code="auth-code")
auth(mock_provider, authorization_code="auth-code")
assert "State parameter is required" in str(exc_info.value)
@ -691,15 +705,21 @@ class TestAuthOrchestration:
grant_types_supported=["authorization_code"],
)
result = auth(mock_provider, mock_service)
result = auth(mock_provider)
assert result == {"result": "success"}
# auth() now returns AuthResult
assert isinstance(result, AuthResult)
assert result.response == {"result": "success"}
# Verify that the result contains the correct action
assert len(result.actions) == 1
assert result.actions[0].action_type == AuthActionType.SAVE_TOKENS
assert result.actions[0].data == new_tokens.model_dump()
assert result.actions[0].provider_id == "provider-id"
assert result.actions[0].tenant_id == "tenant-id"
# Verify refresh was called
mock_refresh.assert_called_once()
mock_service.save_oauth_data.assert_called_with(
"provider-id", "tenant-id", new_tokens.model_dump(), "tokens"
)
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
def test_auth_registration_fails_with_code(self, mock_discover, mock_provider, mock_service):
@ -715,6 +735,6 @@ class TestAuthOrchestration:
mock_provider.retrieve_client_information.return_value = None
with pytest.raises(ValueError) as exc_info:
auth(mock_provider, mock_service, authorization_code="auth-code")
auth(mock_provider, authorization_code="auth-code")
assert "Existing OAuth client information is required" in str(exc_info.value)

View File

@ -1,420 +0,0 @@
"""Unit tests for MCP auth client with retry logic."""
from types import TracebackType
from unittest.mock import Mock, patch
import pytest
from core.entities.mcp_provider import MCPProviderEntity
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPAuthError
from core.mcp.mcp_client import MCPClient
from core.mcp.types import CallToolResult, TextContent, Tool, ToolAnnotations
class TestMCPClientWithAuthRetry:
"""Test suite for MCPClientWithAuthRetry."""
@pytest.fixture
def mock_provider_entity(self):
"""Create a mock provider entity."""
provider = Mock(spec=MCPProviderEntity)
provider.id = "test-provider-id"
provider.tenant_id = "test-tenant-id"
provider.retrieve_tokens.return_value = Mock(
access_token="test-token", token_type="Bearer", expires_in=3600, refresh_token=None
)
return provider
@pytest.fixture
def mock_mcp_service(self):
"""Create a mock MCP service."""
service = Mock()
service.get_provider_entity.return_value = Mock(
retrieve_tokens=lambda: Mock(
access_token="new-test-token", token_type="Bearer", expires_in=3600, refresh_token=None
)
)
return service
@pytest.fixture
def auth_callback(self):
"""Create a mock auth callback."""
return Mock()
def test_init(self, mock_provider_entity, mock_mcp_service, auth_callback):
"""Test client initialization."""
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
headers={"Authorization": "Bearer test"},
timeout=30.0,
sse_read_timeout=60.0,
provider_entity=mock_provider_entity,
auth_callback=auth_callback,
authorization_code="test-auth-code",
by_server_id=True,
mcp_service=mock_mcp_service,
)
assert client.server_url == "http://test.example.com"
assert client.headers == {"Authorization": "Bearer test"}
assert client.timeout == 30.0
assert client.sse_read_timeout == 60.0
assert client.provider_entity == mock_provider_entity
assert client.auth_callback == auth_callback
assert client.authorization_code == "test-auth-code"
assert client.by_server_id is True
assert client.mcp_service == mock_mcp_service
assert client._has_retried is False
# In inheritance design, we don't have _client attribute
assert hasattr(client, "_session") # Inherited from MCPClient
def test_inheritance_structure(self):
"""Test that MCPClientWithAuthRetry properly inherits from MCPClient."""
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
headers={"Authorization": "Bearer test"},
)
# Verify inheritance
assert isinstance(client, MCPClient)
# Verify inherited attributes are accessible
assert hasattr(client, "server_url")
assert hasattr(client, "headers")
assert hasattr(client, "_session")
assert hasattr(client, "_exit_stack")
assert hasattr(client, "_initialized")
def test_handle_auth_error_no_retry_components(self):
"""Test auth error handling when retry components are missing."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
error = MCPAuthError("Auth failed")
with pytest.raises(MCPAuthError) as exc_info:
client._handle_auth_error(error)
assert exc_info.value == error
def test_handle_auth_error_already_retried(self, mock_provider_entity, mock_mcp_service, auth_callback):
"""Test auth error handling when already retried."""
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
provider_entity=mock_provider_entity,
auth_callback=auth_callback,
mcp_service=mock_mcp_service,
)
client._has_retried = True
error = MCPAuthError("Auth failed")
with pytest.raises(MCPAuthError) as exc_info:
client._handle_auth_error(error)
assert exc_info.value == error
auth_callback.assert_not_called()
def test_handle_auth_error_successful_refresh(self, mock_provider_entity, mock_mcp_service, auth_callback):
"""Test successful auth refresh on error."""
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
provider_entity=mock_provider_entity,
auth_callback=auth_callback,
authorization_code="test-code",
by_server_id=True,
mcp_service=mock_mcp_service,
)
# Configure mocks
new_provider = Mock(spec=MCPProviderEntity)
new_provider.id = "test-provider-id"
new_provider.tenant_id = "test-tenant-id"
new_provider.retrieve_tokens.return_value = Mock(
access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None
)
mock_mcp_service.get_provider_entity.return_value = new_provider
error = MCPAuthError("Auth failed")
client._handle_auth_error(error)
# Verify auth flow
auth_callback.assert_called_once_with(mock_provider_entity, mock_mcp_service, "test-code")
mock_mcp_service.get_provider_entity.assert_called_once_with(
"test-provider-id", "test-tenant-id", by_server_id=True
)
assert client.headers["Authorization"] == "Bearer new-token"
assert client.authorization_code is None # Should be cleared after use
assert client._has_retried is True
def test_handle_auth_error_refresh_fails(self, mock_provider_entity, mock_mcp_service, auth_callback):
"""Test auth refresh failure."""
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
provider_entity=mock_provider_entity,
auth_callback=auth_callback,
mcp_service=mock_mcp_service,
)
auth_callback.side_effect = Exception("Auth callback failed")
error = MCPAuthError("Original auth failed")
with pytest.raises(MCPAuthError) as exc_info:
client._handle_auth_error(error)
assert "Authentication retry failed" in str(exc_info.value)
def test_handle_auth_error_no_token_received(self, mock_provider_entity, mock_mcp_service, auth_callback):
"""Test auth refresh when no token is received."""
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
provider_entity=mock_provider_entity,
auth_callback=auth_callback,
mcp_service=mock_mcp_service,
)
# Configure mock to return no token
new_provider = Mock(spec=MCPProviderEntity)
new_provider.retrieve_tokens.return_value = None
mock_mcp_service.get_provider_entity.return_value = new_provider
error = MCPAuthError("Auth failed")
with pytest.raises(MCPAuthError) as exc_info:
client._handle_auth_error(error)
assert "no token received" in str(exc_info.value)
def test_execute_with_retry_success(self):
"""Test successful execution without retry."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
mock_func = Mock(return_value="success")
result = client._execute_with_retry(mock_func, "arg1", kwarg1="value1")
assert result == "success"
mock_func.assert_called_once_with("arg1", kwarg1="value1")
assert client._has_retried is False
def test_execute_with_retry_auth_error_then_success(self, mock_provider_entity, mock_mcp_service, auth_callback):
"""Test execution with auth error followed by successful retry."""
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
provider_entity=mock_provider_entity,
auth_callback=auth_callback,
mcp_service=mock_mcp_service,
)
# Configure new provider with token
new_provider = Mock(spec=MCPProviderEntity)
new_provider.retrieve_tokens.return_value = Mock(
access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None
)
mock_mcp_service.get_provider_entity.return_value = new_provider
# Mock function that fails first, then succeeds
mock_func = Mock(side_effect=[MCPAuthError("Auth failed"), "success"])
# Mock the exit stack and session cleanup
with (
patch.object(client, "_exit_stack") as mock_exit_stack,
patch.object(client, "_session") as mock_session,
patch.object(client, "_initialize") as mock_initialize,
):
client._initialized = True
result = client._execute_with_retry(mock_func, "arg1", kwarg1="value1")
assert result == "success"
assert mock_func.call_count == 2
mock_func.assert_called_with("arg1", kwarg1="value1")
auth_callback.assert_called_once()
mock_exit_stack.close.assert_called_once()
mock_initialize.assert_called_once()
assert client._has_retried is False # Reset after completion
def test_execute_with_retry_non_auth_error(self):
"""Test execution with non-auth error (no retry)."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
mock_func = Mock(side_effect=ValueError("Some other error"))
with pytest.raises(ValueError) as exc_info:
client._execute_with_retry(mock_func)
assert str(exc_info.value) == "Some other error"
mock_func.assert_called_once()
def test_context_manager_enter(self):
"""Test context manager enter."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
with patch.object(client, "_initialize") as mock_initialize:
result = client.__enter__()
assert result == client
assert client._initialized is True
mock_initialize.assert_called_once()
def test_context_manager_enter_with_auth_error(self, mock_provider_entity, mock_mcp_service, auth_callback):
"""Test context manager enter with auth error and retry."""
# Configure new provider with token
new_provider = Mock(spec=MCPProviderEntity)
new_provider.retrieve_tokens.return_value = Mock(
access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None
)
mock_mcp_service.get_provider_entity.return_value = new_provider
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
provider_entity=mock_provider_entity,
auth_callback=auth_callback,
mcp_service=mock_mcp_service,
)
# Mock parent class __enter__ to raise auth error first, then succeed
with patch.object(MCPClient, "__enter__") as mock_parent_enter:
mock_parent_enter.side_effect = [MCPAuthError("Auth failed"), client]
result = client.__enter__()
assert result == client
assert mock_parent_enter.call_count == 2
auth_callback.assert_called_once()
def test_context_manager_exit(self):
"""Test context manager exit."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
with patch.object(client, "cleanup") as mock_cleanup:
exc_type: type[BaseException] | None = None
exc_val: BaseException | None = None
exc_tb: TracebackType | None = None
client.__exit__(exc_type, exc_val, exc_tb)
mock_cleanup.assert_called_once()
def test_list_tools_not_initialized(self):
"""Test list_tools when client not initialized."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
with pytest.raises(ValueError) as exc_info:
client.list_tools()
assert "Session not initialized" in str(exc_info.value)
def test_list_tools_success(self):
"""Test successful list_tools call."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
expected_tools = [
Tool(
name="test-tool",
description="A test tool",
inputSchema={"type": "object", "properties": {}},
annotations=ToolAnnotations(title="Test Tool"),
)
]
# Mock the parent class list_tools method
with patch.object(MCPClient, "list_tools", return_value=expected_tools):
result = client.list_tools()
assert result == expected_tools
def test_list_tools_with_auth_retry(self, mock_provider_entity, mock_mcp_service, auth_callback):
"""Test list_tools with auth retry."""
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
provider_entity=mock_provider_entity,
auth_callback=auth_callback,
mcp_service=mock_mcp_service,
)
# Configure new provider with token
new_provider = Mock(spec=MCPProviderEntity)
new_provider.retrieve_tokens.return_value = Mock(
access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None
)
mock_mcp_service.get_provider_entity.return_value = new_provider
expected_tools = [Tool(name="test-tool", description="A test tool", inputSchema={})]
# Mock parent class list_tools to raise auth error first, then succeed
with patch.object(MCPClient, "list_tools") as mock_list_tools:
mock_list_tools.side_effect = [MCPAuthError("Auth failed"), expected_tools]
result = client.list_tools()
assert result == expected_tools
assert mock_list_tools.call_count == 2
auth_callback.assert_called_once()
def test_invoke_tool_not_initialized(self):
"""Test invoke_tool when client not initialized."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
with pytest.raises(ValueError) as exc_info:
client.invoke_tool("test-tool", {"arg": "value"})
assert "Session not initialized" in str(exc_info.value)
def test_invoke_tool_success(self):
"""Test successful invoke_tool call."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
expected_result = CallToolResult(
content=[TextContent(type="text", text="Tool executed successfully")], isError=False
)
# Mock the parent class invoke_tool method
with patch.object(MCPClient, "invoke_tool", return_value=expected_result) as mock_invoke:
result = client.invoke_tool("test-tool", {"arg": "value"})
assert result == expected_result
mock_invoke.assert_called_once_with("test-tool", {"arg": "value"})
def test_invoke_tool_with_auth_retry(self, mock_provider_entity, mock_mcp_service, auth_callback):
"""Test invoke_tool with auth retry."""
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
provider_entity=mock_provider_entity,
auth_callback=auth_callback,
mcp_service=mock_mcp_service,
)
# Configure new provider with token
new_provider = Mock(spec=MCPProviderEntity)
new_provider.retrieve_tokens.return_value = Mock(
access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None
)
mock_mcp_service.get_provider_entity.return_value = new_provider
expected_result = CallToolResult(content=[TextContent(type="text", text="Success")], isError=False)
# Mock parent class invoke_tool to raise auth error first, then succeed
with patch.object(MCPClient, "invoke_tool") as mock_invoke_tool:
mock_invoke_tool.side_effect = [MCPAuthError("Auth failed"), expected_result]
result = client.invoke_tool("test-tool", {"arg": "value"})
assert result == expected_result
assert mock_invoke_tool.call_count == 2
mock_invoke_tool.assert_called_with("test-tool", {"arg": "value"})
auth_callback.assert_called_once()
def test_cleanup(self):
"""Test cleanup method."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
# Mock the parent class cleanup method
with patch.object(MCPClient, "cleanup") as mock_cleanup:
client.cleanup()
mock_cleanup.assert_called_once()
def test_cleanup_no_client(self):
"""Test cleanup when no client exists."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
# Should not raise
client.cleanup()
# Since MCPClientWithAuthRetry inherits from MCPClient,
# it doesn't have a _client attribute. The test should just
# verify that cleanup can be called without error.
assert not hasattr(client, "_client")