feat: add unit test

This commit is contained in:
Novice 2025-09-16 16:18:41 +08:00
parent f137af4ec5
commit e2fd3f2983
16 changed files with 2945 additions and 68 deletions

View File

@ -18,8 +18,8 @@ from controllers.console.wraps import (
setup_required,
)
from core.mcp.auth.auth_flow import auth, handle_callback
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPError
from core.mcp.error import MCPAuthError, MCPError
from core.mcp.mcp_client import MCPClient
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.oauth import OAuthHandler
@ -974,17 +974,11 @@ class ToolMCPAuthApi(Resource):
headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
try:
# Use MCPClientWithAuthRetry to handle authentication automatically
with MCPClientWithAuthRetry(
with MCPClient(
server_url=server_url,
headers=headers,
timeout=provider_entity.timeout,
sse_read_timeout=provider_entity.sse_read_timeout,
provider_entity=provider_entity
if not provider_entity.headers
else None, # Only use auth retry if no custom headers
auth_callback=auth if not provider_entity.headers else None,
authorization_code=args.get("authorization_code"),
mcp_service=service,
):
service.update_provider_credentials(
provider=db_provider,
@ -993,7 +987,8 @@ class ToolMCPAuthApi(Resource):
)
session.commit()
return {"result": "success"}
except MCPAuthError as e:
return auth(provider_entity, service, args.get("authorization_code"))
except MCPError as e:
service.clear_provider_credentials(provider=db_provider)
session.commit()

View File

@ -1,13 +1,12 @@
"""
MCP Client with Authentication Retry Support
This module provides a wrapper around MCPClient that automatically handles
This module provides an enhanced MCPClient that automatically handles
authentication failures and retries operations after refreshing tokens.
"""
import logging
from collections.abc import Callable
from types import TracebackType
from typing import TYPE_CHECKING, Any, Optional
from core.entities.mcp_provider import MCPProviderEntity
@ -21,12 +20,12 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class MCPClientWithAuthRetry:
class MCPClientWithAuthRetry(MCPClient):
"""
A wrapper around MCPClient that provides automatic authentication retry.
An enhanced MCPClient that provides automatic authentication retry.
This class intercepts MCPAuthError exceptions and attempts to refresh
authentication before retrying the failed operation.
This class extends MCPClient and intercepts MCPAuthError exceptions
to refresh authentication before retrying failed operations.
"""
def __init__(
@ -53,27 +52,17 @@ class MCPClientWithAuthRetry:
provider_entity: Provider entity for authentication
auth_callback: Authentication callback function
authorization_code: Optional authorization code for initial auth
by_server_id: Whether to look up provider by server ID
mcp_service: MCP service instance
"""
self.server_url = server_url
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
super().__init__(server_url, headers, timeout, sse_read_timeout)
self.provider_entity = provider_entity
self.auth_callback = auth_callback
self.authorization_code = authorization_code
self._has_retried = False
self._client: MCPClient | None = None
self.by_server_id = by_server_id
self.mcp_service = mcp_service
def _create_client(self) -> MCPClient:
"""Create a new MCPClient instance with current headers."""
return MCPClient(
server_url=self.server_url,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
)
self._has_retried = False
def _handle_auth_error(self, error: MCPAuthError) -> None:
"""
@ -134,38 +123,35 @@ class MCPClientWithAuthRetry:
return func(*args, **kwargs)
except MCPAuthError as e:
self._handle_auth_error(e)
# Recreate client with new headers
if self._client:
self._client.cleanup()
self._client = self._create_client()
self._client.__enter__()
# Re-initialize the connection with new headers
if self._initialized:
# Clean up existing connection
self._exit_stack.close()
self._session = None
self._initialized = False
# Re-initialize with new headers
self._initialize()
self._initialized = True
return func(*args, **kwargs)
finally:
# Reset retry flag after operation completes
self._has_retried = False
def __enter__(self):
"""Enter the context manager."""
self._client = self._create_client()
"""Enter the context manager with retry support."""
# Try to initialize with retry
def initialize():
if self._client is None:
raise ValueError("Client not created")
self._client.__enter__()
def initialize_with_retry():
super(MCPClientWithAuthRetry, self).__enter__()
return self
return self._execute_with_retry(initialize)
def __exit__(self, exc_type: type | None, exc_value: BaseException | None, traceback: TracebackType | None):
"""Exit the context manager."""
if self._client:
self._client.__exit__(exc_type, exc_value, traceback)
self._client = None
return self._execute_with_retry(initialize_with_retry)
def list_tools(self) -> list[Tool]:
"""
List available tools from the MCP server.
List available tools from the MCP server with auth retry.
Returns:
List of available tools
@ -173,13 +159,11 @@ class MCPClientWithAuthRetry:
Raises:
MCPAuthError: If authentication fails after retries
"""
if not self._client:
raise ValueError("Client not initialized. Use within a context manager.")
return self._execute_with_retry(self._client.list_tools)
return self._execute_with_retry(super().list_tools)
def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult:
"""
Invoke a tool on the MCP server.
Invoke a tool on the MCP server with auth retry.
Args:
tool_name: Name of the tool to invoke
@ -191,12 +175,4 @@ class MCPClientWithAuthRetry:
Raises:
MCPAuthError: If authentication fails after retries
"""
if not self._client:
raise ValueError("Client not initialized. Use within a context manager.")
return self._execute_with_retry(self._client.invoke_tool, tool_name, tool_args)
def cleanup(self):
"""Clean up resources."""
if self._client:
self._client.cleanup()
self._client = None
return self._execute_with_retry(super().invoke_tool, tool_name, tool_args)

View File

@ -0,0 +1 @@

View File

@ -46,7 +46,7 @@ class SSETransport:
url: str,
headers: dict[str, Any] | None = None,
timeout: float = 5.0,
sse_read_timeout: float = 5 * 60,
sse_read_timeout: float = 1 * 60,
):
"""Initialize the SSE transport.
@ -255,7 +255,7 @@ def sse_client(
url: str,
headers: dict[str, Any] | None = None,
timeout: float = 5.0,
sse_read_timeout: float = 5 * 60,
sse_read_timeout: float = 1 * 60,
) -> Generator[tuple[ReadQueue, WriteQueue], None, None]:
"""
Client transport for SSE.

View File

@ -30,7 +30,7 @@ DEFAULT_NEGOTIATED_VERSION = "2025-03-26"
ProgressToken = str | int
Cursor = str
Role = Literal["user", "assistant"]
RequestId = Annotated[int, Field(strict=True)] | str
RequestId = Annotated[int | str, Field(union_mode="left_to_right")]
AnyFunction: TypeAlias = Callable[..., Any]

View File

@ -162,7 +162,6 @@ class MCPTool(Tool):
sse_read_timeout=self.sse_read_timeout,
provider_entity=provider_entity,
auth_callback=auth if mcp_service else None,
by_server_id=True,
mcp_service=mcp_service,
) as mcp_client:
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)

View File

@ -0,0 +1,710 @@
"""Unit tests for MCP OAuth authentication flow."""
from unittest.mock import Mock, patch
import pytest
from core.entities.mcp_provider import MCPProviderEntity
from core.mcp.auth.auth_flow import (
OAUTH_STATE_EXPIRY_SECONDS,
OAUTH_STATE_REDIS_KEY_PREFIX,
OAuthCallbackState,
_create_secure_redis_state,
_retrieve_redis_state,
auth,
check_support_resource_discovery,
discover_oauth_metadata,
exchange_authorization,
generate_pkce_challenge,
handle_callback,
refresh_authorization,
register_client,
start_authorization,
)
from core.mcp.types import (
OAuthClientInformation,
OAuthClientInformationFull,
OAuthClientMetadata,
OAuthMetadata,
OAuthTokens,
)
class TestPKCEGeneration:
"""Test PKCE challenge generation."""
def test_generate_pkce_challenge(self):
"""Test PKCE challenge and verifier generation."""
code_verifier, code_challenge = generate_pkce_challenge()
# Verify format - should be URL-safe base64 without padding
assert "=" not in code_verifier
assert "+" not in code_verifier
assert "/" not in code_verifier
assert "=" not in code_challenge
assert "+" not in code_challenge
assert "/" not in code_challenge
# Verify length
assert len(code_verifier) > 40 # Should be around 54 characters
assert len(code_challenge) > 40 # Should be around 43 characters
def test_generate_pkce_challenge_uniqueness(self):
"""Test that PKCE generation produces unique values."""
results = set()
for _ in range(10):
code_verifier, code_challenge = generate_pkce_challenge()
results.add((code_verifier, code_challenge))
# All should be unique
assert len(results) == 10
class TestRedisStateManagement:
"""Test Redis state management functions."""
@patch("core.mcp.auth.auth_flow.redis_client")
def test_create_secure_redis_state(self, mock_redis):
"""Test creating secure Redis state."""
state_data = OAuthCallbackState(
provider_id="test-provider",
tenant_id="test-tenant",
server_url="https://example.com",
metadata=None,
client_information=OAuthClientInformation(client_id="test-client"),
code_verifier="test-verifier",
redirect_uri="https://redirect.example.com",
)
state_key = _create_secure_redis_state(state_data)
# Verify state key format
assert len(state_key) > 20 # Should be a secure random token
# Verify Redis call
mock_redis.setex.assert_called_once()
call_args = mock_redis.setex.call_args
assert call_args[0][0].startswith(OAUTH_STATE_REDIS_KEY_PREFIX)
assert call_args[0][1] == OAUTH_STATE_EXPIRY_SECONDS
assert state_data.model_dump_json() in call_args[0][2]
@patch("core.mcp.auth.auth_flow.redis_client")
def test_retrieve_redis_state_success(self, mock_redis):
"""Test retrieving state from Redis."""
state_data = OAuthCallbackState(
provider_id="test-provider",
tenant_id="test-tenant",
server_url="https://example.com",
metadata=None,
client_information=OAuthClientInformation(client_id="test-client"),
code_verifier="test-verifier",
redirect_uri="https://redirect.example.com",
)
mock_redis.get.return_value = state_data.model_dump_json()
result = _retrieve_redis_state("test-state-key")
# Verify result
assert result.provider_id == "test-provider"
assert result.tenant_id == "test-tenant"
assert result.server_url == "https://example.com"
# Verify Redis calls
mock_redis.get.assert_called_once_with(f"{OAUTH_STATE_REDIS_KEY_PREFIX}test-state-key")
mock_redis.delete.assert_called_once_with(f"{OAUTH_STATE_REDIS_KEY_PREFIX}test-state-key")
@patch("core.mcp.auth.auth_flow.redis_client")
def test_retrieve_redis_state_not_found(self, mock_redis):
"""Test retrieving non-existent state from Redis."""
mock_redis.get.return_value = None
with pytest.raises(ValueError) as exc_info:
_retrieve_redis_state("nonexistent-key")
assert "State parameter has expired or does not exist" in str(exc_info.value)
@patch("core.mcp.auth.auth_flow.redis_client")
def test_retrieve_redis_state_invalid_json(self, mock_redis):
"""Test retrieving invalid JSON state from Redis."""
mock_redis.get.return_value = '{"invalid": json}'
with pytest.raises(ValueError) as exc_info:
_retrieve_redis_state("test-key")
assert "Invalid state parameter" in str(exc_info.value)
# State should still be deleted
mock_redis.delete.assert_called_once()
class TestOAuthDiscovery:
"""Test OAuth discovery functions."""
@patch("httpx.get")
def test_check_support_resource_discovery_success(self, mock_get):
"""Test successful resource discovery check."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"authorization_server_url": ["https://auth.example.com"]}
mock_get.return_value = mock_response
supported, auth_url = check_support_resource_discovery("https://api.example.com/endpoint")
assert supported is True
assert auth_url == "https://auth.example.com"
mock_get.assert_called_once_with(
"https://api.example.com/.well-known/oauth-protected-resource/endpoint",
headers={"MCP-Protocol-Version": "2025-03-26", "User-Agent": "Dify"},
)
@patch("httpx.get")
def test_check_support_resource_discovery_not_supported(self, mock_get):
"""Test resource discovery not supported."""
mock_response = Mock()
mock_response.status_code = 404
mock_get.return_value = mock_response
supported, auth_url = check_support_resource_discovery("https://api.example.com")
assert supported is False
assert auth_url == ""
@patch("httpx.get")
def test_check_support_resource_discovery_with_query_fragment(self, mock_get):
"""Test resource discovery with query and fragment."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"authorization_server_url": ["https://auth.example.com"]}
mock_get.return_value = mock_response
supported, auth_url = check_support_resource_discovery("https://api.example.com/path?query=1#fragment")
assert supported is True
assert auth_url == "https://auth.example.com"
mock_get.assert_called_once_with(
"https://api.example.com/.well-known/oauth-protected-resource/path?query=1#fragment",
headers={"MCP-Protocol-Version": "2025-03-26", "User-Agent": "Dify"},
)
@patch("httpx.get")
def test_discover_oauth_metadata_with_resource_discovery(self, mock_get):
"""Test OAuth metadata discovery with resource discovery support."""
with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check:
mock_check.return_value = (True, "https://auth.example.com/.well-known/oauth-authorization-server")
mock_response = Mock()
mock_response.status_code = 200
mock_response.is_success = True
mock_response.json.return_value = {
"authorization_endpoint": "https://auth.example.com/authorize",
"token_endpoint": "https://auth.example.com/token",
"response_types_supported": ["code"],
}
mock_get.return_value = mock_response
metadata = discover_oauth_metadata("https://api.example.com")
assert metadata is not None
assert metadata.authorization_endpoint == "https://auth.example.com/authorize"
assert metadata.token_endpoint == "https://auth.example.com/token"
mock_get.assert_called_once_with(
"https://auth.example.com/.well-known/oauth-authorization-server",
headers={"MCP-Protocol-Version": "2025-03-26"},
)
@patch("httpx.get")
def test_discover_oauth_metadata_without_resource_discovery(self, mock_get):
"""Test OAuth metadata discovery without resource discovery."""
with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check:
mock_check.return_value = (False, "")
mock_response = Mock()
mock_response.status_code = 200
mock_response.is_success = True
mock_response.json.return_value = {
"authorization_endpoint": "https://api.example.com/oauth/authorize",
"token_endpoint": "https://api.example.com/oauth/token",
"response_types_supported": ["code"],
}
mock_get.return_value = mock_response
metadata = discover_oauth_metadata("https://api.example.com")
assert metadata is not None
assert metadata.authorization_endpoint == "https://api.example.com/oauth/authorize"
mock_get.assert_called_once_with(
"https://api.example.com/.well-known/oauth-authorization-server",
headers={"MCP-Protocol-Version": "2025-03-26"},
)
@patch("httpx.get")
def test_discover_oauth_metadata_not_found(self, mock_get):
"""Test OAuth metadata discovery when not found."""
with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check:
mock_check.return_value = (False, "")
mock_response = Mock()
mock_response.status_code = 404
mock_get.return_value = mock_response
metadata = discover_oauth_metadata("https://api.example.com")
assert metadata is None
class TestAuthorizationFlow:
"""Test authorization flow functions."""
@patch("core.mcp.auth.auth_flow._create_secure_redis_state")
def test_start_authorization_with_metadata(self, mock_create_state):
"""Test starting authorization with metadata."""
mock_create_state.return_value = "secure-state-key"
metadata = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
code_challenge_methods_supported=["S256"],
)
client_info = OAuthClientInformation(client_id="test-client-id")
auth_url, code_verifier = start_authorization(
"https://api.example.com",
metadata,
client_info,
"https://redirect.example.com",
"provider-id",
"tenant-id",
)
# Verify URL format
assert auth_url.startswith("https://auth.example.com/authorize?")
assert "response_type=code" in auth_url
assert "client_id=test-client-id" in auth_url
assert "code_challenge=" in auth_url
assert "code_challenge_method=S256" in auth_url
assert "redirect_uri=https%3A%2F%2Fredirect.example.com" in auth_url
assert "state=secure-state-key" in auth_url
# Verify code verifier
assert len(code_verifier) > 40
# Verify state was stored
mock_create_state.assert_called_once()
state_data = mock_create_state.call_args[0][0]
assert state_data.provider_id == "provider-id"
assert state_data.tenant_id == "tenant-id"
assert state_data.code_verifier == code_verifier
def test_start_authorization_without_metadata(self):
"""Test starting authorization without metadata."""
with patch("core.mcp.auth.auth_flow._create_secure_redis_state") as mock_create_state:
mock_create_state.return_value = "secure-state-key"
client_info = OAuthClientInformation(client_id="test-client-id")
auth_url, code_verifier = start_authorization(
"https://api.example.com",
None,
client_info,
"https://redirect.example.com",
"provider-id",
"tenant-id",
)
# Should use default authorization endpoint
assert auth_url.startswith("https://api.example.com/authorize?")
def test_start_authorization_invalid_metadata(self):
"""Test starting authorization with invalid metadata."""
metadata = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["token"], # No "code" support
code_challenge_methods_supported=["plain"], # No "S256" support
)
client_info = OAuthClientInformation(client_id="test-client-id")
with pytest.raises(ValueError) as exc_info:
start_authorization(
"https://api.example.com",
metadata,
client_info,
"https://redirect.example.com",
"provider-id",
"tenant-id",
)
assert "does not support response type code" in str(exc_info.value)
@patch("httpx.post")
def test_exchange_authorization_success(self, mock_post):
"""Test successful authorization code exchange."""
mock_response = Mock()
mock_response.is_success = True
mock_response.json.return_value = {
"access_token": "new-access-token",
"token_type": "Bearer",
"expires_in": 3600,
"refresh_token": "new-refresh-token",
}
mock_post.return_value = mock_response
metadata = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
)
client_info = OAuthClientInformation(client_id="test-client-id", client_secret="test-secret")
tokens = exchange_authorization(
"https://api.example.com",
metadata,
client_info,
"auth-code-123",
"code-verifier-xyz",
"https://redirect.example.com",
)
assert tokens.access_token == "new-access-token"
assert tokens.token_type == "Bearer"
assert tokens.expires_in == 3600
assert tokens.refresh_token == "new-refresh-token"
# Verify request
mock_post.assert_called_once_with(
"https://auth.example.com/token",
data={
"grant_type": "authorization_code",
"client_id": "test-client-id",
"client_secret": "test-secret",
"code": "auth-code-123",
"code_verifier": "code-verifier-xyz",
"redirect_uri": "https://redirect.example.com",
},
)
@patch("httpx.post")
def test_exchange_authorization_failure(self, mock_post):
"""Test failed authorization code exchange."""
mock_response = Mock()
mock_response.is_success = False
mock_response.status_code = 400
mock_post.return_value = mock_response
client_info = OAuthClientInformation(client_id="test-client-id")
with pytest.raises(ValueError) as exc_info:
exchange_authorization(
"https://api.example.com",
None,
client_info,
"invalid-code",
"code-verifier",
"https://redirect.example.com",
)
assert "Token exchange failed: HTTP 400" in str(exc_info.value)
@patch("httpx.post")
def test_refresh_authorization_success(self, mock_post):
"""Test successful token refresh."""
mock_response = Mock()
mock_response.is_success = True
mock_response.json.return_value = {
"access_token": "refreshed-access-token",
"token_type": "Bearer",
"expires_in": 3600,
"refresh_token": "new-refresh-token",
}
mock_post.return_value = mock_response
metadata = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["refresh_token"],
)
client_info = OAuthClientInformation(client_id="test-client-id")
tokens = refresh_authorization("https://api.example.com", metadata, client_info, "old-refresh-token")
assert tokens.access_token == "refreshed-access-token"
assert tokens.refresh_token == "new-refresh-token"
# Verify request
mock_post.assert_called_once_with(
"https://auth.example.com/token",
data={
"grant_type": "refresh_token",
"client_id": "test-client-id",
"refresh_token": "old-refresh-token",
},
)
@patch("httpx.post")
def test_register_client_success(self, mock_post):
"""Test successful client registration."""
mock_response = Mock()
mock_response.is_success = True
mock_response.json.return_value = {
"client_id": "new-client-id",
"client_secret": "new-client-secret",
"client_name": "Dify",
"redirect_uris": ["https://redirect.example.com"],
}
mock_post.return_value = mock_response
metadata = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
registration_endpoint="https://auth.example.com/register",
response_types_supported=["code"],
)
client_metadata = OAuthClientMetadata(
client_name="Dify",
redirect_uris=["https://redirect.example.com"],
grant_types=["authorization_code"],
response_types=["code"],
)
client_info = register_client("https://api.example.com", metadata, client_metadata)
assert isinstance(client_info, OAuthClientInformationFull)
assert client_info.client_id == "new-client-id"
assert client_info.client_secret == "new-client-secret"
# Verify request
mock_post.assert_called_once_with(
"https://auth.example.com/register",
json=client_metadata.model_dump(),
headers={"Content-Type": "application/json"},
)
def test_register_client_no_endpoint(self):
"""Test client registration when no endpoint available."""
metadata = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
registration_endpoint=None,
response_types_supported=["code"],
)
client_metadata = OAuthClientMetadata(client_name="Dify", redirect_uris=["https://redirect.example.com"])
with pytest.raises(ValueError) as exc_info:
register_client("https://api.example.com", metadata, client_metadata)
assert "does not support dynamic client registration" in str(exc_info.value)
class TestCallbackHandling:
"""Test OAuth callback handling."""
@patch("core.mcp.auth.auth_flow._retrieve_redis_state")
@patch("core.mcp.auth.auth_flow.exchange_authorization")
def test_handle_callback_success(self, mock_exchange, mock_retrieve_state):
"""Test successful callback handling."""
# Setup state
state_data = OAuthCallbackState(
provider_id="test-provider",
tenant_id="test-tenant",
server_url="https://api.example.com",
metadata=None,
client_information=OAuthClientInformation(client_id="test-client"),
code_verifier="test-verifier",
redirect_uri="https://redirect.example.com",
)
mock_retrieve_state.return_value = state_data
# Setup token exchange
tokens = OAuthTokens(
access_token="new-token",
token_type="Bearer",
expires_in=3600,
)
mock_exchange.return_value = tokens
# Setup service
mock_service = Mock()
result = handle_callback("state-key", "auth-code", mock_service)
assert result == state_data
# Verify calls
mock_retrieve_state.assert_called_once_with("state-key")
mock_exchange.assert_called_once_with(
"https://api.example.com",
None,
state_data.client_information,
"auth-code",
"test-verifier",
"https://redirect.example.com",
)
mock_service.save_oauth_data.assert_called_once_with(
"test-provider", "test-tenant", tokens.model_dump(), "tokens"
)
class TestAuthOrchestration:
"""Test the main auth orchestration function."""
@pytest.fixture
def mock_provider(self):
"""Create a mock provider entity."""
provider = Mock(spec=MCPProviderEntity)
provider.id = "provider-id"
provider.tenant_id = "tenant-id"
provider.decrypt_server_url.return_value = "https://api.example.com"
provider.client_metadata = OAuthClientMetadata(
client_name="Dify",
redirect_uris=["https://redirect.example.com"],
)
provider.redirect_url = "https://redirect.example.com"
provider.retrieve_client_information.return_value = None
provider.retrieve_tokens.return_value = None
return provider
@pytest.fixture
def mock_service(self):
"""Create a mock MCP service."""
return Mock()
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
@patch("core.mcp.auth.auth_flow.register_client")
@patch("core.mcp.auth.auth_flow.start_authorization")
def test_auth_new_registration(self, mock_start_auth, mock_register, mock_discover, mock_provider, mock_service):
"""Test auth flow for new client registration."""
# Setup
mock_discover.return_value = None
mock_register.return_value = OAuthClientInformationFull(
client_id="new-client-id",
client_name="Dify",
redirect_uris=["https://redirect.example.com"],
)
mock_start_auth.return_value = ("https://auth.example.com/authorize?...", "code-verifier")
result = auth(mock_provider, mock_service)
assert result == {"authorization_url": "https://auth.example.com/authorize?..."}
# Verify calls
mock_register.assert_called_once()
mock_service.save_oauth_data.assert_any_call(
"provider-id",
"tenant-id",
{"client_information": mock_register.return_value.model_dump()},
"client_info",
)
mock_service.save_oauth_data.assert_any_call(
"provider-id", "tenant-id", {"code_verifier": "code-verifier"}, "code_verifier"
)
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
@patch("core.mcp.auth.auth_flow._retrieve_redis_state")
@patch("core.mcp.auth.auth_flow.exchange_authorization")
def test_auth_exchange_code(self, mock_exchange, mock_retrieve_state, mock_discover, mock_provider, mock_service):
"""Test auth flow for exchanging authorization code."""
# Setup metadata discovery
mock_discover.return_value = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
)
# Setup existing client
mock_provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="existing-client")
# Setup state retrieval
state_data = OAuthCallbackState(
provider_id="provider-id",
tenant_id="tenant-id",
server_url="https://api.example.com",
metadata=None,
client_information=OAuthClientInformation(client_id="existing-client"),
code_verifier="test-verifier",
redirect_uri="https://redirect.example.com",
)
mock_retrieve_state.return_value = state_data
# Setup token exchange
tokens = OAuthTokens(access_token="new-token", token_type="Bearer", expires_in=3600)
mock_exchange.return_value = tokens
result = auth(mock_provider, mock_service, authorization_code="auth-code", state_param="state-key")
assert result == {"result": "success"}
# Verify token save
mock_service.save_oauth_data.assert_called_with("provider-id", "tenant-id", tokens.model_dump(), "tokens")
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
def test_auth_exchange_code_without_state(self, mock_discover, mock_provider, mock_service):
"""Test auth flow fails when exchanging code without state."""
# Setup metadata discovery
mock_discover.return_value = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
)
mock_provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="existing-client")
with pytest.raises(ValueError) as exc_info:
auth(mock_provider, mock_service, authorization_code="auth-code")
assert "State parameter is required" in str(exc_info.value)
@patch("core.mcp.auth.auth_flow.refresh_authorization")
def test_auth_refresh_token(self, mock_refresh, mock_provider, mock_service):
"""Test auth flow for refreshing tokens."""
# Setup existing client and tokens
mock_provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="existing-client")
mock_provider.retrieve_tokens.return_value = OAuthTokens(
access_token="old-token",
token_type="Bearer",
expires_in=0,
refresh_token="refresh-token",
)
# Setup refresh
new_tokens = OAuthTokens(
access_token="refreshed-token",
token_type="Bearer",
expires_in=3600,
refresh_token="new-refresh-token",
)
mock_refresh.return_value = new_tokens
with patch("core.mcp.auth.auth_flow.discover_oauth_metadata") as mock_discover:
mock_discover.return_value = None
result = auth(mock_provider, mock_service)
assert result == {"result": "success"}
# Verify refresh was called
mock_refresh.assert_called_once()
mock_service.save_oauth_data.assert_called_with(
"provider-id", "tenant-id", new_tokens.model_dump(), "tokens"
)
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
def test_auth_registration_fails_with_code(self, mock_discover, mock_provider, mock_service):
"""Test auth fails when no client info exists but code is provided."""
# Setup metadata discovery
mock_discover.return_value = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
)
mock_provider.retrieve_client_information.return_value = None
with pytest.raises(ValueError) as exc_info:
auth(mock_provider, mock_service, authorization_code="auth-code")
assert "Existing OAuth client information is required" in str(exc_info.value)

View File

@ -0,0 +1,523 @@
"""Unit tests for MCP auth client with retry logic."""
from types import TracebackType
from unittest.mock import MagicMock, Mock, patch
import pytest
from core.entities.mcp_provider import MCPProviderEntity
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPAuthError
from core.mcp.types import CallToolResult, TextContent, Tool, ToolAnnotations
class TestMCPClientWithAuthRetry:
"""Test suite for MCPClientWithAuthRetry."""
@pytest.fixture
def mock_provider_entity(self):
"""Create a mock provider entity."""
provider = Mock(spec=MCPProviderEntity)
provider.id = "test-provider-id"
provider.tenant_id = "test-tenant-id"
provider.retrieve_tokens.return_value = Mock(
access_token="test-token", token_type="Bearer", expires_in=3600, refresh_token=None
)
return provider
@pytest.fixture
def mock_mcp_service(self):
"""Create a mock MCP service."""
service = Mock()
service.get_provider_entity.return_value = Mock(
retrieve_tokens=lambda: Mock(
access_token="new-test-token", token_type="Bearer", expires_in=3600, refresh_token=None
)
)
return service
@pytest.fixture
def auth_callback(self):
"""Create a mock auth callback."""
return Mock()
def test_init(self, mock_provider_entity, mock_mcp_service, auth_callback):
"""Test client initialization."""
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
headers={"Authorization": "Bearer test"},
timeout=30.0,
sse_read_timeout=60.0,
provider_entity=mock_provider_entity,
auth_callback=auth_callback,
authorization_code="test-auth-code",
by_server_id=True,
mcp_service=mock_mcp_service,
)
assert client.server_url == "http://test.example.com"
assert client.headers == {"Authorization": "Bearer test"}
assert client.timeout == 30.0
assert client.sse_read_timeout == 60.0
assert client.provider_entity == mock_provider_entity
assert client.auth_callback == auth_callback
assert client.authorization_code == "test-auth-code"
assert client.by_server_id is True
assert client.mcp_service == mock_mcp_service
assert client._has_retried is False
# In inheritance design, we don't have _client attribute
assert hasattr(client, "_session") # Inherited from MCPClient
def test_inheritance_structure(self):
"""Test that MCPClientWithAuthRetry properly inherits from MCPClient."""
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
headers={"Authorization": "Bearer test"},
)
# Verify inheritance
assert isinstance(client, MCPClient)
# Verify inherited attributes are accessible
assert hasattr(client, "server_url")
assert hasattr(client, "headers")
assert hasattr(client, "_session")
assert hasattr(client, "_exit_stack")
assert hasattr(client, "_initialized")
def test_handle_auth_error_no_retry_components(self):
"""Test auth error handling when retry components are missing."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
error = MCPAuthError("Auth failed")
with pytest.raises(MCPAuthError) as exc_info:
client._handle_auth_error(error)
assert exc_info.value == error
def test_handle_auth_error_already_retried(self, mock_provider_entity, mock_mcp_service, auth_callback):
"""Test auth error handling when already retried."""
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
provider_entity=mock_provider_entity,
auth_callback=auth_callback,
mcp_service=mock_mcp_service,
)
client._has_retried = True
error = MCPAuthError("Auth failed")
with pytest.raises(MCPAuthError) as exc_info:
client._handle_auth_error(error)
assert exc_info.value == error
auth_callback.assert_not_called()
def test_handle_auth_error_successful_refresh(self, mock_provider_entity, mock_mcp_service, auth_callback):
"""Test successful auth refresh on error."""
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
provider_entity=mock_provider_entity,
auth_callback=auth_callback,
authorization_code="test-code",
by_server_id=True,
mcp_service=mock_mcp_service,
)
# Configure mocks
new_provider = Mock(spec=MCPProviderEntity)
new_provider.id = "test-provider-id"
new_provider.tenant_id = "test-tenant-id"
new_provider.retrieve_tokens.return_value = Mock(
access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None
)
mock_mcp_service.get_provider_entity.return_value = new_provider
error = MCPAuthError("Auth failed")
client._handle_auth_error(error)
# Verify auth flow
auth_callback.assert_called_once_with(mock_provider_entity, mock_mcp_service, "test-code")
mock_mcp_service.get_provider_entity.assert_called_once_with(
"test-provider-id", "test-tenant-id", by_server_id=True
)
assert client.headers["Authorization"] == "Bearer new-token"
assert client.authorization_code is None # Should be cleared after use
assert client._has_retried is True
def test_handle_auth_error_refresh_fails(self, mock_provider_entity, mock_mcp_service, auth_callback):
"""Test auth refresh failure."""
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
provider_entity=mock_provider_entity,
auth_callback=auth_callback,
mcp_service=mock_mcp_service,
)
auth_callback.side_effect = Exception("Auth callback failed")
error = MCPAuthError("Original auth failed")
with pytest.raises(MCPAuthError) as exc_info:
client._handle_auth_error(error)
assert "Authentication retry failed" in str(exc_info.value)
def test_handle_auth_error_no_token_received(self, mock_provider_entity, mock_mcp_service, auth_callback):
"""Test auth refresh when no token is received."""
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
provider_entity=mock_provider_entity,
auth_callback=auth_callback,
mcp_service=mock_mcp_service,
)
# Configure mock to return no token
new_provider = Mock(spec=MCPProviderEntity)
new_provider.retrieve_tokens.return_value = None
mock_mcp_service.get_provider_entity.return_value = new_provider
error = MCPAuthError("Auth failed")
with pytest.raises(MCPAuthError) as exc_info:
client._handle_auth_error(error)
assert "no token received" in str(exc_info.value)
def test_execute_with_retry_success(self):
"""Test successful execution without retry."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
mock_func = Mock(return_value="success")
result = client._execute_with_retry(mock_func, "arg1", kwarg1="value1")
assert result == "success"
mock_func.assert_called_once_with("arg1", kwarg1="value1")
assert client._has_retried is False
@patch("core.mcp.auth_client.MCPClient")
def test_execute_with_retry_auth_error_then_success(
self, mock_mcp_client_class, mock_provider_entity, mock_mcp_service, auth_callback
):
"""Test execution with auth error followed by successful retry."""
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
provider_entity=mock_provider_entity,
auth_callback=auth_callback,
mcp_service=mock_mcp_service,
)
# Configure mock clients (old and new)
mock_client_old = MagicMock()
mock_client_new = MagicMock()
client._client = mock_client_old
# Make _create_client return the new client on retry
mock_mcp_client_class.return_value = mock_client_new
# Configure new provider with token
new_provider = Mock(spec=MCPProviderEntity)
new_provider.retrieve_tokens.return_value = Mock(
access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None
)
mock_mcp_service.get_provider_entity.return_value = new_provider
# Mock function that fails first, then succeeds
mock_func = Mock(side_effect=[MCPAuthError("Auth failed"), "success"])
result = client._execute_with_retry(mock_func, "arg1", kwarg1="value1")
assert result == "success"
assert mock_func.call_count == 2
mock_func.assert_called_with("arg1", kwarg1="value1")
auth_callback.assert_called_once()
mock_client_old.cleanup.assert_called_once()
mock_client_new.__enter__.assert_called_once()
assert client._has_retried is False # Reset after completion
def test_execute_with_retry_non_auth_error(self):
"""Test execution with non-auth error (no retry)."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
mock_func = Mock(side_effect=ValueError("Some other error"))
with pytest.raises(ValueError) as exc_info:
client._execute_with_retry(mock_func)
assert str(exc_info.value) == "Some other error"
mock_func.assert_called_once()
@patch("core.mcp.auth_client.MCPClient")
def test_context_manager_enter(self, mock_mcp_client_class):
"""Test context manager enter."""
mock_client_instance = MagicMock()
mock_client_instance.__enter__.return_value = mock_client_instance
mock_mcp_client_class.return_value = mock_client_instance
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
result = client.__enter__()
assert result == client
assert client._client == mock_client_instance
mock_client_instance.__enter__.assert_called_once()
@patch("core.mcp.auth_client.MCPClient")
def test_context_manager_enter_with_auth_error(
self, mock_mcp_client_class, mock_provider_entity, mock_mcp_service, auth_callback
):
"""Test context manager enter with auth error and retry."""
mock_client_instance = MagicMock()
mock_mcp_client_class.return_value = mock_client_instance
# Configure new provider with token
new_provider = Mock(spec=MCPProviderEntity)
new_provider.retrieve_tokens.return_value = Mock(
access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None
)
mock_mcp_service.get_provider_entity.return_value = new_provider
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
provider_entity=mock_provider_entity,
auth_callback=auth_callback,
mcp_service=mock_mcp_service,
)
# First call to client.__enter__ raises auth error, second succeeds
call_count = 0
def enter_side_effect():
nonlocal call_count
call_count += 1
if call_count == 1:
raise MCPAuthError("Auth failed")
return mock_client_instance
mock_client_instance.__enter__.side_effect = enter_side_effect
result = client.__enter__()
assert result == client
assert mock_client_instance.__enter__.call_count == 3
auth_callback.assert_called_once()
def test_context_manager_exit(self):
"""Test context manager exit."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
mock_client = MagicMock()
client._client = mock_client
exc_type: type[BaseException] | None = None
exc_val: BaseException | None = None
exc_tb: TracebackType | None = None
client.__exit__(exc_type, exc_val, exc_tb)
mock_client.__exit__.assert_called_once_with(None, None, None)
assert client._client is None
def test_list_tools_not_initialized(self):
"""Test list_tools when client not initialized."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
with pytest.raises(ValueError) as exc_info:
client.list_tools()
assert "Client not initialized" in str(exc_info.value)
def test_list_tools_success(self):
"""Test successful list_tools call."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
mock_client = Mock()
client._client = mock_client
expected_tools = [
Tool(
name="test-tool",
description="A test tool",
inputSchema={"type": "object", "properties": {}},
annotations=ToolAnnotations(title="Test Tool"),
)
]
mock_client.list_tools.return_value = expected_tools
result = client.list_tools()
assert result == expected_tools
mock_client.list_tools.assert_called_once()
@patch("core.mcp.auth_client.MCPClient")
def test_list_tools_with_auth_retry(
self, mock_mcp_client_class, mock_provider_entity, mock_mcp_service, auth_callback
):
"""Test list_tools with auth retry."""
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
provider_entity=mock_provider_entity,
auth_callback=auth_callback,
mcp_service=mock_mcp_service,
)
# Configure mock clients (old and new)
mock_client_old = MagicMock()
mock_client_new = MagicMock()
client._client = mock_client_old
# Make _create_client return the new client on retry
mock_mcp_client_class.return_value = mock_client_new
# Configure new provider with token
new_provider = Mock(spec=MCPProviderEntity)
new_provider.retrieve_tokens.return_value = Mock(
access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None
)
mock_mcp_service.get_provider_entity.return_value = new_provider
expected_tools = [Tool(name="test-tool", description="A test tool", inputSchema={})]
# First call raises auth error
mock_client_old.list_tools.side_effect = MCPAuthError("Auth failed")
mock_client_new.list_tools.return_value = expected_tools
# We need to mock the behavior where after client is replaced,
# the new method should be called. But since the method reference
# is already bound to the old client, we need to work around this.
# Let's patch the _execute_with_retry to handle this properly.
with patch.object(client, "_execute_with_retry") as mock_execute:
# Simulate the retry behavior
call_count = [0]
def execute_side_effect(func, *args, **kwargs):
call_count[0] += 1
if call_count[0] == 1:
# First call - simulate auth error and retry
try:
func(*args, **kwargs)
except MCPAuthError:
# Simulate the retry logic
client._handle_auth_error(MCPAuthError("Auth failed"))
if client._client:
client._client.cleanup()
client._client = mock_client_new
client._client.__enter__()
# Now return the result from the new client
return mock_client_new.list_tools(*args, **kwargs)
return func(*args, **kwargs)
mock_execute.side_effect = execute_side_effect
result = client.list_tools()
assert result == expected_tools
mock_client_old.list_tools.assert_called_once()
mock_client_new.list_tools.assert_called_once()
auth_callback.assert_called_once()
mock_client_old.cleanup.assert_called_once()
mock_client_new.__enter__.assert_called_once()
def test_invoke_tool_not_initialized(self):
"""Test invoke_tool when client not initialized."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
with pytest.raises(ValueError) as exc_info:
client.invoke_tool("test-tool", {"arg": "value"})
assert "Client not initialized" in str(exc_info.value)
def test_invoke_tool_success(self):
"""Test successful invoke_tool call."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
mock_client = Mock()
client._client = mock_client
expected_result = CallToolResult(
content=[TextContent(type="text", text="Tool executed successfully")], isError=False
)
mock_client.invoke_tool.return_value = expected_result
result = client.invoke_tool("test-tool", {"arg": "value"})
assert result == expected_result
mock_client.invoke_tool.assert_called_once_with("test-tool", {"arg": "value"})
@patch("core.mcp.auth_client.MCPClient")
def test_invoke_tool_with_auth_retry(
self, mock_mcp_client_class, mock_provider_entity, mock_mcp_service, auth_callback
):
"""Test invoke_tool with auth retry."""
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
provider_entity=mock_provider_entity,
auth_callback=auth_callback,
mcp_service=mock_mcp_service,
)
# Configure mock clients (old and new)
mock_client_old = MagicMock()
mock_client_new = MagicMock()
client._client = mock_client_old
# Make _create_client return the new client on retry
mock_mcp_client_class.return_value = mock_client_new
# Configure new provider with token
new_provider = Mock(spec=MCPProviderEntity)
new_provider.retrieve_tokens.return_value = Mock(
access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None
)
mock_mcp_service.get_provider_entity.return_value = new_provider
expected_result = CallToolResult(content=[TextContent(type="text", text="Success")], isError=False)
# First call raises auth error
mock_client_old.invoke_tool.side_effect = MCPAuthError("Auth failed")
mock_client_new.invoke_tool.return_value = expected_result
# We need to mock the behavior where after client is replaced,
# the new method should be called. Similar to list_tools test.
with patch.object(client, "_execute_with_retry") as mock_execute:
# Simulate the retry behavior
call_count = [0]
def execute_side_effect(func, *args, **kwargs):
call_count[0] += 1
if call_count[0] == 1:
# First call - simulate auth error and retry
try:
func(*args, **kwargs)
except MCPAuthError:
# Simulate the retry logic
client._handle_auth_error(MCPAuthError("Auth failed"))
if client._client:
client._client.cleanup()
client._client = mock_client_new
client._client.__enter__()
# Now return the result from the new client
return mock_client_new.invoke_tool(*args, **kwargs)
return func(*args, **kwargs)
mock_execute.side_effect = execute_side_effect
result = client.invoke_tool("test-tool", {"arg": "value"})
assert result == expected_result
mock_client_old.invoke_tool.assert_called_once_with("test-tool", {"arg": "value"})
mock_client_new.invoke_tool.assert_called_once_with("test-tool", {"arg": "value"})
auth_callback.assert_called_once()
mock_client_old.cleanup.assert_called_once()
mock_client_new.__enter__.assert_called_once()
def test_cleanup(self):
"""Test cleanup method."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
mock_client = Mock()
client._client = mock_client
client.cleanup()
mock_client.cleanup.assert_called_once()
assert client._client is None
def test_cleanup_no_client(self):
"""Test cleanup when no client exists."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
# Should not raise
client.cleanup()
assert client._client is None

View File

@ -0,0 +1,239 @@
"""Unit tests for MCP entities module."""
from unittest.mock import Mock
from core.mcp.entities import (
SUPPORTED_PROTOCOL_VERSIONS,
LifespanContextT,
RequestContext,
SessionT,
)
from core.mcp.session.base_session import BaseSession
from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestParams
class TestProtocolVersions:
"""Test protocol version constants."""
def test_supported_protocol_versions(self):
"""Test supported protocol versions list."""
assert isinstance(SUPPORTED_PROTOCOL_VERSIONS, list)
assert len(SUPPORTED_PROTOCOL_VERSIONS) >= 3
assert "2024-11-05" in SUPPORTED_PROTOCOL_VERSIONS
assert "2025-03-26" in SUPPORTED_PROTOCOL_VERSIONS
assert LATEST_PROTOCOL_VERSION in SUPPORTED_PROTOCOL_VERSIONS
def test_latest_protocol_version_is_supported(self):
"""Test that latest protocol version is in supported versions."""
assert LATEST_PROTOCOL_VERSION in SUPPORTED_PROTOCOL_VERSIONS
class TestRequestContext:
"""Test RequestContext dataclass."""
def test_request_context_creation(self):
"""Test creating a RequestContext instance."""
mock_session = Mock(spec=BaseSession)
mock_lifespan = {"key": "value"}
mock_meta = RequestParams.Meta(progressToken="test-token")
context = RequestContext(
request_id="test-request-123",
meta=mock_meta,
session=mock_session,
lifespan_context=mock_lifespan,
)
assert context.request_id == "test-request-123"
assert context.meta == mock_meta
assert context.session == mock_session
assert context.lifespan_context == mock_lifespan
def test_request_context_with_none_meta(self):
"""Test creating RequestContext with None meta."""
mock_session = Mock(spec=BaseSession)
context = RequestContext(
request_id=42, # Can be int or string
meta=None,
session=mock_session,
lifespan_context=None,
)
assert context.request_id == 42
assert context.meta is None
assert context.session == mock_session
assert context.lifespan_context is None
def test_request_context_attributes(self):
"""Test RequestContext attributes are accessible."""
mock_session = Mock(spec=BaseSession)
context = RequestContext(
request_id="test-123",
meta=None,
session=mock_session,
lifespan_context=None,
)
# Verify attributes are accessible
assert hasattr(context, "request_id")
assert hasattr(context, "meta")
assert hasattr(context, "session")
assert hasattr(context, "lifespan_context")
# Verify values
assert context.request_id == "test-123"
assert context.meta is None
assert context.session == mock_session
assert context.lifespan_context is None
def test_request_context_generic_typing(self):
"""Test RequestContext with different generic types."""
# Create a mock session with specific type
mock_session = Mock(spec=BaseSession)
# Create context with string lifespan context
context_str = RequestContext[BaseSession, str](
request_id="test-1",
meta=None,
session=mock_session,
lifespan_context="string-context",
)
assert isinstance(context_str.lifespan_context, str)
# Create context with dict lifespan context
context_dict = RequestContext[BaseSession, dict](
request_id="test-2",
meta=None,
session=mock_session,
lifespan_context={"key": "value"},
)
assert isinstance(context_dict.lifespan_context, dict)
# Create context with custom object lifespan context
class CustomLifespan:
def __init__(self, data):
self.data = data
custom_lifespan = CustomLifespan("test-data")
context_custom = RequestContext[BaseSession, CustomLifespan](
request_id="test-3",
meta=None,
session=mock_session,
lifespan_context=custom_lifespan,
)
assert isinstance(context_custom.lifespan_context, CustomLifespan)
assert context_custom.lifespan_context.data == "test-data"
def test_request_context_with_progress_meta(self):
"""Test RequestContext with progress metadata."""
mock_session = Mock(spec=BaseSession)
progress_meta = RequestParams.Meta(progressToken="progress-123")
context = RequestContext(
request_id="req-456",
meta=progress_meta,
session=mock_session,
lifespan_context=None,
)
assert context.meta is not None
assert context.meta.progressToken == "progress-123"
def test_request_context_equality(self):
"""Test RequestContext equality comparison."""
mock_session1 = Mock(spec=BaseSession)
mock_session2 = Mock(spec=BaseSession)
context1 = RequestContext(
request_id="test-123",
meta=None,
session=mock_session1,
lifespan_context="context",
)
context2 = RequestContext(
request_id="test-123",
meta=None,
session=mock_session1,
lifespan_context="context",
)
context3 = RequestContext(
request_id="test-456",
meta=None,
session=mock_session1,
lifespan_context="context",
)
# Same values should be equal
assert context1 == context2
# Different request_id should not be equal
assert context1 != context3
# Different session should not be equal
context4 = RequestContext(
request_id="test-123",
meta=None,
session=mock_session2,
lifespan_context="context",
)
assert context1 != context4
def test_request_context_repr(self):
"""Test RequestContext string representation."""
mock_session = Mock(spec=BaseSession)
mock_session.__repr__ = Mock(return_value="<MockSession>")
context = RequestContext(
request_id="test-123",
meta=None,
session=mock_session,
lifespan_context={"data": "test"},
)
repr_str = repr(context)
assert "RequestContext" in repr_str
assert "test-123" in repr_str
assert "MockSession" in repr_str
class TestTypeVariables:
"""Test type variables defined in the module."""
def test_session_type_var(self):
"""Test SessionT type variable."""
# Create a custom session class
class CustomSession(BaseSession):
pass
# Use in generic context
def process_session(session: SessionT) -> SessionT:
return session
mock_session = Mock(spec=CustomSession)
result = process_session(mock_session)
assert result == mock_session
def test_lifespan_context_type_var(self):
"""Test LifespanContextT type variable."""
# Use in generic context
def process_lifespan(context: LifespanContextT) -> LifespanContextT:
return context
# Test with different types
str_context = "string-context"
assert process_lifespan(str_context) == str_context
dict_context = {"key": "value"}
assert process_lifespan(dict_context) == dict_context
class CustomContext:
pass
custom_context = CustomContext()
assert process_lifespan(custom_context) == custom_context

View File

@ -0,0 +1,205 @@
"""Unit tests for MCP error classes."""
import pytest
from core.mcp.error import MCPAuthError, MCPConnectionError, MCPError
class TestMCPError:
"""Test MCPError base exception class."""
def test_mcp_error_creation(self):
"""Test creating MCPError instance."""
error = MCPError("Test error message")
assert str(error) == "Test error message"
assert isinstance(error, Exception)
def test_mcp_error_inheritance(self):
"""Test MCPError inherits from Exception."""
error = MCPError()
assert isinstance(error, Exception)
assert type(error).__name__ == "MCPError"
def test_mcp_error_with_empty_message(self):
"""Test MCPError with empty message."""
error = MCPError()
assert str(error) == ""
def test_mcp_error_raise(self):
"""Test raising MCPError."""
with pytest.raises(MCPError) as exc_info:
raise MCPError("Something went wrong")
assert str(exc_info.value) == "Something went wrong"
class TestMCPConnectionError:
"""Test MCPConnectionError exception class."""
def test_mcp_connection_error_creation(self):
"""Test creating MCPConnectionError instance."""
error = MCPConnectionError("Connection failed")
assert str(error) == "Connection failed"
assert isinstance(error, MCPError)
assert isinstance(error, Exception)
def test_mcp_connection_error_inheritance(self):
"""Test MCPConnectionError inheritance chain."""
error = MCPConnectionError()
assert isinstance(error, MCPConnectionError)
assert isinstance(error, MCPError)
assert isinstance(error, Exception)
def test_mcp_connection_error_raise(self):
"""Test raising MCPConnectionError."""
with pytest.raises(MCPConnectionError) as exc_info:
raise MCPConnectionError("Unable to connect to server")
assert str(exc_info.value) == "Unable to connect to server"
def test_mcp_connection_error_catch_as_mcp_error(self):
"""Test catching MCPConnectionError as MCPError."""
with pytest.raises(MCPError) as exc_info:
raise MCPConnectionError("Connection issue")
assert isinstance(exc_info.value, MCPConnectionError)
assert str(exc_info.value) == "Connection issue"
class TestMCPAuthError:
"""Test MCPAuthError exception class."""
def test_mcp_auth_error_creation(self):
"""Test creating MCPAuthError instance."""
error = MCPAuthError("Authentication failed")
assert str(error) == "Authentication failed"
assert isinstance(error, MCPConnectionError)
assert isinstance(error, MCPError)
assert isinstance(error, Exception)
def test_mcp_auth_error_inheritance(self):
"""Test MCPAuthError inheritance chain."""
error = MCPAuthError()
assert isinstance(error, MCPAuthError)
assert isinstance(error, MCPConnectionError)
assert isinstance(error, MCPError)
assert isinstance(error, Exception)
def test_mcp_auth_error_raise(self):
"""Test raising MCPAuthError."""
with pytest.raises(MCPAuthError) as exc_info:
raise MCPAuthError("Invalid credentials")
assert str(exc_info.value) == "Invalid credentials"
def test_mcp_auth_error_catch_hierarchy(self):
"""Test catching MCPAuthError at different levels."""
# Catch as MCPAuthError
with pytest.raises(MCPAuthError) as exc_info:
raise MCPAuthError("Auth specific error")
assert str(exc_info.value) == "Auth specific error"
# Catch as MCPConnectionError
with pytest.raises(MCPConnectionError) as exc_info:
raise MCPAuthError("Auth connection error")
assert isinstance(exc_info.value, MCPAuthError)
assert str(exc_info.value) == "Auth connection error"
# Catch as MCPError
with pytest.raises(MCPError) as exc_info:
raise MCPAuthError("Auth base error")
assert isinstance(exc_info.value, MCPAuthError)
assert str(exc_info.value) == "Auth base error"
class TestErrorHierarchy:
"""Test the complete error hierarchy."""
def test_exception_hierarchy(self):
"""Test the complete exception hierarchy."""
# Create instances
base_error = MCPError("base")
connection_error = MCPConnectionError("connection")
auth_error = MCPAuthError("auth")
# Test type relationships
assert not isinstance(base_error, MCPConnectionError)
assert not isinstance(base_error, MCPAuthError)
assert isinstance(connection_error, MCPError)
assert not isinstance(connection_error, MCPAuthError)
assert isinstance(auth_error, MCPError)
assert isinstance(auth_error, MCPConnectionError)
def test_error_handling_patterns(self):
"""Test common error handling patterns."""
def raise_auth_error():
raise MCPAuthError("401 Unauthorized")
def raise_connection_error():
raise MCPConnectionError("Connection timeout")
def raise_base_error():
raise MCPError("Generic error")
# Pattern 1: Catch specific errors first
errors_caught = []
for error_func in [raise_auth_error, raise_connection_error, raise_base_error]:
try:
error_func()
except MCPAuthError:
errors_caught.append("auth")
except MCPConnectionError:
errors_caught.append("connection")
except MCPError:
errors_caught.append("base")
assert errors_caught == ["auth", "connection", "base"]
# Pattern 2: Catch all as base error
for error_func in [raise_auth_error, raise_connection_error, raise_base_error]:
with pytest.raises(MCPError) as exc_info:
error_func()
assert isinstance(exc_info.value, MCPError)
def test_error_with_cause(self):
"""Test errors with cause (chained exceptions)."""
original_error = ValueError("Original error")
def raise_chained_error():
try:
raise original_error
except ValueError as e:
raise MCPConnectionError("Connection failed") from e
with pytest.raises(MCPConnectionError) as exc_info:
raise_chained_error()
assert str(exc_info.value) == "Connection failed"
assert exc_info.value.__cause__ == original_error
def test_error_comparison(self):
"""Test error instance comparison."""
error1 = MCPError("Test message")
error2 = MCPError("Test message")
error3 = MCPError("Different message")
# Errors are not equal even with same message (different instances)
assert error1 != error2
assert error1 != error3
# But they have the same type
assert type(error1) == type(error2) == type(error3)
def test_error_representation(self):
"""Test error string representation."""
base_error = MCPError("Base error message")
connection_error = MCPConnectionError("Connection error message")
auth_error = MCPAuthError("Auth error message")
assert repr(base_error) == "MCPError('Base error message')"
assert repr(connection_error) == "MCPConnectionError('Connection error message')"
assert repr(auth_error) == "MCPAuthError('Auth error message')"

View File

@ -0,0 +1,382 @@
"""Unit tests for MCP client."""
from contextlib import ExitStack
from types import TracebackType
from unittest.mock import Mock, patch
import pytest
from core.mcp.error import MCPConnectionError
from core.mcp.mcp_client import MCPClient
from core.mcp.types import CallToolResult, ListToolsResult, TextContent, Tool, ToolAnnotations
class TestMCPClient:
"""Test suite for MCPClient."""
def test_init(self):
"""Test client initialization."""
client = MCPClient(
server_url="http://test.example.com/mcp",
headers={"Authorization": "Bearer test"},
timeout=30.0,
sse_read_timeout=60.0,
)
assert client.server_url == "http://test.example.com/mcp"
assert client.headers == {"Authorization": "Bearer test"}
assert client.timeout == 30.0
assert client.sse_read_timeout == 60.0
assert client._session is None
assert isinstance(client._exit_stack, ExitStack)
assert client._initialized is False
def test_init_defaults(self):
"""Test client initialization with defaults."""
client = MCPClient(server_url="http://test.example.com")
assert client.server_url == "http://test.example.com"
assert client.headers == {}
assert client.timeout is None
assert client.sse_read_timeout is None
@patch("core.mcp.mcp_client.streamablehttp_client")
@patch("core.mcp.mcp_client.ClientSession")
def test_initialize_with_mcp_url(self, mock_client_session, mock_streamable_client):
"""Test initialization with MCP URL."""
# Setup mocks
mock_read_stream = Mock()
mock_write_stream = Mock()
mock_client_context = Mock()
mock_streamable_client.return_value.__enter__.return_value = (
mock_read_stream,
mock_write_stream,
mock_client_context,
)
mock_session = Mock()
mock_client_session.return_value.__enter__.return_value = mock_session
client = MCPClient(server_url="http://test.example.com/mcp")
client._initialize()
# Verify streamable client was called
mock_streamable_client.assert_called_once_with(
url="http://test.example.com/mcp",
headers={},
timeout=None,
sse_read_timeout=None,
)
# Verify session was created
mock_client_session.assert_called_once_with(mock_read_stream, mock_write_stream)
mock_session.initialize.assert_called_once()
assert client._session == mock_session
@patch("core.mcp.mcp_client.sse_client")
@patch("core.mcp.mcp_client.ClientSession")
def test_initialize_with_sse_url(self, mock_client_session, mock_sse_client):
"""Test initialization with SSE URL."""
# Setup mocks
mock_read_stream = Mock()
mock_write_stream = Mock()
mock_sse_client.return_value.__enter__.return_value = (mock_read_stream, mock_write_stream)
mock_session = Mock()
mock_client_session.return_value.__enter__.return_value = mock_session
client = MCPClient(server_url="http://test.example.com/sse")
client._initialize()
# Verify SSE client was called
mock_sse_client.assert_called_once_with(
url="http://test.example.com/sse",
headers={},
timeout=None,
sse_read_timeout=None,
)
# Verify session was created
mock_client_session.assert_called_once_with(mock_read_stream, mock_write_stream)
mock_session.initialize.assert_called_once()
assert client._session == mock_session
@patch("core.mcp.mcp_client.sse_client")
@patch("core.mcp.mcp_client.streamablehttp_client")
@patch("core.mcp.mcp_client.ClientSession")
def test_initialize_with_unknown_method_fallback_to_sse(
self, mock_client_session, mock_streamable_client, mock_sse_client
):
"""Test initialization with unknown method falls back to SSE."""
# Setup mocks
mock_read_stream = Mock()
mock_write_stream = Mock()
mock_sse_client.return_value.__enter__.return_value = (mock_read_stream, mock_write_stream)
mock_session = Mock()
mock_client_session.return_value.__enter__.return_value = mock_session
client = MCPClient(server_url="http://test.example.com/unknown")
client._initialize()
# Verify SSE client was tried
mock_sse_client.assert_called_once()
mock_streamable_client.assert_not_called()
# Verify session was created
assert client._session == mock_session
@patch("core.mcp.mcp_client.sse_client")
@patch("core.mcp.mcp_client.streamablehttp_client")
@patch("core.mcp.mcp_client.ClientSession")
def test_initialize_fallback_from_sse_to_mcp(self, mock_client_session, mock_streamable_client, mock_sse_client):
"""Test initialization falls back from SSE to MCP on connection error."""
# Setup SSE to fail
mock_sse_client.side_effect = MCPConnectionError("SSE connection failed")
# Setup MCP to succeed
mock_read_stream = Mock()
mock_write_stream = Mock()
mock_client_context = Mock()
mock_streamable_client.return_value.__enter__.return_value = (
mock_read_stream,
mock_write_stream,
mock_client_context,
)
mock_session = Mock()
mock_client_session.return_value.__enter__.return_value = mock_session
client = MCPClient(server_url="http://test.example.com/unknown")
client._initialize()
# Verify both were tried
mock_sse_client.assert_called_once()
mock_streamable_client.assert_called_once()
# Verify session was created with MCP
assert client._session == mock_session
@patch("core.mcp.mcp_client.streamablehttp_client")
@patch("core.mcp.mcp_client.ClientSession")
def test_connect_server_mcp(self, mock_client_session, mock_streamable_client):
"""Test connect_server with MCP method."""
# Setup mocks
mock_read_stream = Mock()
mock_write_stream = Mock()
mock_client_context = Mock()
mock_streamable_client.return_value.__enter__.return_value = (
mock_read_stream,
mock_write_stream,
mock_client_context,
)
mock_session = Mock()
mock_client_session.return_value.__enter__.return_value = mock_session
client = MCPClient(server_url="http://test.example.com")
client.connect_server(mock_streamable_client, "mcp")
# Verify correct streams were passed
mock_client_session.assert_called_once_with(mock_read_stream, mock_write_stream)
mock_session.initialize.assert_called_once()
@patch("core.mcp.mcp_client.sse_client")
@patch("core.mcp.mcp_client.ClientSession")
def test_connect_server_sse(self, mock_client_session, mock_sse_client):
"""Test connect_server with SSE method."""
# Setup mocks
mock_read_stream = Mock()
mock_write_stream = Mock()
mock_sse_client.return_value.__enter__.return_value = (mock_read_stream, mock_write_stream)
mock_session = Mock()
mock_client_session.return_value.__enter__.return_value = mock_session
client = MCPClient(server_url="http://test.example.com")
client.connect_server(mock_sse_client, "sse")
# Verify correct streams were passed
mock_client_session.assert_called_once_with(mock_read_stream, mock_write_stream)
mock_session.initialize.assert_called_once()
def test_context_manager_enter(self):
"""Test context manager enter."""
client = MCPClient(server_url="http://test.example.com")
with patch.object(client, "_initialize") as mock_initialize:
result = client.__enter__()
assert result == client
assert client._initialized is True
mock_initialize.assert_called_once()
def test_context_manager_exit(self):
"""Test context manager exit."""
client = MCPClient(server_url="http://test.example.com")
with patch.object(client, "cleanup") as mock_cleanup:
exc_type: type[BaseException] | None = None
exc_val: BaseException | None = None
exc_tb: TracebackType | None = None
client.__exit__(exc_type, exc_val, exc_tb)
mock_cleanup.assert_called_once()
def test_list_tools_not_initialized(self):
"""Test list_tools when session not initialized."""
client = MCPClient(server_url="http://test.example.com")
with pytest.raises(ValueError) as exc_info:
client.list_tools()
assert "Session not initialized" in str(exc_info.value)
def test_list_tools_success(self):
"""Test successful list_tools call."""
client = MCPClient(server_url="http://test.example.com")
# Setup mock session
mock_session = Mock()
expected_tools = [
Tool(
name="test-tool",
description="A test tool",
inputSchema={"type": "object", "properties": {}},
annotations=ToolAnnotations(title="Test Tool"),
)
]
mock_session.list_tools.return_value = ListToolsResult(tools=expected_tools)
client._session = mock_session
result = client.list_tools()
assert result == expected_tools
mock_session.list_tools.assert_called_once()
def test_invoke_tool_not_initialized(self):
"""Test invoke_tool when session not initialized."""
client = MCPClient(server_url="http://test.example.com")
with pytest.raises(ValueError) as exc_info:
client.invoke_tool("test-tool", {"arg": "value"})
assert "Session not initialized" in str(exc_info.value)
def test_invoke_tool_success(self):
"""Test successful invoke_tool call."""
client = MCPClient(server_url="http://test.example.com")
# Setup mock session
mock_session = Mock()
expected_result = CallToolResult(
content=[TextContent(type="text", text="Tool executed successfully")],
isError=False,
)
mock_session.call_tool.return_value = expected_result
client._session = mock_session
result = client.invoke_tool("test-tool", {"arg": "value"})
assert result == expected_result
mock_session.call_tool.assert_called_once_with("test-tool", {"arg": "value"})
def test_cleanup(self):
"""Test cleanup method."""
client = MCPClient(server_url="http://test.example.com")
mock_exit_stack = Mock(spec=ExitStack)
client._exit_stack = mock_exit_stack
client._session = Mock()
client._initialized = True
client.cleanup()
mock_exit_stack.close.assert_called_once()
assert client._session is None
assert client._initialized is False
def test_cleanup_with_error(self):
"""Test cleanup method with error."""
client = MCPClient(server_url="http://test.example.com")
mock_exit_stack = Mock(spec=ExitStack)
mock_exit_stack.close.side_effect = Exception("Cleanup error")
client._exit_stack = mock_exit_stack
client._session = Mock()
client._initialized = True
with pytest.raises(ValueError) as exc_info:
client.cleanup()
assert "Error during cleanup: Cleanup error" in str(exc_info.value)
assert client._session is None
assert client._initialized is False
@patch("core.mcp.mcp_client.streamablehttp_client")
@patch("core.mcp.mcp_client.ClientSession")
def test_full_context_manager_flow(self, mock_client_session, mock_streamable_client):
"""Test full context manager flow."""
# Setup mocks
mock_read_stream = Mock()
mock_write_stream = Mock()
mock_client_context = Mock()
mock_streamable_client.return_value.__enter__.return_value = (
mock_read_stream,
mock_write_stream,
mock_client_context,
)
mock_session = Mock()
mock_client_session.return_value.__enter__.return_value = mock_session
expected_tools = [Tool(name="test-tool", description="Test", inputSchema={})]
mock_session.list_tools.return_value = ListToolsResult(tools=expected_tools)
with MCPClient(server_url="http://test.example.com/mcp") as client:
assert client._initialized is True
assert client._session == mock_session
# Test tool operations
tools = client.list_tools()
assert tools == expected_tools
# After exit, should be cleaned up
assert client._initialized is False
assert client._session is None
def test_headers_passed_to_clients(self):
"""Test that headers are properly passed to underlying clients."""
custom_headers = {
"Authorization": "Bearer test-token",
"X-Custom-Header": "test-value",
}
with patch("core.mcp.mcp_client.streamablehttp_client") as mock_streamable_client:
with patch("core.mcp.mcp_client.ClientSession") as mock_client_session:
# Setup mocks
mock_read_stream = Mock()
mock_write_stream = Mock()
mock_client_context = Mock()
mock_streamable_client.return_value.__enter__.return_value = (
mock_read_stream,
mock_write_stream,
mock_client_context,
)
mock_session = Mock()
mock_client_session.return_value.__enter__.return_value = mock_session
client = MCPClient(
server_url="http://test.example.com/mcp",
headers=custom_headers,
timeout=30.0,
sse_read_timeout=60.0,
)
client._initialize()
# Verify headers were passed
mock_streamable_client.assert_called_once_with(
url="http://test.example.com/mcp",
headers=custom_headers,
timeout=30.0,
sse_read_timeout=60.0,
)

View File

@ -0,0 +1,492 @@
"""Unit tests for MCP types module."""
import pytest
from pydantic import ValidationError
from core.mcp.types import (
INTERNAL_ERROR,
INVALID_PARAMS,
INVALID_REQUEST,
LATEST_PROTOCOL_VERSION,
METHOD_NOT_FOUND,
PARSE_ERROR,
SERVER_LATEST_PROTOCOL_VERSION,
Annotations,
CallToolRequest,
CallToolRequestParams,
CallToolResult,
ClientCapabilities,
CompleteRequest,
CompleteRequestParams,
CompleteResult,
Completion,
CompletionArgument,
CompletionContext,
ErrorData,
ImageContent,
Implementation,
InitializeRequest,
InitializeRequestParams,
InitializeResult,
JSONRPCError,
JSONRPCMessage,
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
ListToolsRequest,
ListToolsResult,
OAuthClientInformation,
OAuthClientMetadata,
OAuthMetadata,
OAuthTokens,
PingRequest,
ProgressNotification,
ProgressNotificationParams,
PromptReference,
RequestParams,
ResourceTemplateReference,
Result,
ServerCapabilities,
TextContent,
Tool,
ToolAnnotations,
)
class TestConstants:
"""Test module constants."""
def test_protocol_versions(self):
"""Test protocol version constants."""
assert LATEST_PROTOCOL_VERSION == "2025-03-26"
assert SERVER_LATEST_PROTOCOL_VERSION == "2024-11-05"
def test_error_codes(self):
"""Test JSON-RPC error code constants."""
assert PARSE_ERROR == -32700
assert INVALID_REQUEST == -32600
assert METHOD_NOT_FOUND == -32601
assert INVALID_PARAMS == -32602
assert INTERNAL_ERROR == -32603
class TestRequestParams:
"""Test RequestParams and related classes."""
def test_request_params_basic(self):
"""Test basic RequestParams creation."""
params = RequestParams()
assert params.meta is None
def test_request_params_with_meta(self):
"""Test RequestParams with meta."""
meta = RequestParams.Meta(progressToken="test-token")
params = RequestParams(_meta=meta)
assert params.meta is not None
assert params.meta.progressToken == "test-token"
def test_request_params_meta_extra_fields(self):
"""Test RequestParams.Meta allows extra fields."""
meta = RequestParams.Meta(progressToken="token", customField="value")
assert meta.progressToken == "token"
assert meta.customField == "value" # type: ignore
def test_request_params_serialization(self):
"""Test RequestParams serialization with _meta alias."""
meta = RequestParams.Meta(progressToken="test")
params = RequestParams(_meta=meta)
# Model dump should use the alias
dumped = params.model_dump(by_alias=True)
assert "_meta" in dumped
assert dumped["_meta"] is not None
assert dumped["_meta"]["progressToken"] == "test"
class TestJSONRPCMessages:
"""Test JSON-RPC message types."""
def test_jsonrpc_request(self):
"""Test JSONRPCRequest creation and validation."""
request = JSONRPCRequest(jsonrpc="2.0", id="test-123", method="test_method", params={"key": "value"})
assert request.jsonrpc == "2.0"
assert request.id == "test-123"
assert request.method == "test_method"
assert request.params == {"key": "value"}
def test_jsonrpc_request_numeric_id(self):
"""Test JSONRPCRequest with numeric ID."""
request = JSONRPCRequest(jsonrpc="2.0", id=123, method="test", params=None)
assert request.id == 123
def test_jsonrpc_notification(self):
"""Test JSONRPCNotification creation."""
notification = JSONRPCNotification(jsonrpc="2.0", method="notification_method", params={"data": "test"})
assert notification.jsonrpc == "2.0"
assert notification.method == "notification_method"
assert not hasattr(notification, "id") # Notifications don't have ID
def test_jsonrpc_response(self):
"""Test JSONRPCResponse creation."""
response = JSONRPCResponse(jsonrpc="2.0", id="req-123", result={"success": True})
assert response.jsonrpc == "2.0"
assert response.id == "req-123"
assert response.result == {"success": True}
def test_jsonrpc_error(self):
"""Test JSONRPCError creation."""
error_data = ErrorData(code=INVALID_PARAMS, message="Invalid parameters", data={"field": "missing"})
error = JSONRPCError(jsonrpc="2.0", id="req-123", error=error_data)
assert error.jsonrpc == "2.0"
assert error.id == "req-123"
assert error.error.code == INVALID_PARAMS
assert error.error.message == "Invalid parameters"
assert error.error.data == {"field": "missing"}
def test_jsonrpc_message_parsing(self):
"""Test JSONRPCMessage parsing different message types."""
# Parse request
request_json = '{"jsonrpc": "2.0", "id": 1, "method": "test", "params": null}'
msg = JSONRPCMessage.model_validate_json(request_json)
assert isinstance(msg.root, JSONRPCRequest)
# Parse response
response_json = '{"jsonrpc": "2.0", "id": 1, "result": {"data": "test"}}'
msg = JSONRPCMessage.model_validate_json(response_json)
assert isinstance(msg.root, JSONRPCResponse)
# Parse error
error_json = '{"jsonrpc": "2.0", "id": 1, "error": {"code": -32600, "message": "Invalid Request"}}'
msg = JSONRPCMessage.model_validate_json(error_json)
assert isinstance(msg.root, JSONRPCError)
class TestCapabilities:
"""Test capability classes."""
def test_client_capabilities(self):
"""Test ClientCapabilities creation."""
caps = ClientCapabilities(
experimental={"feature": {"enabled": True}},
sampling={"model_config": {"extra": "allow"}},
roots={"listChanged": True},
)
assert caps.experimental == {"feature": {"enabled": True}}
assert caps.sampling is not None
assert caps.roots.listChanged is True # type: ignore
def test_server_capabilities(self):
"""Test ServerCapabilities creation."""
caps = ServerCapabilities(
tools={"listChanged": True},
resources={"subscribe": True, "listChanged": False},
prompts={"listChanged": True},
logging={},
completions={},
)
assert caps.tools.listChanged is True # type: ignore
assert caps.resources.subscribe is True # type: ignore
assert caps.resources.listChanged is False # type: ignore
class TestInitialization:
"""Test initialization request/response types."""
def test_initialize_request(self):
"""Test InitializeRequest creation."""
client_info = Implementation(name="test-client", version="1.0.0")
capabilities = ClientCapabilities()
params = InitializeRequestParams(
protocolVersion=LATEST_PROTOCOL_VERSION, capabilities=capabilities, clientInfo=client_info
)
request = InitializeRequest(params=params)
assert request.method == "initialize"
assert request.params.protocolVersion == LATEST_PROTOCOL_VERSION
assert request.params.clientInfo.name == "test-client"
def test_initialize_result(self):
"""Test InitializeResult creation."""
server_info = Implementation(name="test-server", version="1.0.0")
capabilities = ServerCapabilities()
result = InitializeResult(
protocolVersion=LATEST_PROTOCOL_VERSION,
capabilities=capabilities,
serverInfo=server_info,
instructions="Welcome to test server",
)
assert result.protocolVersion == LATEST_PROTOCOL_VERSION
assert result.serverInfo.name == "test-server"
assert result.instructions == "Welcome to test server"
class TestTools:
"""Test tool-related types."""
def test_tool_creation(self):
"""Test Tool creation with all fields."""
tool = Tool(
name="test_tool",
title="Test Tool",
description="A tool for testing",
inputSchema={"type": "object", "properties": {"input": {"type": "string"}}, "required": ["input"]},
outputSchema={"type": "object", "properties": {"result": {"type": "string"}}},
annotations=ToolAnnotations(
title="Test Tool", readOnlyHint=False, destructiveHint=False, idempotentHint=True
),
)
assert tool.name == "test_tool"
assert tool.title == "Test Tool"
assert tool.description == "A tool for testing"
assert tool.inputSchema["properties"]["input"]["type"] == "string"
assert tool.annotations.idempotentHint is True
def test_call_tool_request(self):
"""Test CallToolRequest creation."""
params = CallToolRequestParams(name="test_tool", arguments={"input": "test value"})
request = CallToolRequest(params=params)
assert request.method == "tools/call"
assert request.params.name == "test_tool"
assert request.params.arguments == {"input": "test value"}
def test_call_tool_result(self):
"""Test CallToolResult creation."""
result = CallToolResult(
content=[TextContent(type="text", text="Tool executed successfully")],
structuredContent={"status": "success", "data": "test"},
isError=False,
)
assert len(result.content) == 1
assert result.content[0].text == "Tool executed successfully" # type: ignore
assert result.structuredContent == {"status": "success", "data": "test"}
assert result.isError is False
def test_list_tools_request(self):
"""Test ListToolsRequest creation."""
request = ListToolsRequest()
assert request.method == "tools/list"
def test_list_tools_result(self):
"""Test ListToolsResult creation."""
tool1 = Tool(name="tool1", inputSchema={})
tool2 = Tool(name="tool2", inputSchema={})
result = ListToolsResult(tools=[tool1, tool2])
assert len(result.tools) == 2
assert result.tools[0].name == "tool1"
assert result.tools[1].name == "tool2"
class TestContent:
"""Test content types."""
def test_text_content(self):
"""Test TextContent creation."""
annotations = Annotations(audience=["user"], priority=0.8)
content = TextContent(type="text", text="Hello, world!", annotations=annotations)
assert content.type == "text"
assert content.text == "Hello, world!"
assert content.annotations is not None
assert content.annotations.priority == 0.8
def test_image_content(self):
"""Test ImageContent creation."""
content = ImageContent(type="image", data="base64encodeddata", mimeType="image/png")
assert content.type == "image"
assert content.data == "base64encodeddata"
assert content.mimeType == "image/png"
class TestOAuth:
"""Test OAuth-related types."""
def test_oauth_client_metadata(self):
"""Test OAuthClientMetadata creation."""
metadata = OAuthClientMetadata(
client_name="Test Client",
redirect_uris=["https://example.com/callback"],
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
token_endpoint_auth_method="none",
client_uri="https://example.com",
scope="read write",
)
assert metadata.client_name == "Test Client"
assert len(metadata.redirect_uris) == 1
assert "authorization_code" in metadata.grant_types
def test_oauth_client_information(self):
"""Test OAuthClientInformation creation."""
info = OAuthClientInformation(client_id="test-client-id", client_secret="test-secret")
assert info.client_id == "test-client-id"
assert info.client_secret == "test-secret"
def test_oauth_client_information_without_secret(self):
"""Test OAuthClientInformation without secret."""
info = OAuthClientInformation(client_id="public-client")
assert info.client_id == "public-client"
assert info.client_secret is None
def test_oauth_tokens(self):
"""Test OAuthTokens creation."""
tokens = OAuthTokens(
access_token="access-token-123",
token_type="Bearer",
expires_in=3600,
refresh_token="refresh-token-456",
scope="read write",
)
assert tokens.access_token == "access-token-123"
assert tokens.token_type == "Bearer"
assert tokens.expires_in == 3600
assert tokens.refresh_token == "refresh-token-456"
assert tokens.scope == "read write"
def test_oauth_metadata(self):
"""Test OAuthMetadata creation."""
metadata = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
registration_endpoint="https://auth.example.com/register",
response_types_supported=["code", "token"],
grant_types_supported=["authorization_code", "refresh_token"],
code_challenge_methods_supported=["plain", "S256"],
)
assert metadata.authorization_endpoint == "https://auth.example.com/authorize"
assert "code" in metadata.response_types_supported
assert "S256" in metadata.code_challenge_methods_supported
class TestNotifications:
"""Test notification types."""
def test_progress_notification(self):
"""Test ProgressNotification creation."""
params = ProgressNotificationParams(
progressToken="progress-123", progress=50.0, total=100.0, message="Processing... 50%"
)
notification = ProgressNotification(params=params)
assert notification.method == "notifications/progress"
assert notification.params.progressToken == "progress-123"
assert notification.params.progress == 50.0
assert notification.params.total == 100.0
assert notification.params.message == "Processing... 50%"
def test_ping_request(self):
"""Test PingRequest creation."""
request = PingRequest()
assert request.method == "ping"
assert request.params is None
class TestCompletion:
"""Test completion-related types."""
def test_completion_context(self):
"""Test CompletionContext creation."""
context = CompletionContext(arguments={"template_var": "value"})
assert context.arguments == {"template_var": "value"}
def test_resource_template_reference(self):
"""Test ResourceTemplateReference creation."""
ref = ResourceTemplateReference(type="ref/resource", uri="file:///path/to/{filename}")
assert ref.type == "ref/resource"
assert ref.uri == "file:///path/to/{filename}"
def test_prompt_reference(self):
"""Test PromptReference creation."""
ref = PromptReference(type="ref/prompt", name="test_prompt")
assert ref.type == "ref/prompt"
assert ref.name == "test_prompt"
def test_complete_request(self):
"""Test CompleteRequest creation."""
ref = PromptReference(type="ref/prompt", name="test_prompt")
arg = CompletionArgument(name="arg1", value="val")
params = CompleteRequestParams(ref=ref, argument=arg, context=CompletionContext(arguments={"key": "value"}))
request = CompleteRequest(params=params)
assert request.method == "completion/complete"
assert request.params.ref.name == "test_prompt" # type: ignore
assert request.params.argument.name == "arg1"
def test_complete_result(self):
"""Test CompleteResult creation."""
completion = Completion(values=["option1", "option2", "option3"], total=10, hasMore=True)
result = CompleteResult(completion=completion)
assert len(result.completion.values) == 3
assert result.completion.total == 10
assert result.completion.hasMore is True
class TestValidation:
"""Test validation of various types."""
def test_invalid_jsonrpc_version(self):
"""Test invalid JSON-RPC version validation."""
with pytest.raises(ValidationError):
JSONRPCRequest(
jsonrpc="1.0", # Invalid version
id=1,
method="test",
)
def test_tool_annotations_validation(self):
"""Test ToolAnnotations with invalid values."""
# Valid annotations
annotations = ToolAnnotations(
title="Test", readOnlyHint=True, destructiveHint=False, idempotentHint=True, openWorldHint=False
)
assert annotations.title == "Test"
def test_extra_fields_allowed(self):
"""Test that extra fields are allowed in models."""
# Most models should allow extra fields
tool = Tool(
name="test",
inputSchema={},
customField="allowed", # type: ignore
)
assert tool.customField == "allowed" # type: ignore
def test_result_meta_alias(self):
"""Test Result model with _meta alias."""
# Create with the field name (not alias)
result = Result(_meta={"key": "value"})
# Verify the field is set correctly
assert result.meta == {"key": "value"}
# Dump with alias
dumped = result.model_dump(by_alias=True)
assert "_meta" in dumped
assert dumped["_meta"] == {"key": "value"}

View File

@ -0,0 +1,355 @@
"""Unit tests for MCP utils module."""
import json
from collections.abc import Generator
from unittest.mock import MagicMock, Mock, patch
import httpx
import httpx_sse
import pytest
from core.mcp.utils import (
STATUS_FORCELIST,
create_mcp_error_response,
create_ssrf_proxy_mcp_http_client,
ssrf_proxy_sse_connect,
)
class TestConstants:
"""Test module constants."""
def test_status_forcelist(self):
"""Test STATUS_FORCELIST contains expected HTTP status codes."""
assert STATUS_FORCELIST == [429, 500, 502, 503, 504]
assert 429 in STATUS_FORCELIST # Too Many Requests
assert 500 in STATUS_FORCELIST # Internal Server Error
assert 502 in STATUS_FORCELIST # Bad Gateway
assert 503 in STATUS_FORCELIST # Service Unavailable
assert 504 in STATUS_FORCELIST # Gateway Timeout
class TestCreateSSRFProxyMCPHTTPClient:
"""Test create_ssrf_proxy_mcp_http_client function."""
@patch("core.mcp.utils.dify_config")
def test_create_client_with_all_url_proxy(self, mock_config):
"""Test client creation with SSRF_PROXY_ALL_URL configured."""
mock_config.SSRF_PROXY_ALL_URL = "http://proxy.example.com:8080"
mock_config.HTTP_REQUEST_NODE_SSL_VERIFY = True
client = create_ssrf_proxy_mcp_http_client(
headers={"Authorization": "Bearer token"}, timeout=httpx.Timeout(30.0)
)
assert isinstance(client, httpx.Client)
assert client.headers["Authorization"] == "Bearer token"
assert client.timeout.connect == 30.0
assert client.follow_redirects is True
# Clean up
client.close()
@patch("core.mcp.utils.dify_config")
def test_create_client_with_http_https_proxies(self, mock_config):
"""Test client creation with separate HTTP/HTTPS proxies."""
mock_config.SSRF_PROXY_ALL_URL = None
mock_config.SSRF_PROXY_HTTP_URL = "http://http-proxy.example.com:8080"
mock_config.SSRF_PROXY_HTTPS_URL = "http://https-proxy.example.com:8443"
mock_config.HTTP_REQUEST_NODE_SSL_VERIFY = False
client = create_ssrf_proxy_mcp_http_client()
assert isinstance(client, httpx.Client)
assert client.follow_redirects is True
# Clean up
client.close()
@patch("core.mcp.utils.dify_config")
def test_create_client_without_proxy(self, mock_config):
"""Test client creation without proxy configuration."""
mock_config.SSRF_PROXY_ALL_URL = None
mock_config.SSRF_PROXY_HTTP_URL = None
mock_config.SSRF_PROXY_HTTPS_URL = None
mock_config.HTTP_REQUEST_NODE_SSL_VERIFY = True
headers = {"X-Custom-Header": "value"}
timeout = httpx.Timeout(timeout=30.0, connect=5.0, read=10.0, write=30.0)
client = create_ssrf_proxy_mcp_http_client(headers=headers, timeout=timeout)
assert isinstance(client, httpx.Client)
assert client.headers["X-Custom-Header"] == "value"
assert client.timeout.connect == 5.0
assert client.timeout.read == 10.0
assert client.follow_redirects is True
# Clean up
client.close()
@patch("core.mcp.utils.dify_config")
def test_create_client_default_params(self, mock_config):
"""Test client creation with default parameters."""
mock_config.SSRF_PROXY_ALL_URL = None
mock_config.SSRF_PROXY_HTTP_URL = None
mock_config.SSRF_PROXY_HTTPS_URL = None
mock_config.HTTP_REQUEST_NODE_SSL_VERIFY = True
client = create_ssrf_proxy_mcp_http_client()
assert isinstance(client, httpx.Client)
# httpx.Client adds default headers, so we just check it's a Headers object
assert isinstance(client.headers, httpx.Headers)
# When no timeout is provided, httpx uses its default timeout
assert client.timeout is not None
# Clean up
client.close()
class TestSSRFProxySSEConnect:
"""Test ssrf_proxy_sse_connect function."""
@patch("core.mcp.utils.connect_sse")
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client")
def test_sse_connect_with_provided_client(self, mock_create_client, mock_connect_sse):
"""Test SSE connection with pre-configured client."""
# Setup mocks
mock_client = Mock(spec=httpx.Client)
mock_event_source = Mock(spec=httpx_sse.EventSource)
mock_context = MagicMock()
mock_context.__enter__.return_value = mock_event_source
mock_connect_sse.return_value = mock_context
# Call with provided client
result = ssrf_proxy_sse_connect(
"http://example.com/sse", client=mock_client, method="POST", headers={"Authorization": "Bearer token"}
)
# Verify client creation was not called
mock_create_client.assert_not_called()
# Verify connect_sse was called correctly
mock_connect_sse.assert_called_once_with(
mock_client, "POST", "http://example.com/sse", headers={"Authorization": "Bearer token"}
)
# Verify result
assert result == mock_context
@patch("core.mcp.utils.connect_sse")
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client")
@patch("core.mcp.utils.dify_config")
def test_sse_connect_without_client(self, mock_config, mock_create_client, mock_connect_sse):
"""Test SSE connection without pre-configured client."""
# Setup config
mock_config.SSRF_DEFAULT_TIME_OUT = 30.0
mock_config.SSRF_DEFAULT_CONNECT_TIME_OUT = 10.0
mock_config.SSRF_DEFAULT_READ_TIME_OUT = 60.0
mock_config.SSRF_DEFAULT_WRITE_TIME_OUT = 30.0
# Setup mocks
mock_client = Mock(spec=httpx.Client)
mock_create_client.return_value = mock_client
mock_event_source = Mock(spec=httpx_sse.EventSource)
mock_context = MagicMock()
mock_context.__enter__.return_value = mock_event_source
mock_connect_sse.return_value = mock_context
# Call without client
result = ssrf_proxy_sse_connect("http://example.com/sse", headers={"X-Custom": "value"})
# Verify client was created
mock_create_client.assert_called_once()
call_args = mock_create_client.call_args
assert call_args[1]["headers"] == {"X-Custom": "value"}
timeout = call_args[1]["timeout"]
# httpx.Timeout object has these attributes
assert isinstance(timeout, httpx.Timeout)
assert timeout.connect == 10.0
assert timeout.read == 60.0
assert timeout.write == 30.0
# Verify connect_sse was called
mock_connect_sse.assert_called_once_with(
mock_client,
"GET", # Default method
"http://example.com/sse",
)
# Verify result
assert result == mock_context
@patch("core.mcp.utils.connect_sse")
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client")
def test_sse_connect_with_custom_timeout(self, mock_create_client, mock_connect_sse):
"""Test SSE connection with custom timeout."""
# Setup mocks
mock_client = Mock(spec=httpx.Client)
mock_create_client.return_value = mock_client
mock_event_source = Mock(spec=httpx_sse.EventSource)
mock_context = MagicMock()
mock_context.__enter__.return_value = mock_event_source
mock_connect_sse.return_value = mock_context
custom_timeout = httpx.Timeout(timeout=60.0, read=120.0)
# Call with custom timeout
result = ssrf_proxy_sse_connect("http://example.com/sse", timeout=custom_timeout)
# Verify client was created with custom timeout
mock_create_client.assert_called_once()
call_args = mock_create_client.call_args
assert call_args[1]["timeout"] == custom_timeout
# Verify result
assert result == mock_context
@patch("core.mcp.utils.connect_sse")
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client")
def test_sse_connect_error_cleanup(self, mock_create_client, mock_connect_sse):
"""Test SSE connection cleans up client on error."""
# Setup mocks
mock_client = Mock(spec=httpx.Client)
mock_create_client.return_value = mock_client
# Make connect_sse raise an exception
mock_connect_sse.side_effect = httpx.ConnectError("Connection failed")
# Call should raise the exception
with pytest.raises(httpx.ConnectError):
ssrf_proxy_sse_connect("http://example.com/sse")
# Verify client was cleaned up
mock_client.close.assert_called_once()
@patch("core.mcp.utils.connect_sse")
def test_sse_connect_error_no_cleanup_with_provided_client(self, mock_connect_sse):
"""Test SSE connection doesn't clean up provided client on error."""
# Setup mocks
mock_client = Mock(spec=httpx.Client)
# Make connect_sse raise an exception
mock_connect_sse.side_effect = httpx.ConnectError("Connection failed")
# Call should raise the exception
with pytest.raises(httpx.ConnectError):
ssrf_proxy_sse_connect("http://example.com/sse", client=mock_client)
# Verify client was NOT cleaned up (because it was provided)
mock_client.close.assert_not_called()
class TestCreateMCPErrorResponse:
"""Test create_mcp_error_response function."""
def test_create_error_response_basic(self):
"""Test creating basic error response."""
generator = create_mcp_error_response(request_id="req-123", code=-32600, message="Invalid Request")
# Generator should yield bytes
assert isinstance(generator, Generator)
# Get the response
response_bytes = next(generator)
assert isinstance(response_bytes, bytes)
# Parse the response
response_str = response_bytes.decode("utf-8")
response_json = json.loads(response_str)
assert response_json["jsonrpc"] == "2.0"
assert response_json["id"] == "req-123"
assert response_json["error"]["code"] == -32600
assert response_json["error"]["message"] == "Invalid Request"
assert response_json["error"]["data"] is None
# Generator should be exhausted
with pytest.raises(StopIteration):
next(generator)
def test_create_error_response_with_data(self):
"""Test creating error response with additional data."""
error_data = {"field": "username", "reason": "required"}
generator = create_mcp_error_response(
request_id=456, # Numeric ID
code=-32602,
message="Invalid params",
data=error_data,
)
response_bytes = next(generator)
response_json = json.loads(response_bytes.decode("utf-8"))
assert response_json["id"] == 456
assert response_json["error"]["code"] == -32602
assert response_json["error"]["message"] == "Invalid params"
assert response_json["error"]["data"] == error_data
def test_create_error_response_without_request_id(self):
"""Test creating error response without request ID."""
generator = create_mcp_error_response(request_id=None, code=-32700, message="Parse error")
response_bytes = next(generator)
response_json = json.loads(response_bytes.decode("utf-8"))
# Should default to ID 1
assert response_json["id"] == 1
assert response_json["error"]["code"] == -32700
assert response_json["error"]["message"] == "Parse error"
def test_create_error_response_with_complex_data(self):
"""Test creating error response with complex error data."""
complex_data = {
"errors": [{"field": "name", "message": "Too short"}, {"field": "email", "message": "Invalid format"}],
"timestamp": "2024-01-01T00:00:00Z",
}
generator = create_mcp_error_response(
request_id="complex-req", code=-32602, message="Validation failed", data=complex_data
)
response_bytes = next(generator)
response_json = json.loads(response_bytes.decode("utf-8"))
assert response_json["error"]["data"] == complex_data
assert len(response_json["error"]["data"]["errors"]) == 2
def test_create_error_response_encoding(self):
"""Test error response with non-ASCII characters."""
generator = create_mcp_error_response(
request_id="unicode-req",
code=-32603,
message="内部错误", # Chinese characters
data={"details": "エラー詳細"}, # Japanese characters
)
response_bytes = next(generator)
# Should be valid UTF-8
response_str = response_bytes.decode("utf-8")
response_json = json.loads(response_str)
assert response_json["error"]["message"] == "内部错误"
assert response_json["error"]["data"]["details"] == "エラー詳細"
def test_create_error_response_yields_once(self):
"""Test that error response generator yields exactly once."""
generator = create_mcp_error_response(request_id="test", code=-32600, message="Test")
# First yield should work
first_yield = next(generator)
assert isinstance(first_yield, bytes)
# Second yield should raise StopIteration
with pytest.raises(StopIteration):
next(generator)
# Subsequent calls should also raise
with pytest.raises(StopIteration):
next(generator)