From 685f199f91c6c03e07a5ea5b5c89aacfe514809a Mon Sep 17 00:00:00 2001 From: Novice Date: Tue, 16 Sep 2025 17:25:09 +0800 Subject: [PATCH] chore(mcp): fix pyright checks --- .../unit_tests/core/mcp/test_auth_client.py | 247 +++++------------- 1 file changed, 72 insertions(+), 175 deletions(-) diff --git a/api/tests/unit_tests/core/mcp/test_auth_client.py b/api/tests/unit_tests/core/mcp/test_auth_client.py index 58fa85f4f9..7b06c9df4d 100644 --- a/api/tests/unit_tests/core/mcp/test_auth_client.py +++ b/api/tests/unit_tests/core/mcp/test_auth_client.py @@ -1,13 +1,14 @@ """Unit tests for MCP auth client with retry logic.""" from types import TracebackType -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import Mock, patch import pytest from core.entities.mcp_provider import MCPProviderEntity from core.mcp.auth_client import MCPClientWithAuthRetry from core.mcp.error import MCPAuthError +from core.mcp.mcp_client import MCPClient from core.mcp.types import CallToolResult, TextContent, Tool, ToolAnnotations @@ -192,10 +193,7 @@ class TestMCPClientWithAuthRetry: mock_func.assert_called_once_with("arg1", kwarg1="value1") assert client._has_retried is False - @patch("core.mcp.auth_client.MCPClient") - def test_execute_with_retry_auth_error_then_success( - self, mock_mcp_client_class, mock_provider_entity, mock_mcp_service, auth_callback - ): + def test_execute_with_retry_auth_error_then_success(self, mock_provider_entity, mock_mcp_service, auth_callback): """Test execution with auth error followed by successful retry.""" client = MCPClientWithAuthRetry( server_url="http://test.example.com", @@ -204,14 +202,6 @@ class TestMCPClientWithAuthRetry: mcp_service=mock_mcp_service, ) - # Configure mock clients (old and new) - mock_client_old = MagicMock() - mock_client_new = MagicMock() - client._client = mock_client_old - - # Make _create_client return the new client on retry - mock_mcp_client_class.return_value = mock_client_new - # Configure new provider with token new_provider = Mock(spec=MCPProviderEntity) new_provider.retrieve_tokens.return_value = Mock( @@ -222,15 +212,22 @@ class TestMCPClientWithAuthRetry: # Mock function that fails first, then succeeds mock_func = Mock(side_effect=[MCPAuthError("Auth failed"), "success"]) - result = client._execute_with_retry(mock_func, "arg1", kwarg1="value1") + # Mock the exit stack and session cleanup + with ( + patch.object(client, "_exit_stack") as mock_exit_stack, + patch.object(client, "_session") as mock_session, + patch.object(client, "_initialize") as mock_initialize, + ): + client._initialized = True + result = client._execute_with_retry(mock_func, "arg1", kwarg1="value1") - assert result == "success" - assert mock_func.call_count == 2 - mock_func.assert_called_with("arg1", kwarg1="value1") - auth_callback.assert_called_once() - mock_client_old.cleanup.assert_called_once() - mock_client_new.__enter__.assert_called_once() - assert client._has_retried is False # Reset after completion + assert result == "success" + assert mock_func.call_count == 2 + mock_func.assert_called_with("arg1", kwarg1="value1") + auth_callback.assert_called_once() + mock_exit_stack.close.assert_called_once() + mock_initialize.assert_called_once() + assert client._has_retried is False # Reset after completion def test_execute_with_retry_non_auth_error(self): """Test execution with non-auth error (no retry).""" @@ -244,29 +241,19 @@ class TestMCPClientWithAuthRetry: assert str(exc_info.value) == "Some other error" mock_func.assert_called_once() - @patch("core.mcp.auth_client.MCPClient") - def test_context_manager_enter(self, mock_mcp_client_class): + def test_context_manager_enter(self): """Test context manager enter.""" - mock_client_instance = MagicMock() - mock_client_instance.__enter__.return_value = mock_client_instance - mock_mcp_client_class.return_value = mock_client_instance - client = MCPClientWithAuthRetry(server_url="http://test.example.com") - result = client.__enter__() + with patch.object(client, "_initialize") as mock_initialize: + result = client.__enter__() - assert result == client - assert client._client == mock_client_instance - mock_client_instance.__enter__.assert_called_once() + assert result == client + assert client._initialized is True + mock_initialize.assert_called_once() - @patch("core.mcp.auth_client.MCPClient") - def test_context_manager_enter_with_auth_error( - self, mock_mcp_client_class, mock_provider_entity, mock_mcp_service, auth_callback - ): + def test_context_manager_enter_with_auth_error(self, mock_provider_entity, mock_mcp_service, auth_callback): """Test context manager enter with auth error and retry.""" - mock_client_instance = MagicMock() - mock_mcp_client_class.return_value = mock_client_instance - # Configure new provider with token new_provider = Mock(spec=MCPProviderEntity) new_provider.retrieve_tokens.return_value = Mock( @@ -281,37 +268,27 @@ class TestMCPClientWithAuthRetry: mcp_service=mock_mcp_service, ) - # First call to client.__enter__ raises auth error, second succeeds - call_count = 0 + # Mock parent class __enter__ to raise auth error first, then succeed + with patch.object(MCPClient, "__enter__") as mock_parent_enter: + mock_parent_enter.side_effect = [MCPAuthError("Auth failed"), client] - def enter_side_effect(): - nonlocal call_count - call_count += 1 - if call_count == 1: - raise MCPAuthError("Auth failed") - return mock_client_instance + result = client.__enter__() - mock_client_instance.__enter__.side_effect = enter_side_effect - - result = client.__enter__() - - assert result == client - assert mock_client_instance.__enter__.call_count == 3 - auth_callback.assert_called_once() + assert result == client + assert mock_parent_enter.call_count == 2 + auth_callback.assert_called_once() def test_context_manager_exit(self): """Test context manager exit.""" client = MCPClientWithAuthRetry(server_url="http://test.example.com") - mock_client = MagicMock() - client._client = mock_client - exc_type: type[BaseException] | None = None - exc_val: BaseException | None = None - exc_tb: TracebackType | None = None - client.__exit__(exc_type, exc_val, exc_tb) + with patch.object(client, "cleanup") as mock_cleanup: + exc_type: type[BaseException] | None = None + exc_val: BaseException | None = None + exc_tb: TracebackType | None = None + client.__exit__(exc_type, exc_val, exc_tb) - mock_client.__exit__.assert_called_once_with(None, None, None) - assert client._client is None + mock_cleanup.assert_called_once() def test_list_tools_not_initialized(self): """Test list_tools when client not initialized.""" @@ -320,13 +297,11 @@ class TestMCPClientWithAuthRetry: with pytest.raises(ValueError) as exc_info: client.list_tools() - assert "Client not initialized" in str(exc_info.value) + assert "Session not initialized" in str(exc_info.value) def test_list_tools_success(self): """Test successful list_tools call.""" client = MCPClientWithAuthRetry(server_url="http://test.example.com") - mock_client = Mock() - client._client = mock_client expected_tools = [ Tool( @@ -336,17 +311,13 @@ class TestMCPClientWithAuthRetry: annotations=ToolAnnotations(title="Test Tool"), ) ] - mock_client.list_tools.return_value = expected_tools - result = client.list_tools() + # Mock the parent class list_tools method + with patch.object(MCPClient, "list_tools", return_value=expected_tools): + result = client.list_tools() + assert result == expected_tools - assert result == expected_tools - mock_client.list_tools.assert_called_once() - - @patch("core.mcp.auth_client.MCPClient") - def test_list_tools_with_auth_retry( - self, mock_mcp_client_class, mock_provider_entity, mock_mcp_service, auth_callback - ): + def test_list_tools_with_auth_retry(self, mock_provider_entity, mock_mcp_service, auth_callback): """Test list_tools with auth retry.""" client = MCPClientWithAuthRetry( server_url="http://test.example.com", @@ -355,14 +326,6 @@ class TestMCPClientWithAuthRetry: mcp_service=mock_mcp_service, ) - # Configure mock clients (old and new) - mock_client_old = MagicMock() - mock_client_new = MagicMock() - client._client = mock_client_old - - # Make _create_client return the new client on retry - mock_mcp_client_class.return_value = mock_client_new - # Configure new provider with token new_provider = Mock(spec=MCPProviderEntity) new_provider.retrieve_tokens.return_value = Mock( @@ -371,45 +334,16 @@ class TestMCPClientWithAuthRetry: mock_mcp_service.get_provider_entity.return_value = new_provider expected_tools = [Tool(name="test-tool", description="A test tool", inputSchema={})] - # First call raises auth error - mock_client_old.list_tools.side_effect = MCPAuthError("Auth failed") - mock_client_new.list_tools.return_value = expected_tools - # We need to mock the behavior where after client is replaced, - # the new method should be called. But since the method reference - # is already bound to the old client, we need to work around this. - # Let's patch the _execute_with_retry to handle this properly. + # Mock parent class list_tools to raise auth error first, then succeed + with patch.object(MCPClient, "list_tools") as mock_list_tools: + mock_list_tools.side_effect = [MCPAuthError("Auth failed"), expected_tools] - with patch.object(client, "_execute_with_retry") as mock_execute: - # Simulate the retry behavior - call_count = [0] - - def execute_side_effect(func, *args, **kwargs): - call_count[0] += 1 - if call_count[0] == 1: - # First call - simulate auth error and retry - try: - func(*args, **kwargs) - except MCPAuthError: - # Simulate the retry logic - client._handle_auth_error(MCPAuthError("Auth failed")) - if client._client: - client._client.cleanup() - client._client = mock_client_new - client._client.__enter__() - # Now return the result from the new client - return mock_client_new.list_tools(*args, **kwargs) - return func(*args, **kwargs) - - mock_execute.side_effect = execute_side_effect result = client.list_tools() - assert result == expected_tools - mock_client_old.list_tools.assert_called_once() - mock_client_new.list_tools.assert_called_once() - auth_callback.assert_called_once() - mock_client_old.cleanup.assert_called_once() - mock_client_new.__enter__.assert_called_once() + assert result == expected_tools + assert mock_list_tools.call_count == 2 + auth_callback.assert_called_once() def test_invoke_tool_not_initialized(self): """Test invoke_tool when client not initialized.""" @@ -418,28 +352,24 @@ class TestMCPClientWithAuthRetry: with pytest.raises(ValueError) as exc_info: client.invoke_tool("test-tool", {"arg": "value"}) - assert "Client not initialized" in str(exc_info.value) + assert "Session not initialized" in str(exc_info.value) def test_invoke_tool_success(self): """Test successful invoke_tool call.""" client = MCPClientWithAuthRetry(server_url="http://test.example.com") - mock_client = Mock() - client._client = mock_client expected_result = CallToolResult( content=[TextContent(type="text", text="Tool executed successfully")], isError=False ) - mock_client.invoke_tool.return_value = expected_result - result = client.invoke_tool("test-tool", {"arg": "value"}) + # Mock the parent class invoke_tool method + with patch.object(MCPClient, "invoke_tool", return_value=expected_result) as mock_invoke: + result = client.invoke_tool("test-tool", {"arg": "value"}) - assert result == expected_result - mock_client.invoke_tool.assert_called_once_with("test-tool", {"arg": "value"}) + assert result == expected_result + mock_invoke.assert_called_once_with("test-tool", {"arg": "value"}) - @patch("core.mcp.auth_client.MCPClient") - def test_invoke_tool_with_auth_retry( - self, mock_mcp_client_class, mock_provider_entity, mock_mcp_service, auth_callback - ): + def test_invoke_tool_with_auth_retry(self, mock_provider_entity, mock_mcp_service, auth_callback): """Test invoke_tool with auth retry.""" client = MCPClientWithAuthRetry( server_url="http://test.example.com", @@ -448,14 +378,6 @@ class TestMCPClientWithAuthRetry: mcp_service=mock_mcp_service, ) - # Configure mock clients (old and new) - mock_client_old = MagicMock() - mock_client_new = MagicMock() - client._client = mock_client_old - - # Make _create_client return the new client on retry - mock_mcp_client_class.return_value = mock_client_new - # Configure new provider with token new_provider = Mock(spec=MCPProviderEntity) new_provider.retrieve_tokens.return_value = Mock( @@ -464,54 +386,26 @@ class TestMCPClientWithAuthRetry: mock_mcp_service.get_provider_entity.return_value = new_provider expected_result = CallToolResult(content=[TextContent(type="text", text="Success")], isError=False) - # First call raises auth error - mock_client_old.invoke_tool.side_effect = MCPAuthError("Auth failed") - mock_client_new.invoke_tool.return_value = expected_result - # We need to mock the behavior where after client is replaced, - # the new method should be called. Similar to list_tools test. + # Mock parent class invoke_tool to raise auth error first, then succeed + with patch.object(MCPClient, "invoke_tool") as mock_invoke_tool: + mock_invoke_tool.side_effect = [MCPAuthError("Auth failed"), expected_result] - with patch.object(client, "_execute_with_retry") as mock_execute: - # Simulate the retry behavior - call_count = [0] - - def execute_side_effect(func, *args, **kwargs): - call_count[0] += 1 - if call_count[0] == 1: - # First call - simulate auth error and retry - try: - func(*args, **kwargs) - except MCPAuthError: - # Simulate the retry logic - client._handle_auth_error(MCPAuthError("Auth failed")) - if client._client: - client._client.cleanup() - client._client = mock_client_new - client._client.__enter__() - # Now return the result from the new client - return mock_client_new.invoke_tool(*args, **kwargs) - return func(*args, **kwargs) - - mock_execute.side_effect = execute_side_effect result = client.invoke_tool("test-tool", {"arg": "value"}) - assert result == expected_result - mock_client_old.invoke_tool.assert_called_once_with("test-tool", {"arg": "value"}) - mock_client_new.invoke_tool.assert_called_once_with("test-tool", {"arg": "value"}) - auth_callback.assert_called_once() - mock_client_old.cleanup.assert_called_once() - mock_client_new.__enter__.assert_called_once() + assert result == expected_result + assert mock_invoke_tool.call_count == 2 + mock_invoke_tool.assert_called_with("test-tool", {"arg": "value"}) + auth_callback.assert_called_once() def test_cleanup(self): """Test cleanup method.""" client = MCPClientWithAuthRetry(server_url="http://test.example.com") - mock_client = Mock() - client._client = mock_client - client.cleanup() - - mock_client.cleanup.assert_called_once() - assert client._client is None + # Mock the parent class cleanup method + with patch.object(MCPClient, "cleanup") as mock_cleanup: + client.cleanup() + mock_cleanup.assert_called_once() def test_cleanup_no_client(self): """Test cleanup when no client exists.""" @@ -520,4 +414,7 @@ class TestMCPClientWithAuthRetry: # Should not raise client.cleanup() - assert client._client is None + # Since MCPClientWithAuthRetry inherits from MCPClient, + # it doesn't have a _client attribute. The test should just + # verify that cleanup can be called without error. + assert not hasattr(client, "_client")