mirror of https://github.com/langgenius/dify.git
chore(mcp): fix pyright checks
This commit is contained in:
parent
5547247aa9
commit
685f199f91
|
|
@ -1,13 +1,14 @@
|
|||
"""Unit tests for MCP auth client with retry logic."""
|
||||
|
||||
from types import TracebackType
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.entities.mcp_provider import MCPProviderEntity
|
||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||
from core.mcp.error import MCPAuthError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
from core.mcp.types import CallToolResult, TextContent, Tool, ToolAnnotations
|
||||
|
||||
|
||||
|
|
@ -192,10 +193,7 @@ class TestMCPClientWithAuthRetry:
|
|||
mock_func.assert_called_once_with("arg1", kwarg1="value1")
|
||||
assert client._has_retried is False
|
||||
|
||||
@patch("core.mcp.auth_client.MCPClient")
|
||||
def test_execute_with_retry_auth_error_then_success(
|
||||
self, mock_mcp_client_class, mock_provider_entity, mock_mcp_service, auth_callback
|
||||
):
|
||||
def test_execute_with_retry_auth_error_then_success(self, mock_provider_entity, mock_mcp_service, auth_callback):
|
||||
"""Test execution with auth error followed by successful retry."""
|
||||
client = MCPClientWithAuthRetry(
|
||||
server_url="http://test.example.com",
|
||||
|
|
@ -204,14 +202,6 @@ class TestMCPClientWithAuthRetry:
|
|||
mcp_service=mock_mcp_service,
|
||||
)
|
||||
|
||||
# Configure mock clients (old and new)
|
||||
mock_client_old = MagicMock()
|
||||
mock_client_new = MagicMock()
|
||||
client._client = mock_client_old
|
||||
|
||||
# Make _create_client return the new client on retry
|
||||
mock_mcp_client_class.return_value = mock_client_new
|
||||
|
||||
# Configure new provider with token
|
||||
new_provider = Mock(spec=MCPProviderEntity)
|
||||
new_provider.retrieve_tokens.return_value = Mock(
|
||||
|
|
@ -222,15 +212,22 @@ class TestMCPClientWithAuthRetry:
|
|||
# Mock function that fails first, then succeeds
|
||||
mock_func = Mock(side_effect=[MCPAuthError("Auth failed"), "success"])
|
||||
|
||||
result = client._execute_with_retry(mock_func, "arg1", kwarg1="value1")
|
||||
# Mock the exit stack and session cleanup
|
||||
with (
|
||||
patch.object(client, "_exit_stack") as mock_exit_stack,
|
||||
patch.object(client, "_session") as mock_session,
|
||||
patch.object(client, "_initialize") as mock_initialize,
|
||||
):
|
||||
client._initialized = True
|
||||
result = client._execute_with_retry(mock_func, "arg1", kwarg1="value1")
|
||||
|
||||
assert result == "success"
|
||||
assert mock_func.call_count == 2
|
||||
mock_func.assert_called_with("arg1", kwarg1="value1")
|
||||
auth_callback.assert_called_once()
|
||||
mock_client_old.cleanup.assert_called_once()
|
||||
mock_client_new.__enter__.assert_called_once()
|
||||
assert client._has_retried is False # Reset after completion
|
||||
assert result == "success"
|
||||
assert mock_func.call_count == 2
|
||||
mock_func.assert_called_with("arg1", kwarg1="value1")
|
||||
auth_callback.assert_called_once()
|
||||
mock_exit_stack.close.assert_called_once()
|
||||
mock_initialize.assert_called_once()
|
||||
assert client._has_retried is False # Reset after completion
|
||||
|
||||
def test_execute_with_retry_non_auth_error(self):
|
||||
"""Test execution with non-auth error (no retry)."""
|
||||
|
|
@ -244,29 +241,19 @@ class TestMCPClientWithAuthRetry:
|
|||
assert str(exc_info.value) == "Some other error"
|
||||
mock_func.assert_called_once()
|
||||
|
||||
@patch("core.mcp.auth_client.MCPClient")
|
||||
def test_context_manager_enter(self, mock_mcp_client_class):
|
||||
def test_context_manager_enter(self):
|
||||
"""Test context manager enter."""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_client_instance.__enter__.return_value = mock_client_instance
|
||||
mock_mcp_client_class.return_value = mock_client_instance
|
||||
|
||||
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
|
||||
|
||||
result = client.__enter__()
|
||||
with patch.object(client, "_initialize") as mock_initialize:
|
||||
result = client.__enter__()
|
||||
|
||||
assert result == client
|
||||
assert client._client == mock_client_instance
|
||||
mock_client_instance.__enter__.assert_called_once()
|
||||
assert result == client
|
||||
assert client._initialized is True
|
||||
mock_initialize.assert_called_once()
|
||||
|
||||
@patch("core.mcp.auth_client.MCPClient")
|
||||
def test_context_manager_enter_with_auth_error(
|
||||
self, mock_mcp_client_class, mock_provider_entity, mock_mcp_service, auth_callback
|
||||
):
|
||||
def test_context_manager_enter_with_auth_error(self, mock_provider_entity, mock_mcp_service, auth_callback):
|
||||
"""Test context manager enter with auth error and retry."""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_mcp_client_class.return_value = mock_client_instance
|
||||
|
||||
# Configure new provider with token
|
||||
new_provider = Mock(spec=MCPProviderEntity)
|
||||
new_provider.retrieve_tokens.return_value = Mock(
|
||||
|
|
@ -281,37 +268,27 @@ class TestMCPClientWithAuthRetry:
|
|||
mcp_service=mock_mcp_service,
|
||||
)
|
||||
|
||||
# First call to client.__enter__ raises auth error, second succeeds
|
||||
call_count = 0
|
||||
# Mock parent class __enter__ to raise auth error first, then succeed
|
||||
with patch.object(MCPClient, "__enter__") as mock_parent_enter:
|
||||
mock_parent_enter.side_effect = [MCPAuthError("Auth failed"), client]
|
||||
|
||||
def enter_side_effect():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise MCPAuthError("Auth failed")
|
||||
return mock_client_instance
|
||||
result = client.__enter__()
|
||||
|
||||
mock_client_instance.__enter__.side_effect = enter_side_effect
|
||||
|
||||
result = client.__enter__()
|
||||
|
||||
assert result == client
|
||||
assert mock_client_instance.__enter__.call_count == 3
|
||||
auth_callback.assert_called_once()
|
||||
assert result == client
|
||||
assert mock_parent_enter.call_count == 2
|
||||
auth_callback.assert_called_once()
|
||||
|
||||
def test_context_manager_exit(self):
|
||||
"""Test context manager exit."""
|
||||
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
|
||||
mock_client = MagicMock()
|
||||
client._client = mock_client
|
||||
|
||||
exc_type: type[BaseException] | None = None
|
||||
exc_val: BaseException | None = None
|
||||
exc_tb: TracebackType | None = None
|
||||
client.__exit__(exc_type, exc_val, exc_tb)
|
||||
with patch.object(client, "cleanup") as mock_cleanup:
|
||||
exc_type: type[BaseException] | None = None
|
||||
exc_val: BaseException | None = None
|
||||
exc_tb: TracebackType | None = None
|
||||
client.__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
mock_client.__exit__.assert_called_once_with(None, None, None)
|
||||
assert client._client is None
|
||||
mock_cleanup.assert_called_once()
|
||||
|
||||
def test_list_tools_not_initialized(self):
|
||||
"""Test list_tools when client not initialized."""
|
||||
|
|
@ -320,13 +297,11 @@ class TestMCPClientWithAuthRetry:
|
|||
with pytest.raises(ValueError) as exc_info:
|
||||
client.list_tools()
|
||||
|
||||
assert "Client not initialized" in str(exc_info.value)
|
||||
assert "Session not initialized" in str(exc_info.value)
|
||||
|
||||
def test_list_tools_success(self):
|
||||
"""Test successful list_tools call."""
|
||||
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
|
||||
mock_client = Mock()
|
||||
client._client = mock_client
|
||||
|
||||
expected_tools = [
|
||||
Tool(
|
||||
|
|
@ -336,17 +311,13 @@ class TestMCPClientWithAuthRetry:
|
|||
annotations=ToolAnnotations(title="Test Tool"),
|
||||
)
|
||||
]
|
||||
mock_client.list_tools.return_value = expected_tools
|
||||
|
||||
result = client.list_tools()
|
||||
# Mock the parent class list_tools method
|
||||
with patch.object(MCPClient, "list_tools", return_value=expected_tools):
|
||||
result = client.list_tools()
|
||||
assert result == expected_tools
|
||||
|
||||
assert result == expected_tools
|
||||
mock_client.list_tools.assert_called_once()
|
||||
|
||||
@patch("core.mcp.auth_client.MCPClient")
|
||||
def test_list_tools_with_auth_retry(
|
||||
self, mock_mcp_client_class, mock_provider_entity, mock_mcp_service, auth_callback
|
||||
):
|
||||
def test_list_tools_with_auth_retry(self, mock_provider_entity, mock_mcp_service, auth_callback):
|
||||
"""Test list_tools with auth retry."""
|
||||
client = MCPClientWithAuthRetry(
|
||||
server_url="http://test.example.com",
|
||||
|
|
@ -355,14 +326,6 @@ class TestMCPClientWithAuthRetry:
|
|||
mcp_service=mock_mcp_service,
|
||||
)
|
||||
|
||||
# Configure mock clients (old and new)
|
||||
mock_client_old = MagicMock()
|
||||
mock_client_new = MagicMock()
|
||||
client._client = mock_client_old
|
||||
|
||||
# Make _create_client return the new client on retry
|
||||
mock_mcp_client_class.return_value = mock_client_new
|
||||
|
||||
# Configure new provider with token
|
||||
new_provider = Mock(spec=MCPProviderEntity)
|
||||
new_provider.retrieve_tokens.return_value = Mock(
|
||||
|
|
@ -371,45 +334,16 @@ class TestMCPClientWithAuthRetry:
|
|||
mock_mcp_service.get_provider_entity.return_value = new_provider
|
||||
|
||||
expected_tools = [Tool(name="test-tool", description="A test tool", inputSchema={})]
|
||||
# First call raises auth error
|
||||
mock_client_old.list_tools.side_effect = MCPAuthError("Auth failed")
|
||||
mock_client_new.list_tools.return_value = expected_tools
|
||||
|
||||
# We need to mock the behavior where after client is replaced,
|
||||
# the new method should be called. But since the method reference
|
||||
# is already bound to the old client, we need to work around this.
|
||||
# Let's patch the _execute_with_retry to handle this properly.
|
||||
# Mock parent class list_tools to raise auth error first, then succeed
|
||||
with patch.object(MCPClient, "list_tools") as mock_list_tools:
|
||||
mock_list_tools.side_effect = [MCPAuthError("Auth failed"), expected_tools]
|
||||
|
||||
with patch.object(client, "_execute_with_retry") as mock_execute:
|
||||
# Simulate the retry behavior
|
||||
call_count = [0]
|
||||
|
||||
def execute_side_effect(func, *args, **kwargs):
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
# First call - simulate auth error and retry
|
||||
try:
|
||||
func(*args, **kwargs)
|
||||
except MCPAuthError:
|
||||
# Simulate the retry logic
|
||||
client._handle_auth_error(MCPAuthError("Auth failed"))
|
||||
if client._client:
|
||||
client._client.cleanup()
|
||||
client._client = mock_client_new
|
||||
client._client.__enter__()
|
||||
# Now return the result from the new client
|
||||
return mock_client_new.list_tools(*args, **kwargs)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
mock_execute.side_effect = execute_side_effect
|
||||
result = client.list_tools()
|
||||
|
||||
assert result == expected_tools
|
||||
mock_client_old.list_tools.assert_called_once()
|
||||
mock_client_new.list_tools.assert_called_once()
|
||||
auth_callback.assert_called_once()
|
||||
mock_client_old.cleanup.assert_called_once()
|
||||
mock_client_new.__enter__.assert_called_once()
|
||||
assert result == expected_tools
|
||||
assert mock_list_tools.call_count == 2
|
||||
auth_callback.assert_called_once()
|
||||
|
||||
def test_invoke_tool_not_initialized(self):
|
||||
"""Test invoke_tool when client not initialized."""
|
||||
|
|
@ -418,28 +352,24 @@ class TestMCPClientWithAuthRetry:
|
|||
with pytest.raises(ValueError) as exc_info:
|
||||
client.invoke_tool("test-tool", {"arg": "value"})
|
||||
|
||||
assert "Client not initialized" in str(exc_info.value)
|
||||
assert "Session not initialized" in str(exc_info.value)
|
||||
|
||||
def test_invoke_tool_success(self):
|
||||
"""Test successful invoke_tool call."""
|
||||
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
|
||||
mock_client = Mock()
|
||||
client._client = mock_client
|
||||
|
||||
expected_result = CallToolResult(
|
||||
content=[TextContent(type="text", text="Tool executed successfully")], isError=False
|
||||
)
|
||||
mock_client.invoke_tool.return_value = expected_result
|
||||
|
||||
result = client.invoke_tool("test-tool", {"arg": "value"})
|
||||
# Mock the parent class invoke_tool method
|
||||
with patch.object(MCPClient, "invoke_tool", return_value=expected_result) as mock_invoke:
|
||||
result = client.invoke_tool("test-tool", {"arg": "value"})
|
||||
|
||||
assert result == expected_result
|
||||
mock_client.invoke_tool.assert_called_once_with("test-tool", {"arg": "value"})
|
||||
assert result == expected_result
|
||||
mock_invoke.assert_called_once_with("test-tool", {"arg": "value"})
|
||||
|
||||
@patch("core.mcp.auth_client.MCPClient")
|
||||
def test_invoke_tool_with_auth_retry(
|
||||
self, mock_mcp_client_class, mock_provider_entity, mock_mcp_service, auth_callback
|
||||
):
|
||||
def test_invoke_tool_with_auth_retry(self, mock_provider_entity, mock_mcp_service, auth_callback):
|
||||
"""Test invoke_tool with auth retry."""
|
||||
client = MCPClientWithAuthRetry(
|
||||
server_url="http://test.example.com",
|
||||
|
|
@ -448,14 +378,6 @@ class TestMCPClientWithAuthRetry:
|
|||
mcp_service=mock_mcp_service,
|
||||
)
|
||||
|
||||
# Configure mock clients (old and new)
|
||||
mock_client_old = MagicMock()
|
||||
mock_client_new = MagicMock()
|
||||
client._client = mock_client_old
|
||||
|
||||
# Make _create_client return the new client on retry
|
||||
mock_mcp_client_class.return_value = mock_client_new
|
||||
|
||||
# Configure new provider with token
|
||||
new_provider = Mock(spec=MCPProviderEntity)
|
||||
new_provider.retrieve_tokens.return_value = Mock(
|
||||
|
|
@ -464,54 +386,26 @@ class TestMCPClientWithAuthRetry:
|
|||
mock_mcp_service.get_provider_entity.return_value = new_provider
|
||||
|
||||
expected_result = CallToolResult(content=[TextContent(type="text", text="Success")], isError=False)
|
||||
# First call raises auth error
|
||||
mock_client_old.invoke_tool.side_effect = MCPAuthError("Auth failed")
|
||||
mock_client_new.invoke_tool.return_value = expected_result
|
||||
|
||||
# We need to mock the behavior where after client is replaced,
|
||||
# the new method should be called. Similar to list_tools test.
|
||||
# Mock parent class invoke_tool to raise auth error first, then succeed
|
||||
with patch.object(MCPClient, "invoke_tool") as mock_invoke_tool:
|
||||
mock_invoke_tool.side_effect = [MCPAuthError("Auth failed"), expected_result]
|
||||
|
||||
with patch.object(client, "_execute_with_retry") as mock_execute:
|
||||
# Simulate the retry behavior
|
||||
call_count = [0]
|
||||
|
||||
def execute_side_effect(func, *args, **kwargs):
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
# First call - simulate auth error and retry
|
||||
try:
|
||||
func(*args, **kwargs)
|
||||
except MCPAuthError:
|
||||
# Simulate the retry logic
|
||||
client._handle_auth_error(MCPAuthError("Auth failed"))
|
||||
if client._client:
|
||||
client._client.cleanup()
|
||||
client._client = mock_client_new
|
||||
client._client.__enter__()
|
||||
# Now return the result from the new client
|
||||
return mock_client_new.invoke_tool(*args, **kwargs)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
mock_execute.side_effect = execute_side_effect
|
||||
result = client.invoke_tool("test-tool", {"arg": "value"})
|
||||
|
||||
assert result == expected_result
|
||||
mock_client_old.invoke_tool.assert_called_once_with("test-tool", {"arg": "value"})
|
||||
mock_client_new.invoke_tool.assert_called_once_with("test-tool", {"arg": "value"})
|
||||
auth_callback.assert_called_once()
|
||||
mock_client_old.cleanup.assert_called_once()
|
||||
mock_client_new.__enter__.assert_called_once()
|
||||
assert result == expected_result
|
||||
assert mock_invoke_tool.call_count == 2
|
||||
mock_invoke_tool.assert_called_with("test-tool", {"arg": "value"})
|
||||
auth_callback.assert_called_once()
|
||||
|
||||
def test_cleanup(self):
|
||||
"""Test cleanup method."""
|
||||
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
|
||||
mock_client = Mock()
|
||||
client._client = mock_client
|
||||
|
||||
client.cleanup()
|
||||
|
||||
mock_client.cleanup.assert_called_once()
|
||||
assert client._client is None
|
||||
# Mock the parent class cleanup method
|
||||
with patch.object(MCPClient, "cleanup") as mock_cleanup:
|
||||
client.cleanup()
|
||||
mock_cleanup.assert_called_once()
|
||||
|
||||
def test_cleanup_no_client(self):
|
||||
"""Test cleanup when no client exists."""
|
||||
|
|
@ -520,4 +414,7 @@ class TestMCPClientWithAuthRetry:
|
|||
# Should not raise
|
||||
client.cleanup()
|
||||
|
||||
assert client._client is None
|
||||
# Since MCPClientWithAuthRetry inherits from MCPClient,
|
||||
# it doesn't have a _client attribute. The test should just
|
||||
# verify that cleanup can be called without error.
|
||||
assert not hasattr(client, "_client")
|
||||
|
|
|
|||
Loading…
Reference in New Issue