chore(mcp): fix pyright checks

This commit is contained in:
Novice 2025-09-16 17:25:09 +08:00
parent 5547247aa9
commit 685f199f91
1 changed files with 72 additions and 175 deletions

View File

@ -1,13 +1,14 @@
"""Unit tests for MCP auth client with retry logic."""
from types import TracebackType
from unittest.mock import MagicMock, Mock, patch
from unittest.mock import Mock, patch
import pytest
from core.entities.mcp_provider import MCPProviderEntity
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPAuthError
from core.mcp.mcp_client import MCPClient
from core.mcp.types import CallToolResult, TextContent, Tool, ToolAnnotations
@ -192,10 +193,7 @@ class TestMCPClientWithAuthRetry:
mock_func.assert_called_once_with("arg1", kwarg1="value1")
assert client._has_retried is False
@patch("core.mcp.auth_client.MCPClient")
def test_execute_with_retry_auth_error_then_success(
self, mock_mcp_client_class, mock_provider_entity, mock_mcp_service, auth_callback
):
def test_execute_with_retry_auth_error_then_success(self, mock_provider_entity, mock_mcp_service, auth_callback):
"""Test execution with auth error followed by successful retry."""
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
@ -204,14 +202,6 @@ class TestMCPClientWithAuthRetry:
mcp_service=mock_mcp_service,
)
# Configure mock clients (old and new)
mock_client_old = MagicMock()
mock_client_new = MagicMock()
client._client = mock_client_old
# Make _create_client return the new client on retry
mock_mcp_client_class.return_value = mock_client_new
# Configure new provider with token
new_provider = Mock(spec=MCPProviderEntity)
new_provider.retrieve_tokens.return_value = Mock(
@ -222,15 +212,22 @@ class TestMCPClientWithAuthRetry:
# Mock function that fails first, then succeeds
mock_func = Mock(side_effect=[MCPAuthError("Auth failed"), "success"])
result = client._execute_with_retry(mock_func, "arg1", kwarg1="value1")
# Mock the exit stack and session cleanup
with (
patch.object(client, "_exit_stack") as mock_exit_stack,
patch.object(client, "_session") as mock_session,
patch.object(client, "_initialize") as mock_initialize,
):
client._initialized = True
result = client._execute_with_retry(mock_func, "arg1", kwarg1="value1")
assert result == "success"
assert mock_func.call_count == 2
mock_func.assert_called_with("arg1", kwarg1="value1")
auth_callback.assert_called_once()
mock_client_old.cleanup.assert_called_once()
mock_client_new.__enter__.assert_called_once()
assert client._has_retried is False # Reset after completion
assert result == "success"
assert mock_func.call_count == 2
mock_func.assert_called_with("arg1", kwarg1="value1")
auth_callback.assert_called_once()
mock_exit_stack.close.assert_called_once()
mock_initialize.assert_called_once()
assert client._has_retried is False # Reset after completion
def test_execute_with_retry_non_auth_error(self):
"""Test execution with non-auth error (no retry)."""
@ -244,29 +241,19 @@ class TestMCPClientWithAuthRetry:
assert str(exc_info.value) == "Some other error"
mock_func.assert_called_once()
@patch("core.mcp.auth_client.MCPClient")
def test_context_manager_enter(self, mock_mcp_client_class):
def test_context_manager_enter(self):
"""Test context manager enter."""
mock_client_instance = MagicMock()
mock_client_instance.__enter__.return_value = mock_client_instance
mock_mcp_client_class.return_value = mock_client_instance
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
result = client.__enter__()
with patch.object(client, "_initialize") as mock_initialize:
result = client.__enter__()
assert result == client
assert client._client == mock_client_instance
mock_client_instance.__enter__.assert_called_once()
assert result == client
assert client._initialized is True
mock_initialize.assert_called_once()
@patch("core.mcp.auth_client.MCPClient")
def test_context_manager_enter_with_auth_error(
self, mock_mcp_client_class, mock_provider_entity, mock_mcp_service, auth_callback
):
def test_context_manager_enter_with_auth_error(self, mock_provider_entity, mock_mcp_service, auth_callback):
"""Test context manager enter with auth error and retry."""
mock_client_instance = MagicMock()
mock_mcp_client_class.return_value = mock_client_instance
# Configure new provider with token
new_provider = Mock(spec=MCPProviderEntity)
new_provider.retrieve_tokens.return_value = Mock(
@ -281,37 +268,27 @@ class TestMCPClientWithAuthRetry:
mcp_service=mock_mcp_service,
)
# First call to client.__enter__ raises auth error, second succeeds
call_count = 0
# Mock parent class __enter__ to raise auth error first, then succeed
with patch.object(MCPClient, "__enter__") as mock_parent_enter:
mock_parent_enter.side_effect = [MCPAuthError("Auth failed"), client]
def enter_side_effect():
nonlocal call_count
call_count += 1
if call_count == 1:
raise MCPAuthError("Auth failed")
return mock_client_instance
result = client.__enter__()
mock_client_instance.__enter__.side_effect = enter_side_effect
result = client.__enter__()
assert result == client
assert mock_client_instance.__enter__.call_count == 3
auth_callback.assert_called_once()
assert result == client
assert mock_parent_enter.call_count == 2
auth_callback.assert_called_once()
def test_context_manager_exit(self):
"""Test context manager exit."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
mock_client = MagicMock()
client._client = mock_client
exc_type: type[BaseException] | None = None
exc_val: BaseException | None = None
exc_tb: TracebackType | None = None
client.__exit__(exc_type, exc_val, exc_tb)
with patch.object(client, "cleanup") as mock_cleanup:
exc_type: type[BaseException] | None = None
exc_val: BaseException | None = None
exc_tb: TracebackType | None = None
client.__exit__(exc_type, exc_val, exc_tb)
mock_client.__exit__.assert_called_once_with(None, None, None)
assert client._client is None
mock_cleanup.assert_called_once()
def test_list_tools_not_initialized(self):
"""Test list_tools when client not initialized."""
@ -320,13 +297,11 @@ class TestMCPClientWithAuthRetry:
with pytest.raises(ValueError) as exc_info:
client.list_tools()
assert "Client not initialized" in str(exc_info.value)
assert "Session not initialized" in str(exc_info.value)
def test_list_tools_success(self):
"""Test successful list_tools call."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
mock_client = Mock()
client._client = mock_client
expected_tools = [
Tool(
@ -336,17 +311,13 @@ class TestMCPClientWithAuthRetry:
annotations=ToolAnnotations(title="Test Tool"),
)
]
mock_client.list_tools.return_value = expected_tools
result = client.list_tools()
# Mock the parent class list_tools method
with patch.object(MCPClient, "list_tools", return_value=expected_tools):
result = client.list_tools()
assert result == expected_tools
assert result == expected_tools
mock_client.list_tools.assert_called_once()
@patch("core.mcp.auth_client.MCPClient")
def test_list_tools_with_auth_retry(
self, mock_mcp_client_class, mock_provider_entity, mock_mcp_service, auth_callback
):
def test_list_tools_with_auth_retry(self, mock_provider_entity, mock_mcp_service, auth_callback):
"""Test list_tools with auth retry."""
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
@ -355,14 +326,6 @@ class TestMCPClientWithAuthRetry:
mcp_service=mock_mcp_service,
)
# Configure mock clients (old and new)
mock_client_old = MagicMock()
mock_client_new = MagicMock()
client._client = mock_client_old
# Make _create_client return the new client on retry
mock_mcp_client_class.return_value = mock_client_new
# Configure new provider with token
new_provider = Mock(spec=MCPProviderEntity)
new_provider.retrieve_tokens.return_value = Mock(
@ -371,45 +334,16 @@ class TestMCPClientWithAuthRetry:
mock_mcp_service.get_provider_entity.return_value = new_provider
expected_tools = [Tool(name="test-tool", description="A test tool", inputSchema={})]
# First call raises auth error
mock_client_old.list_tools.side_effect = MCPAuthError("Auth failed")
mock_client_new.list_tools.return_value = expected_tools
# We need to mock the behavior where after client is replaced,
# the new method should be called. But since the method reference
# is already bound to the old client, we need to work around this.
# Let's patch the _execute_with_retry to handle this properly.
# Mock parent class list_tools to raise auth error first, then succeed
with patch.object(MCPClient, "list_tools") as mock_list_tools:
mock_list_tools.side_effect = [MCPAuthError("Auth failed"), expected_tools]
with patch.object(client, "_execute_with_retry") as mock_execute:
# Simulate the retry behavior
call_count = [0]
def execute_side_effect(func, *args, **kwargs):
call_count[0] += 1
if call_count[0] == 1:
# First call - simulate auth error and retry
try:
func(*args, **kwargs)
except MCPAuthError:
# Simulate the retry logic
client._handle_auth_error(MCPAuthError("Auth failed"))
if client._client:
client._client.cleanup()
client._client = mock_client_new
client._client.__enter__()
# Now return the result from the new client
return mock_client_new.list_tools(*args, **kwargs)
return func(*args, **kwargs)
mock_execute.side_effect = execute_side_effect
result = client.list_tools()
assert result == expected_tools
mock_client_old.list_tools.assert_called_once()
mock_client_new.list_tools.assert_called_once()
auth_callback.assert_called_once()
mock_client_old.cleanup.assert_called_once()
mock_client_new.__enter__.assert_called_once()
assert result == expected_tools
assert mock_list_tools.call_count == 2
auth_callback.assert_called_once()
def test_invoke_tool_not_initialized(self):
"""Test invoke_tool when client not initialized."""
@ -418,28 +352,24 @@ class TestMCPClientWithAuthRetry:
with pytest.raises(ValueError) as exc_info:
client.invoke_tool("test-tool", {"arg": "value"})
assert "Client not initialized" in str(exc_info.value)
assert "Session not initialized" in str(exc_info.value)
def test_invoke_tool_success(self):
"""Test successful invoke_tool call."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
mock_client = Mock()
client._client = mock_client
expected_result = CallToolResult(
content=[TextContent(type="text", text="Tool executed successfully")], isError=False
)
mock_client.invoke_tool.return_value = expected_result
result = client.invoke_tool("test-tool", {"arg": "value"})
# Mock the parent class invoke_tool method
with patch.object(MCPClient, "invoke_tool", return_value=expected_result) as mock_invoke:
result = client.invoke_tool("test-tool", {"arg": "value"})
assert result == expected_result
mock_client.invoke_tool.assert_called_once_with("test-tool", {"arg": "value"})
assert result == expected_result
mock_invoke.assert_called_once_with("test-tool", {"arg": "value"})
@patch("core.mcp.auth_client.MCPClient")
def test_invoke_tool_with_auth_retry(
self, mock_mcp_client_class, mock_provider_entity, mock_mcp_service, auth_callback
):
def test_invoke_tool_with_auth_retry(self, mock_provider_entity, mock_mcp_service, auth_callback):
"""Test invoke_tool with auth retry."""
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
@ -448,14 +378,6 @@ class TestMCPClientWithAuthRetry:
mcp_service=mock_mcp_service,
)
# Configure mock clients (old and new)
mock_client_old = MagicMock()
mock_client_new = MagicMock()
client._client = mock_client_old
# Make _create_client return the new client on retry
mock_mcp_client_class.return_value = mock_client_new
# Configure new provider with token
new_provider = Mock(spec=MCPProviderEntity)
new_provider.retrieve_tokens.return_value = Mock(
@ -464,54 +386,26 @@ class TestMCPClientWithAuthRetry:
mock_mcp_service.get_provider_entity.return_value = new_provider
expected_result = CallToolResult(content=[TextContent(type="text", text="Success")], isError=False)
# First call raises auth error
mock_client_old.invoke_tool.side_effect = MCPAuthError("Auth failed")
mock_client_new.invoke_tool.return_value = expected_result
# We need to mock the behavior where after client is replaced,
# the new method should be called. Similar to list_tools test.
# Mock parent class invoke_tool to raise auth error first, then succeed
with patch.object(MCPClient, "invoke_tool") as mock_invoke_tool:
mock_invoke_tool.side_effect = [MCPAuthError("Auth failed"), expected_result]
with patch.object(client, "_execute_with_retry") as mock_execute:
# Simulate the retry behavior
call_count = [0]
def execute_side_effect(func, *args, **kwargs):
call_count[0] += 1
if call_count[0] == 1:
# First call - simulate auth error and retry
try:
func(*args, **kwargs)
except MCPAuthError:
# Simulate the retry logic
client._handle_auth_error(MCPAuthError("Auth failed"))
if client._client:
client._client.cleanup()
client._client = mock_client_new
client._client.__enter__()
# Now return the result from the new client
return mock_client_new.invoke_tool(*args, **kwargs)
return func(*args, **kwargs)
mock_execute.side_effect = execute_side_effect
result = client.invoke_tool("test-tool", {"arg": "value"})
assert result == expected_result
mock_client_old.invoke_tool.assert_called_once_with("test-tool", {"arg": "value"})
mock_client_new.invoke_tool.assert_called_once_with("test-tool", {"arg": "value"})
auth_callback.assert_called_once()
mock_client_old.cleanup.assert_called_once()
mock_client_new.__enter__.assert_called_once()
assert result == expected_result
assert mock_invoke_tool.call_count == 2
mock_invoke_tool.assert_called_with("test-tool", {"arg": "value"})
auth_callback.assert_called_once()
def test_cleanup(self):
"""Test cleanup method."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
mock_client = Mock()
client._client = mock_client
client.cleanup()
mock_client.cleanup.assert_called_once()
assert client._client is None
# Mock the parent class cleanup method
with patch.object(MCPClient, "cleanup") as mock_cleanup:
client.cleanup()
mock_cleanup.assert_called_once()
def test_cleanup_no_client(self):
"""Test cleanup when no client exists."""
@ -520,4 +414,7 @@ class TestMCPClientWithAuthRetry:
# Should not raise
client.cleanup()
assert client._client is None
# Since MCPClientWithAuthRetry inherits from MCPClient,
# it doesn't have a _client attribute. The test should just
# verify that cleanup can be called without error.
assert not hasattr(client, "_client")