From e2fd3f29830031e102b1ade163d7cb21c75dda3e Mon Sep 17 00:00:00 2001 From: Novice Date: Tue, 16 Sep 2025 16:18:41 +0800 Subject: [PATCH] feat: add unit test --- .../console/workspace/tool_providers.py | 15 +- api/core/mcp/auth_client.py | 84 +-- api/core/mcp/auth_client_comparison.md | 1 + api/core/mcp/client/sse_client.py | 4 +- api/core/mcp/types.py | 2 +- api/core/tools/mcp_tool/tool.py | 1 - api/tests/unit_tests/core/mcp/__init__.py | 0 .../unit_tests/core/mcp/auth/__init__.py | 0 .../core/mcp/auth/test_auth_flow.py | 710 ++++++++++++++++++ .../unit_tests/core/mcp/test_auth_client.py | 523 +++++++++++++ .../core/mcp/test_auth_client_inheritance.py | 0 .../unit_tests/core/mcp/test_entities.py | 239 ++++++ api/tests/unit_tests/core/mcp/test_error.py | 205 +++++ .../unit_tests/core/mcp/test_mcp_client.py | 382 ++++++++++ api/tests/unit_tests/core/mcp/test_types.py | 492 ++++++++++++ api/tests/unit_tests/core/mcp/test_utils.py | 355 +++++++++ 16 files changed, 2945 insertions(+), 68 deletions(-) create mode 100644 api/core/mcp/auth_client_comparison.md create mode 100644 api/tests/unit_tests/core/mcp/__init__.py create mode 100644 api/tests/unit_tests/core/mcp/auth/__init__.py create mode 100644 api/tests/unit_tests/core/mcp/auth/test_auth_flow.py create mode 100644 api/tests/unit_tests/core/mcp/test_auth_client.py create mode 100644 api/tests/unit_tests/core/mcp/test_auth_client_inheritance.py create mode 100644 api/tests/unit_tests/core/mcp/test_entities.py create mode 100644 api/tests/unit_tests/core/mcp/test_error.py create mode 100644 api/tests/unit_tests/core/mcp/test_mcp_client.py create mode 100644 api/tests/unit_tests/core/mcp/test_types.py create mode 100644 api/tests/unit_tests/core/mcp/test_utils.py diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 094370f1cc..96bc288a77 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -18,8 +18,8 @@ from controllers.console.wraps import ( setup_required, ) from core.mcp.auth.auth_flow import auth, handle_callback -from core.mcp.auth_client import MCPClientWithAuthRetry -from core.mcp.error import MCPError +from core.mcp.error import MCPAuthError, MCPError +from core.mcp.mcp_client import MCPClient from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin import ToolProviderID from core.plugin.impl.oauth import OAuthHandler @@ -974,17 +974,11 @@ class ToolMCPAuthApi(Resource): headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}" try: # Use MCPClientWithAuthRetry to handle authentication automatically - with MCPClientWithAuthRetry( + with MCPClient( server_url=server_url, headers=headers, timeout=provider_entity.timeout, sse_read_timeout=provider_entity.sse_read_timeout, - provider_entity=provider_entity - if not provider_entity.headers - else None, # Only use auth retry if no custom headers - auth_callback=auth if not provider_entity.headers else None, - authorization_code=args.get("authorization_code"), - mcp_service=service, ): service.update_provider_credentials( provider=db_provider, @@ -993,7 +987,8 @@ class ToolMCPAuthApi(Resource): ) session.commit() return {"result": "success"} - + except MCPAuthError as e: + return auth(provider_entity, service, args.get("authorization_code")) except MCPError as e: service.clear_provider_credentials(provider=db_provider) session.commit() diff --git a/api/core/mcp/auth_client.py b/api/core/mcp/auth_client.py index ede2518dde..dac0f3f3ec 100644 --- a/api/core/mcp/auth_client.py +++ b/api/core/mcp/auth_client.py @@ -1,13 +1,12 @@ """ MCP Client with Authentication Retry Support -This module provides a wrapper around MCPClient that automatically handles +This module provides an enhanced MCPClient that automatically handles authentication failures and retries operations after refreshing tokens. """ import logging from collections.abc import Callable -from types import TracebackType from typing import TYPE_CHECKING, Any, Optional from core.entities.mcp_provider import MCPProviderEntity @@ -21,12 +20,12 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class MCPClientWithAuthRetry: +class MCPClientWithAuthRetry(MCPClient): """ - A wrapper around MCPClient that provides automatic authentication retry. + An enhanced MCPClient that provides automatic authentication retry. - This class intercepts MCPAuthError exceptions and attempts to refresh - authentication before retrying the failed operation. + This class extends MCPClient and intercepts MCPAuthError exceptions + to refresh authentication before retrying failed operations. """ def __init__( @@ -53,27 +52,17 @@ class MCPClientWithAuthRetry: provider_entity: Provider entity for authentication auth_callback: Authentication callback function authorization_code: Optional authorization code for initial auth + by_server_id: Whether to look up provider by server ID + mcp_service: MCP service instance """ - self.server_url = server_url - self.headers = headers or {} - self.timeout = timeout - self.sse_read_timeout = sse_read_timeout + super().__init__(server_url, headers, timeout, sse_read_timeout) + self.provider_entity = provider_entity self.auth_callback = auth_callback self.authorization_code = authorization_code - self._has_retried = False - self._client: MCPClient | None = None self.by_server_id = by_server_id self.mcp_service = mcp_service - - def _create_client(self) -> MCPClient: - """Create a new MCPClient instance with current headers.""" - return MCPClient( - server_url=self.server_url, - headers=self.headers, - timeout=self.timeout, - sse_read_timeout=self.sse_read_timeout, - ) + self._has_retried = False def _handle_auth_error(self, error: MCPAuthError) -> None: """ @@ -134,38 +123,35 @@ class MCPClientWithAuthRetry: return func(*args, **kwargs) except MCPAuthError as e: self._handle_auth_error(e) - # Recreate client with new headers - if self._client: - self._client.cleanup() - self._client = self._create_client() - self._client.__enter__() + + # Re-initialize the connection with new headers + if self._initialized: + # Clean up existing connection + self._exit_stack.close() + self._session = None + self._initialized = False + + # Re-initialize with new headers + self._initialize() + self._initialized = True + return func(*args, **kwargs) finally: # Reset retry flag after operation completes self._has_retried = False def __enter__(self): - """Enter the context manager.""" - self._client = self._create_client() + """Enter the context manager with retry support.""" - # Try to initialize with retry - def initialize(): - if self._client is None: - raise ValueError("Client not created") - self._client.__enter__() + def initialize_with_retry(): + super(MCPClientWithAuthRetry, self).__enter__() return self - return self._execute_with_retry(initialize) - - def __exit__(self, exc_type: type | None, exc_value: BaseException | None, traceback: TracebackType | None): - """Exit the context manager.""" - if self._client: - self._client.__exit__(exc_type, exc_value, traceback) - self._client = None + return self._execute_with_retry(initialize_with_retry) def list_tools(self) -> list[Tool]: """ - List available tools from the MCP server. + List available tools from the MCP server with auth retry. Returns: List of available tools @@ -173,13 +159,11 @@ class MCPClientWithAuthRetry: Raises: MCPAuthError: If authentication fails after retries """ - if not self._client: - raise ValueError("Client not initialized. Use within a context manager.") - return self._execute_with_retry(self._client.list_tools) + return self._execute_with_retry(super().list_tools) def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult: """ - Invoke a tool on the MCP server. + Invoke a tool on the MCP server with auth retry. Args: tool_name: Name of the tool to invoke @@ -191,12 +175,4 @@ class MCPClientWithAuthRetry: Raises: MCPAuthError: If authentication fails after retries """ - if not self._client: - raise ValueError("Client not initialized. Use within a context manager.") - return self._execute_with_retry(self._client.invoke_tool, tool_name, tool_args) - - def cleanup(self): - """Clean up resources.""" - if self._client: - self._client.cleanup() - self._client = None + return self._execute_with_retry(super().invoke_tool, tool_name, tool_args) diff --git a/api/core/mcp/auth_client_comparison.md b/api/core/mcp/auth_client_comparison.md new file mode 100644 index 0000000000..0519ecba6e --- /dev/null +++ b/api/core/mcp/auth_client_comparison.md @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/api/core/mcp/client/sse_client.py b/api/core/mcp/client/sse_client.py index 6db22a09e0..2c7276e585 100644 --- a/api/core/mcp/client/sse_client.py +++ b/api/core/mcp/client/sse_client.py @@ -46,7 +46,7 @@ class SSETransport: url: str, headers: dict[str, Any] | None = None, timeout: float = 5.0, - sse_read_timeout: float = 5 * 60, + sse_read_timeout: float = 1 * 60, ): """Initialize the SSE transport. @@ -255,7 +255,7 @@ def sse_client( url: str, headers: dict[str, Any] | None = None, timeout: float = 5.0, - sse_read_timeout: float = 5 * 60, + sse_read_timeout: float = 1 * 60, ) -> Generator[tuple[ReadQueue, WriteQueue], None, None]: """ Client transport for SSE. diff --git a/api/core/mcp/types.py b/api/core/mcp/types.py index fa0360df03..55c989ca1e 100644 --- a/api/core/mcp/types.py +++ b/api/core/mcp/types.py @@ -30,7 +30,7 @@ DEFAULT_NEGOTIATED_VERSION = "2025-03-26" ProgressToken = str | int Cursor = str Role = Literal["user", "assistant"] -RequestId = Annotated[int, Field(strict=True)] | str +RequestId = Annotated[int | str, Field(union_mode="left_to_right")] AnyFunction: TypeAlias = Callable[..., Any] diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index 27274c859b..6398cc56a9 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -162,7 +162,6 @@ class MCPTool(Tool): sse_read_timeout=self.sse_read_timeout, provider_entity=provider_entity, auth_callback=auth if mcp_service else None, - by_server_id=True, mcp_service=mcp_service, ) as mcp_client: return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters) diff --git a/api/tests/unit_tests/core/mcp/__init__.py b/api/tests/unit_tests/core/mcp/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/mcp/auth/__init__.py b/api/tests/unit_tests/core/mcp/auth/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py b/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py new file mode 100644 index 0000000000..cce77aa018 --- /dev/null +++ b/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py @@ -0,0 +1,710 @@ +"""Unit tests for MCP OAuth authentication flow.""" + +from unittest.mock import Mock, patch + +import pytest + +from core.entities.mcp_provider import MCPProviderEntity +from core.mcp.auth.auth_flow import ( + OAUTH_STATE_EXPIRY_SECONDS, + OAUTH_STATE_REDIS_KEY_PREFIX, + OAuthCallbackState, + _create_secure_redis_state, + _retrieve_redis_state, + auth, + check_support_resource_discovery, + discover_oauth_metadata, + exchange_authorization, + generate_pkce_challenge, + handle_callback, + refresh_authorization, + register_client, + start_authorization, +) +from core.mcp.types import ( + OAuthClientInformation, + OAuthClientInformationFull, + OAuthClientMetadata, + OAuthMetadata, + OAuthTokens, +) + + +class TestPKCEGeneration: + """Test PKCE challenge generation.""" + + def test_generate_pkce_challenge(self): + """Test PKCE challenge and verifier generation.""" + code_verifier, code_challenge = generate_pkce_challenge() + + # Verify format - should be URL-safe base64 without padding + assert "=" not in code_verifier + assert "+" not in code_verifier + assert "/" not in code_verifier + assert "=" not in code_challenge + assert "+" not in code_challenge + assert "/" not in code_challenge + + # Verify length + assert len(code_verifier) > 40 # Should be around 54 characters + assert len(code_challenge) > 40 # Should be around 43 characters + + def test_generate_pkce_challenge_uniqueness(self): + """Test that PKCE generation produces unique values.""" + results = set() + for _ in range(10): + code_verifier, code_challenge = generate_pkce_challenge() + results.add((code_verifier, code_challenge)) + + # All should be unique + assert len(results) == 10 + + +class TestRedisStateManagement: + """Test Redis state management functions.""" + + @patch("core.mcp.auth.auth_flow.redis_client") + def test_create_secure_redis_state(self, mock_redis): + """Test creating secure Redis state.""" + state_data = OAuthCallbackState( + provider_id="test-provider", + tenant_id="test-tenant", + server_url="https://example.com", + metadata=None, + client_information=OAuthClientInformation(client_id="test-client"), + code_verifier="test-verifier", + redirect_uri="https://redirect.example.com", + ) + + state_key = _create_secure_redis_state(state_data) + + # Verify state key format + assert len(state_key) > 20 # Should be a secure random token + + # Verify Redis call + mock_redis.setex.assert_called_once() + call_args = mock_redis.setex.call_args + assert call_args[0][0].startswith(OAUTH_STATE_REDIS_KEY_PREFIX) + assert call_args[0][1] == OAUTH_STATE_EXPIRY_SECONDS + assert state_data.model_dump_json() in call_args[0][2] + + @patch("core.mcp.auth.auth_flow.redis_client") + def test_retrieve_redis_state_success(self, mock_redis): + """Test retrieving state from Redis.""" + state_data = OAuthCallbackState( + provider_id="test-provider", + tenant_id="test-tenant", + server_url="https://example.com", + metadata=None, + client_information=OAuthClientInformation(client_id="test-client"), + code_verifier="test-verifier", + redirect_uri="https://redirect.example.com", + ) + mock_redis.get.return_value = state_data.model_dump_json() + + result = _retrieve_redis_state("test-state-key") + + # Verify result + assert result.provider_id == "test-provider" + assert result.tenant_id == "test-tenant" + assert result.server_url == "https://example.com" + + # Verify Redis calls + mock_redis.get.assert_called_once_with(f"{OAUTH_STATE_REDIS_KEY_PREFIX}test-state-key") + mock_redis.delete.assert_called_once_with(f"{OAUTH_STATE_REDIS_KEY_PREFIX}test-state-key") + + @patch("core.mcp.auth.auth_flow.redis_client") + def test_retrieve_redis_state_not_found(self, mock_redis): + """Test retrieving non-existent state from Redis.""" + mock_redis.get.return_value = None + + with pytest.raises(ValueError) as exc_info: + _retrieve_redis_state("nonexistent-key") + + assert "State parameter has expired or does not exist" in str(exc_info.value) + + @patch("core.mcp.auth.auth_flow.redis_client") + def test_retrieve_redis_state_invalid_json(self, mock_redis): + """Test retrieving invalid JSON state from Redis.""" + mock_redis.get.return_value = '{"invalid": json}' + + with pytest.raises(ValueError) as exc_info: + _retrieve_redis_state("test-key") + + assert "Invalid state parameter" in str(exc_info.value) + # State should still be deleted + mock_redis.delete.assert_called_once() + + +class TestOAuthDiscovery: + """Test OAuth discovery functions.""" + + @patch("httpx.get") + def test_check_support_resource_discovery_success(self, mock_get): + """Test successful resource discovery check.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"authorization_server_url": ["https://auth.example.com"]} + mock_get.return_value = mock_response + + supported, auth_url = check_support_resource_discovery("https://api.example.com/endpoint") + + assert supported is True + assert auth_url == "https://auth.example.com" + mock_get.assert_called_once_with( + "https://api.example.com/.well-known/oauth-protected-resource/endpoint", + headers={"MCP-Protocol-Version": "2025-03-26", "User-Agent": "Dify"}, + ) + + @patch("httpx.get") + def test_check_support_resource_discovery_not_supported(self, mock_get): + """Test resource discovery not supported.""" + mock_response = Mock() + mock_response.status_code = 404 + mock_get.return_value = mock_response + + supported, auth_url = check_support_resource_discovery("https://api.example.com") + + assert supported is False + assert auth_url == "" + + @patch("httpx.get") + def test_check_support_resource_discovery_with_query_fragment(self, mock_get): + """Test resource discovery with query and fragment.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"authorization_server_url": ["https://auth.example.com"]} + mock_get.return_value = mock_response + + supported, auth_url = check_support_resource_discovery("https://api.example.com/path?query=1#fragment") + + assert supported is True + assert auth_url == "https://auth.example.com" + mock_get.assert_called_once_with( + "https://api.example.com/.well-known/oauth-protected-resource/path?query=1#fragment", + headers={"MCP-Protocol-Version": "2025-03-26", "User-Agent": "Dify"}, + ) + + @patch("httpx.get") + def test_discover_oauth_metadata_with_resource_discovery(self, mock_get): + """Test OAuth metadata discovery with resource discovery support.""" + with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check: + mock_check.return_value = (True, "https://auth.example.com/.well-known/oauth-authorization-server") + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.is_success = True + mock_response.json.return_value = { + "authorization_endpoint": "https://auth.example.com/authorize", + "token_endpoint": "https://auth.example.com/token", + "response_types_supported": ["code"], + } + mock_get.return_value = mock_response + + metadata = discover_oauth_metadata("https://api.example.com") + + assert metadata is not None + assert metadata.authorization_endpoint == "https://auth.example.com/authorize" + assert metadata.token_endpoint == "https://auth.example.com/token" + mock_get.assert_called_once_with( + "https://auth.example.com/.well-known/oauth-authorization-server", + headers={"MCP-Protocol-Version": "2025-03-26"}, + ) + + @patch("httpx.get") + def test_discover_oauth_metadata_without_resource_discovery(self, mock_get): + """Test OAuth metadata discovery without resource discovery.""" + with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check: + mock_check.return_value = (False, "") + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.is_success = True + mock_response.json.return_value = { + "authorization_endpoint": "https://api.example.com/oauth/authorize", + "token_endpoint": "https://api.example.com/oauth/token", + "response_types_supported": ["code"], + } + mock_get.return_value = mock_response + + metadata = discover_oauth_metadata("https://api.example.com") + + assert metadata is not None + assert metadata.authorization_endpoint == "https://api.example.com/oauth/authorize" + mock_get.assert_called_once_with( + "https://api.example.com/.well-known/oauth-authorization-server", + headers={"MCP-Protocol-Version": "2025-03-26"}, + ) + + @patch("httpx.get") + def test_discover_oauth_metadata_not_found(self, mock_get): + """Test OAuth metadata discovery when not found.""" + with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check: + mock_check.return_value = (False, "") + + mock_response = Mock() + mock_response.status_code = 404 + mock_get.return_value = mock_response + + metadata = discover_oauth_metadata("https://api.example.com") + + assert metadata is None + + +class TestAuthorizationFlow: + """Test authorization flow functions.""" + + @patch("core.mcp.auth.auth_flow._create_secure_redis_state") + def test_start_authorization_with_metadata(self, mock_create_state): + """Test starting authorization with metadata.""" + mock_create_state.return_value = "secure-state-key" + + metadata = OAuthMetadata( + authorization_endpoint="https://auth.example.com/authorize", + token_endpoint="https://auth.example.com/token", + response_types_supported=["code"], + code_challenge_methods_supported=["S256"], + ) + client_info = OAuthClientInformation(client_id="test-client-id") + + auth_url, code_verifier = start_authorization( + "https://api.example.com", + metadata, + client_info, + "https://redirect.example.com", + "provider-id", + "tenant-id", + ) + + # Verify URL format + assert auth_url.startswith("https://auth.example.com/authorize?") + assert "response_type=code" in auth_url + assert "client_id=test-client-id" in auth_url + assert "code_challenge=" in auth_url + assert "code_challenge_method=S256" in auth_url + assert "redirect_uri=https%3A%2F%2Fredirect.example.com" in auth_url + assert "state=secure-state-key" in auth_url + + # Verify code verifier + assert len(code_verifier) > 40 + + # Verify state was stored + mock_create_state.assert_called_once() + state_data = mock_create_state.call_args[0][0] + assert state_data.provider_id == "provider-id" + assert state_data.tenant_id == "tenant-id" + assert state_data.code_verifier == code_verifier + + def test_start_authorization_without_metadata(self): + """Test starting authorization without metadata.""" + with patch("core.mcp.auth.auth_flow._create_secure_redis_state") as mock_create_state: + mock_create_state.return_value = "secure-state-key" + + client_info = OAuthClientInformation(client_id="test-client-id") + + auth_url, code_verifier = start_authorization( + "https://api.example.com", + None, + client_info, + "https://redirect.example.com", + "provider-id", + "tenant-id", + ) + + # Should use default authorization endpoint + assert auth_url.startswith("https://api.example.com/authorize?") + + def test_start_authorization_invalid_metadata(self): + """Test starting authorization with invalid metadata.""" + metadata = OAuthMetadata( + authorization_endpoint="https://auth.example.com/authorize", + token_endpoint="https://auth.example.com/token", + response_types_supported=["token"], # No "code" support + code_challenge_methods_supported=["plain"], # No "S256" support + ) + client_info = OAuthClientInformation(client_id="test-client-id") + + with pytest.raises(ValueError) as exc_info: + start_authorization( + "https://api.example.com", + metadata, + client_info, + "https://redirect.example.com", + "provider-id", + "tenant-id", + ) + + assert "does not support response type code" in str(exc_info.value) + + @patch("httpx.post") + def test_exchange_authorization_success(self, mock_post): + """Test successful authorization code exchange.""" + mock_response = Mock() + mock_response.is_success = True + mock_response.json.return_value = { + "access_token": "new-access-token", + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": "new-refresh-token", + } + mock_post.return_value = mock_response + + metadata = OAuthMetadata( + authorization_endpoint="https://auth.example.com/authorize", + token_endpoint="https://auth.example.com/token", + response_types_supported=["code"], + grant_types_supported=["authorization_code"], + ) + client_info = OAuthClientInformation(client_id="test-client-id", client_secret="test-secret") + + tokens = exchange_authorization( + "https://api.example.com", + metadata, + client_info, + "auth-code-123", + "code-verifier-xyz", + "https://redirect.example.com", + ) + + assert tokens.access_token == "new-access-token" + assert tokens.token_type == "Bearer" + assert tokens.expires_in == 3600 + assert tokens.refresh_token == "new-refresh-token" + + # Verify request + mock_post.assert_called_once_with( + "https://auth.example.com/token", + data={ + "grant_type": "authorization_code", + "client_id": "test-client-id", + "client_secret": "test-secret", + "code": "auth-code-123", + "code_verifier": "code-verifier-xyz", + "redirect_uri": "https://redirect.example.com", + }, + ) + + @patch("httpx.post") + def test_exchange_authorization_failure(self, mock_post): + """Test failed authorization code exchange.""" + mock_response = Mock() + mock_response.is_success = False + mock_response.status_code = 400 + mock_post.return_value = mock_response + + client_info = OAuthClientInformation(client_id="test-client-id") + + with pytest.raises(ValueError) as exc_info: + exchange_authorization( + "https://api.example.com", + None, + client_info, + "invalid-code", + "code-verifier", + "https://redirect.example.com", + ) + + assert "Token exchange failed: HTTP 400" in str(exc_info.value) + + @patch("httpx.post") + def test_refresh_authorization_success(self, mock_post): + """Test successful token refresh.""" + mock_response = Mock() + mock_response.is_success = True + mock_response.json.return_value = { + "access_token": "refreshed-access-token", + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": "new-refresh-token", + } + mock_post.return_value = mock_response + + metadata = OAuthMetadata( + authorization_endpoint="https://auth.example.com/authorize", + token_endpoint="https://auth.example.com/token", + response_types_supported=["code"], + grant_types_supported=["refresh_token"], + ) + client_info = OAuthClientInformation(client_id="test-client-id") + + tokens = refresh_authorization("https://api.example.com", metadata, client_info, "old-refresh-token") + + assert tokens.access_token == "refreshed-access-token" + assert tokens.refresh_token == "new-refresh-token" + + # Verify request + mock_post.assert_called_once_with( + "https://auth.example.com/token", + data={ + "grant_type": "refresh_token", + "client_id": "test-client-id", + "refresh_token": "old-refresh-token", + }, + ) + + @patch("httpx.post") + def test_register_client_success(self, mock_post): + """Test successful client registration.""" + mock_response = Mock() + mock_response.is_success = True + mock_response.json.return_value = { + "client_id": "new-client-id", + "client_secret": "new-client-secret", + "client_name": "Dify", + "redirect_uris": ["https://redirect.example.com"], + } + mock_post.return_value = mock_response + + metadata = OAuthMetadata( + authorization_endpoint="https://auth.example.com/authorize", + token_endpoint="https://auth.example.com/token", + registration_endpoint="https://auth.example.com/register", + response_types_supported=["code"], + ) + client_metadata = OAuthClientMetadata( + client_name="Dify", + redirect_uris=["https://redirect.example.com"], + grant_types=["authorization_code"], + response_types=["code"], + ) + + client_info = register_client("https://api.example.com", metadata, client_metadata) + + assert isinstance(client_info, OAuthClientInformationFull) + assert client_info.client_id == "new-client-id" + assert client_info.client_secret == "new-client-secret" + + # Verify request + mock_post.assert_called_once_with( + "https://auth.example.com/register", + json=client_metadata.model_dump(), + headers={"Content-Type": "application/json"}, + ) + + def test_register_client_no_endpoint(self): + """Test client registration when no endpoint available.""" + metadata = OAuthMetadata( + authorization_endpoint="https://auth.example.com/authorize", + token_endpoint="https://auth.example.com/token", + registration_endpoint=None, + response_types_supported=["code"], + ) + client_metadata = OAuthClientMetadata(client_name="Dify", redirect_uris=["https://redirect.example.com"]) + + with pytest.raises(ValueError) as exc_info: + register_client("https://api.example.com", metadata, client_metadata) + + assert "does not support dynamic client registration" in str(exc_info.value) + + +class TestCallbackHandling: + """Test OAuth callback handling.""" + + @patch("core.mcp.auth.auth_flow._retrieve_redis_state") + @patch("core.mcp.auth.auth_flow.exchange_authorization") + def test_handle_callback_success(self, mock_exchange, mock_retrieve_state): + """Test successful callback handling.""" + # Setup state + state_data = OAuthCallbackState( + provider_id="test-provider", + tenant_id="test-tenant", + server_url="https://api.example.com", + metadata=None, + client_information=OAuthClientInformation(client_id="test-client"), + code_verifier="test-verifier", + redirect_uri="https://redirect.example.com", + ) + mock_retrieve_state.return_value = state_data + + # Setup token exchange + tokens = OAuthTokens( + access_token="new-token", + token_type="Bearer", + expires_in=3600, + ) + mock_exchange.return_value = tokens + + # Setup service + mock_service = Mock() + + result = handle_callback("state-key", "auth-code", mock_service) + + assert result == state_data + + # Verify calls + mock_retrieve_state.assert_called_once_with("state-key") + mock_exchange.assert_called_once_with( + "https://api.example.com", + None, + state_data.client_information, + "auth-code", + "test-verifier", + "https://redirect.example.com", + ) + mock_service.save_oauth_data.assert_called_once_with( + "test-provider", "test-tenant", tokens.model_dump(), "tokens" + ) + + +class TestAuthOrchestration: + """Test the main auth orchestration function.""" + + @pytest.fixture + def mock_provider(self): + """Create a mock provider entity.""" + provider = Mock(spec=MCPProviderEntity) + provider.id = "provider-id" + provider.tenant_id = "tenant-id" + provider.decrypt_server_url.return_value = "https://api.example.com" + provider.client_metadata = OAuthClientMetadata( + client_name="Dify", + redirect_uris=["https://redirect.example.com"], + ) + provider.redirect_url = "https://redirect.example.com" + provider.retrieve_client_information.return_value = None + provider.retrieve_tokens.return_value = None + return provider + + @pytest.fixture + def mock_service(self): + """Create a mock MCP service.""" + return Mock() + + @patch("core.mcp.auth.auth_flow.discover_oauth_metadata") + @patch("core.mcp.auth.auth_flow.register_client") + @patch("core.mcp.auth.auth_flow.start_authorization") + def test_auth_new_registration(self, mock_start_auth, mock_register, mock_discover, mock_provider, mock_service): + """Test auth flow for new client registration.""" + # Setup + mock_discover.return_value = None + mock_register.return_value = OAuthClientInformationFull( + client_id="new-client-id", + client_name="Dify", + redirect_uris=["https://redirect.example.com"], + ) + mock_start_auth.return_value = ("https://auth.example.com/authorize?...", "code-verifier") + + result = auth(mock_provider, mock_service) + + assert result == {"authorization_url": "https://auth.example.com/authorize?..."} + + # Verify calls + mock_register.assert_called_once() + mock_service.save_oauth_data.assert_any_call( + "provider-id", + "tenant-id", + {"client_information": mock_register.return_value.model_dump()}, + "client_info", + ) + mock_service.save_oauth_data.assert_any_call( + "provider-id", "tenant-id", {"code_verifier": "code-verifier"}, "code_verifier" + ) + + @patch("core.mcp.auth.auth_flow.discover_oauth_metadata") + @patch("core.mcp.auth.auth_flow._retrieve_redis_state") + @patch("core.mcp.auth.auth_flow.exchange_authorization") + def test_auth_exchange_code(self, mock_exchange, mock_retrieve_state, mock_discover, mock_provider, mock_service): + """Test auth flow for exchanging authorization code.""" + # Setup metadata discovery + mock_discover.return_value = OAuthMetadata( + authorization_endpoint="https://auth.example.com/authorize", + token_endpoint="https://auth.example.com/token", + response_types_supported=["code"], + grant_types_supported=["authorization_code"], + ) + + # Setup existing client + mock_provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="existing-client") + + # Setup state retrieval + state_data = OAuthCallbackState( + provider_id="provider-id", + tenant_id="tenant-id", + server_url="https://api.example.com", + metadata=None, + client_information=OAuthClientInformation(client_id="existing-client"), + code_verifier="test-verifier", + redirect_uri="https://redirect.example.com", + ) + mock_retrieve_state.return_value = state_data + + # Setup token exchange + tokens = OAuthTokens(access_token="new-token", token_type="Bearer", expires_in=3600) + mock_exchange.return_value = tokens + + result = auth(mock_provider, mock_service, authorization_code="auth-code", state_param="state-key") + + assert result == {"result": "success"} + + # Verify token save + mock_service.save_oauth_data.assert_called_with("provider-id", "tenant-id", tokens.model_dump(), "tokens") + + @patch("core.mcp.auth.auth_flow.discover_oauth_metadata") + def test_auth_exchange_code_without_state(self, mock_discover, mock_provider, mock_service): + """Test auth flow fails when exchanging code without state.""" + # Setup metadata discovery + mock_discover.return_value = OAuthMetadata( + authorization_endpoint="https://auth.example.com/authorize", + token_endpoint="https://auth.example.com/token", + response_types_supported=["code"], + grant_types_supported=["authorization_code"], + ) + + mock_provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="existing-client") + + with pytest.raises(ValueError) as exc_info: + auth(mock_provider, mock_service, authorization_code="auth-code") + + assert "State parameter is required" in str(exc_info.value) + + @patch("core.mcp.auth.auth_flow.refresh_authorization") + def test_auth_refresh_token(self, mock_refresh, mock_provider, mock_service): + """Test auth flow for refreshing tokens.""" + # Setup existing client and tokens + mock_provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="existing-client") + mock_provider.retrieve_tokens.return_value = OAuthTokens( + access_token="old-token", + token_type="Bearer", + expires_in=0, + refresh_token="refresh-token", + ) + + # Setup refresh + new_tokens = OAuthTokens( + access_token="refreshed-token", + token_type="Bearer", + expires_in=3600, + refresh_token="new-refresh-token", + ) + mock_refresh.return_value = new_tokens + + with patch("core.mcp.auth.auth_flow.discover_oauth_metadata") as mock_discover: + mock_discover.return_value = None + + result = auth(mock_provider, mock_service) + + assert result == {"result": "success"} + + # Verify refresh was called + mock_refresh.assert_called_once() + mock_service.save_oauth_data.assert_called_with( + "provider-id", "tenant-id", new_tokens.model_dump(), "tokens" + ) + + @patch("core.mcp.auth.auth_flow.discover_oauth_metadata") + def test_auth_registration_fails_with_code(self, mock_discover, mock_provider, mock_service): + """Test auth fails when no client info exists but code is provided.""" + # Setup metadata discovery + mock_discover.return_value = OAuthMetadata( + authorization_endpoint="https://auth.example.com/authorize", + token_endpoint="https://auth.example.com/token", + response_types_supported=["code"], + grant_types_supported=["authorization_code"], + ) + + mock_provider.retrieve_client_information.return_value = None + + with pytest.raises(ValueError) as exc_info: + auth(mock_provider, mock_service, authorization_code="auth-code") + + assert "Existing OAuth client information is required" in str(exc_info.value) diff --git a/api/tests/unit_tests/core/mcp/test_auth_client.py b/api/tests/unit_tests/core/mcp/test_auth_client.py new file mode 100644 index 0000000000..58fa85f4f9 --- /dev/null +++ b/api/tests/unit_tests/core/mcp/test_auth_client.py @@ -0,0 +1,523 @@ +"""Unit tests for MCP auth client with retry logic.""" + +from types import TracebackType +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from core.entities.mcp_provider import MCPProviderEntity +from core.mcp.auth_client import MCPClientWithAuthRetry +from core.mcp.error import MCPAuthError +from core.mcp.types import CallToolResult, TextContent, Tool, ToolAnnotations + + +class TestMCPClientWithAuthRetry: + """Test suite for MCPClientWithAuthRetry.""" + + @pytest.fixture + def mock_provider_entity(self): + """Create a mock provider entity.""" + provider = Mock(spec=MCPProviderEntity) + provider.id = "test-provider-id" + provider.tenant_id = "test-tenant-id" + provider.retrieve_tokens.return_value = Mock( + access_token="test-token", token_type="Bearer", expires_in=3600, refresh_token=None + ) + return provider + + @pytest.fixture + def mock_mcp_service(self): + """Create a mock MCP service.""" + service = Mock() + service.get_provider_entity.return_value = Mock( + retrieve_tokens=lambda: Mock( + access_token="new-test-token", token_type="Bearer", expires_in=3600, refresh_token=None + ) + ) + return service + + @pytest.fixture + def auth_callback(self): + """Create a mock auth callback.""" + return Mock() + + def test_init(self, mock_provider_entity, mock_mcp_service, auth_callback): + """Test client initialization.""" + client = MCPClientWithAuthRetry( + server_url="http://test.example.com", + headers={"Authorization": "Bearer test"}, + timeout=30.0, + sse_read_timeout=60.0, + provider_entity=mock_provider_entity, + auth_callback=auth_callback, + authorization_code="test-auth-code", + by_server_id=True, + mcp_service=mock_mcp_service, + ) + + assert client.server_url == "http://test.example.com" + assert client.headers == {"Authorization": "Bearer test"} + assert client.timeout == 30.0 + assert client.sse_read_timeout == 60.0 + assert client.provider_entity == mock_provider_entity + assert client.auth_callback == auth_callback + assert client.authorization_code == "test-auth-code" + assert client.by_server_id is True + assert client.mcp_service == mock_mcp_service + assert client._has_retried is False + # In inheritance design, we don't have _client attribute + assert hasattr(client, "_session") # Inherited from MCPClient + + def test_inheritance_structure(self): + """Test that MCPClientWithAuthRetry properly inherits from MCPClient.""" + client = MCPClientWithAuthRetry( + server_url="http://test.example.com", + headers={"Authorization": "Bearer test"}, + ) + + # Verify inheritance + assert isinstance(client, MCPClient) + + # Verify inherited attributes are accessible + assert hasattr(client, "server_url") + assert hasattr(client, "headers") + assert hasattr(client, "_session") + assert hasattr(client, "_exit_stack") + assert hasattr(client, "_initialized") + + def test_handle_auth_error_no_retry_components(self): + """Test auth error handling when retry components are missing.""" + client = MCPClientWithAuthRetry(server_url="http://test.example.com") + error = MCPAuthError("Auth failed") + + with pytest.raises(MCPAuthError) as exc_info: + client._handle_auth_error(error) + + assert exc_info.value == error + + def test_handle_auth_error_already_retried(self, mock_provider_entity, mock_mcp_service, auth_callback): + """Test auth error handling when already retried.""" + client = MCPClientWithAuthRetry( + server_url="http://test.example.com", + provider_entity=mock_provider_entity, + auth_callback=auth_callback, + mcp_service=mock_mcp_service, + ) + client._has_retried = True + error = MCPAuthError("Auth failed") + + with pytest.raises(MCPAuthError) as exc_info: + client._handle_auth_error(error) + + assert exc_info.value == error + auth_callback.assert_not_called() + + def test_handle_auth_error_successful_refresh(self, mock_provider_entity, mock_mcp_service, auth_callback): + """Test successful auth refresh on error.""" + client = MCPClientWithAuthRetry( + server_url="http://test.example.com", + provider_entity=mock_provider_entity, + auth_callback=auth_callback, + authorization_code="test-code", + by_server_id=True, + mcp_service=mock_mcp_service, + ) + + # Configure mocks + new_provider = Mock(spec=MCPProviderEntity) + new_provider.id = "test-provider-id" + new_provider.tenant_id = "test-tenant-id" + new_provider.retrieve_tokens.return_value = Mock( + access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None + ) + mock_mcp_service.get_provider_entity.return_value = new_provider + + error = MCPAuthError("Auth failed") + client._handle_auth_error(error) + + # Verify auth flow + auth_callback.assert_called_once_with(mock_provider_entity, mock_mcp_service, "test-code") + mock_mcp_service.get_provider_entity.assert_called_once_with( + "test-provider-id", "test-tenant-id", by_server_id=True + ) + assert client.headers["Authorization"] == "Bearer new-token" + assert client.authorization_code is None # Should be cleared after use + assert client._has_retried is True + + def test_handle_auth_error_refresh_fails(self, mock_provider_entity, mock_mcp_service, auth_callback): + """Test auth refresh failure.""" + client = MCPClientWithAuthRetry( + server_url="http://test.example.com", + provider_entity=mock_provider_entity, + auth_callback=auth_callback, + mcp_service=mock_mcp_service, + ) + + auth_callback.side_effect = Exception("Auth callback failed") + + error = MCPAuthError("Original auth failed") + with pytest.raises(MCPAuthError) as exc_info: + client._handle_auth_error(error) + + assert "Authentication retry failed" in str(exc_info.value) + + def test_handle_auth_error_no_token_received(self, mock_provider_entity, mock_mcp_service, auth_callback): + """Test auth refresh when no token is received.""" + client = MCPClientWithAuthRetry( + server_url="http://test.example.com", + provider_entity=mock_provider_entity, + auth_callback=auth_callback, + mcp_service=mock_mcp_service, + ) + + # Configure mock to return no token + new_provider = Mock(spec=MCPProviderEntity) + new_provider.retrieve_tokens.return_value = None + mock_mcp_service.get_provider_entity.return_value = new_provider + + error = MCPAuthError("Auth failed") + with pytest.raises(MCPAuthError) as exc_info: + client._handle_auth_error(error) + + assert "no token received" in str(exc_info.value) + + def test_execute_with_retry_success(self): + """Test successful execution without retry.""" + client = MCPClientWithAuthRetry(server_url="http://test.example.com") + + mock_func = Mock(return_value="success") + result = client._execute_with_retry(mock_func, "arg1", kwarg1="value1") + + assert result == "success" + mock_func.assert_called_once_with("arg1", kwarg1="value1") + assert client._has_retried is False + + @patch("core.mcp.auth_client.MCPClient") + def test_execute_with_retry_auth_error_then_success( + self, mock_mcp_client_class, mock_provider_entity, mock_mcp_service, auth_callback + ): + """Test execution with auth error followed by successful retry.""" + client = MCPClientWithAuthRetry( + server_url="http://test.example.com", + provider_entity=mock_provider_entity, + auth_callback=auth_callback, + mcp_service=mock_mcp_service, + ) + + # Configure mock clients (old and new) + mock_client_old = MagicMock() + mock_client_new = MagicMock() + client._client = mock_client_old + + # Make _create_client return the new client on retry + mock_mcp_client_class.return_value = mock_client_new + + # Configure new provider with token + new_provider = Mock(spec=MCPProviderEntity) + new_provider.retrieve_tokens.return_value = Mock( + access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None + ) + mock_mcp_service.get_provider_entity.return_value = new_provider + + # Mock function that fails first, then succeeds + mock_func = Mock(side_effect=[MCPAuthError("Auth failed"), "success"]) + + result = client._execute_with_retry(mock_func, "arg1", kwarg1="value1") + + assert result == "success" + assert mock_func.call_count == 2 + mock_func.assert_called_with("arg1", kwarg1="value1") + auth_callback.assert_called_once() + mock_client_old.cleanup.assert_called_once() + mock_client_new.__enter__.assert_called_once() + assert client._has_retried is False # Reset after completion + + def test_execute_with_retry_non_auth_error(self): + """Test execution with non-auth error (no retry).""" + client = MCPClientWithAuthRetry(server_url="http://test.example.com") + + mock_func = Mock(side_effect=ValueError("Some other error")) + + with pytest.raises(ValueError) as exc_info: + client._execute_with_retry(mock_func) + + assert str(exc_info.value) == "Some other error" + mock_func.assert_called_once() + + @patch("core.mcp.auth_client.MCPClient") + def test_context_manager_enter(self, mock_mcp_client_class): + """Test context manager enter.""" + mock_client_instance = MagicMock() + mock_client_instance.__enter__.return_value = mock_client_instance + mock_mcp_client_class.return_value = mock_client_instance + + client = MCPClientWithAuthRetry(server_url="http://test.example.com") + + result = client.__enter__() + + assert result == client + assert client._client == mock_client_instance + mock_client_instance.__enter__.assert_called_once() + + @patch("core.mcp.auth_client.MCPClient") + def test_context_manager_enter_with_auth_error( + self, mock_mcp_client_class, mock_provider_entity, mock_mcp_service, auth_callback + ): + """Test context manager enter with auth error and retry.""" + mock_client_instance = MagicMock() + mock_mcp_client_class.return_value = mock_client_instance + + # Configure new provider with token + new_provider = Mock(spec=MCPProviderEntity) + new_provider.retrieve_tokens.return_value = Mock( + access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None + ) + mock_mcp_service.get_provider_entity.return_value = new_provider + + client = MCPClientWithAuthRetry( + server_url="http://test.example.com", + provider_entity=mock_provider_entity, + auth_callback=auth_callback, + mcp_service=mock_mcp_service, + ) + + # First call to client.__enter__ raises auth error, second succeeds + call_count = 0 + + def enter_side_effect(): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise MCPAuthError("Auth failed") + return mock_client_instance + + mock_client_instance.__enter__.side_effect = enter_side_effect + + result = client.__enter__() + + assert result == client + assert mock_client_instance.__enter__.call_count == 3 + auth_callback.assert_called_once() + + def test_context_manager_exit(self): + """Test context manager exit.""" + client = MCPClientWithAuthRetry(server_url="http://test.example.com") + mock_client = MagicMock() + client._client = mock_client + + exc_type: type[BaseException] | None = None + exc_val: BaseException | None = None + exc_tb: TracebackType | None = None + client.__exit__(exc_type, exc_val, exc_tb) + + mock_client.__exit__.assert_called_once_with(None, None, None) + assert client._client is None + + def test_list_tools_not_initialized(self): + """Test list_tools when client not initialized.""" + client = MCPClientWithAuthRetry(server_url="http://test.example.com") + + with pytest.raises(ValueError) as exc_info: + client.list_tools() + + assert "Client not initialized" in str(exc_info.value) + + def test_list_tools_success(self): + """Test successful list_tools call.""" + client = MCPClientWithAuthRetry(server_url="http://test.example.com") + mock_client = Mock() + client._client = mock_client + + expected_tools = [ + Tool( + name="test-tool", + description="A test tool", + inputSchema={"type": "object", "properties": {}}, + annotations=ToolAnnotations(title="Test Tool"), + ) + ] + mock_client.list_tools.return_value = expected_tools + + result = client.list_tools() + + assert result == expected_tools + mock_client.list_tools.assert_called_once() + + @patch("core.mcp.auth_client.MCPClient") + def test_list_tools_with_auth_retry( + self, mock_mcp_client_class, mock_provider_entity, mock_mcp_service, auth_callback + ): + """Test list_tools with auth retry.""" + client = MCPClientWithAuthRetry( + server_url="http://test.example.com", + provider_entity=mock_provider_entity, + auth_callback=auth_callback, + mcp_service=mock_mcp_service, + ) + + # Configure mock clients (old and new) + mock_client_old = MagicMock() + mock_client_new = MagicMock() + client._client = mock_client_old + + # Make _create_client return the new client on retry + mock_mcp_client_class.return_value = mock_client_new + + # Configure new provider with token + new_provider = Mock(spec=MCPProviderEntity) + new_provider.retrieve_tokens.return_value = Mock( + access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None + ) + mock_mcp_service.get_provider_entity.return_value = new_provider + + expected_tools = [Tool(name="test-tool", description="A test tool", inputSchema={})] + # First call raises auth error + mock_client_old.list_tools.side_effect = MCPAuthError("Auth failed") + mock_client_new.list_tools.return_value = expected_tools + + # We need to mock the behavior where after client is replaced, + # the new method should be called. But since the method reference + # is already bound to the old client, we need to work around this. + # Let's patch the _execute_with_retry to handle this properly. + + with patch.object(client, "_execute_with_retry") as mock_execute: + # Simulate the retry behavior + call_count = [0] + + def execute_side_effect(func, *args, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + # First call - simulate auth error and retry + try: + func(*args, **kwargs) + except MCPAuthError: + # Simulate the retry logic + client._handle_auth_error(MCPAuthError("Auth failed")) + if client._client: + client._client.cleanup() + client._client = mock_client_new + client._client.__enter__() + # Now return the result from the new client + return mock_client_new.list_tools(*args, **kwargs) + return func(*args, **kwargs) + + mock_execute.side_effect = execute_side_effect + result = client.list_tools() + + assert result == expected_tools + mock_client_old.list_tools.assert_called_once() + mock_client_new.list_tools.assert_called_once() + auth_callback.assert_called_once() + mock_client_old.cleanup.assert_called_once() + mock_client_new.__enter__.assert_called_once() + + def test_invoke_tool_not_initialized(self): + """Test invoke_tool when client not initialized.""" + client = MCPClientWithAuthRetry(server_url="http://test.example.com") + + with pytest.raises(ValueError) as exc_info: + client.invoke_tool("test-tool", {"arg": "value"}) + + assert "Client not initialized" in str(exc_info.value) + + def test_invoke_tool_success(self): + """Test successful invoke_tool call.""" + client = MCPClientWithAuthRetry(server_url="http://test.example.com") + mock_client = Mock() + client._client = mock_client + + expected_result = CallToolResult( + content=[TextContent(type="text", text="Tool executed successfully")], isError=False + ) + mock_client.invoke_tool.return_value = expected_result + + result = client.invoke_tool("test-tool", {"arg": "value"}) + + assert result == expected_result + mock_client.invoke_tool.assert_called_once_with("test-tool", {"arg": "value"}) + + @patch("core.mcp.auth_client.MCPClient") + def test_invoke_tool_with_auth_retry( + self, mock_mcp_client_class, mock_provider_entity, mock_mcp_service, auth_callback + ): + """Test invoke_tool with auth retry.""" + client = MCPClientWithAuthRetry( + server_url="http://test.example.com", + provider_entity=mock_provider_entity, + auth_callback=auth_callback, + mcp_service=mock_mcp_service, + ) + + # Configure mock clients (old and new) + mock_client_old = MagicMock() + mock_client_new = MagicMock() + client._client = mock_client_old + + # Make _create_client return the new client on retry + mock_mcp_client_class.return_value = mock_client_new + + # Configure new provider with token + new_provider = Mock(spec=MCPProviderEntity) + new_provider.retrieve_tokens.return_value = Mock( + access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None + ) + mock_mcp_service.get_provider_entity.return_value = new_provider + + expected_result = CallToolResult(content=[TextContent(type="text", text="Success")], isError=False) + # First call raises auth error + mock_client_old.invoke_tool.side_effect = MCPAuthError("Auth failed") + mock_client_new.invoke_tool.return_value = expected_result + + # We need to mock the behavior where after client is replaced, + # the new method should be called. Similar to list_tools test. + + with patch.object(client, "_execute_with_retry") as mock_execute: + # Simulate the retry behavior + call_count = [0] + + def execute_side_effect(func, *args, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + # First call - simulate auth error and retry + try: + func(*args, **kwargs) + except MCPAuthError: + # Simulate the retry logic + client._handle_auth_error(MCPAuthError("Auth failed")) + if client._client: + client._client.cleanup() + client._client = mock_client_new + client._client.__enter__() + # Now return the result from the new client + return mock_client_new.invoke_tool(*args, **kwargs) + return func(*args, **kwargs) + + mock_execute.side_effect = execute_side_effect + result = client.invoke_tool("test-tool", {"arg": "value"}) + + assert result == expected_result + mock_client_old.invoke_tool.assert_called_once_with("test-tool", {"arg": "value"}) + mock_client_new.invoke_tool.assert_called_once_with("test-tool", {"arg": "value"}) + auth_callback.assert_called_once() + mock_client_old.cleanup.assert_called_once() + mock_client_new.__enter__.assert_called_once() + + def test_cleanup(self): + """Test cleanup method.""" + client = MCPClientWithAuthRetry(server_url="http://test.example.com") + mock_client = Mock() + client._client = mock_client + + client.cleanup() + + mock_client.cleanup.assert_called_once() + assert client._client is None + + def test_cleanup_no_client(self): + """Test cleanup when no client exists.""" + client = MCPClientWithAuthRetry(server_url="http://test.example.com") + + # Should not raise + client.cleanup() + + assert client._client is None diff --git a/api/tests/unit_tests/core/mcp/test_auth_client_inheritance.py b/api/tests/unit_tests/core/mcp/test_auth_client_inheritance.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/mcp/test_entities.py b/api/tests/unit_tests/core/mcp/test_entities.py new file mode 100644 index 0000000000..3fede55916 --- /dev/null +++ b/api/tests/unit_tests/core/mcp/test_entities.py @@ -0,0 +1,239 @@ +"""Unit tests for MCP entities module.""" + +from unittest.mock import Mock + +from core.mcp.entities import ( + SUPPORTED_PROTOCOL_VERSIONS, + LifespanContextT, + RequestContext, + SessionT, +) +from core.mcp.session.base_session import BaseSession +from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestParams + + +class TestProtocolVersions: + """Test protocol version constants.""" + + def test_supported_protocol_versions(self): + """Test supported protocol versions list.""" + assert isinstance(SUPPORTED_PROTOCOL_VERSIONS, list) + assert len(SUPPORTED_PROTOCOL_VERSIONS) >= 3 + assert "2024-11-05" in SUPPORTED_PROTOCOL_VERSIONS + assert "2025-03-26" in SUPPORTED_PROTOCOL_VERSIONS + assert LATEST_PROTOCOL_VERSION in SUPPORTED_PROTOCOL_VERSIONS + + def test_latest_protocol_version_is_supported(self): + """Test that latest protocol version is in supported versions.""" + assert LATEST_PROTOCOL_VERSION in SUPPORTED_PROTOCOL_VERSIONS + + +class TestRequestContext: + """Test RequestContext dataclass.""" + + def test_request_context_creation(self): + """Test creating a RequestContext instance.""" + mock_session = Mock(spec=BaseSession) + mock_lifespan = {"key": "value"} + mock_meta = RequestParams.Meta(progressToken="test-token") + + context = RequestContext( + request_id="test-request-123", + meta=mock_meta, + session=mock_session, + lifespan_context=mock_lifespan, + ) + + assert context.request_id == "test-request-123" + assert context.meta == mock_meta + assert context.session == mock_session + assert context.lifespan_context == mock_lifespan + + def test_request_context_with_none_meta(self): + """Test creating RequestContext with None meta.""" + mock_session = Mock(spec=BaseSession) + + context = RequestContext( + request_id=42, # Can be int or string + meta=None, + session=mock_session, + lifespan_context=None, + ) + + assert context.request_id == 42 + assert context.meta is None + assert context.session == mock_session + assert context.lifespan_context is None + + def test_request_context_attributes(self): + """Test RequestContext attributes are accessible.""" + mock_session = Mock(spec=BaseSession) + + context = RequestContext( + request_id="test-123", + meta=None, + session=mock_session, + lifespan_context=None, + ) + + # Verify attributes are accessible + assert hasattr(context, "request_id") + assert hasattr(context, "meta") + assert hasattr(context, "session") + assert hasattr(context, "lifespan_context") + + # Verify values + assert context.request_id == "test-123" + assert context.meta is None + assert context.session == mock_session + assert context.lifespan_context is None + + def test_request_context_generic_typing(self): + """Test RequestContext with different generic types.""" + # Create a mock session with specific type + mock_session = Mock(spec=BaseSession) + + # Create context with string lifespan context + context_str = RequestContext[BaseSession, str]( + request_id="test-1", + meta=None, + session=mock_session, + lifespan_context="string-context", + ) + assert isinstance(context_str.lifespan_context, str) + + # Create context with dict lifespan context + context_dict = RequestContext[BaseSession, dict]( + request_id="test-2", + meta=None, + session=mock_session, + lifespan_context={"key": "value"}, + ) + assert isinstance(context_dict.lifespan_context, dict) + + # Create context with custom object lifespan context + class CustomLifespan: + def __init__(self, data): + self.data = data + + custom_lifespan = CustomLifespan("test-data") + context_custom = RequestContext[BaseSession, CustomLifespan]( + request_id="test-3", + meta=None, + session=mock_session, + lifespan_context=custom_lifespan, + ) + assert isinstance(context_custom.lifespan_context, CustomLifespan) + assert context_custom.lifespan_context.data == "test-data" + + def test_request_context_with_progress_meta(self): + """Test RequestContext with progress metadata.""" + mock_session = Mock(spec=BaseSession) + progress_meta = RequestParams.Meta(progressToken="progress-123") + + context = RequestContext( + request_id="req-456", + meta=progress_meta, + session=mock_session, + lifespan_context=None, + ) + + assert context.meta is not None + assert context.meta.progressToken == "progress-123" + + def test_request_context_equality(self): + """Test RequestContext equality comparison.""" + mock_session1 = Mock(spec=BaseSession) + mock_session2 = Mock(spec=BaseSession) + + context1 = RequestContext( + request_id="test-123", + meta=None, + session=mock_session1, + lifespan_context="context", + ) + + context2 = RequestContext( + request_id="test-123", + meta=None, + session=mock_session1, + lifespan_context="context", + ) + + context3 = RequestContext( + request_id="test-456", + meta=None, + session=mock_session1, + lifespan_context="context", + ) + + # Same values should be equal + assert context1 == context2 + + # Different request_id should not be equal + assert context1 != context3 + + # Different session should not be equal + context4 = RequestContext( + request_id="test-123", + meta=None, + session=mock_session2, + lifespan_context="context", + ) + assert context1 != context4 + + def test_request_context_repr(self): + """Test RequestContext string representation.""" + mock_session = Mock(spec=BaseSession) + mock_session.__repr__ = Mock(return_value="") + + context = RequestContext( + request_id="test-123", + meta=None, + session=mock_session, + lifespan_context={"data": "test"}, + ) + + repr_str = repr(context) + assert "RequestContext" in repr_str + assert "test-123" in repr_str + assert "MockSession" in repr_str + + +class TestTypeVariables: + """Test type variables defined in the module.""" + + def test_session_type_var(self): + """Test SessionT type variable.""" + + # Create a custom session class + class CustomSession(BaseSession): + pass + + # Use in generic context + def process_session(session: SessionT) -> SessionT: + return session + + mock_session = Mock(spec=CustomSession) + result = process_session(mock_session) + assert result == mock_session + + def test_lifespan_context_type_var(self): + """Test LifespanContextT type variable.""" + + # Use in generic context + def process_lifespan(context: LifespanContextT) -> LifespanContextT: + return context + + # Test with different types + str_context = "string-context" + assert process_lifespan(str_context) == str_context + + dict_context = {"key": "value"} + assert process_lifespan(dict_context) == dict_context + + class CustomContext: + pass + + custom_context = CustomContext() + assert process_lifespan(custom_context) == custom_context diff --git a/api/tests/unit_tests/core/mcp/test_error.py b/api/tests/unit_tests/core/mcp/test_error.py new file mode 100644 index 0000000000..3a95fae673 --- /dev/null +++ b/api/tests/unit_tests/core/mcp/test_error.py @@ -0,0 +1,205 @@ +"""Unit tests for MCP error classes.""" + +import pytest + +from core.mcp.error import MCPAuthError, MCPConnectionError, MCPError + + +class TestMCPError: + """Test MCPError base exception class.""" + + def test_mcp_error_creation(self): + """Test creating MCPError instance.""" + error = MCPError("Test error message") + assert str(error) == "Test error message" + assert isinstance(error, Exception) + + def test_mcp_error_inheritance(self): + """Test MCPError inherits from Exception.""" + error = MCPError() + assert isinstance(error, Exception) + assert type(error).__name__ == "MCPError" + + def test_mcp_error_with_empty_message(self): + """Test MCPError with empty message.""" + error = MCPError() + assert str(error) == "" + + def test_mcp_error_raise(self): + """Test raising MCPError.""" + with pytest.raises(MCPError) as exc_info: + raise MCPError("Something went wrong") + + assert str(exc_info.value) == "Something went wrong" + + +class TestMCPConnectionError: + """Test MCPConnectionError exception class.""" + + def test_mcp_connection_error_creation(self): + """Test creating MCPConnectionError instance.""" + error = MCPConnectionError("Connection failed") + assert str(error) == "Connection failed" + assert isinstance(error, MCPError) + assert isinstance(error, Exception) + + def test_mcp_connection_error_inheritance(self): + """Test MCPConnectionError inheritance chain.""" + error = MCPConnectionError() + assert isinstance(error, MCPConnectionError) + assert isinstance(error, MCPError) + assert isinstance(error, Exception) + + def test_mcp_connection_error_raise(self): + """Test raising MCPConnectionError.""" + with pytest.raises(MCPConnectionError) as exc_info: + raise MCPConnectionError("Unable to connect to server") + + assert str(exc_info.value) == "Unable to connect to server" + + def test_mcp_connection_error_catch_as_mcp_error(self): + """Test catching MCPConnectionError as MCPError.""" + with pytest.raises(MCPError) as exc_info: + raise MCPConnectionError("Connection issue") + + assert isinstance(exc_info.value, MCPConnectionError) + assert str(exc_info.value) == "Connection issue" + + +class TestMCPAuthError: + """Test MCPAuthError exception class.""" + + def test_mcp_auth_error_creation(self): + """Test creating MCPAuthError instance.""" + error = MCPAuthError("Authentication failed") + assert str(error) == "Authentication failed" + assert isinstance(error, MCPConnectionError) + assert isinstance(error, MCPError) + assert isinstance(error, Exception) + + def test_mcp_auth_error_inheritance(self): + """Test MCPAuthError inheritance chain.""" + error = MCPAuthError() + assert isinstance(error, MCPAuthError) + assert isinstance(error, MCPConnectionError) + assert isinstance(error, MCPError) + assert isinstance(error, Exception) + + def test_mcp_auth_error_raise(self): + """Test raising MCPAuthError.""" + with pytest.raises(MCPAuthError) as exc_info: + raise MCPAuthError("Invalid credentials") + + assert str(exc_info.value) == "Invalid credentials" + + def test_mcp_auth_error_catch_hierarchy(self): + """Test catching MCPAuthError at different levels.""" + # Catch as MCPAuthError + with pytest.raises(MCPAuthError) as exc_info: + raise MCPAuthError("Auth specific error") + assert str(exc_info.value) == "Auth specific error" + + # Catch as MCPConnectionError + with pytest.raises(MCPConnectionError) as exc_info: + raise MCPAuthError("Auth connection error") + assert isinstance(exc_info.value, MCPAuthError) + assert str(exc_info.value) == "Auth connection error" + + # Catch as MCPError + with pytest.raises(MCPError) as exc_info: + raise MCPAuthError("Auth base error") + assert isinstance(exc_info.value, MCPAuthError) + assert str(exc_info.value) == "Auth base error" + + +class TestErrorHierarchy: + """Test the complete error hierarchy.""" + + def test_exception_hierarchy(self): + """Test the complete exception hierarchy.""" + # Create instances + base_error = MCPError("base") + connection_error = MCPConnectionError("connection") + auth_error = MCPAuthError("auth") + + # Test type relationships + assert not isinstance(base_error, MCPConnectionError) + assert not isinstance(base_error, MCPAuthError) + + assert isinstance(connection_error, MCPError) + assert not isinstance(connection_error, MCPAuthError) + + assert isinstance(auth_error, MCPError) + assert isinstance(auth_error, MCPConnectionError) + + def test_error_handling_patterns(self): + """Test common error handling patterns.""" + + def raise_auth_error(): + raise MCPAuthError("401 Unauthorized") + + def raise_connection_error(): + raise MCPConnectionError("Connection timeout") + + def raise_base_error(): + raise MCPError("Generic error") + + # Pattern 1: Catch specific errors first + errors_caught = [] + + for error_func in [raise_auth_error, raise_connection_error, raise_base_error]: + try: + error_func() + except MCPAuthError: + errors_caught.append("auth") + except MCPConnectionError: + errors_caught.append("connection") + except MCPError: + errors_caught.append("base") + + assert errors_caught == ["auth", "connection", "base"] + + # Pattern 2: Catch all as base error + for error_func in [raise_auth_error, raise_connection_error, raise_base_error]: + with pytest.raises(MCPError) as exc_info: + error_func() + assert isinstance(exc_info.value, MCPError) + + def test_error_with_cause(self): + """Test errors with cause (chained exceptions).""" + original_error = ValueError("Original error") + + def raise_chained_error(): + try: + raise original_error + except ValueError as e: + raise MCPConnectionError("Connection failed") from e + + with pytest.raises(MCPConnectionError) as exc_info: + raise_chained_error() + + assert str(exc_info.value) == "Connection failed" + assert exc_info.value.__cause__ == original_error + + def test_error_comparison(self): + """Test error instance comparison.""" + error1 = MCPError("Test message") + error2 = MCPError("Test message") + error3 = MCPError("Different message") + + # Errors are not equal even with same message (different instances) + assert error1 != error2 + assert error1 != error3 + + # But they have the same type + assert type(error1) == type(error2) == type(error3) + + def test_error_representation(self): + """Test error string representation.""" + base_error = MCPError("Base error message") + connection_error = MCPConnectionError("Connection error message") + auth_error = MCPAuthError("Auth error message") + + assert repr(base_error) == "MCPError('Base error message')" + assert repr(connection_error) == "MCPConnectionError('Connection error message')" + assert repr(auth_error) == "MCPAuthError('Auth error message')" diff --git a/api/tests/unit_tests/core/mcp/test_mcp_client.py b/api/tests/unit_tests/core/mcp/test_mcp_client.py new file mode 100644 index 0000000000..c0420d3371 --- /dev/null +++ b/api/tests/unit_tests/core/mcp/test_mcp_client.py @@ -0,0 +1,382 @@ +"""Unit tests for MCP client.""" + +from contextlib import ExitStack +from types import TracebackType +from unittest.mock import Mock, patch + +import pytest + +from core.mcp.error import MCPConnectionError +from core.mcp.mcp_client import MCPClient +from core.mcp.types import CallToolResult, ListToolsResult, TextContent, Tool, ToolAnnotations + + +class TestMCPClient: + """Test suite for MCPClient.""" + + def test_init(self): + """Test client initialization.""" + client = MCPClient( + server_url="http://test.example.com/mcp", + headers={"Authorization": "Bearer test"}, + timeout=30.0, + sse_read_timeout=60.0, + ) + + assert client.server_url == "http://test.example.com/mcp" + assert client.headers == {"Authorization": "Bearer test"} + assert client.timeout == 30.0 + assert client.sse_read_timeout == 60.0 + assert client._session is None + assert isinstance(client._exit_stack, ExitStack) + assert client._initialized is False + + def test_init_defaults(self): + """Test client initialization with defaults.""" + client = MCPClient(server_url="http://test.example.com") + + assert client.server_url == "http://test.example.com" + assert client.headers == {} + assert client.timeout is None + assert client.sse_read_timeout is None + + @patch("core.mcp.mcp_client.streamablehttp_client") + @patch("core.mcp.mcp_client.ClientSession") + def test_initialize_with_mcp_url(self, mock_client_session, mock_streamable_client): + """Test initialization with MCP URL.""" + # Setup mocks + mock_read_stream = Mock() + mock_write_stream = Mock() + mock_client_context = Mock() + mock_streamable_client.return_value.__enter__.return_value = ( + mock_read_stream, + mock_write_stream, + mock_client_context, + ) + + mock_session = Mock() + mock_client_session.return_value.__enter__.return_value = mock_session + + client = MCPClient(server_url="http://test.example.com/mcp") + client._initialize() + + # Verify streamable client was called + mock_streamable_client.assert_called_once_with( + url="http://test.example.com/mcp", + headers={}, + timeout=None, + sse_read_timeout=None, + ) + + # Verify session was created + mock_client_session.assert_called_once_with(mock_read_stream, mock_write_stream) + mock_session.initialize.assert_called_once() + assert client._session == mock_session + + @patch("core.mcp.mcp_client.sse_client") + @patch("core.mcp.mcp_client.ClientSession") + def test_initialize_with_sse_url(self, mock_client_session, mock_sse_client): + """Test initialization with SSE URL.""" + # Setup mocks + mock_read_stream = Mock() + mock_write_stream = Mock() + mock_sse_client.return_value.__enter__.return_value = (mock_read_stream, mock_write_stream) + + mock_session = Mock() + mock_client_session.return_value.__enter__.return_value = mock_session + + client = MCPClient(server_url="http://test.example.com/sse") + client._initialize() + + # Verify SSE client was called + mock_sse_client.assert_called_once_with( + url="http://test.example.com/sse", + headers={}, + timeout=None, + sse_read_timeout=None, + ) + + # Verify session was created + mock_client_session.assert_called_once_with(mock_read_stream, mock_write_stream) + mock_session.initialize.assert_called_once() + assert client._session == mock_session + + @patch("core.mcp.mcp_client.sse_client") + @patch("core.mcp.mcp_client.streamablehttp_client") + @patch("core.mcp.mcp_client.ClientSession") + def test_initialize_with_unknown_method_fallback_to_sse( + self, mock_client_session, mock_streamable_client, mock_sse_client + ): + """Test initialization with unknown method falls back to SSE.""" + # Setup mocks + mock_read_stream = Mock() + mock_write_stream = Mock() + mock_sse_client.return_value.__enter__.return_value = (mock_read_stream, mock_write_stream) + + mock_session = Mock() + mock_client_session.return_value.__enter__.return_value = mock_session + + client = MCPClient(server_url="http://test.example.com/unknown") + client._initialize() + + # Verify SSE client was tried + mock_sse_client.assert_called_once() + mock_streamable_client.assert_not_called() + + # Verify session was created + assert client._session == mock_session + + @patch("core.mcp.mcp_client.sse_client") + @patch("core.mcp.mcp_client.streamablehttp_client") + @patch("core.mcp.mcp_client.ClientSession") + def test_initialize_fallback_from_sse_to_mcp(self, mock_client_session, mock_streamable_client, mock_sse_client): + """Test initialization falls back from SSE to MCP on connection error.""" + # Setup SSE to fail + mock_sse_client.side_effect = MCPConnectionError("SSE connection failed") + + # Setup MCP to succeed + mock_read_stream = Mock() + mock_write_stream = Mock() + mock_client_context = Mock() + mock_streamable_client.return_value.__enter__.return_value = ( + mock_read_stream, + mock_write_stream, + mock_client_context, + ) + + mock_session = Mock() + mock_client_session.return_value.__enter__.return_value = mock_session + + client = MCPClient(server_url="http://test.example.com/unknown") + client._initialize() + + # Verify both were tried + mock_sse_client.assert_called_once() + mock_streamable_client.assert_called_once() + + # Verify session was created with MCP + assert client._session == mock_session + + @patch("core.mcp.mcp_client.streamablehttp_client") + @patch("core.mcp.mcp_client.ClientSession") + def test_connect_server_mcp(self, mock_client_session, mock_streamable_client): + """Test connect_server with MCP method.""" + # Setup mocks + mock_read_stream = Mock() + mock_write_stream = Mock() + mock_client_context = Mock() + mock_streamable_client.return_value.__enter__.return_value = ( + mock_read_stream, + mock_write_stream, + mock_client_context, + ) + + mock_session = Mock() + mock_client_session.return_value.__enter__.return_value = mock_session + + client = MCPClient(server_url="http://test.example.com") + client.connect_server(mock_streamable_client, "mcp") + + # Verify correct streams were passed + mock_client_session.assert_called_once_with(mock_read_stream, mock_write_stream) + mock_session.initialize.assert_called_once() + + @patch("core.mcp.mcp_client.sse_client") + @patch("core.mcp.mcp_client.ClientSession") + def test_connect_server_sse(self, mock_client_session, mock_sse_client): + """Test connect_server with SSE method.""" + # Setup mocks + mock_read_stream = Mock() + mock_write_stream = Mock() + mock_sse_client.return_value.__enter__.return_value = (mock_read_stream, mock_write_stream) + + mock_session = Mock() + mock_client_session.return_value.__enter__.return_value = mock_session + + client = MCPClient(server_url="http://test.example.com") + client.connect_server(mock_sse_client, "sse") + + # Verify correct streams were passed + mock_client_session.assert_called_once_with(mock_read_stream, mock_write_stream) + mock_session.initialize.assert_called_once() + + def test_context_manager_enter(self): + """Test context manager enter.""" + client = MCPClient(server_url="http://test.example.com") + + with patch.object(client, "_initialize") as mock_initialize: + result = client.__enter__() + + assert result == client + assert client._initialized is True + mock_initialize.assert_called_once() + + def test_context_manager_exit(self): + """Test context manager exit.""" + client = MCPClient(server_url="http://test.example.com") + + with patch.object(client, "cleanup") as mock_cleanup: + exc_type: type[BaseException] | None = None + exc_val: BaseException | None = None + exc_tb: TracebackType | None = None + client.__exit__(exc_type, exc_val, exc_tb) + + mock_cleanup.assert_called_once() + + def test_list_tools_not_initialized(self): + """Test list_tools when session not initialized.""" + client = MCPClient(server_url="http://test.example.com") + + with pytest.raises(ValueError) as exc_info: + client.list_tools() + + assert "Session not initialized" in str(exc_info.value) + + def test_list_tools_success(self): + """Test successful list_tools call.""" + client = MCPClient(server_url="http://test.example.com") + + # Setup mock session + mock_session = Mock() + expected_tools = [ + Tool( + name="test-tool", + description="A test tool", + inputSchema={"type": "object", "properties": {}}, + annotations=ToolAnnotations(title="Test Tool"), + ) + ] + mock_session.list_tools.return_value = ListToolsResult(tools=expected_tools) + client._session = mock_session + + result = client.list_tools() + + assert result == expected_tools + mock_session.list_tools.assert_called_once() + + def test_invoke_tool_not_initialized(self): + """Test invoke_tool when session not initialized.""" + client = MCPClient(server_url="http://test.example.com") + + with pytest.raises(ValueError) as exc_info: + client.invoke_tool("test-tool", {"arg": "value"}) + + assert "Session not initialized" in str(exc_info.value) + + def test_invoke_tool_success(self): + """Test successful invoke_tool call.""" + client = MCPClient(server_url="http://test.example.com") + + # Setup mock session + mock_session = Mock() + expected_result = CallToolResult( + content=[TextContent(type="text", text="Tool executed successfully")], + isError=False, + ) + mock_session.call_tool.return_value = expected_result + client._session = mock_session + + result = client.invoke_tool("test-tool", {"arg": "value"}) + + assert result == expected_result + mock_session.call_tool.assert_called_once_with("test-tool", {"arg": "value"}) + + def test_cleanup(self): + """Test cleanup method.""" + client = MCPClient(server_url="http://test.example.com") + mock_exit_stack = Mock(spec=ExitStack) + client._exit_stack = mock_exit_stack + client._session = Mock() + client._initialized = True + + client.cleanup() + + mock_exit_stack.close.assert_called_once() + assert client._session is None + assert client._initialized is False + + def test_cleanup_with_error(self): + """Test cleanup method with error.""" + client = MCPClient(server_url="http://test.example.com") + mock_exit_stack = Mock(spec=ExitStack) + mock_exit_stack.close.side_effect = Exception("Cleanup error") + client._exit_stack = mock_exit_stack + client._session = Mock() + client._initialized = True + + with pytest.raises(ValueError) as exc_info: + client.cleanup() + + assert "Error during cleanup: Cleanup error" in str(exc_info.value) + assert client._session is None + assert client._initialized is False + + @patch("core.mcp.mcp_client.streamablehttp_client") + @patch("core.mcp.mcp_client.ClientSession") + def test_full_context_manager_flow(self, mock_client_session, mock_streamable_client): + """Test full context manager flow.""" + # Setup mocks + mock_read_stream = Mock() + mock_write_stream = Mock() + mock_client_context = Mock() + mock_streamable_client.return_value.__enter__.return_value = ( + mock_read_stream, + mock_write_stream, + mock_client_context, + ) + + mock_session = Mock() + mock_client_session.return_value.__enter__.return_value = mock_session + + expected_tools = [Tool(name="test-tool", description="Test", inputSchema={})] + mock_session.list_tools.return_value = ListToolsResult(tools=expected_tools) + + with MCPClient(server_url="http://test.example.com/mcp") as client: + assert client._initialized is True + assert client._session == mock_session + + # Test tool operations + tools = client.list_tools() + assert tools == expected_tools + + # After exit, should be cleaned up + assert client._initialized is False + assert client._session is None + + def test_headers_passed_to_clients(self): + """Test that headers are properly passed to underlying clients.""" + custom_headers = { + "Authorization": "Bearer test-token", + "X-Custom-Header": "test-value", + } + + with patch("core.mcp.mcp_client.streamablehttp_client") as mock_streamable_client: + with patch("core.mcp.mcp_client.ClientSession") as mock_client_session: + # Setup mocks + mock_read_stream = Mock() + mock_write_stream = Mock() + mock_client_context = Mock() + mock_streamable_client.return_value.__enter__.return_value = ( + mock_read_stream, + mock_write_stream, + mock_client_context, + ) + + mock_session = Mock() + mock_client_session.return_value.__enter__.return_value = mock_session + + client = MCPClient( + server_url="http://test.example.com/mcp", + headers=custom_headers, + timeout=30.0, + sse_read_timeout=60.0, + ) + client._initialize() + + # Verify headers were passed + mock_streamable_client.assert_called_once_with( + url="http://test.example.com/mcp", + headers=custom_headers, + timeout=30.0, + sse_read_timeout=60.0, + ) diff --git a/api/tests/unit_tests/core/mcp/test_types.py b/api/tests/unit_tests/core/mcp/test_types.py new file mode 100644 index 0000000000..6d8130bd13 --- /dev/null +++ b/api/tests/unit_tests/core/mcp/test_types.py @@ -0,0 +1,492 @@ +"""Unit tests for MCP types module.""" + +import pytest +from pydantic import ValidationError + +from core.mcp.types import ( + INTERNAL_ERROR, + INVALID_PARAMS, + INVALID_REQUEST, + LATEST_PROTOCOL_VERSION, + METHOD_NOT_FOUND, + PARSE_ERROR, + SERVER_LATEST_PROTOCOL_VERSION, + Annotations, + CallToolRequest, + CallToolRequestParams, + CallToolResult, + ClientCapabilities, + CompleteRequest, + CompleteRequestParams, + CompleteResult, + Completion, + CompletionArgument, + CompletionContext, + ErrorData, + ImageContent, + Implementation, + InitializeRequest, + InitializeRequestParams, + InitializeResult, + JSONRPCError, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + ListToolsRequest, + ListToolsResult, + OAuthClientInformation, + OAuthClientMetadata, + OAuthMetadata, + OAuthTokens, + PingRequest, + ProgressNotification, + ProgressNotificationParams, + PromptReference, + RequestParams, + ResourceTemplateReference, + Result, + ServerCapabilities, + TextContent, + Tool, + ToolAnnotations, +) + + +class TestConstants: + """Test module constants.""" + + def test_protocol_versions(self): + """Test protocol version constants.""" + assert LATEST_PROTOCOL_VERSION == "2025-03-26" + assert SERVER_LATEST_PROTOCOL_VERSION == "2024-11-05" + + def test_error_codes(self): + """Test JSON-RPC error code constants.""" + assert PARSE_ERROR == -32700 + assert INVALID_REQUEST == -32600 + assert METHOD_NOT_FOUND == -32601 + assert INVALID_PARAMS == -32602 + assert INTERNAL_ERROR == -32603 + + +class TestRequestParams: + """Test RequestParams and related classes.""" + + def test_request_params_basic(self): + """Test basic RequestParams creation.""" + params = RequestParams() + assert params.meta is None + + def test_request_params_with_meta(self): + """Test RequestParams with meta.""" + meta = RequestParams.Meta(progressToken="test-token") + params = RequestParams(_meta=meta) + assert params.meta is not None + assert params.meta.progressToken == "test-token" + + def test_request_params_meta_extra_fields(self): + """Test RequestParams.Meta allows extra fields.""" + meta = RequestParams.Meta(progressToken="token", customField="value") + assert meta.progressToken == "token" + assert meta.customField == "value" # type: ignore + + def test_request_params_serialization(self): + """Test RequestParams serialization with _meta alias.""" + meta = RequestParams.Meta(progressToken="test") + params = RequestParams(_meta=meta) + + # Model dump should use the alias + dumped = params.model_dump(by_alias=True) + assert "_meta" in dumped + assert dumped["_meta"] is not None + assert dumped["_meta"]["progressToken"] == "test" + + +class TestJSONRPCMessages: + """Test JSON-RPC message types.""" + + def test_jsonrpc_request(self): + """Test JSONRPCRequest creation and validation.""" + request = JSONRPCRequest(jsonrpc="2.0", id="test-123", method="test_method", params={"key": "value"}) + + assert request.jsonrpc == "2.0" + assert request.id == "test-123" + assert request.method == "test_method" + assert request.params == {"key": "value"} + + def test_jsonrpc_request_numeric_id(self): + """Test JSONRPCRequest with numeric ID.""" + request = JSONRPCRequest(jsonrpc="2.0", id=123, method="test", params=None) + assert request.id == 123 + + def test_jsonrpc_notification(self): + """Test JSONRPCNotification creation.""" + notification = JSONRPCNotification(jsonrpc="2.0", method="notification_method", params={"data": "test"}) + + assert notification.jsonrpc == "2.0" + assert notification.method == "notification_method" + assert not hasattr(notification, "id") # Notifications don't have ID + + def test_jsonrpc_response(self): + """Test JSONRPCResponse creation.""" + response = JSONRPCResponse(jsonrpc="2.0", id="req-123", result={"success": True}) + + assert response.jsonrpc == "2.0" + assert response.id == "req-123" + assert response.result == {"success": True} + + def test_jsonrpc_error(self): + """Test JSONRPCError creation.""" + error_data = ErrorData(code=INVALID_PARAMS, message="Invalid parameters", data={"field": "missing"}) + + error = JSONRPCError(jsonrpc="2.0", id="req-123", error=error_data) + + assert error.jsonrpc == "2.0" + assert error.id == "req-123" + assert error.error.code == INVALID_PARAMS + assert error.error.message == "Invalid parameters" + assert error.error.data == {"field": "missing"} + + def test_jsonrpc_message_parsing(self): + """Test JSONRPCMessage parsing different message types.""" + # Parse request + request_json = '{"jsonrpc": "2.0", "id": 1, "method": "test", "params": null}' + msg = JSONRPCMessage.model_validate_json(request_json) + assert isinstance(msg.root, JSONRPCRequest) + + # Parse response + response_json = '{"jsonrpc": "2.0", "id": 1, "result": {"data": "test"}}' + msg = JSONRPCMessage.model_validate_json(response_json) + assert isinstance(msg.root, JSONRPCResponse) + + # Parse error + error_json = '{"jsonrpc": "2.0", "id": 1, "error": {"code": -32600, "message": "Invalid Request"}}' + msg = JSONRPCMessage.model_validate_json(error_json) + assert isinstance(msg.root, JSONRPCError) + + +class TestCapabilities: + """Test capability classes.""" + + def test_client_capabilities(self): + """Test ClientCapabilities creation.""" + caps = ClientCapabilities( + experimental={"feature": {"enabled": True}}, + sampling={"model_config": {"extra": "allow"}}, + roots={"listChanged": True}, + ) + + assert caps.experimental == {"feature": {"enabled": True}} + assert caps.sampling is not None + assert caps.roots.listChanged is True # type: ignore + + def test_server_capabilities(self): + """Test ServerCapabilities creation.""" + caps = ServerCapabilities( + tools={"listChanged": True}, + resources={"subscribe": True, "listChanged": False}, + prompts={"listChanged": True}, + logging={}, + completions={}, + ) + + assert caps.tools.listChanged is True # type: ignore + assert caps.resources.subscribe is True # type: ignore + assert caps.resources.listChanged is False # type: ignore + + +class TestInitialization: + """Test initialization request/response types.""" + + def test_initialize_request(self): + """Test InitializeRequest creation.""" + client_info = Implementation(name="test-client", version="1.0.0") + capabilities = ClientCapabilities() + + params = InitializeRequestParams( + protocolVersion=LATEST_PROTOCOL_VERSION, capabilities=capabilities, clientInfo=client_info + ) + + request = InitializeRequest(params=params) + + assert request.method == "initialize" + assert request.params.protocolVersion == LATEST_PROTOCOL_VERSION + assert request.params.clientInfo.name == "test-client" + + def test_initialize_result(self): + """Test InitializeResult creation.""" + server_info = Implementation(name="test-server", version="1.0.0") + capabilities = ServerCapabilities() + + result = InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=capabilities, + serverInfo=server_info, + instructions="Welcome to test server", + ) + + assert result.protocolVersion == LATEST_PROTOCOL_VERSION + assert result.serverInfo.name == "test-server" + assert result.instructions == "Welcome to test server" + + +class TestTools: + """Test tool-related types.""" + + def test_tool_creation(self): + """Test Tool creation with all fields.""" + tool = Tool( + name="test_tool", + title="Test Tool", + description="A tool for testing", + inputSchema={"type": "object", "properties": {"input": {"type": "string"}}, "required": ["input"]}, + outputSchema={"type": "object", "properties": {"result": {"type": "string"}}}, + annotations=ToolAnnotations( + title="Test Tool", readOnlyHint=False, destructiveHint=False, idempotentHint=True + ), + ) + + assert tool.name == "test_tool" + assert tool.title == "Test Tool" + assert tool.description == "A tool for testing" + assert tool.inputSchema["properties"]["input"]["type"] == "string" + assert tool.annotations.idempotentHint is True + + def test_call_tool_request(self): + """Test CallToolRequest creation.""" + params = CallToolRequestParams(name="test_tool", arguments={"input": "test value"}) + + request = CallToolRequest(params=params) + + assert request.method == "tools/call" + assert request.params.name == "test_tool" + assert request.params.arguments == {"input": "test value"} + + def test_call_tool_result(self): + """Test CallToolResult creation.""" + result = CallToolResult( + content=[TextContent(type="text", text="Tool executed successfully")], + structuredContent={"status": "success", "data": "test"}, + isError=False, + ) + + assert len(result.content) == 1 + assert result.content[0].text == "Tool executed successfully" # type: ignore + assert result.structuredContent == {"status": "success", "data": "test"} + assert result.isError is False + + def test_list_tools_request(self): + """Test ListToolsRequest creation.""" + request = ListToolsRequest() + assert request.method == "tools/list" + + def test_list_tools_result(self): + """Test ListToolsResult creation.""" + tool1 = Tool(name="tool1", inputSchema={}) + tool2 = Tool(name="tool2", inputSchema={}) + + result = ListToolsResult(tools=[tool1, tool2]) + + assert len(result.tools) == 2 + assert result.tools[0].name == "tool1" + assert result.tools[1].name == "tool2" + + +class TestContent: + """Test content types.""" + + def test_text_content(self): + """Test TextContent creation.""" + annotations = Annotations(audience=["user"], priority=0.8) + content = TextContent(type="text", text="Hello, world!", annotations=annotations) + + assert content.type == "text" + assert content.text == "Hello, world!" + assert content.annotations is not None + assert content.annotations.priority == 0.8 + + def test_image_content(self): + """Test ImageContent creation.""" + content = ImageContent(type="image", data="base64encodeddata", mimeType="image/png") + + assert content.type == "image" + assert content.data == "base64encodeddata" + assert content.mimeType == "image/png" + + +class TestOAuth: + """Test OAuth-related types.""" + + def test_oauth_client_metadata(self): + """Test OAuthClientMetadata creation.""" + metadata = OAuthClientMetadata( + client_name="Test Client", + redirect_uris=["https://example.com/callback"], + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + token_endpoint_auth_method="none", + client_uri="https://example.com", + scope="read write", + ) + + assert metadata.client_name == "Test Client" + assert len(metadata.redirect_uris) == 1 + assert "authorization_code" in metadata.grant_types + + def test_oauth_client_information(self): + """Test OAuthClientInformation creation.""" + info = OAuthClientInformation(client_id="test-client-id", client_secret="test-secret") + + assert info.client_id == "test-client-id" + assert info.client_secret == "test-secret" + + def test_oauth_client_information_without_secret(self): + """Test OAuthClientInformation without secret.""" + info = OAuthClientInformation(client_id="public-client") + + assert info.client_id == "public-client" + assert info.client_secret is None + + def test_oauth_tokens(self): + """Test OAuthTokens creation.""" + tokens = OAuthTokens( + access_token="access-token-123", + token_type="Bearer", + expires_in=3600, + refresh_token="refresh-token-456", + scope="read write", + ) + + assert tokens.access_token == "access-token-123" + assert tokens.token_type == "Bearer" + assert tokens.expires_in == 3600 + assert tokens.refresh_token == "refresh-token-456" + assert tokens.scope == "read write" + + def test_oauth_metadata(self): + """Test OAuthMetadata creation.""" + metadata = OAuthMetadata( + authorization_endpoint="https://auth.example.com/authorize", + token_endpoint="https://auth.example.com/token", + registration_endpoint="https://auth.example.com/register", + response_types_supported=["code", "token"], + grant_types_supported=["authorization_code", "refresh_token"], + code_challenge_methods_supported=["plain", "S256"], + ) + + assert metadata.authorization_endpoint == "https://auth.example.com/authorize" + assert "code" in metadata.response_types_supported + assert "S256" in metadata.code_challenge_methods_supported + + +class TestNotifications: + """Test notification types.""" + + def test_progress_notification(self): + """Test ProgressNotification creation.""" + params = ProgressNotificationParams( + progressToken="progress-123", progress=50.0, total=100.0, message="Processing... 50%" + ) + + notification = ProgressNotification(params=params) + + assert notification.method == "notifications/progress" + assert notification.params.progressToken == "progress-123" + assert notification.params.progress == 50.0 + assert notification.params.total == 100.0 + assert notification.params.message == "Processing... 50%" + + def test_ping_request(self): + """Test PingRequest creation.""" + request = PingRequest() + assert request.method == "ping" + assert request.params is None + + +class TestCompletion: + """Test completion-related types.""" + + def test_completion_context(self): + """Test CompletionContext creation.""" + context = CompletionContext(arguments={"template_var": "value"}) + assert context.arguments == {"template_var": "value"} + + def test_resource_template_reference(self): + """Test ResourceTemplateReference creation.""" + ref = ResourceTemplateReference(type="ref/resource", uri="file:///path/to/{filename}") + assert ref.type == "ref/resource" + assert ref.uri == "file:///path/to/{filename}" + + def test_prompt_reference(self): + """Test PromptReference creation.""" + ref = PromptReference(type="ref/prompt", name="test_prompt") + assert ref.type == "ref/prompt" + assert ref.name == "test_prompt" + + def test_complete_request(self): + """Test CompleteRequest creation.""" + ref = PromptReference(type="ref/prompt", name="test_prompt") + arg = CompletionArgument(name="arg1", value="val") + + params = CompleteRequestParams(ref=ref, argument=arg, context=CompletionContext(arguments={"key": "value"})) + + request = CompleteRequest(params=params) + + assert request.method == "completion/complete" + assert request.params.ref.name == "test_prompt" # type: ignore + assert request.params.argument.name == "arg1" + + def test_complete_result(self): + """Test CompleteResult creation.""" + completion = Completion(values=["option1", "option2", "option3"], total=10, hasMore=True) + + result = CompleteResult(completion=completion) + + assert len(result.completion.values) == 3 + assert result.completion.total == 10 + assert result.completion.hasMore is True + + +class TestValidation: + """Test validation of various types.""" + + def test_invalid_jsonrpc_version(self): + """Test invalid JSON-RPC version validation.""" + with pytest.raises(ValidationError): + JSONRPCRequest( + jsonrpc="1.0", # Invalid version + id=1, + method="test", + ) + + def test_tool_annotations_validation(self): + """Test ToolAnnotations with invalid values.""" + # Valid annotations + annotations = ToolAnnotations( + title="Test", readOnlyHint=True, destructiveHint=False, idempotentHint=True, openWorldHint=False + ) + assert annotations.title == "Test" + + def test_extra_fields_allowed(self): + """Test that extra fields are allowed in models.""" + # Most models should allow extra fields + tool = Tool( + name="test", + inputSchema={}, + customField="allowed", # type: ignore + ) + assert tool.customField == "allowed" # type: ignore + + def test_result_meta_alias(self): + """Test Result model with _meta alias.""" + # Create with the field name (not alias) + result = Result(_meta={"key": "value"}) + + # Verify the field is set correctly + assert result.meta == {"key": "value"} + + # Dump with alias + dumped = result.model_dump(by_alias=True) + assert "_meta" in dumped + assert dumped["_meta"] == {"key": "value"} diff --git a/api/tests/unit_tests/core/mcp/test_utils.py b/api/tests/unit_tests/core/mcp/test_utils.py new file mode 100644 index 0000000000..ca41d5f4c1 --- /dev/null +++ b/api/tests/unit_tests/core/mcp/test_utils.py @@ -0,0 +1,355 @@ +"""Unit tests for MCP utils module.""" + +import json +from collections.abc import Generator +from unittest.mock import MagicMock, Mock, patch + +import httpx +import httpx_sse +import pytest + +from core.mcp.utils import ( + STATUS_FORCELIST, + create_mcp_error_response, + create_ssrf_proxy_mcp_http_client, + ssrf_proxy_sse_connect, +) + + +class TestConstants: + """Test module constants.""" + + def test_status_forcelist(self): + """Test STATUS_FORCELIST contains expected HTTP status codes.""" + assert STATUS_FORCELIST == [429, 500, 502, 503, 504] + assert 429 in STATUS_FORCELIST # Too Many Requests + assert 500 in STATUS_FORCELIST # Internal Server Error + assert 502 in STATUS_FORCELIST # Bad Gateway + assert 503 in STATUS_FORCELIST # Service Unavailable + assert 504 in STATUS_FORCELIST # Gateway Timeout + + +class TestCreateSSRFProxyMCPHTTPClient: + """Test create_ssrf_proxy_mcp_http_client function.""" + + @patch("core.mcp.utils.dify_config") + def test_create_client_with_all_url_proxy(self, mock_config): + """Test client creation with SSRF_PROXY_ALL_URL configured.""" + mock_config.SSRF_PROXY_ALL_URL = "http://proxy.example.com:8080" + mock_config.HTTP_REQUEST_NODE_SSL_VERIFY = True + + client = create_ssrf_proxy_mcp_http_client( + headers={"Authorization": "Bearer token"}, timeout=httpx.Timeout(30.0) + ) + + assert isinstance(client, httpx.Client) + assert client.headers["Authorization"] == "Bearer token" + assert client.timeout.connect == 30.0 + assert client.follow_redirects is True + + # Clean up + client.close() + + @patch("core.mcp.utils.dify_config") + def test_create_client_with_http_https_proxies(self, mock_config): + """Test client creation with separate HTTP/HTTPS proxies.""" + mock_config.SSRF_PROXY_ALL_URL = None + mock_config.SSRF_PROXY_HTTP_URL = "http://http-proxy.example.com:8080" + mock_config.SSRF_PROXY_HTTPS_URL = "http://https-proxy.example.com:8443" + mock_config.HTTP_REQUEST_NODE_SSL_VERIFY = False + + client = create_ssrf_proxy_mcp_http_client() + + assert isinstance(client, httpx.Client) + assert client.follow_redirects is True + + # Clean up + client.close() + + @patch("core.mcp.utils.dify_config") + def test_create_client_without_proxy(self, mock_config): + """Test client creation without proxy configuration.""" + mock_config.SSRF_PROXY_ALL_URL = None + mock_config.SSRF_PROXY_HTTP_URL = None + mock_config.SSRF_PROXY_HTTPS_URL = None + mock_config.HTTP_REQUEST_NODE_SSL_VERIFY = True + + headers = {"X-Custom-Header": "value"} + timeout = httpx.Timeout(timeout=30.0, connect=5.0, read=10.0, write=30.0) + + client = create_ssrf_proxy_mcp_http_client(headers=headers, timeout=timeout) + + assert isinstance(client, httpx.Client) + assert client.headers["X-Custom-Header"] == "value" + assert client.timeout.connect == 5.0 + assert client.timeout.read == 10.0 + assert client.follow_redirects is True + + # Clean up + client.close() + + @patch("core.mcp.utils.dify_config") + def test_create_client_default_params(self, mock_config): + """Test client creation with default parameters.""" + mock_config.SSRF_PROXY_ALL_URL = None + mock_config.SSRF_PROXY_HTTP_URL = None + mock_config.SSRF_PROXY_HTTPS_URL = None + mock_config.HTTP_REQUEST_NODE_SSL_VERIFY = True + + client = create_ssrf_proxy_mcp_http_client() + + assert isinstance(client, httpx.Client) + # httpx.Client adds default headers, so we just check it's a Headers object + assert isinstance(client.headers, httpx.Headers) + # When no timeout is provided, httpx uses its default timeout + assert client.timeout is not None + + # Clean up + client.close() + + +class TestSSRFProxySSEConnect: + """Test ssrf_proxy_sse_connect function.""" + + @patch("core.mcp.utils.connect_sse") + @patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client") + def test_sse_connect_with_provided_client(self, mock_create_client, mock_connect_sse): + """Test SSE connection with pre-configured client.""" + # Setup mocks + mock_client = Mock(spec=httpx.Client) + mock_event_source = Mock(spec=httpx_sse.EventSource) + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_event_source + mock_connect_sse.return_value = mock_context + + # Call with provided client + result = ssrf_proxy_sse_connect( + "http://example.com/sse", client=mock_client, method="POST", headers={"Authorization": "Bearer token"} + ) + + # Verify client creation was not called + mock_create_client.assert_not_called() + + # Verify connect_sse was called correctly + mock_connect_sse.assert_called_once_with( + mock_client, "POST", "http://example.com/sse", headers={"Authorization": "Bearer token"} + ) + + # Verify result + assert result == mock_context + + @patch("core.mcp.utils.connect_sse") + @patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client") + @patch("core.mcp.utils.dify_config") + def test_sse_connect_without_client(self, mock_config, mock_create_client, mock_connect_sse): + """Test SSE connection without pre-configured client.""" + # Setup config + mock_config.SSRF_DEFAULT_TIME_OUT = 30.0 + mock_config.SSRF_DEFAULT_CONNECT_TIME_OUT = 10.0 + mock_config.SSRF_DEFAULT_READ_TIME_OUT = 60.0 + mock_config.SSRF_DEFAULT_WRITE_TIME_OUT = 30.0 + + # Setup mocks + mock_client = Mock(spec=httpx.Client) + mock_create_client.return_value = mock_client + + mock_event_source = Mock(spec=httpx_sse.EventSource) + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_event_source + mock_connect_sse.return_value = mock_context + + # Call without client + result = ssrf_proxy_sse_connect("http://example.com/sse", headers={"X-Custom": "value"}) + + # Verify client was created + mock_create_client.assert_called_once() + call_args = mock_create_client.call_args + assert call_args[1]["headers"] == {"X-Custom": "value"} + + timeout = call_args[1]["timeout"] + # httpx.Timeout object has these attributes + assert isinstance(timeout, httpx.Timeout) + assert timeout.connect == 10.0 + assert timeout.read == 60.0 + assert timeout.write == 30.0 + + # Verify connect_sse was called + mock_connect_sse.assert_called_once_with( + mock_client, + "GET", # Default method + "http://example.com/sse", + ) + + # Verify result + assert result == mock_context + + @patch("core.mcp.utils.connect_sse") + @patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client") + def test_sse_connect_with_custom_timeout(self, mock_create_client, mock_connect_sse): + """Test SSE connection with custom timeout.""" + # Setup mocks + mock_client = Mock(spec=httpx.Client) + mock_create_client.return_value = mock_client + + mock_event_source = Mock(spec=httpx_sse.EventSource) + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_event_source + mock_connect_sse.return_value = mock_context + + custom_timeout = httpx.Timeout(timeout=60.0, read=120.0) + + # Call with custom timeout + result = ssrf_proxy_sse_connect("http://example.com/sse", timeout=custom_timeout) + + # Verify client was created with custom timeout + mock_create_client.assert_called_once() + call_args = mock_create_client.call_args + assert call_args[1]["timeout"] == custom_timeout + + # Verify result + assert result == mock_context + + @patch("core.mcp.utils.connect_sse") + @patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client") + def test_sse_connect_error_cleanup(self, mock_create_client, mock_connect_sse): + """Test SSE connection cleans up client on error.""" + # Setup mocks + mock_client = Mock(spec=httpx.Client) + mock_create_client.return_value = mock_client + + # Make connect_sse raise an exception + mock_connect_sse.side_effect = httpx.ConnectError("Connection failed") + + # Call should raise the exception + with pytest.raises(httpx.ConnectError): + ssrf_proxy_sse_connect("http://example.com/sse") + + # Verify client was cleaned up + mock_client.close.assert_called_once() + + @patch("core.mcp.utils.connect_sse") + def test_sse_connect_error_no_cleanup_with_provided_client(self, mock_connect_sse): + """Test SSE connection doesn't clean up provided client on error.""" + # Setup mocks + mock_client = Mock(spec=httpx.Client) + + # Make connect_sse raise an exception + mock_connect_sse.side_effect = httpx.ConnectError("Connection failed") + + # Call should raise the exception + with pytest.raises(httpx.ConnectError): + ssrf_proxy_sse_connect("http://example.com/sse", client=mock_client) + + # Verify client was NOT cleaned up (because it was provided) + mock_client.close.assert_not_called() + + +class TestCreateMCPErrorResponse: + """Test create_mcp_error_response function.""" + + def test_create_error_response_basic(self): + """Test creating basic error response.""" + generator = create_mcp_error_response(request_id="req-123", code=-32600, message="Invalid Request") + + # Generator should yield bytes + assert isinstance(generator, Generator) + + # Get the response + response_bytes = next(generator) + assert isinstance(response_bytes, bytes) + + # Parse the response + response_str = response_bytes.decode("utf-8") + response_json = json.loads(response_str) + + assert response_json["jsonrpc"] == "2.0" + assert response_json["id"] == "req-123" + assert response_json["error"]["code"] == -32600 + assert response_json["error"]["message"] == "Invalid Request" + assert response_json["error"]["data"] is None + + # Generator should be exhausted + with pytest.raises(StopIteration): + next(generator) + + def test_create_error_response_with_data(self): + """Test creating error response with additional data.""" + error_data = {"field": "username", "reason": "required"} + + generator = create_mcp_error_response( + request_id=456, # Numeric ID + code=-32602, + message="Invalid params", + data=error_data, + ) + + response_bytes = next(generator) + response_json = json.loads(response_bytes.decode("utf-8")) + + assert response_json["id"] == 456 + assert response_json["error"]["code"] == -32602 + assert response_json["error"]["message"] == "Invalid params" + assert response_json["error"]["data"] == error_data + + def test_create_error_response_without_request_id(self): + """Test creating error response without request ID.""" + generator = create_mcp_error_response(request_id=None, code=-32700, message="Parse error") + + response_bytes = next(generator) + response_json = json.loads(response_bytes.decode("utf-8")) + + # Should default to ID 1 + assert response_json["id"] == 1 + assert response_json["error"]["code"] == -32700 + assert response_json["error"]["message"] == "Parse error" + + def test_create_error_response_with_complex_data(self): + """Test creating error response with complex error data.""" + complex_data = { + "errors": [{"field": "name", "message": "Too short"}, {"field": "email", "message": "Invalid format"}], + "timestamp": "2024-01-01T00:00:00Z", + } + + generator = create_mcp_error_response( + request_id="complex-req", code=-32602, message="Validation failed", data=complex_data + ) + + response_bytes = next(generator) + response_json = json.loads(response_bytes.decode("utf-8")) + + assert response_json["error"]["data"] == complex_data + assert len(response_json["error"]["data"]["errors"]) == 2 + + def test_create_error_response_encoding(self): + """Test error response with non-ASCII characters.""" + generator = create_mcp_error_response( + request_id="unicode-req", + code=-32603, + message="内部错误", # Chinese characters + data={"details": "エラー詳細"}, # Japanese characters + ) + + response_bytes = next(generator) + + # Should be valid UTF-8 + response_str = response_bytes.decode("utf-8") + response_json = json.loads(response_str) + + assert response_json["error"]["message"] == "内部错误" + assert response_json["error"]["data"]["details"] == "エラー詳細" + + def test_create_error_response_yields_once(self): + """Test that error response generator yields exactly once.""" + generator = create_mcp_error_response(request_id="test", code=-32600, message="Test") + + # First yield should work + first_yield = next(generator) + assert isinstance(first_yield, bytes) + + # Second yield should raise StopIteration + with pytest.raises(StopIteration): + next(generator) + + # Subsequent calls should also raise + with pytest.raises(StopIteration): + next(generator)