mirror of https://github.com/langgenius/dify.git
refactor(mcp): clean the auth code
This commit is contained in:
parent
8cf4a0d3ad
commit
ffd3a461f6
|
|
@ -30,7 +30,7 @@ from models.provider_ids import ToolProviderID
|
|||
from services.plugin.oauth_service import OAuthProxyService
|
||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService, OAuthDataType
|
||||
from services.tools.tool_labels_service import ToolLabelsService
|
||||
from services.tools.tools_manage_service import ToolCommonService
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
|
@ -897,10 +897,6 @@ class ToolProviderMCPApi(Resource):
|
|||
args = parser.parse_args()
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
# Validate server URL
|
||||
if not is_valid_url(args["server_url"]):
|
||||
raise ValueError("Server URL is not valid.")
|
||||
|
||||
# Parse and validate models
|
||||
configuration = MCPConfiguration.model_validate(args["configuration"])
|
||||
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
||||
|
|
@ -941,15 +937,21 @@ class ToolProviderMCPApi(Resource):
|
|||
.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if not is_valid_url(args["server_url"]):
|
||||
if "[__HIDDEN__]" in args["server_url"]:
|
||||
pass
|
||||
else:
|
||||
raise ValueError("Server URL is not valid.")
|
||||
configuration = MCPConfiguration.model_validate(args["configuration"])
|
||||
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# Step 1: Validate server URL change if needed (includes URL format validation and network operation)
|
||||
validation_result = None
|
||||
with Session(db.engine) as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
validation_result = service.validate_server_url_change(
|
||||
tenant_id=current_tenant_id, provider_id=args["provider_id"], new_server_url=args["server_url"]
|
||||
)
|
||||
|
||||
# No need to check for errors here, exceptions will be raised directly
|
||||
|
||||
# Step 2: Perform database update in a transaction
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.update_provider(
|
||||
|
|
@ -964,6 +966,7 @@ class ToolProviderMCPApi(Resource):
|
|||
headers=args["headers"],
|
||||
configuration=configuration,
|
||||
authentication=authentication,
|
||||
validation_result=validation_result,
|
||||
)
|
||||
return {"result": "success"}
|
||||
|
||||
|
|
@ -998,47 +1001,49 @@ class ToolMCPAuthApi(Resource):
|
|||
provider_id = args["provider_id"]
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
if not db_provider:
|
||||
raise ValueError("provider not found")
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
if not db_provider:
|
||||
raise ValueError("provider not found")
|
||||
|
||||
# Convert to entity
|
||||
provider_entity = db_provider.to_entity()
|
||||
server_url = provider_entity.decrypt_server_url()
|
||||
headers = provider_entity.decrypt_authentication()
|
||||
# Convert to entity
|
||||
provider_entity = db_provider.to_entity()
|
||||
server_url = provider_entity.decrypt_server_url()
|
||||
headers = provider_entity.decrypt_authentication()
|
||||
|
||||
# Try to connect without active transaction
|
||||
# Try to connect without active transaction
|
||||
try:
|
||||
# Use MCPClientWithAuthRetry to handle authentication automatically
|
||||
with MCPClient(
|
||||
server_url=server_url,
|
||||
headers=headers,
|
||||
timeout=provider_entity.timeout,
|
||||
sse_read_timeout=provider_entity.sse_read_timeout,
|
||||
):
|
||||
# Create new transaction for update
|
||||
with session.begin():
|
||||
service.update_provider_credentials(
|
||||
provider=db_provider,
|
||||
credentials=provider_entity.credentials,
|
||||
authed=True,
|
||||
)
|
||||
return {"result": "success"}
|
||||
except MCPAuthError as e:
|
||||
service = MCPToolManageService(session=session)
|
||||
try:
|
||||
# Use MCPClientWithAuthRetry to handle authentication automatically
|
||||
with MCPClient(
|
||||
server_url=server_url,
|
||||
headers=headers,
|
||||
timeout=provider_entity.timeout,
|
||||
sse_read_timeout=provider_entity.sse_read_timeout,
|
||||
):
|
||||
# Create new transaction for update
|
||||
with session.begin():
|
||||
service.update_provider_credentials(
|
||||
provider=db_provider,
|
||||
credentials=provider_entity.credentials,
|
||||
authed=True,
|
||||
)
|
||||
return {"result": "success"}
|
||||
except MCPAuthError as e:
|
||||
service = MCPToolManageService(session=session)
|
||||
try:
|
||||
return auth(provider_entity, service, args.get("authorization_code"))
|
||||
except MCPRefreshTokenError as e:
|
||||
with session.begin():
|
||||
service.clear_provider_credentials(provider=db_provider)
|
||||
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
|
||||
except MCPError as e:
|
||||
auth_result = auth(provider_entity, args.get("authorization_code"))
|
||||
with session.begin():
|
||||
response = service.execute_auth_actions(auth_result)
|
||||
return response
|
||||
except MCPRefreshTokenError as e:
|
||||
with session.begin():
|
||||
service.clear_provider_credentials(provider=db_provider)
|
||||
raise ValueError(f"Failed to connect to MCP server: {e}") from e
|
||||
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
|
||||
except MCPError as e:
|
||||
with session.begin():
|
||||
service.clear_provider_credentials(provider=db_provider)
|
||||
raise ValueError(f"Failed to connect to MCP server: {e}") from e
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/mcp/tools/<path:provider_id>")
|
||||
|
|
@ -1048,7 +1053,7 @@ class ToolMCPDetailApi(Resource):
|
|||
@account_initialization_required
|
||||
def get(self, provider_id):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
|
||||
|
|
@ -1062,7 +1067,7 @@ class ToolMCPListAllApi(Resource):
|
|||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
tools = service.list_providers(tenant_id=tenant_id)
|
||||
|
||||
|
|
@ -1100,6 +1105,11 @@ class ToolMCPCallbackApi(Resource):
|
|||
# Create service instance for handle_callback
|
||||
with Session(db.engine) as session, session.begin():
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
handle_callback(state_key, authorization_code, mcp_service)
|
||||
# handle_callback now returns state data and tokens
|
||||
state_data, tokens = handle_callback(state_key, authorization_code)
|
||||
# Save tokens using the service layer
|
||||
mcp_service.save_oauth_data(
|
||||
state_data.provider_id, state_data.tenant_id, tokens.model_dump(), OAuthDataType.TOKENS
|
||||
)
|
||||
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
|
|
|||
|
|
@ -4,14 +4,14 @@ import json
|
|||
import os
|
||||
import secrets
|
||||
import urllib.parse
|
||||
from typing import TYPE_CHECKING
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
from httpx import ConnectError, HTTPStatusError, RequestError
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.entities.mcp_provider import MCPProviderEntity, MCPSupportGrantType
|
||||
from core.helper import ssrf_proxy
|
||||
from core.mcp.entities import AuthAction, AuthActionType, AuthResult, OAuthCallbackState
|
||||
from core.mcp.error import MCPRefreshTokenError
|
||||
from core.mcp.types import (
|
||||
LATEST_PROTOCOL_VERSION,
|
||||
|
|
@ -23,23 +23,10 @@ from core.mcp.types import (
|
|||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
|
||||
OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry
|
||||
OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
|
||||
|
||||
|
||||
class OAuthCallbackState(BaseModel):
|
||||
provider_id: str
|
||||
tenant_id: str
|
||||
server_url: str
|
||||
metadata: OAuthMetadata | None = None
|
||||
client_information: OAuthClientInformation
|
||||
code_verifier: str
|
||||
redirect_uri: str
|
||||
|
||||
|
||||
def generate_pkce_challenge() -> tuple[str, str]:
|
||||
"""Generate PKCE challenge and verifier."""
|
||||
code_verifier = base64.urlsafe_b64encode(os.urandom(40)).decode("utf-8")
|
||||
|
|
@ -86,8 +73,13 @@ def _retrieve_redis_state(state_key: str) -> OAuthCallbackState:
|
|||
raise ValueError(f"Invalid state parameter: {str(e)}")
|
||||
|
||||
|
||||
def handle_callback(state_key: str, authorization_code: str, mcp_service: "MCPToolManageService") -> OAuthCallbackState:
|
||||
"""Handle the callback from the OAuth provider."""
|
||||
def handle_callback(state_key: str, authorization_code: str) -> tuple[OAuthCallbackState, OAuthTokens]:
|
||||
"""
|
||||
Handle the callback from the OAuth provider.
|
||||
|
||||
Returns:
|
||||
A tuple of (callback_state, tokens) that can be used by the caller to save data.
|
||||
"""
|
||||
# Retrieve state data from Redis (state is automatically deleted after retrieval)
|
||||
full_state_data = _retrieve_redis_state(state_key)
|
||||
|
||||
|
|
@ -100,10 +92,7 @@ def handle_callback(state_key: str, authorization_code: str, mcp_service: "MCPTo
|
|||
full_state_data.redirect_uri,
|
||||
)
|
||||
|
||||
# Save tokens using the service layer
|
||||
mcp_service.save_oauth_data(full_state_data.provider_id, full_state_data.tenant_id, tokens.model_dump(), "tokens")
|
||||
|
||||
return full_state_data
|
||||
return full_state_data, tokens
|
||||
|
||||
|
||||
def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
|
||||
|
|
@ -361,11 +350,24 @@ def register_client(
|
|||
|
||||
def auth(
|
||||
provider: MCPProviderEntity,
|
||||
mcp_service: "MCPToolManageService",
|
||||
authorization_code: str | None = None,
|
||||
state_param: str | None = None,
|
||||
) -> dict[str, str]:
|
||||
"""Orchestrates the full auth flow with a server using secure Redis state storage."""
|
||||
) -> AuthResult:
|
||||
"""
|
||||
Orchestrates the full auth flow with a server using secure Redis state storage.
|
||||
|
||||
This function performs only network operations and returns actions that need
|
||||
to be performed by the caller (such as saving data to database).
|
||||
|
||||
Args:
|
||||
provider: The MCP provider entity
|
||||
authorization_code: Optional authorization code from OAuth callback
|
||||
state_param: Optional state parameter from OAuth callback
|
||||
|
||||
Returns:
|
||||
AuthResult containing actions to be performed and response data
|
||||
"""
|
||||
actions: list[AuthAction] = []
|
||||
server_url = provider.decrypt_server_url()
|
||||
server_metadata = discover_oauth_metadata(server_url)
|
||||
client_metadata = provider.client_metadata
|
||||
|
|
@ -407,9 +409,14 @@ def auth(
|
|||
except RequestError as e:
|
||||
raise ValueError(f"Could not register OAuth client: {e}")
|
||||
|
||||
# Save client information using service layer
|
||||
mcp_service.save_oauth_data(
|
||||
provider_id, tenant_id, {"client_information": full_information.model_dump()}, "client_info"
|
||||
# Return action to save client information
|
||||
actions.append(
|
||||
AuthAction(
|
||||
action_type=AuthActionType.SAVE_CLIENT_INFO,
|
||||
data={"client_information": full_information.model_dump()},
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
client_information = full_information
|
||||
|
|
@ -426,12 +433,20 @@ def auth(
|
|||
scope,
|
||||
)
|
||||
|
||||
# Save tokens and grant type
|
||||
# Return action to save tokens and grant type
|
||||
token_data = tokens.model_dump()
|
||||
token_data["grant_type"] = MCPSupportGrantType.CLIENT_CREDENTIALS.value
|
||||
mcp_service.save_oauth_data(provider_id, tenant_id, token_data, "tokens")
|
||||
|
||||
return {"result": "success"}
|
||||
actions.append(
|
||||
AuthAction(
|
||||
action_type=AuthActionType.SAVE_TOKENS,
|
||||
data=token_data,
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
return AuthResult(actions=actions, response={"result": "success"})
|
||||
except (RequestError, ValueError, KeyError) as e:
|
||||
# RequestError: HTTP request failed
|
||||
# ValueError: Invalid response data
|
||||
|
|
@ -465,10 +480,17 @@ def auth(
|
|||
redirect_uri,
|
||||
)
|
||||
|
||||
# Save tokens using service layer
|
||||
mcp_service.save_oauth_data(provider_id, tenant_id, tokens.model_dump(), "tokens")
|
||||
# Return action to save tokens
|
||||
actions.append(
|
||||
AuthAction(
|
||||
action_type=AuthActionType.SAVE_TOKENS,
|
||||
data=tokens.model_dump(),
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
return AuthResult(actions=actions, response={"result": "success"})
|
||||
|
||||
provider_tokens = provider.retrieve_tokens()
|
||||
|
||||
|
|
@ -479,10 +501,17 @@ def auth(
|
|||
server_url, server_metadata, client_information, provider_tokens.refresh_token
|
||||
)
|
||||
|
||||
# Save new tokens using service layer
|
||||
mcp_service.save_oauth_data(provider_id, tenant_id, new_tokens.model_dump(), "tokens")
|
||||
# Return action to save new tokens
|
||||
actions.append(
|
||||
AuthAction(
|
||||
action_type=AuthActionType.SAVE_TOKENS,
|
||||
data=new_tokens.model_dump(),
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
return AuthResult(actions=actions, response={"result": "success"})
|
||||
except (RequestError, ValueError, KeyError) as e:
|
||||
# RequestError: HTTP request failed
|
||||
# ValueError: Invalid response data
|
||||
|
|
@ -499,7 +528,14 @@ def auth(
|
|||
tenant_id,
|
||||
)
|
||||
|
||||
# Save code verifier using service layer
|
||||
mcp_service.save_oauth_data(provider_id, tenant_id, {"code_verifier": code_verifier}, "code_verifier")
|
||||
# Return action to save code verifier
|
||||
actions.append(
|
||||
AuthAction(
|
||||
action_type=AuthActionType.SAVE_CODE_VERIFIER,
|
||||
data={"code_verifier": code_verifier},
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
return {"authorization_url": authorization_url}
|
||||
return AuthResult(actions=actions, response={"authorization_url": authorization_url})
|
||||
|
|
|
|||
|
|
@ -7,15 +7,15 @@ authentication failures and retries operations after refreshing tokens.
|
|||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.entities.mcp_provider import MCPProviderEntity
|
||||
from core.mcp.error import MCPAuthError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
from core.mcp.types import CallToolResult, Tool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
from extensions.ext_database import db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -26,6 +26,9 @@ class MCPClientWithAuthRetry(MCPClient):
|
|||
|
||||
This class extends MCPClient and intercepts MCPAuthError exceptions
|
||||
to refresh authentication before retrying failed operations.
|
||||
|
||||
Note: This class uses lazy session creation - database sessions are only
|
||||
created when authentication retry is actually needed, not on every request.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -35,11 +38,8 @@ class MCPClientWithAuthRetry(MCPClient):
|
|||
timeout: float | None = None,
|
||||
sse_read_timeout: float | None = None,
|
||||
provider_entity: MCPProviderEntity | None = None,
|
||||
auth_callback: Callable[[MCPProviderEntity, "MCPToolManageService", Optional[str]], dict[str, str]]
|
||||
| None = None,
|
||||
authorization_code: str | None = None,
|
||||
by_server_id: bool = False,
|
||||
mcp_service: Optional["MCPToolManageService"] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the MCP client with auth retry capability.
|
||||
|
|
@ -50,31 +50,30 @@ class MCPClientWithAuthRetry(MCPClient):
|
|||
timeout: Request timeout
|
||||
sse_read_timeout: SSE read timeout
|
||||
provider_entity: Provider entity for authentication
|
||||
auth_callback: Authentication callback function
|
||||
authorization_code: Optional authorization code for initial auth
|
||||
by_server_id: Whether to look up provider by server ID
|
||||
mcp_service: MCP service instance
|
||||
"""
|
||||
super().__init__(server_url, headers, timeout, sse_read_timeout)
|
||||
|
||||
self.provider_entity = provider_entity
|
||||
self.auth_callback = auth_callback
|
||||
self.authorization_code = authorization_code
|
||||
self.by_server_id = by_server_id
|
||||
self.mcp_service = mcp_service
|
||||
self._has_retried = False
|
||||
|
||||
def _handle_auth_error(self, error: MCPAuthError) -> None:
|
||||
"""
|
||||
Handle authentication error by refreshing tokens.
|
||||
|
||||
This method creates a short-lived database session only when authentication
|
||||
retry is needed, minimizing database connection hold time.
|
||||
|
||||
Args:
|
||||
error: The authentication error
|
||||
|
||||
Raises:
|
||||
MCPAuthError: If authentication fails or max retries reached
|
||||
"""
|
||||
if not self.provider_entity or not self.auth_callback or not self.mcp_service:
|
||||
if not self.provider_entity:
|
||||
raise error
|
||||
if self._has_retried:
|
||||
raise error
|
||||
|
|
@ -82,13 +81,23 @@ class MCPClientWithAuthRetry(MCPClient):
|
|||
self._has_retried = True
|
||||
|
||||
try:
|
||||
# Perform authentication
|
||||
self.auth_callback(self.provider_entity, self.mcp_service, self.authorization_code)
|
||||
# Create a temporary session only for auth retry
|
||||
# This session is short-lived and only exists during the auth operation
|
||||
|
||||
# Retrieve new tokens
|
||||
self.provider_entity = self.mcp_service.get_provider_entity(
|
||||
self.provider_entity.id, self.provider_entity.tenant_id, by_server_id=self.by_server_id
|
||||
)
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
|
||||
with Session(db.engine) as session, session.begin():
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
|
||||
# Perform authentication using the service's auth method
|
||||
mcp_service.auth_with_actions(self.provider_entity, self.authorization_code)
|
||||
|
||||
# Retrieve new tokens
|
||||
self.provider_entity = mcp_service.get_provider_entity(
|
||||
self.provider_entity.id, self.provider_entity.tenant_id, by_server_id=self.by_server_id
|
||||
)
|
||||
|
||||
# Session is closed here, before we update headers
|
||||
token = self.provider_entity.retrieve_tokens()
|
||||
if not token:
|
||||
raise MCPAuthError("Authentication failed - no token received")
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.mcp.session.base_session import BaseSession
|
||||
from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestId, RequestParams
|
||||
from core.mcp.types import LATEST_PROTOCOL_VERSION, OAuthClientInformation, OAuthMetadata, RequestId, RequestParams
|
||||
|
||||
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION]
|
||||
|
||||
|
|
@ -17,3 +20,41 @@ class RequestContext(Generic[SessionT, LifespanContextT]):
|
|||
meta: RequestParams.Meta | None
|
||||
session: SessionT
|
||||
lifespan_context: LifespanContextT
|
||||
|
||||
|
||||
class AuthActionType(StrEnum):
|
||||
"""Types of actions that can be performed during auth flow."""
|
||||
|
||||
SAVE_CLIENT_INFO = "save_client_info"
|
||||
SAVE_TOKENS = "save_tokens"
|
||||
SAVE_CODE_VERIFIER = "save_code_verifier"
|
||||
START_AUTHORIZATION = "start_authorization"
|
||||
SUCCESS = "success"
|
||||
|
||||
|
||||
class AuthAction(BaseModel):
|
||||
"""Represents an action that needs to be performed as a result of auth flow."""
|
||||
|
||||
action_type: AuthActionType
|
||||
data: dict[str, Any]
|
||||
provider_id: str | None = None
|
||||
tenant_id: str | None = None
|
||||
|
||||
|
||||
class AuthResult(BaseModel):
|
||||
"""Result of auth function containing actions to be performed and response data."""
|
||||
|
||||
actions: list[AuthAction]
|
||||
response: dict[str, str]
|
||||
|
||||
|
||||
class OAuthCallbackState(BaseModel):
|
||||
"""State data stored in Redis during OAuth callback flow."""
|
||||
|
||||
provider_id: str
|
||||
tenant_id: str
|
||||
server_url: str
|
||||
metadata: OAuthMetadata | None = None
|
||||
client_information: OAuthClientInformation
|
||||
code_verifier: str
|
||||
redirect_uri: str
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ import json
|
|||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.mcp.auth.auth_flow import auth
|
||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||
from core.mcp.error import MCPConnectionError
|
||||
from core.mcp.types import CallToolResult, ImageContent, TextContent
|
||||
|
|
@ -125,71 +124,39 @@ class MCPTool(Tool):
|
|||
headers = self.headers.copy() if self.headers else {}
|
||||
tool_parameters = self._handle_none_parameter(tool_parameters)
|
||||
|
||||
# Get provider entity to access tokens
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Get MCP service from invoke parameters or create new one
|
||||
provider_entity = None
|
||||
mcp_service = None
|
||||
from extensions.ext_database import db
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
|
||||
# Check if mcp_service is passed in tool_parameters
|
||||
if "_mcp_service" in tool_parameters:
|
||||
mcp_service = tool_parameters.pop("_mcp_service")
|
||||
if mcp_service:
|
||||
provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True)
|
||||
headers = provider_entity.decrypt_headers()
|
||||
# Try to get existing token and add to headers
|
||||
if not headers:
|
||||
tokens = provider_entity.retrieve_tokens()
|
||||
if tokens and tokens.access_token:
|
||||
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
|
||||
# Step 1: Load provider entity and credentials in a short-lived session
|
||||
# This minimizes database connection hold time
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True)
|
||||
|
||||
# Use MCPClientWithAuthRetry to handle authentication automatically
|
||||
try:
|
||||
with MCPClientWithAuthRetry(
|
||||
server_url=provider_entity.decrypt_server_url() if provider_entity else self.server_url,
|
||||
headers=headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
provider_entity=provider_entity,
|
||||
auth_callback=auth if mcp_service else None,
|
||||
mcp_service=mcp_service,
|
||||
) as mcp_client:
|
||||
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
|
||||
except MCPConnectionError as e:
|
||||
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
|
||||
except (ValueError, TypeError, KeyError) as e:
|
||||
# Catch specific exceptions that might occur during tool invocation
|
||||
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
|
||||
else:
|
||||
# Fallback to creating service with database session
|
||||
from sqlalchemy.orm import Session
|
||||
# Decrypt and prepare all credentials before closing session
|
||||
server_url = provider_entity.decrypt_server_url()
|
||||
headers = provider_entity.decrypt_headers()
|
||||
|
||||
from extensions.ext_database import db
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
# Try to get existing token and add to headers
|
||||
if not headers:
|
||||
tokens = provider_entity.retrieve_tokens()
|
||||
if tokens and tokens.access_token:
|
||||
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True)
|
||||
headers = provider_entity.decrypt_headers()
|
||||
# Try to get existing token and add to headers
|
||||
if not headers:
|
||||
tokens = provider_entity.retrieve_tokens()
|
||||
if tokens and tokens.access_token:
|
||||
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
|
||||
|
||||
# Use MCPClientWithAuthRetry to handle authentication automatically
|
||||
try:
|
||||
with MCPClientWithAuthRetry(
|
||||
server_url=provider_entity.decrypt_server_url() if provider_entity else self.server_url,
|
||||
headers=headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
provider_entity=provider_entity,
|
||||
auth_callback=auth if mcp_service else None,
|
||||
mcp_service=mcp_service,
|
||||
) as mcp_client:
|
||||
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
|
||||
except MCPConnectionError as e:
|
||||
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
|
||||
# Step 2: Session is now closed, perform network operations without holding database connection
|
||||
# MCPClientWithAuthRetry will create a new session lazily only if auth retry is needed
|
||||
try:
|
||||
with MCPClientWithAuthRetry(
|
||||
server_url=server_url,
|
||||
headers=headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
provider_entity=provider_entity,
|
||||
) as mcp_client:
|
||||
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
|
||||
except MCPConnectionError as e:
|
||||
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import or_, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
|
@ -12,6 +14,7 @@ from sqlalchemy.orm import Session
|
|||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity
|
||||
from core.helper import encrypter
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.mcp.auth.auth_flow import auth
|
||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||
from core.mcp.error import MCPAuthError, MCPError
|
||||
from core.tools.entities.api_entities import ToolProviderApiEntity
|
||||
|
|
@ -28,6 +31,38 @@ EMPTY_TOOLS_JSON = "[]"
|
|||
EMPTY_CREDENTIALS_JSON = "{}"
|
||||
|
||||
|
||||
class OAuthDataType(StrEnum):
|
||||
"""Types of OAuth data that can be saved."""
|
||||
|
||||
TOKENS = "tokens"
|
||||
CLIENT_INFO = "client_info"
|
||||
CODE_VERIFIER = "code_verifier"
|
||||
MIXED = "mixed"
|
||||
|
||||
|
||||
class ReconnectResult(BaseModel):
|
||||
"""Result of reconnecting to an MCP provider"""
|
||||
|
||||
authed: bool = Field(description="Whether the provider is authenticated")
|
||||
tools: str = Field(description="JSON string of tool list")
|
||||
encrypted_credentials: str = Field(description="JSON string of encrypted credentials")
|
||||
|
||||
|
||||
class ServerUrlValidationResult(BaseModel):
|
||||
"""Result of server URL validation check"""
|
||||
|
||||
needs_validation: bool
|
||||
validation_passed: bool = False
|
||||
reconnect_result: ReconnectResult | None = None
|
||||
encrypted_server_url: str | None = None
|
||||
server_url_hash: str | None = None
|
||||
|
||||
@property
|
||||
def should_update_server_url(self) -> bool:
|
||||
"""Check if server URL should be updated based on validation result"""
|
||||
return self.needs_validation and self.validation_passed and self.reconnect_result is not None
|
||||
|
||||
|
||||
class MCPToolManageService:
|
||||
"""Service class for managing MCP tools and providers."""
|
||||
|
||||
|
|
@ -91,6 +126,10 @@ class MCPToolManageService:
|
|||
headers: dict[str, str] | None = None,
|
||||
) -> ToolProviderApiEntity:
|
||||
"""Create a new MCP provider."""
|
||||
# Validate URL format
|
||||
if not self._is_valid_url(server_url):
|
||||
raise ValueError("Server URL is not valid.")
|
||||
|
||||
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
|
||||
|
||||
# Check for existing provider
|
||||
|
|
@ -99,13 +138,12 @@ class MCPToolManageService:
|
|||
# Encrypt sensitive data
|
||||
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
|
||||
encrypted_headers = self._prepare_encrypted_dict(headers, tenant_id) if headers else None
|
||||
if authentication is not None and authentication.client_id and authentication.client_secret:
|
||||
# Build the full credentials structure with encrypted client_id and client_secret
|
||||
encrypted_credentials = None
|
||||
if authentication is not None and authentication.client_id:
|
||||
encrypted_credentials = self._build_and_encrypt_credentials(
|
||||
authentication.client_id, authentication.client_secret, tenant_id
|
||||
)
|
||||
else:
|
||||
encrypted_credentials = None
|
||||
|
||||
# Create provider
|
||||
mcp_tool = MCPToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
|
|
@ -142,24 +180,39 @@ class MCPToolManageService:
|
|||
headers: dict[str, str] | None = None,
|
||||
configuration: MCPConfiguration,
|
||||
authentication: MCPAuthentication | None = None,
|
||||
validation_result: ServerUrlValidationResult | None = None,
|
||||
) -> None:
|
||||
"""Update an MCP provider."""
|
||||
"""
|
||||
Update an MCP provider.
|
||||
|
||||
Args:
|
||||
validation_result: Pre-validation result from validate_server_url_change.
|
||||
If provided and contains reconnect_result, it will be used
|
||||
instead of performing network operations.
|
||||
"""
|
||||
mcp_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
|
||||
reconnect_result = None
|
||||
# Check for duplicate name (excluding current provider)
|
||||
if name != mcp_provider.name:
|
||||
stmt = select(MCPToolProvider).where(
|
||||
MCPToolProvider.tenant_id == tenant_id,
|
||||
MCPToolProvider.name == name,
|
||||
MCPToolProvider.id != provider_id,
|
||||
)
|
||||
existing_provider = self._session.scalar(stmt)
|
||||
if existing_provider:
|
||||
raise ValueError(f"MCP tool {name} already exists")
|
||||
|
||||
# Get URL update data from validation result
|
||||
encrypted_server_url = None
|
||||
server_url_hash = None
|
||||
reconnect_result = None
|
||||
|
||||
# Handle server URL update
|
||||
if UNCHANGED_SERVER_URL_PLACEHOLDER not in server_url:
|
||||
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
|
||||
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
|
||||
|
||||
if server_url_hash != mcp_provider.server_url_hash:
|
||||
reconnect_result = self._reconnect_provider(
|
||||
server_url=server_url,
|
||||
provider=mcp_provider,
|
||||
)
|
||||
if validation_result and validation_result.encrypted_server_url:
|
||||
# Use all data from validation result
|
||||
encrypted_server_url = validation_result.encrypted_server_url
|
||||
server_url_hash = validation_result.server_url_hash
|
||||
reconnect_result = validation_result.reconnect_result
|
||||
|
||||
try:
|
||||
# Update basic fields
|
||||
|
|
@ -169,63 +222,35 @@ class MCPToolManageService:
|
|||
mcp_provider.server_identifier = server_identifier
|
||||
|
||||
# Update server URL if changed
|
||||
if encrypted_server_url is not None and server_url_hash is not None:
|
||||
if encrypted_server_url and server_url_hash:
|
||||
mcp_provider.server_url = encrypted_server_url
|
||||
mcp_provider.server_url_hash = server_url_hash
|
||||
|
||||
if reconnect_result:
|
||||
mcp_provider.authed = reconnect_result["authed"]
|
||||
mcp_provider.tools = reconnect_result["tools"]
|
||||
mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"]
|
||||
mcp_provider.authed = reconnect_result.authed
|
||||
mcp_provider.tools = reconnect_result.tools
|
||||
mcp_provider.encrypted_credentials = reconnect_result.encrypted_credentials
|
||||
|
||||
# Update optional fields
|
||||
if configuration.timeout is not None:
|
||||
mcp_provider.timeout = configuration.timeout
|
||||
if configuration.sse_read_timeout is not None:
|
||||
mcp_provider.sse_read_timeout = configuration.sse_read_timeout
|
||||
# Update optional configuration fields
|
||||
self._update_optional_fields(mcp_provider, configuration)
|
||||
|
||||
# Update headers if provided
|
||||
if headers is not None:
|
||||
if headers:
|
||||
# Build headers preserving unchanged masked values
|
||||
final_headers = self._merge_headers_with_masked(incoming_headers=headers, mcp_provider=mcp_provider)
|
||||
encrypted_headers_dict = self._prepare_encrypted_dict(final_headers, tenant_id)
|
||||
mcp_provider.encrypted_headers = encrypted_headers_dict
|
||||
else:
|
||||
# Clear headers if empty dict passed
|
||||
mcp_provider.encrypted_headers = None
|
||||
mcp_provider.encrypted_headers = self._process_headers(headers, mcp_provider, tenant_id)
|
||||
|
||||
# Update credentials if provided
|
||||
if authentication is not None and authentication.client_id and authentication.client_secret:
|
||||
# Merge with existing credentials to handle masked values
|
||||
(
|
||||
final_client_id,
|
||||
final_client_secret,
|
||||
) = self._merge_credentials_with_masked(
|
||||
authentication.client_id, authentication.client_secret, mcp_provider
|
||||
)
|
||||
if authentication and authentication.client_id:
|
||||
mcp_provider.encrypted_credentials = self._process_credentials(authentication, mcp_provider, tenant_id)
|
||||
|
||||
# Build and encrypt new credentials
|
||||
encrypted_credentials = self._build_and_encrypt_credentials(
|
||||
final_client_id, final_client_secret, tenant_id
|
||||
)
|
||||
mcp_provider.encrypted_credentials = encrypted_credentials
|
||||
|
||||
self._session.commit()
|
||||
# Flush changes to database
|
||||
self._session.flush()
|
||||
except IntegrityError as e:
|
||||
self._session.rollback()
|
||||
self._handle_integrity_error(e, name, server_url, server_identifier)
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
# Catch specific exceptions that might occur during update
|
||||
# ValueError: invalid data provided
|
||||
# AttributeError: missing required attributes
|
||||
# TypeError: type conversion errors
|
||||
self._session.rollback()
|
||||
raise
|
||||
|
||||
def delete_provider(self, *, tenant_id: str, provider_id: str) -> None:
|
||||
"""Delete an MCP provider."""
|
||||
mcp_tool = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
self._session.delete(mcp_tool)
|
||||
self._session.commit()
|
||||
|
||||
def list_providers(self, *, tenant_id: str, for_list: bool = False) -> list[ToolProviderApiEntity]:
|
||||
"""List all MCP providers for a tenant."""
|
||||
|
|
@ -241,8 +266,6 @@ class MCPToolManageService:
|
|||
|
||||
def list_provider_tools(self, *, tenant_id: str, provider_id: str) -> ToolProviderApiEntity:
|
||||
"""List tools from remote MCP server."""
|
||||
from core.mcp.auth.auth_flow import auth
|
||||
|
||||
# Load provider and convert to entity
|
||||
db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
provider_entity = db_provider.to_entity()
|
||||
|
|
@ -257,9 +280,7 @@ class MCPToolManageService:
|
|||
# Retrieve tools from remote server
|
||||
server_url = provider_entity.decrypt_server_url()
|
||||
try:
|
||||
tools = self._retrieve_remote_mcp_tools(
|
||||
server_url, headers, provider_entity, lambda p, s, c: auth(p, self, c)
|
||||
)
|
||||
tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity)
|
||||
except MCPError as e:
|
||||
raise ValueError(f"Failed to connect to MCP server: {e}")
|
||||
|
||||
|
|
@ -305,9 +326,12 @@ class MCPToolManageService:
|
|||
if not authed:
|
||||
provider.tools = EMPTY_TOOLS_JSON
|
||||
|
||||
self._session.commit()
|
||||
# Flush changes to database
|
||||
self._session.flush()
|
||||
|
||||
def save_oauth_data(self, provider_id: str, tenant_id: str, data: dict[str, Any], data_type: str = "mixed") -> None:
|
||||
def save_oauth_data(
|
||||
self, provider_id: str, tenant_id: str, data: dict[str, Any], data_type: OAuthDataType = OAuthDataType.MIXED
|
||||
) -> None:
|
||||
"""
|
||||
Save OAuth-related data (tokens, client info, code verifier).
|
||||
|
||||
|
|
@ -315,12 +339,14 @@ class MCPToolManageService:
|
|||
provider_id: Provider ID
|
||||
tenant_id: Tenant ID
|
||||
data: Data to save (tokens, client info, or code verifier)
|
||||
data_type: Type of data ('tokens', 'client_info', 'code_verifier', 'mixed')
|
||||
data_type: Type of OAuth data to save
|
||||
"""
|
||||
db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
|
||||
# Determine if this makes the provider authenticated
|
||||
authed = data_type == "tokens" or (data_type == "mixed" and "access_token" in data) or None
|
||||
authed = (
|
||||
data_type == OAuthDataType.TOKENS or (data_type == OAuthDataType.MIXED and "access_token" in data) or None
|
||||
)
|
||||
|
||||
self.update_provider_credentials(provider=db_provider, credentials=data, authed=authed)
|
||||
|
||||
|
|
@ -330,7 +356,6 @@ class MCPToolManageService:
|
|||
provider.encrypted_credentials = EMPTY_CREDENTIALS_JSON
|
||||
provider.updated_at = datetime.now()
|
||||
provider.authed = False
|
||||
self._session.commit()
|
||||
|
||||
# ========== Private Helper Methods ==========
|
||||
|
||||
|
|
@ -406,41 +431,123 @@ class MCPToolManageService:
|
|||
server_url: str,
|
||||
headers: dict[str, str],
|
||||
provider_entity: MCPProviderEntity,
|
||||
auth_callback: Callable[[MCPProviderEntity, "MCPToolManageService", str | None], dict[str, str]],
|
||||
):
|
||||
"""Retrieve tools from remote MCP server."""
|
||||
with MCPClientWithAuthRetry(
|
||||
server_url,
|
||||
server_url=server_url,
|
||||
headers=headers,
|
||||
timeout=provider_entity.timeout,
|
||||
sse_read_timeout=provider_entity.sse_read_timeout,
|
||||
provider_entity=provider_entity,
|
||||
auth_callback=auth_callback,
|
||||
mcp_service=self,
|
||||
) as mcp_client:
|
||||
return mcp_client.list_tools()
|
||||
|
||||
def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> dict[str, Any]:
|
||||
"""Attempt to reconnect to MCP provider with new server URL."""
|
||||
from core.mcp.auth.auth_flow import auth
|
||||
def execute_auth_actions(self, auth_result: Any) -> dict[str, str]:
|
||||
"""
|
||||
Execute the actions returned by the auth function.
|
||||
|
||||
This method processes the AuthResult and performs the necessary database operations.
|
||||
|
||||
Args:
|
||||
auth_result: The result from the auth function
|
||||
|
||||
Returns:
|
||||
The response from the auth result
|
||||
"""
|
||||
from core.mcp.entities import AuthAction, AuthActionType
|
||||
|
||||
action: AuthAction
|
||||
for action in auth_result.actions:
|
||||
if action.provider_id is None or action.tenant_id is None:
|
||||
continue
|
||||
|
||||
if action.action_type == AuthActionType.SAVE_CLIENT_INFO:
|
||||
self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.CLIENT_INFO)
|
||||
elif action.action_type == AuthActionType.SAVE_TOKENS:
|
||||
self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.TOKENS)
|
||||
elif action.action_type == AuthActionType.SAVE_CODE_VERIFIER:
|
||||
self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.CODE_VERIFIER)
|
||||
|
||||
return auth_result.response
|
||||
|
||||
def auth_with_actions(
|
||||
self, provider_entity: MCPProviderEntity, authorization_code: str | None = None
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Perform authentication and execute all resulting actions.
|
||||
|
||||
This method is used by MCPClientWithAuthRetry for automatic re-authentication.
|
||||
|
||||
Args:
|
||||
provider_entity: The MCP provider entity
|
||||
authorization_code: Optional authorization code
|
||||
|
||||
Returns:
|
||||
Response dictionary from auth result
|
||||
"""
|
||||
auth_result = auth(provider_entity, authorization_code)
|
||||
return self.execute_auth_actions(auth_result)
|
||||
|
||||
def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> ReconnectResult:
|
||||
"""Attempt to reconnect to MCP provider with new server URL."""
|
||||
provider_entity = provider.to_entity()
|
||||
headers = provider_entity.headers
|
||||
|
||||
try:
|
||||
tools = self._retrieve_remote_mcp_tools(
|
||||
server_url, headers, provider_entity, lambda p, s, c: auth(p, self, c)
|
||||
tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity)
|
||||
return ReconnectResult(
|
||||
authed=True,
|
||||
tools=json.dumps([tool.model_dump() for tool in tools]),
|
||||
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
|
||||
)
|
||||
return {
|
||||
"authed": True,
|
||||
"tools": json.dumps([tool.model_dump() for tool in tools]),
|
||||
"encrypted_credentials": EMPTY_CREDENTIALS_JSON,
|
||||
}
|
||||
except MCPAuthError:
|
||||
return {"authed": False, "tools": EMPTY_TOOLS_JSON, "encrypted_credentials": EMPTY_CREDENTIALS_JSON}
|
||||
return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON)
|
||||
except MCPError as e:
|
||||
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
|
||||
|
||||
def validate_server_url_change(
|
||||
self, *, tenant_id: str, provider_id: str, new_server_url: str
|
||||
) -> ServerUrlValidationResult:
|
||||
"""
|
||||
Validate server URL change by attempting to connect to the new server.
|
||||
This method should be called BEFORE update_provider to perform network operations
|
||||
outside of the database transaction.
|
||||
|
||||
Returns:
|
||||
ServerUrlValidationResult: Validation result with connection status and tools if successful
|
||||
"""
|
||||
# Handle hidden/unchanged URL
|
||||
if UNCHANGED_SERVER_URL_PLACEHOLDER in new_server_url:
|
||||
return ServerUrlValidationResult(needs_validation=False)
|
||||
|
||||
# Validate URL format
|
||||
if not self._is_valid_url(new_server_url):
|
||||
raise ValueError("Server URL is not valid.")
|
||||
|
||||
# Always encrypt and hash the URL
|
||||
encrypted_server_url = encrypter.encrypt_token(tenant_id, new_server_url)
|
||||
new_server_url_hash = hashlib.sha256(new_server_url.encode()).hexdigest()
|
||||
|
||||
# Get current provider
|
||||
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
|
||||
# Check if URL is actually different
|
||||
if new_server_url_hash == provider.server_url_hash:
|
||||
# URL hasn't changed, but still return the encrypted data
|
||||
return ServerUrlValidationResult(
|
||||
needs_validation=False, encrypted_server_url=encrypted_server_url, server_url_hash=new_server_url_hash
|
||||
)
|
||||
|
||||
# Perform validation by attempting to connect
|
||||
reconnect_result = self._reconnect_provider(server_url=new_server_url, provider=provider)
|
||||
return ServerUrlValidationResult(
|
||||
needs_validation=True,
|
||||
validation_passed=True,
|
||||
reconnect_result=reconnect_result,
|
||||
encrypted_server_url=encrypted_server_url,
|
||||
server_url_hash=new_server_url_hash,
|
||||
)
|
||||
|
||||
def _build_tool_provider_response(
|
||||
self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list
|
||||
) -> ToolProviderApiEntity:
|
||||
|
|
@ -466,6 +573,45 @@ class MCPToolManageService:
|
|||
raise ValueError(f"MCP tool {server_identifier} already exists")
|
||||
raise
|
||||
|
||||
def _is_valid_url(self, url: str) -> bool:
|
||||
"""Validate URL format."""
|
||||
if not url:
|
||||
return False
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
return all([parsed.scheme, parsed.netloc]) and parsed.scheme in ["http", "https"]
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
def _update_optional_fields(self, mcp_provider: MCPToolProvider, configuration: MCPConfiguration) -> None:
|
||||
"""Update optional configuration fields using setattr for cleaner code."""
|
||||
field_mapping = {"timeout": configuration.timeout, "sse_read_timeout": configuration.sse_read_timeout}
|
||||
|
||||
for field, value in field_mapping.items():
|
||||
if value is not None:
|
||||
setattr(mcp_provider, field, value)
|
||||
|
||||
def _process_headers(self, headers: dict[str, str], mcp_provider: MCPToolProvider, tenant_id: str) -> str | None:
|
||||
"""Process headers update, handling empty dict to clear headers."""
|
||||
if not headers:
|
||||
return None
|
||||
|
||||
# Merge with existing headers to preserve masked values
|
||||
final_headers = self._merge_headers_with_masked(incoming_headers=headers, mcp_provider=mcp_provider)
|
||||
return self._prepare_encrypted_dict(final_headers, tenant_id)
|
||||
|
||||
def _process_credentials(
|
||||
self, authentication: MCPAuthentication, mcp_provider: MCPToolProvider, tenant_id: str
|
||||
) -> str:
|
||||
"""Process credentials update, handling masked values."""
|
||||
# Merge with existing credentials
|
||||
final_client_id, final_client_secret = self._merge_credentials_with_masked(
|
||||
authentication.client_id, authentication.client_secret, mcp_provider
|
||||
)
|
||||
|
||||
# Build and encrypt
|
||||
return self._build_and_encrypt_credentials(final_client_id, final_client_secret, tenant_id)
|
||||
|
||||
def _merge_headers_with_masked(
|
||||
self, incoming_headers: dict[str, str], mcp_provider: MCPToolProvider
|
||||
) -> dict[str, str]:
|
||||
|
|
@ -530,12 +676,12 @@ class MCPToolManageService:
|
|||
# Create a flat structure with all credential data
|
||||
credentials_data = {
|
||||
"client_id": client_id,
|
||||
"encrypted_client_secret": client_secret,
|
||||
"client_name": CLIENT_NAME,
|
||||
"is_dynamic_registration": False,
|
||||
}
|
||||
|
||||
# Only client_id and client_secret need encryption
|
||||
secret_fields = ["encrypted_client_secret"] if client_secret else []
|
||||
secret_fields = []
|
||||
if client_secret is not None:
|
||||
credentials_data["encrypted_client_secret"] = encrypter.encrypt_token(tenant_id, client_secret)
|
||||
secret_fields = ["encrypted_client_secret"]
|
||||
client_info = self._encrypt_dict_fields(credentials_data, secret_fields, tenant_id)
|
||||
return json.dumps({"client_information": client_info})
|
||||
|
|
|
|||
|
|
@ -1108,75 +1108,6 @@ class TestMCPToolManageService:
|
|||
assert icon_data["content"] == "🚀"
|
||||
assert icon_data["background"] == "#4ECDC4"
|
||||
|
||||
def test_update_mcp_provider_with_server_url_change(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful update of MCP provider with server URL change.
|
||||
|
||||
This test verifies:
|
||||
- Proper handling of server URL changes
|
||||
- Correct reconnection logic
|
||||
- Database state updates
|
||||
- External service integration
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Create MCP provider
|
||||
mcp_provider = self._create_test_mcp_provider(
|
||||
db_session_with_containers, mock_external_service_dependencies, tenant.id, account.id
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# Mock the reconnection method
|
||||
with patch.object(MCPToolManageService, "_reconnect_provider") as mock_reconnect:
|
||||
mock_reconnect.return_value = {
|
||||
"authed": True,
|
||||
"tools": '[{"name": "test_tool"}]',
|
||||
"encrypted_credentials": "{}",
|
||||
}
|
||||
|
||||
# Act: Execute the method under test
|
||||
from core.entities.mcp_provider import MCPConfiguration
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = MCPToolManageService(db.session())
|
||||
service.update_provider(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=mcp_provider.id,
|
||||
name="Updated MCP Provider",
|
||||
server_url="https://new-example.com/mcp",
|
||||
icon="🚀",
|
||||
icon_type="emoji",
|
||||
icon_background="#4ECDC4",
|
||||
server_identifier="updated_identifier_123",
|
||||
configuration=MCPConfiguration(
|
||||
timeout=45.0,
|
||||
sse_read_timeout=400.0,
|
||||
),
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
db.session.refresh(mcp_provider)
|
||||
assert mcp_provider.name == "Updated MCP Provider"
|
||||
assert mcp_provider.server_identifier == "updated_identifier_123"
|
||||
assert mcp_provider.timeout == 45.0
|
||||
assert mcp_provider.sse_read_timeout == 400.0
|
||||
assert mcp_provider.updated_at is not None
|
||||
|
||||
# Verify reconnection was called
|
||||
mock_reconnect.assert_called_once_with(
|
||||
server_url="https://new-example.com/mcp",
|
||||
provider=mcp_provider,
|
||||
)
|
||||
|
||||
def test_update_mcp_provider_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test error handling when updating MCP provider with duplicate name.
|
||||
|
|
@ -1387,14 +1318,14 @@ class TestMCPToolManageService:
|
|||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert result["authed"] is True
|
||||
assert result["tools"] is not None
|
||||
assert result["encrypted_credentials"] == "{}"
|
||||
assert result.authed is True
|
||||
assert result.tools is not None
|
||||
assert result.encrypted_credentials == "{}"
|
||||
|
||||
# Verify tools were properly serialized
|
||||
import json
|
||||
|
||||
tools_data = json.loads(result["tools"])
|
||||
tools_data = json.loads(result.tools)
|
||||
assert len(tools_data) == 2
|
||||
assert tools_data[0]["name"] == "test_tool_1"
|
||||
assert tools_data[1]["name"] == "test_tool_2"
|
||||
|
|
@ -1441,9 +1372,9 @@ class TestMCPToolManageService:
|
|||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert result["authed"] is False
|
||||
assert result["tools"] == "[]"
|
||||
assert result["encrypted_credentials"] == "{}"
|
||||
assert result.authed is False
|
||||
assert result.tools == "[]"
|
||||
assert result.encrypted_credentials == "{}"
|
||||
|
||||
def test_re_connect_mcp_provider_connection_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from core.mcp.auth.auth_flow import (
|
|||
register_client,
|
||||
start_authorization,
|
||||
)
|
||||
from core.mcp.entities import AuthActionType, AuthResult
|
||||
from core.mcp.types import (
|
||||
OAuthClientInformation,
|
||||
OAuthClientInformationFull,
|
||||
|
|
@ -527,9 +528,10 @@ class TestCallbackHandling:
|
|||
# Setup service
|
||||
mock_service = Mock()
|
||||
|
||||
result = handle_callback("state-key", "auth-code", mock_service)
|
||||
state_result, tokens_result = handle_callback("state-key", "auth-code")
|
||||
|
||||
assert result == state_data
|
||||
assert state_result == state_data
|
||||
assert tokens_result == tokens
|
||||
|
||||
# Verify calls
|
||||
mock_retrieve_state.assert_called_once_with("state-key")
|
||||
|
|
@ -541,9 +543,8 @@ class TestCallbackHandling:
|
|||
"test-verifier",
|
||||
"https://redirect.example.com",
|
||||
)
|
||||
mock_service.save_oauth_data.assert_called_once_with(
|
||||
"test-provider", "test-tenant", tokens.model_dump(), "tokens"
|
||||
)
|
||||
# Note: handle_callback no longer saves tokens directly, it just returns them
|
||||
# The caller (e.g., controller) is responsible for saving via execute_auth_actions
|
||||
|
||||
|
||||
class TestAuthOrchestration:
|
||||
|
|
@ -589,21 +590,28 @@ class TestAuthOrchestration:
|
|||
)
|
||||
mock_start_auth.return_value = ("https://auth.example.com/authorize?...", "code-verifier")
|
||||
|
||||
result = auth(mock_provider, mock_service)
|
||||
result = auth(mock_provider)
|
||||
|
||||
assert result == {"authorization_url": "https://auth.example.com/authorize?..."}
|
||||
# auth() now returns AuthResult
|
||||
assert isinstance(result, AuthResult)
|
||||
assert result.response == {"authorization_url": "https://auth.example.com/authorize?..."}
|
||||
|
||||
# Verify that the result contains the correct actions
|
||||
assert len(result.actions) == 2
|
||||
# Check for SAVE_CLIENT_INFO action
|
||||
client_info_action = next(a for a in result.actions if a.action_type == AuthActionType.SAVE_CLIENT_INFO)
|
||||
assert client_info_action.data == {"client_information": mock_register.return_value.model_dump()}
|
||||
assert client_info_action.provider_id == "provider-id"
|
||||
assert client_info_action.tenant_id == "tenant-id"
|
||||
|
||||
# Check for SAVE_CODE_VERIFIER action
|
||||
verifier_action = next(a for a in result.actions if a.action_type == AuthActionType.SAVE_CODE_VERIFIER)
|
||||
assert verifier_action.data == {"code_verifier": "code-verifier"}
|
||||
assert verifier_action.provider_id == "provider-id"
|
||||
assert verifier_action.tenant_id == "tenant-id"
|
||||
|
||||
# Verify calls
|
||||
mock_register.assert_called_once()
|
||||
mock_service.save_oauth_data.assert_any_call(
|
||||
"provider-id",
|
||||
"tenant-id",
|
||||
{"client_information": mock_register.return_value.model_dump()},
|
||||
"client_info",
|
||||
)
|
||||
mock_service.save_oauth_data.assert_any_call(
|
||||
"provider-id", "tenant-id", {"code_verifier": "code-verifier"}, "code_verifier"
|
||||
)
|
||||
|
||||
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
|
||||
@patch("core.mcp.auth.auth_flow._retrieve_redis_state")
|
||||
|
|
@ -637,12 +645,18 @@ class TestAuthOrchestration:
|
|||
tokens = OAuthTokens(access_token="new-token", token_type="Bearer", expires_in=3600)
|
||||
mock_exchange.return_value = tokens
|
||||
|
||||
result = auth(mock_provider, mock_service, authorization_code="auth-code", state_param="state-key")
|
||||
result = auth(mock_provider, authorization_code="auth-code", state_param="state-key")
|
||||
|
||||
assert result == {"result": "success"}
|
||||
# auth() now returns AuthResult, not a dict
|
||||
assert isinstance(result, AuthResult)
|
||||
assert result.response == {"result": "success"}
|
||||
|
||||
# Verify token save
|
||||
mock_service.save_oauth_data.assert_called_with("provider-id", "tenant-id", tokens.model_dump(), "tokens")
|
||||
# Verify that the result contains the correct action
|
||||
assert len(result.actions) == 1
|
||||
assert result.actions[0].action_type == AuthActionType.SAVE_TOKENS
|
||||
assert result.actions[0].data == tokens.model_dump()
|
||||
assert result.actions[0].provider_id == "provider-id"
|
||||
assert result.actions[0].tenant_id == "tenant-id"
|
||||
|
||||
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
|
||||
def test_auth_exchange_code_without_state(self, mock_discover, mock_provider, mock_service):
|
||||
|
|
@ -658,7 +672,7 @@ class TestAuthOrchestration:
|
|||
mock_provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="existing-client")
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
auth(mock_provider, mock_service, authorization_code="auth-code")
|
||||
auth(mock_provider, authorization_code="auth-code")
|
||||
|
||||
assert "State parameter is required" in str(exc_info.value)
|
||||
|
||||
|
|
@ -691,15 +705,21 @@ class TestAuthOrchestration:
|
|||
grant_types_supported=["authorization_code"],
|
||||
)
|
||||
|
||||
result = auth(mock_provider, mock_service)
|
||||
result = auth(mock_provider)
|
||||
|
||||
assert result == {"result": "success"}
|
||||
# auth() now returns AuthResult
|
||||
assert isinstance(result, AuthResult)
|
||||
assert result.response == {"result": "success"}
|
||||
|
||||
# Verify that the result contains the correct action
|
||||
assert len(result.actions) == 1
|
||||
assert result.actions[0].action_type == AuthActionType.SAVE_TOKENS
|
||||
assert result.actions[0].data == new_tokens.model_dump()
|
||||
assert result.actions[0].provider_id == "provider-id"
|
||||
assert result.actions[0].tenant_id == "tenant-id"
|
||||
|
||||
# Verify refresh was called
|
||||
mock_refresh.assert_called_once()
|
||||
mock_service.save_oauth_data.assert_called_with(
|
||||
"provider-id", "tenant-id", new_tokens.model_dump(), "tokens"
|
||||
)
|
||||
|
||||
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
|
||||
def test_auth_registration_fails_with_code(self, mock_discover, mock_provider, mock_service):
|
||||
|
|
@ -715,6 +735,6 @@ class TestAuthOrchestration:
|
|||
mock_provider.retrieve_client_information.return_value = None
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
auth(mock_provider, mock_service, authorization_code="auth-code")
|
||||
auth(mock_provider, authorization_code="auth-code")
|
||||
|
||||
assert "Existing OAuth client information is required" in str(exc_info.value)
|
||||
|
|
|
|||
|
|
@ -1,420 +0,0 @@
|
|||
"""Unit tests for MCP auth client with retry logic."""
|
||||
|
||||
from types import TracebackType
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.entities.mcp_provider import MCPProviderEntity
|
||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||
from core.mcp.error import MCPAuthError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
from core.mcp.types import CallToolResult, TextContent, Tool, ToolAnnotations
|
||||
|
||||
|
||||
class TestMCPClientWithAuthRetry:
|
||||
"""Test suite for MCPClientWithAuthRetry."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider_entity(self):
|
||||
"""Create a mock provider entity."""
|
||||
provider = Mock(spec=MCPProviderEntity)
|
||||
provider.id = "test-provider-id"
|
||||
provider.tenant_id = "test-tenant-id"
|
||||
provider.retrieve_tokens.return_value = Mock(
|
||||
access_token="test-token", token_type="Bearer", expires_in=3600, refresh_token=None
|
||||
)
|
||||
return provider
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mcp_service(self):
|
||||
"""Create a mock MCP service."""
|
||||
service = Mock()
|
||||
service.get_provider_entity.return_value = Mock(
|
||||
retrieve_tokens=lambda: Mock(
|
||||
access_token="new-test-token", token_type="Bearer", expires_in=3600, refresh_token=None
|
||||
)
|
||||
)
|
||||
return service
|
||||
|
||||
@pytest.fixture
|
||||
def auth_callback(self):
|
||||
"""Create a mock auth callback."""
|
||||
return Mock()
|
||||
|
||||
def test_init(self, mock_provider_entity, mock_mcp_service, auth_callback):
|
||||
"""Test client initialization."""
|
||||
client = MCPClientWithAuthRetry(
|
||||
server_url="http://test.example.com",
|
||||
headers={"Authorization": "Bearer test"},
|
||||
timeout=30.0,
|
||||
sse_read_timeout=60.0,
|
||||
provider_entity=mock_provider_entity,
|
||||
auth_callback=auth_callback,
|
||||
authorization_code="test-auth-code",
|
||||
by_server_id=True,
|
||||
mcp_service=mock_mcp_service,
|
||||
)
|
||||
|
||||
assert client.server_url == "http://test.example.com"
|
||||
assert client.headers == {"Authorization": "Bearer test"}
|
||||
assert client.timeout == 30.0
|
||||
assert client.sse_read_timeout == 60.0
|
||||
assert client.provider_entity == mock_provider_entity
|
||||
assert client.auth_callback == auth_callback
|
||||
assert client.authorization_code == "test-auth-code"
|
||||
assert client.by_server_id is True
|
||||
assert client.mcp_service == mock_mcp_service
|
||||
assert client._has_retried is False
|
||||
# In inheritance design, we don't have _client attribute
|
||||
assert hasattr(client, "_session") # Inherited from MCPClient
|
||||
|
||||
def test_inheritance_structure(self):
|
||||
"""Test that MCPClientWithAuthRetry properly inherits from MCPClient."""
|
||||
client = MCPClientWithAuthRetry(
|
||||
server_url="http://test.example.com",
|
||||
headers={"Authorization": "Bearer test"},
|
||||
)
|
||||
|
||||
# Verify inheritance
|
||||
assert isinstance(client, MCPClient)
|
||||
|
||||
# Verify inherited attributes are accessible
|
||||
assert hasattr(client, "server_url")
|
||||
assert hasattr(client, "headers")
|
||||
assert hasattr(client, "_session")
|
||||
assert hasattr(client, "_exit_stack")
|
||||
assert hasattr(client, "_initialized")
|
||||
|
||||
def test_handle_auth_error_no_retry_components(self):
|
||||
"""Test auth error handling when retry components are missing."""
|
||||
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
|
||||
error = MCPAuthError("Auth failed")
|
||||
|
||||
with pytest.raises(MCPAuthError) as exc_info:
|
||||
client._handle_auth_error(error)
|
||||
|
||||
assert exc_info.value == error
|
||||
|
||||
def test_handle_auth_error_already_retried(self, mock_provider_entity, mock_mcp_service, auth_callback):
|
||||
"""Test auth error handling when already retried."""
|
||||
client = MCPClientWithAuthRetry(
|
||||
server_url="http://test.example.com",
|
||||
provider_entity=mock_provider_entity,
|
||||
auth_callback=auth_callback,
|
||||
mcp_service=mock_mcp_service,
|
||||
)
|
||||
client._has_retried = True
|
||||
error = MCPAuthError("Auth failed")
|
||||
|
||||
with pytest.raises(MCPAuthError) as exc_info:
|
||||
client._handle_auth_error(error)
|
||||
|
||||
assert exc_info.value == error
|
||||
auth_callback.assert_not_called()
|
||||
|
||||
def test_handle_auth_error_successful_refresh(self, mock_provider_entity, mock_mcp_service, auth_callback):
|
||||
"""Test successful auth refresh on error."""
|
||||
client = MCPClientWithAuthRetry(
|
||||
server_url="http://test.example.com",
|
||||
provider_entity=mock_provider_entity,
|
||||
auth_callback=auth_callback,
|
||||
authorization_code="test-code",
|
||||
by_server_id=True,
|
||||
mcp_service=mock_mcp_service,
|
||||
)
|
||||
|
||||
# Configure mocks
|
||||
new_provider = Mock(spec=MCPProviderEntity)
|
||||
new_provider.id = "test-provider-id"
|
||||
new_provider.tenant_id = "test-tenant-id"
|
||||
new_provider.retrieve_tokens.return_value = Mock(
|
||||
access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None
|
||||
)
|
||||
mock_mcp_service.get_provider_entity.return_value = new_provider
|
||||
|
||||
error = MCPAuthError("Auth failed")
|
||||
client._handle_auth_error(error)
|
||||
|
||||
# Verify auth flow
|
||||
auth_callback.assert_called_once_with(mock_provider_entity, mock_mcp_service, "test-code")
|
||||
mock_mcp_service.get_provider_entity.assert_called_once_with(
|
||||
"test-provider-id", "test-tenant-id", by_server_id=True
|
||||
)
|
||||
assert client.headers["Authorization"] == "Bearer new-token"
|
||||
assert client.authorization_code is None # Should be cleared after use
|
||||
assert client._has_retried is True
|
||||
|
||||
def test_handle_auth_error_refresh_fails(self, mock_provider_entity, mock_mcp_service, auth_callback):
|
||||
"""Test auth refresh failure."""
|
||||
client = MCPClientWithAuthRetry(
|
||||
server_url="http://test.example.com",
|
||||
provider_entity=mock_provider_entity,
|
||||
auth_callback=auth_callback,
|
||||
mcp_service=mock_mcp_service,
|
||||
)
|
||||
|
||||
auth_callback.side_effect = Exception("Auth callback failed")
|
||||
|
||||
error = MCPAuthError("Original auth failed")
|
||||
with pytest.raises(MCPAuthError) as exc_info:
|
||||
client._handle_auth_error(error)
|
||||
|
||||
assert "Authentication retry failed" in str(exc_info.value)
|
||||
|
||||
def test_handle_auth_error_no_token_received(self, mock_provider_entity, mock_mcp_service, auth_callback):
|
||||
"""Test auth refresh when no token is received."""
|
||||
client = MCPClientWithAuthRetry(
|
||||
server_url="http://test.example.com",
|
||||
provider_entity=mock_provider_entity,
|
||||
auth_callback=auth_callback,
|
||||
mcp_service=mock_mcp_service,
|
||||
)
|
||||
|
||||
# Configure mock to return no token
|
||||
new_provider = Mock(spec=MCPProviderEntity)
|
||||
new_provider.retrieve_tokens.return_value = None
|
||||
mock_mcp_service.get_provider_entity.return_value = new_provider
|
||||
|
||||
error = MCPAuthError("Auth failed")
|
||||
with pytest.raises(MCPAuthError) as exc_info:
|
||||
client._handle_auth_error(error)
|
||||
|
||||
assert "no token received" in str(exc_info.value)
|
||||
|
||||
def test_execute_with_retry_success(self):
|
||||
"""Test successful execution without retry."""
|
||||
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
|
||||
|
||||
mock_func = Mock(return_value="success")
|
||||
result = client._execute_with_retry(mock_func, "arg1", kwarg1="value1")
|
||||
|
||||
assert result == "success"
|
||||
mock_func.assert_called_once_with("arg1", kwarg1="value1")
|
||||
assert client._has_retried is False
|
||||
|
||||
def test_execute_with_retry_auth_error_then_success(self, mock_provider_entity, mock_mcp_service, auth_callback):
|
||||
"""Test execution with auth error followed by successful retry."""
|
||||
client = MCPClientWithAuthRetry(
|
||||
server_url="http://test.example.com",
|
||||
provider_entity=mock_provider_entity,
|
||||
auth_callback=auth_callback,
|
||||
mcp_service=mock_mcp_service,
|
||||
)
|
||||
|
||||
# Configure new provider with token
|
||||
new_provider = Mock(spec=MCPProviderEntity)
|
||||
new_provider.retrieve_tokens.return_value = Mock(
|
||||
access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None
|
||||
)
|
||||
mock_mcp_service.get_provider_entity.return_value = new_provider
|
||||
|
||||
# Mock function that fails first, then succeeds
|
||||
mock_func = Mock(side_effect=[MCPAuthError("Auth failed"), "success"])
|
||||
|
||||
# Mock the exit stack and session cleanup
|
||||
with (
|
||||
patch.object(client, "_exit_stack") as mock_exit_stack,
|
||||
patch.object(client, "_session") as mock_session,
|
||||
patch.object(client, "_initialize") as mock_initialize,
|
||||
):
|
||||
client._initialized = True
|
||||
result = client._execute_with_retry(mock_func, "arg1", kwarg1="value1")
|
||||
|
||||
assert result == "success"
|
||||
assert mock_func.call_count == 2
|
||||
mock_func.assert_called_with("arg1", kwarg1="value1")
|
||||
auth_callback.assert_called_once()
|
||||
mock_exit_stack.close.assert_called_once()
|
||||
mock_initialize.assert_called_once()
|
||||
assert client._has_retried is False # Reset after completion
|
||||
|
||||
def test_execute_with_retry_non_auth_error(self):
|
||||
"""Test execution with non-auth error (no retry)."""
|
||||
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
|
||||
|
||||
mock_func = Mock(side_effect=ValueError("Some other error"))
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
client._execute_with_retry(mock_func)
|
||||
|
||||
assert str(exc_info.value) == "Some other error"
|
||||
mock_func.assert_called_once()
|
||||
|
||||
def test_context_manager_enter(self):
|
||||
"""Test context manager enter."""
|
||||
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
|
||||
|
||||
with patch.object(client, "_initialize") as mock_initialize:
|
||||
result = client.__enter__()
|
||||
|
||||
assert result == client
|
||||
assert client._initialized is True
|
||||
mock_initialize.assert_called_once()
|
||||
|
||||
def test_context_manager_enter_with_auth_error(self, mock_provider_entity, mock_mcp_service, auth_callback):
|
||||
"""Test context manager enter with auth error and retry."""
|
||||
# Configure new provider with token
|
||||
new_provider = Mock(spec=MCPProviderEntity)
|
||||
new_provider.retrieve_tokens.return_value = Mock(
|
||||
access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None
|
||||
)
|
||||
mock_mcp_service.get_provider_entity.return_value = new_provider
|
||||
|
||||
client = MCPClientWithAuthRetry(
|
||||
server_url="http://test.example.com",
|
||||
provider_entity=mock_provider_entity,
|
||||
auth_callback=auth_callback,
|
||||
mcp_service=mock_mcp_service,
|
||||
)
|
||||
|
||||
# Mock parent class __enter__ to raise auth error first, then succeed
|
||||
with patch.object(MCPClient, "__enter__") as mock_parent_enter:
|
||||
mock_parent_enter.side_effect = [MCPAuthError("Auth failed"), client]
|
||||
|
||||
result = client.__enter__()
|
||||
|
||||
assert result == client
|
||||
assert mock_parent_enter.call_count == 2
|
||||
auth_callback.assert_called_once()
|
||||
|
||||
def test_context_manager_exit(self):
|
||||
"""Test context manager exit."""
|
||||
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
|
||||
|
||||
with patch.object(client, "cleanup") as mock_cleanup:
|
||||
exc_type: type[BaseException] | None = None
|
||||
exc_val: BaseException | None = None
|
||||
exc_tb: TracebackType | None = None
|
||||
client.__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
mock_cleanup.assert_called_once()
|
||||
|
||||
def test_list_tools_not_initialized(self):
|
||||
"""Test list_tools when client not initialized."""
|
||||
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
client.list_tools()
|
||||
|
||||
assert "Session not initialized" in str(exc_info.value)
|
||||
|
||||
def test_list_tools_success(self):
|
||||
"""Test successful list_tools call."""
|
||||
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
|
||||
|
||||
expected_tools = [
|
||||
Tool(
|
||||
name="test-tool",
|
||||
description="A test tool",
|
||||
inputSchema={"type": "object", "properties": {}},
|
||||
annotations=ToolAnnotations(title="Test Tool"),
|
||||
)
|
||||
]
|
||||
|
||||
# Mock the parent class list_tools method
|
||||
with patch.object(MCPClient, "list_tools", return_value=expected_tools):
|
||||
result = client.list_tools()
|
||||
assert result == expected_tools
|
||||
|
||||
def test_list_tools_with_auth_retry(self, mock_provider_entity, mock_mcp_service, auth_callback):
|
||||
"""Test list_tools with auth retry."""
|
||||
client = MCPClientWithAuthRetry(
|
||||
server_url="http://test.example.com",
|
||||
provider_entity=mock_provider_entity,
|
||||
auth_callback=auth_callback,
|
||||
mcp_service=mock_mcp_service,
|
||||
)
|
||||
|
||||
# Configure new provider with token
|
||||
new_provider = Mock(spec=MCPProviderEntity)
|
||||
new_provider.retrieve_tokens.return_value = Mock(
|
||||
access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None
|
||||
)
|
||||
mock_mcp_service.get_provider_entity.return_value = new_provider
|
||||
|
||||
expected_tools = [Tool(name="test-tool", description="A test tool", inputSchema={})]
|
||||
|
||||
# Mock parent class list_tools to raise auth error first, then succeed
|
||||
with patch.object(MCPClient, "list_tools") as mock_list_tools:
|
||||
mock_list_tools.side_effect = [MCPAuthError("Auth failed"), expected_tools]
|
||||
|
||||
result = client.list_tools()
|
||||
|
||||
assert result == expected_tools
|
||||
assert mock_list_tools.call_count == 2
|
||||
auth_callback.assert_called_once()
|
||||
|
||||
def test_invoke_tool_not_initialized(self):
|
||||
"""Test invoke_tool when client not initialized."""
|
||||
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
client.invoke_tool("test-tool", {"arg": "value"})
|
||||
|
||||
assert "Session not initialized" in str(exc_info.value)
|
||||
|
||||
def test_invoke_tool_success(self):
|
||||
"""Test successful invoke_tool call."""
|
||||
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
|
||||
|
||||
expected_result = CallToolResult(
|
||||
content=[TextContent(type="text", text="Tool executed successfully")], isError=False
|
||||
)
|
||||
|
||||
# Mock the parent class invoke_tool method
|
||||
with patch.object(MCPClient, "invoke_tool", return_value=expected_result) as mock_invoke:
|
||||
result = client.invoke_tool("test-tool", {"arg": "value"})
|
||||
|
||||
assert result == expected_result
|
||||
mock_invoke.assert_called_once_with("test-tool", {"arg": "value"})
|
||||
|
||||
def test_invoke_tool_with_auth_retry(self, mock_provider_entity, mock_mcp_service, auth_callback):
|
||||
"""Test invoke_tool with auth retry."""
|
||||
client = MCPClientWithAuthRetry(
|
||||
server_url="http://test.example.com",
|
||||
provider_entity=mock_provider_entity,
|
||||
auth_callback=auth_callback,
|
||||
mcp_service=mock_mcp_service,
|
||||
)
|
||||
|
||||
# Configure new provider with token
|
||||
new_provider = Mock(spec=MCPProviderEntity)
|
||||
new_provider.retrieve_tokens.return_value = Mock(
|
||||
access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None
|
||||
)
|
||||
mock_mcp_service.get_provider_entity.return_value = new_provider
|
||||
|
||||
expected_result = CallToolResult(content=[TextContent(type="text", text="Success")], isError=False)
|
||||
|
||||
# Mock parent class invoke_tool to raise auth error first, then succeed
|
||||
with patch.object(MCPClient, "invoke_tool") as mock_invoke_tool:
|
||||
mock_invoke_tool.side_effect = [MCPAuthError("Auth failed"), expected_result]
|
||||
|
||||
result = client.invoke_tool("test-tool", {"arg": "value"})
|
||||
|
||||
assert result == expected_result
|
||||
assert mock_invoke_tool.call_count == 2
|
||||
mock_invoke_tool.assert_called_with("test-tool", {"arg": "value"})
|
||||
auth_callback.assert_called_once()
|
||||
|
||||
def test_cleanup(self):
|
||||
"""Test cleanup method."""
|
||||
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
|
||||
|
||||
# Mock the parent class cleanup method
|
||||
with patch.object(MCPClient, "cleanup") as mock_cleanup:
|
||||
client.cleanup()
|
||||
mock_cleanup.assert_called_once()
|
||||
|
||||
def test_cleanup_no_client(self):
|
||||
"""Test cleanup when no client exists."""
|
||||
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
|
||||
|
||||
# Should not raise
|
||||
client.cleanup()
|
||||
|
||||
# Since MCPClientWithAuthRetry inherits from MCPClient,
|
||||
# it doesn't have a _client attribute. The test should just
|
||||
# verify that cleanup can be called without error.
|
||||
assert not hasattr(client, "_client")
|
||||
Loading…
Reference in New Issue