mirror of
https://github.com/langgenius/dify.git
synced 2026-04-26 10:16:40 +08:00
refactor(mcp): clean the auth code
This commit is contained in:
parent
8cf4a0d3ad
commit
ffd3a461f6
@ -30,7 +30,7 @@ from models.provider_ids import ToolProviderID
|
|||||||
from services.plugin.oauth_service import OAuthProxyService
|
from services.plugin.oauth_service import OAuthProxyService
|
||||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
from services.tools.mcp_tools_manage_service import MCPToolManageService, OAuthDataType
|
||||||
from services.tools.tool_labels_service import ToolLabelsService
|
from services.tools.tool_labels_service import ToolLabelsService
|
||||||
from services.tools.tools_manage_service import ToolCommonService
|
from services.tools.tools_manage_service import ToolCommonService
|
||||||
from services.tools.tools_transform_service import ToolTransformService
|
from services.tools.tools_transform_service import ToolTransformService
|
||||||
@ -897,10 +897,6 @@ class ToolProviderMCPApi(Resource):
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
user, tenant_id = current_account_with_tenant()
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
# Validate server URL
|
|
||||||
if not is_valid_url(args["server_url"]):
|
|
||||||
raise ValueError("Server URL is not valid.")
|
|
||||||
|
|
||||||
# Parse and validate models
|
# Parse and validate models
|
||||||
configuration = MCPConfiguration.model_validate(args["configuration"])
|
configuration = MCPConfiguration.model_validate(args["configuration"])
|
||||||
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
||||||
@ -941,15 +937,21 @@ class ToolProviderMCPApi(Resource):
|
|||||||
.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
|
.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if not is_valid_url(args["server_url"]):
|
|
||||||
if "[__HIDDEN__]" in args["server_url"]:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
raise ValueError("Server URL is not valid.")
|
|
||||||
configuration = MCPConfiguration.model_validate(args["configuration"])
|
configuration = MCPConfiguration.model_validate(args["configuration"])
|
||||||
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
||||||
_, current_tenant_id = current_account_with_tenant()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
|
# Step 1: Validate server URL change if needed (includes URL format validation and network operation)
|
||||||
|
validation_result = None
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
service = MCPToolManageService(session=session)
|
||||||
|
validation_result = service.validate_server_url_change(
|
||||||
|
tenant_id=current_tenant_id, provider_id=args["provider_id"], new_server_url=args["server_url"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# No need to check for errors here, exceptions will be raised directly
|
||||||
|
|
||||||
|
# Step 2: Perform database update in a transaction
|
||||||
with Session(db.engine) as session, session.begin():
|
with Session(db.engine) as session, session.begin():
|
||||||
service = MCPToolManageService(session=session)
|
service = MCPToolManageService(session=session)
|
||||||
service.update_provider(
|
service.update_provider(
|
||||||
@ -964,6 +966,7 @@ class ToolProviderMCPApi(Resource):
|
|||||||
headers=args["headers"],
|
headers=args["headers"],
|
||||||
configuration=configuration,
|
configuration=configuration,
|
||||||
authentication=authentication,
|
authentication=authentication,
|
||||||
|
validation_result=validation_result,
|
||||||
)
|
)
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
@ -998,47 +1001,49 @@ class ToolMCPAuthApi(Resource):
|
|||||||
provider_id = args["provider_id"]
|
provider_id = args["provider_id"]
|
||||||
_, tenant_id = current_account_with_tenant()
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session, session.begin():
|
||||||
with session.begin():
|
service = MCPToolManageService(session=session)
|
||||||
service = MCPToolManageService(session=session)
|
db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||||
db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
if not db_provider:
|
||||||
if not db_provider:
|
raise ValueError("provider not found")
|
||||||
raise ValueError("provider not found")
|
|
||||||
|
|
||||||
# Convert to entity
|
# Convert to entity
|
||||||
provider_entity = db_provider.to_entity()
|
provider_entity = db_provider.to_entity()
|
||||||
server_url = provider_entity.decrypt_server_url()
|
server_url = provider_entity.decrypt_server_url()
|
||||||
headers = provider_entity.decrypt_authentication()
|
headers = provider_entity.decrypt_authentication()
|
||||||
|
|
||||||
# Try to connect without active transaction
|
# Try to connect without active transaction
|
||||||
|
try:
|
||||||
|
# Use MCPClientWithAuthRetry to handle authentication automatically
|
||||||
|
with MCPClient(
|
||||||
|
server_url=server_url,
|
||||||
|
headers=headers,
|
||||||
|
timeout=provider_entity.timeout,
|
||||||
|
sse_read_timeout=provider_entity.sse_read_timeout,
|
||||||
|
):
|
||||||
|
# Create new transaction for update
|
||||||
|
with session.begin():
|
||||||
|
service.update_provider_credentials(
|
||||||
|
provider=db_provider,
|
||||||
|
credentials=provider_entity.credentials,
|
||||||
|
authed=True,
|
||||||
|
)
|
||||||
|
return {"result": "success"}
|
||||||
|
except MCPAuthError as e:
|
||||||
|
service = MCPToolManageService(session=session)
|
||||||
try:
|
try:
|
||||||
# Use MCPClientWithAuthRetry to handle authentication automatically
|
auth_result = auth(provider_entity, args.get("authorization_code"))
|
||||||
with MCPClient(
|
with session.begin():
|
||||||
server_url=server_url,
|
response = service.execute_auth_actions(auth_result)
|
||||||
headers=headers,
|
return response
|
||||||
timeout=provider_entity.timeout,
|
except MCPRefreshTokenError as e:
|
||||||
sse_read_timeout=provider_entity.sse_read_timeout,
|
|
||||||
):
|
|
||||||
# Create new transaction for update
|
|
||||||
with session.begin():
|
|
||||||
service.update_provider_credentials(
|
|
||||||
provider=db_provider,
|
|
||||||
credentials=provider_entity.credentials,
|
|
||||||
authed=True,
|
|
||||||
)
|
|
||||||
return {"result": "success"}
|
|
||||||
except MCPAuthError as e:
|
|
||||||
service = MCPToolManageService(session=session)
|
|
||||||
try:
|
|
||||||
return auth(provider_entity, service, args.get("authorization_code"))
|
|
||||||
except MCPRefreshTokenError as e:
|
|
||||||
with session.begin():
|
|
||||||
service.clear_provider_credentials(provider=db_provider)
|
|
||||||
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
|
|
||||||
except MCPError as e:
|
|
||||||
with session.begin():
|
with session.begin():
|
||||||
service.clear_provider_credentials(provider=db_provider)
|
service.clear_provider_credentials(provider=db_provider)
|
||||||
raise ValueError(f"Failed to connect to MCP server: {e}") from e
|
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
|
||||||
|
except MCPError as e:
|
||||||
|
with session.begin():
|
||||||
|
service.clear_provider_credentials(provider=db_provider)
|
||||||
|
raise ValueError(f"Failed to connect to MCP server: {e}") from e
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/workspaces/current/tool-provider/mcp/tools/<path:provider_id>")
|
@console_ns.route("/workspaces/current/tool-provider/mcp/tools/<path:provider_id>")
|
||||||
@ -1048,7 +1053,7 @@ class ToolMCPDetailApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider_id):
|
def get(self, provider_id):
|
||||||
_, tenant_id = current_account_with_tenant()
|
_, tenant_id = current_account_with_tenant()
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session, session.begin():
|
||||||
service = MCPToolManageService(session=session)
|
service = MCPToolManageService(session=session)
|
||||||
provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||||
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
|
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
|
||||||
@ -1062,7 +1067,7 @@ class ToolMCPListAllApi(Resource):
|
|||||||
def get(self):
|
def get(self):
|
||||||
_, tenant_id = current_account_with_tenant()
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
with Session(db.engine, expire_on_commit=False) as session:
|
with Session(db.engine) as session, session.begin():
|
||||||
service = MCPToolManageService(session=session)
|
service = MCPToolManageService(session=session)
|
||||||
tools = service.list_providers(tenant_id=tenant_id)
|
tools = service.list_providers(tenant_id=tenant_id)
|
||||||
|
|
||||||
@ -1100,6 +1105,11 @@ class ToolMCPCallbackApi(Resource):
|
|||||||
# Create service instance for handle_callback
|
# Create service instance for handle_callback
|
||||||
with Session(db.engine) as session, session.begin():
|
with Session(db.engine) as session, session.begin():
|
||||||
mcp_service = MCPToolManageService(session=session)
|
mcp_service = MCPToolManageService(session=session)
|
||||||
handle_callback(state_key, authorization_code, mcp_service)
|
# handle_callback now returns state data and tokens
|
||||||
|
state_data, tokens = handle_callback(state_key, authorization_code)
|
||||||
|
# Save tokens using the service layer
|
||||||
|
mcp_service.save_oauth_data(
|
||||||
|
state_data.provider_id, state_data.tenant_id, tokens.model_dump(), OAuthDataType.TOKENS
|
||||||
|
)
|
||||||
|
|
||||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||||
|
|||||||
@ -4,14 +4,14 @@ import json
|
|||||||
import os
|
import os
|
||||||
import secrets
|
import secrets
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
from urllib.parse import urljoin, urlparse
|
from urllib.parse import urljoin, urlparse
|
||||||
|
|
||||||
from httpx import ConnectError, HTTPStatusError, RequestError
|
from httpx import ConnectError, HTTPStatusError, RequestError
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from core.entities.mcp_provider import MCPProviderEntity, MCPSupportGrantType
|
from core.entities.mcp_provider import MCPProviderEntity, MCPSupportGrantType
|
||||||
from core.helper import ssrf_proxy
|
from core.helper import ssrf_proxy
|
||||||
|
from core.mcp.entities import AuthAction, AuthActionType, AuthResult, OAuthCallbackState
|
||||||
from core.mcp.error import MCPRefreshTokenError
|
from core.mcp.error import MCPRefreshTokenError
|
||||||
from core.mcp.types import (
|
from core.mcp.types import (
|
||||||
LATEST_PROTOCOL_VERSION,
|
LATEST_PROTOCOL_VERSION,
|
||||||
@ -23,23 +23,10 @@ from core.mcp.types import (
|
|||||||
)
|
)
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
|
||||||
|
|
||||||
OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry
|
OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry
|
||||||
OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
|
OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
|
||||||
|
|
||||||
|
|
||||||
class OAuthCallbackState(BaseModel):
|
|
||||||
provider_id: str
|
|
||||||
tenant_id: str
|
|
||||||
server_url: str
|
|
||||||
metadata: OAuthMetadata | None = None
|
|
||||||
client_information: OAuthClientInformation
|
|
||||||
code_verifier: str
|
|
||||||
redirect_uri: str
|
|
||||||
|
|
||||||
|
|
||||||
def generate_pkce_challenge() -> tuple[str, str]:
|
def generate_pkce_challenge() -> tuple[str, str]:
|
||||||
"""Generate PKCE challenge and verifier."""
|
"""Generate PKCE challenge and verifier."""
|
||||||
code_verifier = base64.urlsafe_b64encode(os.urandom(40)).decode("utf-8")
|
code_verifier = base64.urlsafe_b64encode(os.urandom(40)).decode("utf-8")
|
||||||
@ -86,8 +73,13 @@ def _retrieve_redis_state(state_key: str) -> OAuthCallbackState:
|
|||||||
raise ValueError(f"Invalid state parameter: {str(e)}")
|
raise ValueError(f"Invalid state parameter: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
def handle_callback(state_key: str, authorization_code: str, mcp_service: "MCPToolManageService") -> OAuthCallbackState:
|
def handle_callback(state_key: str, authorization_code: str) -> tuple[OAuthCallbackState, OAuthTokens]:
|
||||||
"""Handle the callback from the OAuth provider."""
|
"""
|
||||||
|
Handle the callback from the OAuth provider.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of (callback_state, tokens) that can be used by the caller to save data.
|
||||||
|
"""
|
||||||
# Retrieve state data from Redis (state is automatically deleted after retrieval)
|
# Retrieve state data from Redis (state is automatically deleted after retrieval)
|
||||||
full_state_data = _retrieve_redis_state(state_key)
|
full_state_data = _retrieve_redis_state(state_key)
|
||||||
|
|
||||||
@ -100,10 +92,7 @@ def handle_callback(state_key: str, authorization_code: str, mcp_service: "MCPTo
|
|||||||
full_state_data.redirect_uri,
|
full_state_data.redirect_uri,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save tokens using the service layer
|
return full_state_data, tokens
|
||||||
mcp_service.save_oauth_data(full_state_data.provider_id, full_state_data.tenant_id, tokens.model_dump(), "tokens")
|
|
||||||
|
|
||||||
return full_state_data
|
|
||||||
|
|
||||||
|
|
||||||
def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
|
def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
|
||||||
@ -361,11 +350,24 @@ def register_client(
|
|||||||
|
|
||||||
def auth(
|
def auth(
|
||||||
provider: MCPProviderEntity,
|
provider: MCPProviderEntity,
|
||||||
mcp_service: "MCPToolManageService",
|
|
||||||
authorization_code: str | None = None,
|
authorization_code: str | None = None,
|
||||||
state_param: str | None = None,
|
state_param: str | None = None,
|
||||||
) -> dict[str, str]:
|
) -> AuthResult:
|
||||||
"""Orchestrates the full auth flow with a server using secure Redis state storage."""
|
"""
|
||||||
|
Orchestrates the full auth flow with a server using secure Redis state storage.
|
||||||
|
|
||||||
|
This function performs only network operations and returns actions that need
|
||||||
|
to be performed by the caller (such as saving data to database).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: The MCP provider entity
|
||||||
|
authorization_code: Optional authorization code from OAuth callback
|
||||||
|
state_param: Optional state parameter from OAuth callback
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AuthResult containing actions to be performed and response data
|
||||||
|
"""
|
||||||
|
actions: list[AuthAction] = []
|
||||||
server_url = provider.decrypt_server_url()
|
server_url = provider.decrypt_server_url()
|
||||||
server_metadata = discover_oauth_metadata(server_url)
|
server_metadata = discover_oauth_metadata(server_url)
|
||||||
client_metadata = provider.client_metadata
|
client_metadata = provider.client_metadata
|
||||||
@ -407,9 +409,14 @@ def auth(
|
|||||||
except RequestError as e:
|
except RequestError as e:
|
||||||
raise ValueError(f"Could not register OAuth client: {e}")
|
raise ValueError(f"Could not register OAuth client: {e}")
|
||||||
|
|
||||||
# Save client information using service layer
|
# Return action to save client information
|
||||||
mcp_service.save_oauth_data(
|
actions.append(
|
||||||
provider_id, tenant_id, {"client_information": full_information.model_dump()}, "client_info"
|
AuthAction(
|
||||||
|
action_type=AuthActionType.SAVE_CLIENT_INFO,
|
||||||
|
data={"client_information": full_information.model_dump()},
|
||||||
|
provider_id=provider_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
client_information = full_information
|
client_information = full_information
|
||||||
@ -426,12 +433,20 @@ def auth(
|
|||||||
scope,
|
scope,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save tokens and grant type
|
# Return action to save tokens and grant type
|
||||||
token_data = tokens.model_dump()
|
token_data = tokens.model_dump()
|
||||||
token_data["grant_type"] = MCPSupportGrantType.CLIENT_CREDENTIALS.value
|
token_data["grant_type"] = MCPSupportGrantType.CLIENT_CREDENTIALS.value
|
||||||
mcp_service.save_oauth_data(provider_id, tenant_id, token_data, "tokens")
|
|
||||||
|
|
||||||
return {"result": "success"}
|
actions.append(
|
||||||
|
AuthAction(
|
||||||
|
action_type=AuthActionType.SAVE_TOKENS,
|
||||||
|
data=token_data,
|
||||||
|
provider_id=provider_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return AuthResult(actions=actions, response={"result": "success"})
|
||||||
except (RequestError, ValueError, KeyError) as e:
|
except (RequestError, ValueError, KeyError) as e:
|
||||||
# RequestError: HTTP request failed
|
# RequestError: HTTP request failed
|
||||||
# ValueError: Invalid response data
|
# ValueError: Invalid response data
|
||||||
@ -465,10 +480,17 @@ def auth(
|
|||||||
redirect_uri,
|
redirect_uri,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save tokens using service layer
|
# Return action to save tokens
|
||||||
mcp_service.save_oauth_data(provider_id, tenant_id, tokens.model_dump(), "tokens")
|
actions.append(
|
||||||
|
AuthAction(
|
||||||
|
action_type=AuthActionType.SAVE_TOKENS,
|
||||||
|
data=tokens.model_dump(),
|
||||||
|
provider_id=provider_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return {"result": "success"}
|
return AuthResult(actions=actions, response={"result": "success"})
|
||||||
|
|
||||||
provider_tokens = provider.retrieve_tokens()
|
provider_tokens = provider.retrieve_tokens()
|
||||||
|
|
||||||
@ -479,10 +501,17 @@ def auth(
|
|||||||
server_url, server_metadata, client_information, provider_tokens.refresh_token
|
server_url, server_metadata, client_information, provider_tokens.refresh_token
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save new tokens using service layer
|
# Return action to save new tokens
|
||||||
mcp_service.save_oauth_data(provider_id, tenant_id, new_tokens.model_dump(), "tokens")
|
actions.append(
|
||||||
|
AuthAction(
|
||||||
|
action_type=AuthActionType.SAVE_TOKENS,
|
||||||
|
data=new_tokens.model_dump(),
|
||||||
|
provider_id=provider_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return {"result": "success"}
|
return AuthResult(actions=actions, response={"result": "success"})
|
||||||
except (RequestError, ValueError, KeyError) as e:
|
except (RequestError, ValueError, KeyError) as e:
|
||||||
# RequestError: HTTP request failed
|
# RequestError: HTTP request failed
|
||||||
# ValueError: Invalid response data
|
# ValueError: Invalid response data
|
||||||
@ -499,7 +528,14 @@ def auth(
|
|||||||
tenant_id,
|
tenant_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save code verifier using service layer
|
# Return action to save code verifier
|
||||||
mcp_service.save_oauth_data(provider_id, tenant_id, {"code_verifier": code_verifier}, "code_verifier")
|
actions.append(
|
||||||
|
AuthAction(
|
||||||
|
action_type=AuthActionType.SAVE_CODE_VERIFIER,
|
||||||
|
data={"code_verifier": code_verifier},
|
||||||
|
provider_id=provider_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return {"authorization_url": authorization_url}
|
return AuthResult(actions=actions, response={"authorization_url": authorization_url})
|
||||||
|
|||||||
@ -7,15 +7,15 @@ authentication failures and retries operations after refreshing tokens.
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from core.entities.mcp_provider import MCPProviderEntity
|
from core.entities.mcp_provider import MCPProviderEntity
|
||||||
from core.mcp.error import MCPAuthError
|
from core.mcp.error import MCPAuthError
|
||||||
from core.mcp.mcp_client import MCPClient
|
from core.mcp.mcp_client import MCPClient
|
||||||
from core.mcp.types import CallToolResult, Tool
|
from core.mcp.types import CallToolResult, Tool
|
||||||
|
from extensions.ext_database import db
|
||||||
if TYPE_CHECKING:
|
|
||||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -26,6 +26,9 @@ class MCPClientWithAuthRetry(MCPClient):
|
|||||||
|
|
||||||
This class extends MCPClient and intercepts MCPAuthError exceptions
|
This class extends MCPClient and intercepts MCPAuthError exceptions
|
||||||
to refresh authentication before retrying failed operations.
|
to refresh authentication before retrying failed operations.
|
||||||
|
|
||||||
|
Note: This class uses lazy session creation - database sessions are only
|
||||||
|
created when authentication retry is actually needed, not on every request.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -35,11 +38,8 @@ class MCPClientWithAuthRetry(MCPClient):
|
|||||||
timeout: float | None = None,
|
timeout: float | None = None,
|
||||||
sse_read_timeout: float | None = None,
|
sse_read_timeout: float | None = None,
|
||||||
provider_entity: MCPProviderEntity | None = None,
|
provider_entity: MCPProviderEntity | None = None,
|
||||||
auth_callback: Callable[[MCPProviderEntity, "MCPToolManageService", Optional[str]], dict[str, str]]
|
|
||||||
| None = None,
|
|
||||||
authorization_code: str | None = None,
|
authorization_code: str | None = None,
|
||||||
by_server_id: bool = False,
|
by_server_id: bool = False,
|
||||||
mcp_service: Optional["MCPToolManageService"] = None,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the MCP client with auth retry capability.
|
Initialize the MCP client with auth retry capability.
|
||||||
@ -50,31 +50,30 @@ class MCPClientWithAuthRetry(MCPClient):
|
|||||||
timeout: Request timeout
|
timeout: Request timeout
|
||||||
sse_read_timeout: SSE read timeout
|
sse_read_timeout: SSE read timeout
|
||||||
provider_entity: Provider entity for authentication
|
provider_entity: Provider entity for authentication
|
||||||
auth_callback: Authentication callback function
|
|
||||||
authorization_code: Optional authorization code for initial auth
|
authorization_code: Optional authorization code for initial auth
|
||||||
by_server_id: Whether to look up provider by server ID
|
by_server_id: Whether to look up provider by server ID
|
||||||
mcp_service: MCP service instance
|
|
||||||
"""
|
"""
|
||||||
super().__init__(server_url, headers, timeout, sse_read_timeout)
|
super().__init__(server_url, headers, timeout, sse_read_timeout)
|
||||||
|
|
||||||
self.provider_entity = provider_entity
|
self.provider_entity = provider_entity
|
||||||
self.auth_callback = auth_callback
|
|
||||||
self.authorization_code = authorization_code
|
self.authorization_code = authorization_code
|
||||||
self.by_server_id = by_server_id
|
self.by_server_id = by_server_id
|
||||||
self.mcp_service = mcp_service
|
|
||||||
self._has_retried = False
|
self._has_retried = False
|
||||||
|
|
||||||
def _handle_auth_error(self, error: MCPAuthError) -> None:
|
def _handle_auth_error(self, error: MCPAuthError) -> None:
|
||||||
"""
|
"""
|
||||||
Handle authentication error by refreshing tokens.
|
Handle authentication error by refreshing tokens.
|
||||||
|
|
||||||
|
This method creates a short-lived database session only when authentication
|
||||||
|
retry is needed, minimizing database connection hold time.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
error: The authentication error
|
error: The authentication error
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
MCPAuthError: If authentication fails or max retries reached
|
MCPAuthError: If authentication fails or max retries reached
|
||||||
"""
|
"""
|
||||||
if not self.provider_entity or not self.auth_callback or not self.mcp_service:
|
if not self.provider_entity:
|
||||||
raise error
|
raise error
|
||||||
if self._has_retried:
|
if self._has_retried:
|
||||||
raise error
|
raise error
|
||||||
@ -82,13 +81,23 @@ class MCPClientWithAuthRetry(MCPClient):
|
|||||||
self._has_retried = True
|
self._has_retried = True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Perform authentication
|
# Create a temporary session only for auth retry
|
||||||
self.auth_callback(self.provider_entity, self.mcp_service, self.authorization_code)
|
# This session is short-lived and only exists during the auth operation
|
||||||
|
|
||||||
# Retrieve new tokens
|
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||||
self.provider_entity = self.mcp_service.get_provider_entity(
|
|
||||||
self.provider_entity.id, self.provider_entity.tenant_id, by_server_id=self.by_server_id
|
with Session(db.engine) as session, session.begin():
|
||||||
)
|
mcp_service = MCPToolManageService(session=session)
|
||||||
|
|
||||||
|
# Perform authentication using the service's auth method
|
||||||
|
mcp_service.auth_with_actions(self.provider_entity, self.authorization_code)
|
||||||
|
|
||||||
|
# Retrieve new tokens
|
||||||
|
self.provider_entity = mcp_service.get_provider_entity(
|
||||||
|
self.provider_entity.id, self.provider_entity.tenant_id, by_server_id=self.by_server_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Session is closed here, before we update headers
|
||||||
token = self.provider_entity.retrieve_tokens()
|
token = self.provider_entity.retrieve_tokens()
|
||||||
if not token:
|
if not token:
|
||||||
raise MCPAuthError("Authentication failed - no token received")
|
raise MCPAuthError("Authentication failed - no token received")
|
||||||
|
|||||||
@ -1,8 +1,11 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from enum import StrEnum
|
||||||
from typing import Any, Generic, TypeVar
|
from typing import Any, Generic, TypeVar
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from core.mcp.session.base_session import BaseSession
|
from core.mcp.session.base_session import BaseSession
|
||||||
from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestId, RequestParams
|
from core.mcp.types import LATEST_PROTOCOL_VERSION, OAuthClientInformation, OAuthMetadata, RequestId, RequestParams
|
||||||
|
|
||||||
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION]
|
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION]
|
||||||
|
|
||||||
@ -17,3 +20,41 @@ class RequestContext(Generic[SessionT, LifespanContextT]):
|
|||||||
meta: RequestParams.Meta | None
|
meta: RequestParams.Meta | None
|
||||||
session: SessionT
|
session: SessionT
|
||||||
lifespan_context: LifespanContextT
|
lifespan_context: LifespanContextT
|
||||||
|
|
||||||
|
|
||||||
|
class AuthActionType(StrEnum):
|
||||||
|
"""Types of actions that can be performed during auth flow."""
|
||||||
|
|
||||||
|
SAVE_CLIENT_INFO = "save_client_info"
|
||||||
|
SAVE_TOKENS = "save_tokens"
|
||||||
|
SAVE_CODE_VERIFIER = "save_code_verifier"
|
||||||
|
START_AUTHORIZATION = "start_authorization"
|
||||||
|
SUCCESS = "success"
|
||||||
|
|
||||||
|
|
||||||
|
class AuthAction(BaseModel):
|
||||||
|
"""Represents an action that needs to be performed as a result of auth flow."""
|
||||||
|
|
||||||
|
action_type: AuthActionType
|
||||||
|
data: dict[str, Any]
|
||||||
|
provider_id: str | None = None
|
||||||
|
tenant_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class AuthResult(BaseModel):
|
||||||
|
"""Result of auth function containing actions to be performed and response data."""
|
||||||
|
|
||||||
|
actions: list[AuthAction]
|
||||||
|
response: dict[str, str]
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthCallbackState(BaseModel):
|
||||||
|
"""State data stored in Redis during OAuth callback flow."""
|
||||||
|
|
||||||
|
provider_id: str
|
||||||
|
tenant_id: str
|
||||||
|
server_url: str
|
||||||
|
metadata: OAuthMetadata | None = None
|
||||||
|
client_information: OAuthClientInformation
|
||||||
|
code_verifier: str
|
||||||
|
redirect_uri: str
|
||||||
|
|||||||
@ -3,7 +3,6 @@ import json
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from core.mcp.auth.auth_flow import auth
|
|
||||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||||
from core.mcp.error import MCPConnectionError
|
from core.mcp.error import MCPConnectionError
|
||||||
from core.mcp.types import CallToolResult, ImageContent, TextContent
|
from core.mcp.types import CallToolResult, ImageContent, TextContent
|
||||||
@ -125,71 +124,39 @@ class MCPTool(Tool):
|
|||||||
headers = self.headers.copy() if self.headers else {}
|
headers = self.headers.copy() if self.headers else {}
|
||||||
tool_parameters = self._handle_none_parameter(tool_parameters)
|
tool_parameters = self._handle_none_parameter(tool_parameters)
|
||||||
|
|
||||||
# Get provider entity to access tokens
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
# Get MCP service from invoke parameters or create new one
|
from extensions.ext_database import db
|
||||||
provider_entity = None
|
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||||
mcp_service = None
|
|
||||||
|
|
||||||
# Check if mcp_service is passed in tool_parameters
|
# Step 1: Load provider entity and credentials in a short-lived session
|
||||||
if "_mcp_service" in tool_parameters:
|
# This minimizes database connection hold time
|
||||||
mcp_service = tool_parameters.pop("_mcp_service")
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
if mcp_service:
|
mcp_service = MCPToolManageService(session=session)
|
||||||
provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True)
|
provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True)
|
||||||
headers = provider_entity.decrypt_headers()
|
|
||||||
# Try to get existing token and add to headers
|
|
||||||
if not headers:
|
|
||||||
tokens = provider_entity.retrieve_tokens()
|
|
||||||
if tokens and tokens.access_token:
|
|
||||||
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
|
|
||||||
|
|
||||||
# Use MCPClientWithAuthRetry to handle authentication automatically
|
# Decrypt and prepare all credentials before closing session
|
||||||
try:
|
server_url = provider_entity.decrypt_server_url()
|
||||||
with MCPClientWithAuthRetry(
|
headers = provider_entity.decrypt_headers()
|
||||||
server_url=provider_entity.decrypt_server_url() if provider_entity else self.server_url,
|
|
||||||
headers=headers,
|
|
||||||
timeout=self.timeout,
|
|
||||||
sse_read_timeout=self.sse_read_timeout,
|
|
||||||
provider_entity=provider_entity,
|
|
||||||
auth_callback=auth if mcp_service else None,
|
|
||||||
mcp_service=mcp_service,
|
|
||||||
) as mcp_client:
|
|
||||||
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
|
|
||||||
except MCPConnectionError as e:
|
|
||||||
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
|
|
||||||
except (ValueError, TypeError, KeyError) as e:
|
|
||||||
# Catch specific exceptions that might occur during tool invocation
|
|
||||||
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
|
|
||||||
else:
|
|
||||||
# Fallback to creating service with database session
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from extensions.ext_database import db
|
# Try to get existing token and add to headers
|
||||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
if not headers:
|
||||||
|
tokens = provider_entity.retrieve_tokens()
|
||||||
|
if tokens and tokens.access_token:
|
||||||
|
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
|
||||||
|
|
||||||
with Session(db.engine, expire_on_commit=False) as session:
|
# Step 2: Session is now closed, perform network operations without holding database connection
|
||||||
mcp_service = MCPToolManageService(session=session)
|
# MCPClientWithAuthRetry will create a new session lazily only if auth retry is needed
|
||||||
provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True)
|
try:
|
||||||
headers = provider_entity.decrypt_headers()
|
with MCPClientWithAuthRetry(
|
||||||
# Try to get existing token and add to headers
|
server_url=server_url,
|
||||||
if not headers:
|
headers=headers,
|
||||||
tokens = provider_entity.retrieve_tokens()
|
timeout=self.timeout,
|
||||||
if tokens and tokens.access_token:
|
sse_read_timeout=self.sse_read_timeout,
|
||||||
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
|
provider_entity=provider_entity,
|
||||||
|
) as mcp_client:
|
||||||
# Use MCPClientWithAuthRetry to handle authentication automatically
|
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
|
||||||
try:
|
except MCPConnectionError as e:
|
||||||
with MCPClientWithAuthRetry(
|
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
|
||||||
server_url=provider_entity.decrypt_server_url() if provider_entity else self.server_url,
|
except Exception as e:
|
||||||
headers=headers,
|
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
|
||||||
timeout=self.timeout,
|
|
||||||
sse_read_timeout=self.sse_read_timeout,
|
|
||||||
provider_entity=provider_entity,
|
|
||||||
auth_callback=auth if mcp_service else None,
|
|
||||||
mcp_service=mcp_service,
|
|
||||||
) as mcp_client:
|
|
||||||
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
|
|
||||||
except MCPConnectionError as e:
|
|
||||||
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
|
|
||||||
except Exception as e:
|
|
||||||
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
|
|
||||||
|
|||||||
@ -1,10 +1,12 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Callable
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from enum import StrEnum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy import or_, select
|
from sqlalchemy import or_, select
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@ -12,6 +14,7 @@ from sqlalchemy.orm import Session
|
|||||||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity
|
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity
|
||||||
from core.helper import encrypter
|
from core.helper import encrypter
|
||||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||||
|
from core.mcp.auth.auth_flow import auth
|
||||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||||
from core.mcp.error import MCPAuthError, MCPError
|
from core.mcp.error import MCPAuthError, MCPError
|
||||||
from core.tools.entities.api_entities import ToolProviderApiEntity
|
from core.tools.entities.api_entities import ToolProviderApiEntity
|
||||||
@ -28,6 +31,38 @@ EMPTY_TOOLS_JSON = "[]"
|
|||||||
EMPTY_CREDENTIALS_JSON = "{}"
|
EMPTY_CREDENTIALS_JSON = "{}"
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthDataType(StrEnum):
|
||||||
|
"""Types of OAuth data that can be saved."""
|
||||||
|
|
||||||
|
TOKENS = "tokens"
|
||||||
|
CLIENT_INFO = "client_info"
|
||||||
|
CODE_VERIFIER = "code_verifier"
|
||||||
|
MIXED = "mixed"
|
||||||
|
|
||||||
|
|
||||||
|
class ReconnectResult(BaseModel):
|
||||||
|
"""Result of reconnecting to an MCP provider"""
|
||||||
|
|
||||||
|
authed: bool = Field(description="Whether the provider is authenticated")
|
||||||
|
tools: str = Field(description="JSON string of tool list")
|
||||||
|
encrypted_credentials: str = Field(description="JSON string of encrypted credentials")
|
||||||
|
|
||||||
|
|
||||||
|
class ServerUrlValidationResult(BaseModel):
|
||||||
|
"""Result of server URL validation check"""
|
||||||
|
|
||||||
|
needs_validation: bool
|
||||||
|
validation_passed: bool = False
|
||||||
|
reconnect_result: ReconnectResult | None = None
|
||||||
|
encrypted_server_url: str | None = None
|
||||||
|
server_url_hash: str | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def should_update_server_url(self) -> bool:
|
||||||
|
"""Check if server URL should be updated based on validation result"""
|
||||||
|
return self.needs_validation and self.validation_passed and self.reconnect_result is not None
|
||||||
|
|
||||||
|
|
||||||
class MCPToolManageService:
|
class MCPToolManageService:
|
||||||
"""Service class for managing MCP tools and providers."""
|
"""Service class for managing MCP tools and providers."""
|
||||||
|
|
||||||
@ -91,6 +126,10 @@ class MCPToolManageService:
|
|||||||
headers: dict[str, str] | None = None,
|
headers: dict[str, str] | None = None,
|
||||||
) -> ToolProviderApiEntity:
|
) -> ToolProviderApiEntity:
|
||||||
"""Create a new MCP provider."""
|
"""Create a new MCP provider."""
|
||||||
|
# Validate URL format
|
||||||
|
if not self._is_valid_url(server_url):
|
||||||
|
raise ValueError("Server URL is not valid.")
|
||||||
|
|
||||||
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
|
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
|
||||||
|
|
||||||
# Check for existing provider
|
# Check for existing provider
|
||||||
@ -99,13 +138,12 @@ class MCPToolManageService:
|
|||||||
# Encrypt sensitive data
|
# Encrypt sensitive data
|
||||||
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
|
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
|
||||||
encrypted_headers = self._prepare_encrypted_dict(headers, tenant_id) if headers else None
|
encrypted_headers = self._prepare_encrypted_dict(headers, tenant_id) if headers else None
|
||||||
if authentication is not None and authentication.client_id and authentication.client_secret:
|
encrypted_credentials = None
|
||||||
# Build the full credentials structure with encrypted client_id and client_secret
|
if authentication is not None and authentication.client_id:
|
||||||
encrypted_credentials = self._build_and_encrypt_credentials(
|
encrypted_credentials = self._build_and_encrypt_credentials(
|
||||||
authentication.client_id, authentication.client_secret, tenant_id
|
authentication.client_id, authentication.client_secret, tenant_id
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
encrypted_credentials = None
|
|
||||||
# Create provider
|
# Create provider
|
||||||
mcp_tool = MCPToolProvider(
|
mcp_tool = MCPToolProvider(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
@ -142,24 +180,39 @@ class MCPToolManageService:
|
|||||||
headers: dict[str, str] | None = None,
|
headers: dict[str, str] | None = None,
|
||||||
configuration: MCPConfiguration,
|
configuration: MCPConfiguration,
|
||||||
authentication: MCPAuthentication | None = None,
|
authentication: MCPAuthentication | None = None,
|
||||||
|
validation_result: ServerUrlValidationResult | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update an MCP provider."""
|
"""
|
||||||
|
Update an MCP provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
validation_result: Pre-validation result from validate_server_url_change.
|
||||||
|
If provided and contains reconnect_result, it will be used
|
||||||
|
instead of performing network operations.
|
||||||
|
"""
|
||||||
mcp_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
mcp_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||||
|
|
||||||
reconnect_result = None
|
# Check for duplicate name (excluding current provider)
|
||||||
|
if name != mcp_provider.name:
|
||||||
|
stmt = select(MCPToolProvider).where(
|
||||||
|
MCPToolProvider.tenant_id == tenant_id,
|
||||||
|
MCPToolProvider.name == name,
|
||||||
|
MCPToolProvider.id != provider_id,
|
||||||
|
)
|
||||||
|
existing_provider = self._session.scalar(stmt)
|
||||||
|
if existing_provider:
|
||||||
|
raise ValueError(f"MCP tool {name} already exists")
|
||||||
|
|
||||||
|
# Get URL update data from validation result
|
||||||
encrypted_server_url = None
|
encrypted_server_url = None
|
||||||
server_url_hash = None
|
server_url_hash = None
|
||||||
|
reconnect_result = None
|
||||||
|
|
||||||
# Handle server URL update
|
if validation_result and validation_result.encrypted_server_url:
|
||||||
if UNCHANGED_SERVER_URL_PLACEHOLDER not in server_url:
|
# Use all data from validation result
|
||||||
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
|
encrypted_server_url = validation_result.encrypted_server_url
|
||||||
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
|
server_url_hash = validation_result.server_url_hash
|
||||||
|
reconnect_result = validation_result.reconnect_result
|
||||||
if server_url_hash != mcp_provider.server_url_hash:
|
|
||||||
reconnect_result = self._reconnect_provider(
|
|
||||||
server_url=server_url,
|
|
||||||
provider=mcp_provider,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Update basic fields
|
# Update basic fields
|
||||||
@ -169,63 +222,35 @@ class MCPToolManageService:
|
|||||||
mcp_provider.server_identifier = server_identifier
|
mcp_provider.server_identifier = server_identifier
|
||||||
|
|
||||||
# Update server URL if changed
|
# Update server URL if changed
|
||||||
if encrypted_server_url is not None and server_url_hash is not None:
|
if encrypted_server_url and server_url_hash:
|
||||||
mcp_provider.server_url = encrypted_server_url
|
mcp_provider.server_url = encrypted_server_url
|
||||||
mcp_provider.server_url_hash = server_url_hash
|
mcp_provider.server_url_hash = server_url_hash
|
||||||
|
|
||||||
if reconnect_result:
|
if reconnect_result:
|
||||||
mcp_provider.authed = reconnect_result["authed"]
|
mcp_provider.authed = reconnect_result.authed
|
||||||
mcp_provider.tools = reconnect_result["tools"]
|
mcp_provider.tools = reconnect_result.tools
|
||||||
mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"]
|
mcp_provider.encrypted_credentials = reconnect_result.encrypted_credentials
|
||||||
|
|
||||||
# Update optional fields
|
# Update optional configuration fields
|
||||||
if configuration.timeout is not None:
|
self._update_optional_fields(mcp_provider, configuration)
|
||||||
mcp_provider.timeout = configuration.timeout
|
|
||||||
if configuration.sse_read_timeout is not None:
|
# Update headers if provided
|
||||||
mcp_provider.sse_read_timeout = configuration.sse_read_timeout
|
|
||||||
if headers is not None:
|
if headers is not None:
|
||||||
if headers:
|
mcp_provider.encrypted_headers = self._process_headers(headers, mcp_provider, tenant_id)
|
||||||
# Build headers preserving unchanged masked values
|
|
||||||
final_headers = self._merge_headers_with_masked(incoming_headers=headers, mcp_provider=mcp_provider)
|
|
||||||
encrypted_headers_dict = self._prepare_encrypted_dict(final_headers, tenant_id)
|
|
||||||
mcp_provider.encrypted_headers = encrypted_headers_dict
|
|
||||||
else:
|
|
||||||
# Clear headers if empty dict passed
|
|
||||||
mcp_provider.encrypted_headers = None
|
|
||||||
|
|
||||||
# Update credentials if provided
|
# Update credentials if provided
|
||||||
if authentication is not None and authentication.client_id and authentication.client_secret:
|
if authentication and authentication.client_id:
|
||||||
# Merge with existing credentials to handle masked values
|
mcp_provider.encrypted_credentials = self._process_credentials(authentication, mcp_provider, tenant_id)
|
||||||
(
|
|
||||||
final_client_id,
|
|
||||||
final_client_secret,
|
|
||||||
) = self._merge_credentials_with_masked(
|
|
||||||
authentication.client_id, authentication.client_secret, mcp_provider
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build and encrypt new credentials
|
# Flush changes to database
|
||||||
encrypted_credentials = self._build_and_encrypt_credentials(
|
self._session.flush()
|
||||||
final_client_id, final_client_secret, tenant_id
|
|
||||||
)
|
|
||||||
mcp_provider.encrypted_credentials = encrypted_credentials
|
|
||||||
|
|
||||||
self._session.commit()
|
|
||||||
except IntegrityError as e:
|
except IntegrityError as e:
|
||||||
self._session.rollback()
|
|
||||||
self._handle_integrity_error(e, name, server_url, server_identifier)
|
self._handle_integrity_error(e, name, server_url, server_identifier)
|
||||||
except (ValueError, AttributeError, TypeError) as e:
|
|
||||||
# Catch specific exceptions that might occur during update
|
|
||||||
# ValueError: invalid data provided
|
|
||||||
# AttributeError: missing required attributes
|
|
||||||
# TypeError: type conversion errors
|
|
||||||
self._session.rollback()
|
|
||||||
raise
|
|
||||||
|
|
||||||
def delete_provider(self, *, tenant_id: str, provider_id: str) -> None:
|
def delete_provider(self, *, tenant_id: str, provider_id: str) -> None:
|
||||||
"""Delete an MCP provider."""
|
"""Delete an MCP provider."""
|
||||||
mcp_tool = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
mcp_tool = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||||
self._session.delete(mcp_tool)
|
self._session.delete(mcp_tool)
|
||||||
self._session.commit()
|
|
||||||
|
|
||||||
def list_providers(self, *, tenant_id: str, for_list: bool = False) -> list[ToolProviderApiEntity]:
|
def list_providers(self, *, tenant_id: str, for_list: bool = False) -> list[ToolProviderApiEntity]:
|
||||||
"""List all MCP providers for a tenant."""
|
"""List all MCP providers for a tenant."""
|
||||||
@ -241,8 +266,6 @@ class MCPToolManageService:
|
|||||||
|
|
||||||
def list_provider_tools(self, *, tenant_id: str, provider_id: str) -> ToolProviderApiEntity:
|
def list_provider_tools(self, *, tenant_id: str, provider_id: str) -> ToolProviderApiEntity:
|
||||||
"""List tools from remote MCP server."""
|
"""List tools from remote MCP server."""
|
||||||
from core.mcp.auth.auth_flow import auth
|
|
||||||
|
|
||||||
# Load provider and convert to entity
|
# Load provider and convert to entity
|
||||||
db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||||
provider_entity = db_provider.to_entity()
|
provider_entity = db_provider.to_entity()
|
||||||
@ -257,9 +280,7 @@ class MCPToolManageService:
|
|||||||
# Retrieve tools from remote server
|
# Retrieve tools from remote server
|
||||||
server_url = provider_entity.decrypt_server_url()
|
server_url = provider_entity.decrypt_server_url()
|
||||||
try:
|
try:
|
||||||
tools = self._retrieve_remote_mcp_tools(
|
tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity)
|
||||||
server_url, headers, provider_entity, lambda p, s, c: auth(p, self, c)
|
|
||||||
)
|
|
||||||
except MCPError as e:
|
except MCPError as e:
|
||||||
raise ValueError(f"Failed to connect to MCP server: {e}")
|
raise ValueError(f"Failed to connect to MCP server: {e}")
|
||||||
|
|
||||||
@ -305,9 +326,12 @@ class MCPToolManageService:
|
|||||||
if not authed:
|
if not authed:
|
||||||
provider.tools = EMPTY_TOOLS_JSON
|
provider.tools = EMPTY_TOOLS_JSON
|
||||||
|
|
||||||
self._session.commit()
|
# Flush changes to database
|
||||||
|
self._session.flush()
|
||||||
|
|
||||||
def save_oauth_data(self, provider_id: str, tenant_id: str, data: dict[str, Any], data_type: str = "mixed") -> None:
|
def save_oauth_data(
|
||||||
|
self, provider_id: str, tenant_id: str, data: dict[str, Any], data_type: OAuthDataType = OAuthDataType.MIXED
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Save OAuth-related data (tokens, client info, code verifier).
|
Save OAuth-related data (tokens, client info, code verifier).
|
||||||
|
|
||||||
@ -315,12 +339,14 @@ class MCPToolManageService:
|
|||||||
provider_id: Provider ID
|
provider_id: Provider ID
|
||||||
tenant_id: Tenant ID
|
tenant_id: Tenant ID
|
||||||
data: Data to save (tokens, client info, or code verifier)
|
data: Data to save (tokens, client info, or code verifier)
|
||||||
data_type: Type of data ('tokens', 'client_info', 'code_verifier', 'mixed')
|
data_type: Type of OAuth data to save
|
||||||
"""
|
"""
|
||||||
db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||||
|
|
||||||
# Determine if this makes the provider authenticated
|
# Determine if this makes the provider authenticated
|
||||||
authed = data_type == "tokens" or (data_type == "mixed" and "access_token" in data) or None
|
authed = (
|
||||||
|
data_type == OAuthDataType.TOKENS or (data_type == OAuthDataType.MIXED and "access_token" in data) or None
|
||||||
|
)
|
||||||
|
|
||||||
self.update_provider_credentials(provider=db_provider, credentials=data, authed=authed)
|
self.update_provider_credentials(provider=db_provider, credentials=data, authed=authed)
|
||||||
|
|
||||||
@ -330,7 +356,6 @@ class MCPToolManageService:
|
|||||||
provider.encrypted_credentials = EMPTY_CREDENTIALS_JSON
|
provider.encrypted_credentials = EMPTY_CREDENTIALS_JSON
|
||||||
provider.updated_at = datetime.now()
|
provider.updated_at = datetime.now()
|
||||||
provider.authed = False
|
provider.authed = False
|
||||||
self._session.commit()
|
|
||||||
|
|
||||||
# ========== Private Helper Methods ==========
|
# ========== Private Helper Methods ==========
|
||||||
|
|
||||||
@ -406,41 +431,123 @@ class MCPToolManageService:
|
|||||||
server_url: str,
|
server_url: str,
|
||||||
headers: dict[str, str],
|
headers: dict[str, str],
|
||||||
provider_entity: MCPProviderEntity,
|
provider_entity: MCPProviderEntity,
|
||||||
auth_callback: Callable[[MCPProviderEntity, "MCPToolManageService", str | None], dict[str, str]],
|
|
||||||
):
|
):
|
||||||
"""Retrieve tools from remote MCP server."""
|
"""Retrieve tools from remote MCP server."""
|
||||||
with MCPClientWithAuthRetry(
|
with MCPClientWithAuthRetry(
|
||||||
server_url,
|
server_url=server_url,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
timeout=provider_entity.timeout,
|
timeout=provider_entity.timeout,
|
||||||
sse_read_timeout=provider_entity.sse_read_timeout,
|
sse_read_timeout=provider_entity.sse_read_timeout,
|
||||||
provider_entity=provider_entity,
|
provider_entity=provider_entity,
|
||||||
auth_callback=auth_callback,
|
|
||||||
mcp_service=self,
|
|
||||||
) as mcp_client:
|
) as mcp_client:
|
||||||
return mcp_client.list_tools()
|
return mcp_client.list_tools()
|
||||||
|
|
||||||
def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> dict[str, Any]:
|
def execute_auth_actions(self, auth_result: Any) -> dict[str, str]:
|
||||||
"""Attempt to reconnect to MCP provider with new server URL."""
|
"""
|
||||||
from core.mcp.auth.auth_flow import auth
|
Execute the actions returned by the auth function.
|
||||||
|
|
||||||
|
This method processes the AuthResult and performs the necessary database operations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
auth_result: The result from the auth function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The response from the auth result
|
||||||
|
"""
|
||||||
|
from core.mcp.entities import AuthAction, AuthActionType
|
||||||
|
|
||||||
|
action: AuthAction
|
||||||
|
for action in auth_result.actions:
|
||||||
|
if action.provider_id is None or action.tenant_id is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if action.action_type == AuthActionType.SAVE_CLIENT_INFO:
|
||||||
|
self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.CLIENT_INFO)
|
||||||
|
elif action.action_type == AuthActionType.SAVE_TOKENS:
|
||||||
|
self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.TOKENS)
|
||||||
|
elif action.action_type == AuthActionType.SAVE_CODE_VERIFIER:
|
||||||
|
self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.CODE_VERIFIER)
|
||||||
|
|
||||||
|
return auth_result.response
|
||||||
|
|
||||||
|
def auth_with_actions(
|
||||||
|
self, provider_entity: MCPProviderEntity, authorization_code: str | None = None
|
||||||
|
) -> dict[str, str]:
|
||||||
|
"""
|
||||||
|
Perform authentication and execute all resulting actions.
|
||||||
|
|
||||||
|
This method is used by MCPClientWithAuthRetry for automatic re-authentication.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_entity: The MCP provider entity
|
||||||
|
authorization_code: Optional authorization code
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response dictionary from auth result
|
||||||
|
"""
|
||||||
|
auth_result = auth(provider_entity, authorization_code)
|
||||||
|
return self.execute_auth_actions(auth_result)
|
||||||
|
|
||||||
|
def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> ReconnectResult:
|
||||||
|
"""Attempt to reconnect to MCP provider with new server URL."""
|
||||||
provider_entity = provider.to_entity()
|
provider_entity = provider.to_entity()
|
||||||
headers = provider_entity.headers
|
headers = provider_entity.headers
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tools = self._retrieve_remote_mcp_tools(
|
tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity)
|
||||||
server_url, headers, provider_entity, lambda p, s, c: auth(p, self, c)
|
return ReconnectResult(
|
||||||
|
authed=True,
|
||||||
|
tools=json.dumps([tool.model_dump() for tool in tools]),
|
||||||
|
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
|
||||||
)
|
)
|
||||||
return {
|
|
||||||
"authed": True,
|
|
||||||
"tools": json.dumps([tool.model_dump() for tool in tools]),
|
|
||||||
"encrypted_credentials": EMPTY_CREDENTIALS_JSON,
|
|
||||||
}
|
|
||||||
except MCPAuthError:
|
except MCPAuthError:
|
||||||
return {"authed": False, "tools": EMPTY_TOOLS_JSON, "encrypted_credentials": EMPTY_CREDENTIALS_JSON}
|
return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON)
|
||||||
except MCPError as e:
|
except MCPError as e:
|
||||||
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
|
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
|
||||||
|
|
||||||
|
def validate_server_url_change(
|
||||||
|
self, *, tenant_id: str, provider_id: str, new_server_url: str
|
||||||
|
) -> ServerUrlValidationResult:
|
||||||
|
"""
|
||||||
|
Validate server URL change by attempting to connect to the new server.
|
||||||
|
This method should be called BEFORE update_provider to perform network operations
|
||||||
|
outside of the database transaction.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ServerUrlValidationResult: Validation result with connection status and tools if successful
|
||||||
|
"""
|
||||||
|
# Handle hidden/unchanged URL
|
||||||
|
if UNCHANGED_SERVER_URL_PLACEHOLDER in new_server_url:
|
||||||
|
return ServerUrlValidationResult(needs_validation=False)
|
||||||
|
|
||||||
|
# Validate URL format
|
||||||
|
if not self._is_valid_url(new_server_url):
|
||||||
|
raise ValueError("Server URL is not valid.")
|
||||||
|
|
||||||
|
# Always encrypt and hash the URL
|
||||||
|
encrypted_server_url = encrypter.encrypt_token(tenant_id, new_server_url)
|
||||||
|
new_server_url_hash = hashlib.sha256(new_server_url.encode()).hexdigest()
|
||||||
|
|
||||||
|
# Get current provider
|
||||||
|
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||||
|
|
||||||
|
# Check if URL is actually different
|
||||||
|
if new_server_url_hash == provider.server_url_hash:
|
||||||
|
# URL hasn't changed, but still return the encrypted data
|
||||||
|
return ServerUrlValidationResult(
|
||||||
|
needs_validation=False, encrypted_server_url=encrypted_server_url, server_url_hash=new_server_url_hash
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform validation by attempting to connect
|
||||||
|
reconnect_result = self._reconnect_provider(server_url=new_server_url, provider=provider)
|
||||||
|
return ServerUrlValidationResult(
|
||||||
|
needs_validation=True,
|
||||||
|
validation_passed=True,
|
||||||
|
reconnect_result=reconnect_result,
|
||||||
|
encrypted_server_url=encrypted_server_url,
|
||||||
|
server_url_hash=new_server_url_hash,
|
||||||
|
)
|
||||||
|
|
||||||
def _build_tool_provider_response(
|
def _build_tool_provider_response(
|
||||||
self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list
|
self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list
|
||||||
) -> ToolProviderApiEntity:
|
) -> ToolProviderApiEntity:
|
||||||
@ -466,6 +573,45 @@ class MCPToolManageService:
|
|||||||
raise ValueError(f"MCP tool {server_identifier} already exists")
|
raise ValueError(f"MCP tool {server_identifier} already exists")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
def _is_valid_url(self, url: str) -> bool:
|
||||||
|
"""Validate URL format."""
|
||||||
|
if not url:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
parsed = urlparse(url)
|
||||||
|
return all([parsed.scheme, parsed.netloc]) and parsed.scheme in ["http", "https"]
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _update_optional_fields(self, mcp_provider: MCPToolProvider, configuration: MCPConfiguration) -> None:
|
||||||
|
"""Update optional configuration fields using setattr for cleaner code."""
|
||||||
|
field_mapping = {"timeout": configuration.timeout, "sse_read_timeout": configuration.sse_read_timeout}
|
||||||
|
|
||||||
|
for field, value in field_mapping.items():
|
||||||
|
if value is not None:
|
||||||
|
setattr(mcp_provider, field, value)
|
||||||
|
|
||||||
|
def _process_headers(self, headers: dict[str, str], mcp_provider: MCPToolProvider, tenant_id: str) -> str | None:
|
||||||
|
"""Process headers update, handling empty dict to clear headers."""
|
||||||
|
if not headers:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Merge with existing headers to preserve masked values
|
||||||
|
final_headers = self._merge_headers_with_masked(incoming_headers=headers, mcp_provider=mcp_provider)
|
||||||
|
return self._prepare_encrypted_dict(final_headers, tenant_id)
|
||||||
|
|
||||||
|
def _process_credentials(
|
||||||
|
self, authentication: MCPAuthentication, mcp_provider: MCPToolProvider, tenant_id: str
|
||||||
|
) -> str:
|
||||||
|
"""Process credentials update, handling masked values."""
|
||||||
|
# Merge with existing credentials
|
||||||
|
final_client_id, final_client_secret = self._merge_credentials_with_masked(
|
||||||
|
authentication.client_id, authentication.client_secret, mcp_provider
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build and encrypt
|
||||||
|
return self._build_and_encrypt_credentials(final_client_id, final_client_secret, tenant_id)
|
||||||
|
|
||||||
def _merge_headers_with_masked(
|
def _merge_headers_with_masked(
|
||||||
self, incoming_headers: dict[str, str], mcp_provider: MCPToolProvider
|
self, incoming_headers: dict[str, str], mcp_provider: MCPToolProvider
|
||||||
) -> dict[str, str]:
|
) -> dict[str, str]:
|
||||||
@ -530,12 +676,12 @@ class MCPToolManageService:
|
|||||||
# Create a flat structure with all credential data
|
# Create a flat structure with all credential data
|
||||||
credentials_data = {
|
credentials_data = {
|
||||||
"client_id": client_id,
|
"client_id": client_id,
|
||||||
"encrypted_client_secret": client_secret,
|
|
||||||
"client_name": CLIENT_NAME,
|
"client_name": CLIENT_NAME,
|
||||||
"is_dynamic_registration": False,
|
"is_dynamic_registration": False,
|
||||||
}
|
}
|
||||||
|
secret_fields = []
|
||||||
# Only client_id and client_secret need encryption
|
if client_secret is not None:
|
||||||
secret_fields = ["encrypted_client_secret"] if client_secret else []
|
credentials_data["encrypted_client_secret"] = encrypter.encrypt_token(tenant_id, client_secret)
|
||||||
|
secret_fields = ["encrypted_client_secret"]
|
||||||
client_info = self._encrypt_dict_fields(credentials_data, secret_fields, tenant_id)
|
client_info = self._encrypt_dict_fields(credentials_data, secret_fields, tenant_id)
|
||||||
return json.dumps({"client_information": client_info})
|
return json.dumps({"client_information": client_info})
|
||||||
|
|||||||
@ -1108,75 +1108,6 @@ class TestMCPToolManageService:
|
|||||||
assert icon_data["content"] == "🚀"
|
assert icon_data["content"] == "🚀"
|
||||||
assert icon_data["background"] == "#4ECDC4"
|
assert icon_data["background"] == "#4ECDC4"
|
||||||
|
|
||||||
def test_update_mcp_provider_with_server_url_change(
|
|
||||||
self, db_session_with_containers, mock_external_service_dependencies
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Test successful update of MCP provider with server URL change.
|
|
||||||
|
|
||||||
This test verifies:
|
|
||||||
- Proper handling of server URL changes
|
|
||||||
- Correct reconnection logic
|
|
||||||
- Database state updates
|
|
||||||
- External service integration
|
|
||||||
"""
|
|
||||||
# Arrange: Create test data
|
|
||||||
fake = Faker()
|
|
||||||
account, tenant = self._create_test_account_and_tenant(
|
|
||||||
db_session_with_containers, mock_external_service_dependencies
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create MCP provider
|
|
||||||
mcp_provider = self._create_test_mcp_provider(
|
|
||||||
db_session_with_containers, mock_external_service_dependencies, tenant.id, account.id
|
|
||||||
)
|
|
||||||
|
|
||||||
from extensions.ext_database import db
|
|
||||||
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# Mock the reconnection method
|
|
||||||
with patch.object(MCPToolManageService, "_reconnect_provider") as mock_reconnect:
|
|
||||||
mock_reconnect.return_value = {
|
|
||||||
"authed": True,
|
|
||||||
"tools": '[{"name": "test_tool"}]',
|
|
||||||
"encrypted_credentials": "{}",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Act: Execute the method under test
|
|
||||||
from core.entities.mcp_provider import MCPConfiguration
|
|
||||||
from extensions.ext_database import db
|
|
||||||
|
|
||||||
service = MCPToolManageService(db.session())
|
|
||||||
service.update_provider(
|
|
||||||
tenant_id=tenant.id,
|
|
||||||
provider_id=mcp_provider.id,
|
|
||||||
name="Updated MCP Provider",
|
|
||||||
server_url="https://new-example.com/mcp",
|
|
||||||
icon="🚀",
|
|
||||||
icon_type="emoji",
|
|
||||||
icon_background="#4ECDC4",
|
|
||||||
server_identifier="updated_identifier_123",
|
|
||||||
configuration=MCPConfiguration(
|
|
||||||
timeout=45.0,
|
|
||||||
sse_read_timeout=400.0,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assert: Verify the expected outcomes
|
|
||||||
db.session.refresh(mcp_provider)
|
|
||||||
assert mcp_provider.name == "Updated MCP Provider"
|
|
||||||
assert mcp_provider.server_identifier == "updated_identifier_123"
|
|
||||||
assert mcp_provider.timeout == 45.0
|
|
||||||
assert mcp_provider.sse_read_timeout == 400.0
|
|
||||||
assert mcp_provider.updated_at is not None
|
|
||||||
|
|
||||||
# Verify reconnection was called
|
|
||||||
mock_reconnect.assert_called_once_with(
|
|
||||||
server_url="https://new-example.com/mcp",
|
|
||||||
provider=mcp_provider,
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_update_mcp_provider_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies):
|
def test_update_mcp_provider_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies):
|
||||||
"""
|
"""
|
||||||
Test error handling when updating MCP provider with duplicate name.
|
Test error handling when updating MCP provider with duplicate name.
|
||||||
@ -1387,14 +1318,14 @@ class TestMCPToolManageService:
|
|||||||
|
|
||||||
# Assert: Verify the expected outcomes
|
# Assert: Verify the expected outcomes
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result["authed"] is True
|
assert result.authed is True
|
||||||
assert result["tools"] is not None
|
assert result.tools is not None
|
||||||
assert result["encrypted_credentials"] == "{}"
|
assert result.encrypted_credentials == "{}"
|
||||||
|
|
||||||
# Verify tools were properly serialized
|
# Verify tools were properly serialized
|
||||||
import json
|
import json
|
||||||
|
|
||||||
tools_data = json.loads(result["tools"])
|
tools_data = json.loads(result.tools)
|
||||||
assert len(tools_data) == 2
|
assert len(tools_data) == 2
|
||||||
assert tools_data[0]["name"] == "test_tool_1"
|
assert tools_data[0]["name"] == "test_tool_1"
|
||||||
assert tools_data[1]["name"] == "test_tool_2"
|
assert tools_data[1]["name"] == "test_tool_2"
|
||||||
@ -1441,9 +1372,9 @@ class TestMCPToolManageService:
|
|||||||
|
|
||||||
# Assert: Verify the expected outcomes
|
# Assert: Verify the expected outcomes
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result["authed"] is False
|
assert result.authed is False
|
||||||
assert result["tools"] == "[]"
|
assert result.tools == "[]"
|
||||||
assert result["encrypted_credentials"] == "{}"
|
assert result.encrypted_credentials == "{}"
|
||||||
|
|
||||||
def test_re_connect_mcp_provider_connection_error(
|
def test_re_connect_mcp_provider_connection_error(
|
||||||
self, db_session_with_containers, mock_external_service_dependencies
|
self, db_session_with_containers, mock_external_service_dependencies
|
||||||
|
|||||||
@ -21,6 +21,7 @@ from core.mcp.auth.auth_flow import (
|
|||||||
register_client,
|
register_client,
|
||||||
start_authorization,
|
start_authorization,
|
||||||
)
|
)
|
||||||
|
from core.mcp.entities import AuthActionType, AuthResult
|
||||||
from core.mcp.types import (
|
from core.mcp.types import (
|
||||||
OAuthClientInformation,
|
OAuthClientInformation,
|
||||||
OAuthClientInformationFull,
|
OAuthClientInformationFull,
|
||||||
@ -527,9 +528,10 @@ class TestCallbackHandling:
|
|||||||
# Setup service
|
# Setup service
|
||||||
mock_service = Mock()
|
mock_service = Mock()
|
||||||
|
|
||||||
result = handle_callback("state-key", "auth-code", mock_service)
|
state_result, tokens_result = handle_callback("state-key", "auth-code")
|
||||||
|
|
||||||
assert result == state_data
|
assert state_result == state_data
|
||||||
|
assert tokens_result == tokens
|
||||||
|
|
||||||
# Verify calls
|
# Verify calls
|
||||||
mock_retrieve_state.assert_called_once_with("state-key")
|
mock_retrieve_state.assert_called_once_with("state-key")
|
||||||
@ -541,9 +543,8 @@ class TestCallbackHandling:
|
|||||||
"test-verifier",
|
"test-verifier",
|
||||||
"https://redirect.example.com",
|
"https://redirect.example.com",
|
||||||
)
|
)
|
||||||
mock_service.save_oauth_data.assert_called_once_with(
|
# Note: handle_callback no longer saves tokens directly, it just returns them
|
||||||
"test-provider", "test-tenant", tokens.model_dump(), "tokens"
|
# The caller (e.g., controller) is responsible for saving via execute_auth_actions
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestAuthOrchestration:
|
class TestAuthOrchestration:
|
||||||
@ -589,21 +590,28 @@ class TestAuthOrchestration:
|
|||||||
)
|
)
|
||||||
mock_start_auth.return_value = ("https://auth.example.com/authorize?...", "code-verifier")
|
mock_start_auth.return_value = ("https://auth.example.com/authorize?...", "code-verifier")
|
||||||
|
|
||||||
result = auth(mock_provider, mock_service)
|
result = auth(mock_provider)
|
||||||
|
|
||||||
assert result == {"authorization_url": "https://auth.example.com/authorize?..."}
|
# auth() now returns AuthResult
|
||||||
|
assert isinstance(result, AuthResult)
|
||||||
|
assert result.response == {"authorization_url": "https://auth.example.com/authorize?..."}
|
||||||
|
|
||||||
|
# Verify that the result contains the correct actions
|
||||||
|
assert len(result.actions) == 2
|
||||||
|
# Check for SAVE_CLIENT_INFO action
|
||||||
|
client_info_action = next(a for a in result.actions if a.action_type == AuthActionType.SAVE_CLIENT_INFO)
|
||||||
|
assert client_info_action.data == {"client_information": mock_register.return_value.model_dump()}
|
||||||
|
assert client_info_action.provider_id == "provider-id"
|
||||||
|
assert client_info_action.tenant_id == "tenant-id"
|
||||||
|
|
||||||
|
# Check for SAVE_CODE_VERIFIER action
|
||||||
|
verifier_action = next(a for a in result.actions if a.action_type == AuthActionType.SAVE_CODE_VERIFIER)
|
||||||
|
assert verifier_action.data == {"code_verifier": "code-verifier"}
|
||||||
|
assert verifier_action.provider_id == "provider-id"
|
||||||
|
assert verifier_action.tenant_id == "tenant-id"
|
||||||
|
|
||||||
# Verify calls
|
# Verify calls
|
||||||
mock_register.assert_called_once()
|
mock_register.assert_called_once()
|
||||||
mock_service.save_oauth_data.assert_any_call(
|
|
||||||
"provider-id",
|
|
||||||
"tenant-id",
|
|
||||||
{"client_information": mock_register.return_value.model_dump()},
|
|
||||||
"client_info",
|
|
||||||
)
|
|
||||||
mock_service.save_oauth_data.assert_any_call(
|
|
||||||
"provider-id", "tenant-id", {"code_verifier": "code-verifier"}, "code_verifier"
|
|
||||||
)
|
|
||||||
|
|
||||||
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
|
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
|
||||||
@patch("core.mcp.auth.auth_flow._retrieve_redis_state")
|
@patch("core.mcp.auth.auth_flow._retrieve_redis_state")
|
||||||
@ -637,12 +645,18 @@ class TestAuthOrchestration:
|
|||||||
tokens = OAuthTokens(access_token="new-token", token_type="Bearer", expires_in=3600)
|
tokens = OAuthTokens(access_token="new-token", token_type="Bearer", expires_in=3600)
|
||||||
mock_exchange.return_value = tokens
|
mock_exchange.return_value = tokens
|
||||||
|
|
||||||
result = auth(mock_provider, mock_service, authorization_code="auth-code", state_param="state-key")
|
result = auth(mock_provider, authorization_code="auth-code", state_param="state-key")
|
||||||
|
|
||||||
assert result == {"result": "success"}
|
# auth() now returns AuthResult, not a dict
|
||||||
|
assert isinstance(result, AuthResult)
|
||||||
|
assert result.response == {"result": "success"}
|
||||||
|
|
||||||
# Verify token save
|
# Verify that the result contains the correct action
|
||||||
mock_service.save_oauth_data.assert_called_with("provider-id", "tenant-id", tokens.model_dump(), "tokens")
|
assert len(result.actions) == 1
|
||||||
|
assert result.actions[0].action_type == AuthActionType.SAVE_TOKENS
|
||||||
|
assert result.actions[0].data == tokens.model_dump()
|
||||||
|
assert result.actions[0].provider_id == "provider-id"
|
||||||
|
assert result.actions[0].tenant_id == "tenant-id"
|
||||||
|
|
||||||
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
|
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
|
||||||
def test_auth_exchange_code_without_state(self, mock_discover, mock_provider, mock_service):
|
def test_auth_exchange_code_without_state(self, mock_discover, mock_provider, mock_service):
|
||||||
@ -658,7 +672,7 @@ class TestAuthOrchestration:
|
|||||||
mock_provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="existing-client")
|
mock_provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="existing-client")
|
||||||
|
|
||||||
with pytest.raises(ValueError) as exc_info:
|
with pytest.raises(ValueError) as exc_info:
|
||||||
auth(mock_provider, mock_service, authorization_code="auth-code")
|
auth(mock_provider, authorization_code="auth-code")
|
||||||
|
|
||||||
assert "State parameter is required" in str(exc_info.value)
|
assert "State parameter is required" in str(exc_info.value)
|
||||||
|
|
||||||
@ -691,15 +705,21 @@ class TestAuthOrchestration:
|
|||||||
grant_types_supported=["authorization_code"],
|
grant_types_supported=["authorization_code"],
|
||||||
)
|
)
|
||||||
|
|
||||||
result = auth(mock_provider, mock_service)
|
result = auth(mock_provider)
|
||||||
|
|
||||||
assert result == {"result": "success"}
|
# auth() now returns AuthResult
|
||||||
|
assert isinstance(result, AuthResult)
|
||||||
|
assert result.response == {"result": "success"}
|
||||||
|
|
||||||
|
# Verify that the result contains the correct action
|
||||||
|
assert len(result.actions) == 1
|
||||||
|
assert result.actions[0].action_type == AuthActionType.SAVE_TOKENS
|
||||||
|
assert result.actions[0].data == new_tokens.model_dump()
|
||||||
|
assert result.actions[0].provider_id == "provider-id"
|
||||||
|
assert result.actions[0].tenant_id == "tenant-id"
|
||||||
|
|
||||||
# Verify refresh was called
|
# Verify refresh was called
|
||||||
mock_refresh.assert_called_once()
|
mock_refresh.assert_called_once()
|
||||||
mock_service.save_oauth_data.assert_called_with(
|
|
||||||
"provider-id", "tenant-id", new_tokens.model_dump(), "tokens"
|
|
||||||
)
|
|
||||||
|
|
||||||
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
|
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
|
||||||
def test_auth_registration_fails_with_code(self, mock_discover, mock_provider, mock_service):
|
def test_auth_registration_fails_with_code(self, mock_discover, mock_provider, mock_service):
|
||||||
@ -715,6 +735,6 @@ class TestAuthOrchestration:
|
|||||||
mock_provider.retrieve_client_information.return_value = None
|
mock_provider.retrieve_client_information.return_value = None
|
||||||
|
|
||||||
with pytest.raises(ValueError) as exc_info:
|
with pytest.raises(ValueError) as exc_info:
|
||||||
auth(mock_provider, mock_service, authorization_code="auth-code")
|
auth(mock_provider, authorization_code="auth-code")
|
||||||
|
|
||||||
assert "Existing OAuth client information is required" in str(exc_info.value)
|
assert "Existing OAuth client information is required" in str(exc_info.value)
|
||||||
|
|||||||
@ -1,420 +0,0 @@
|
|||||||
"""Unit tests for MCP auth client with retry logic."""
|
|
||||||
|
|
||||||
from types import TracebackType
|
|
||||||
from unittest.mock import Mock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from core.entities.mcp_provider import MCPProviderEntity
|
|
||||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
|
||||||
from core.mcp.error import MCPAuthError
|
|
||||||
from core.mcp.mcp_client import MCPClient
|
|
||||||
from core.mcp.types import CallToolResult, TextContent, Tool, ToolAnnotations
|
|
||||||
|
|
||||||
|
|
||||||
class TestMCPClientWithAuthRetry:
|
|
||||||
"""Test suite for MCPClientWithAuthRetry."""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_provider_entity(self):
|
|
||||||
"""Create a mock provider entity."""
|
|
||||||
provider = Mock(spec=MCPProviderEntity)
|
|
||||||
provider.id = "test-provider-id"
|
|
||||||
provider.tenant_id = "test-tenant-id"
|
|
||||||
provider.retrieve_tokens.return_value = Mock(
|
|
||||||
access_token="test-token", token_type="Bearer", expires_in=3600, refresh_token=None
|
|
||||||
)
|
|
||||||
return provider
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_mcp_service(self):
|
|
||||||
"""Create a mock MCP service."""
|
|
||||||
service = Mock()
|
|
||||||
service.get_provider_entity.return_value = Mock(
|
|
||||||
retrieve_tokens=lambda: Mock(
|
|
||||||
access_token="new-test-token", token_type="Bearer", expires_in=3600, refresh_token=None
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return service
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def auth_callback(self):
|
|
||||||
"""Create a mock auth callback."""
|
|
||||||
return Mock()
|
|
||||||
|
|
||||||
def test_init(self, mock_provider_entity, mock_mcp_service, auth_callback):
|
|
||||||
"""Test client initialization."""
|
|
||||||
client = MCPClientWithAuthRetry(
|
|
||||||
server_url="http://test.example.com",
|
|
||||||
headers={"Authorization": "Bearer test"},
|
|
||||||
timeout=30.0,
|
|
||||||
sse_read_timeout=60.0,
|
|
||||||
provider_entity=mock_provider_entity,
|
|
||||||
auth_callback=auth_callback,
|
|
||||||
authorization_code="test-auth-code",
|
|
||||||
by_server_id=True,
|
|
||||||
mcp_service=mock_mcp_service,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert client.server_url == "http://test.example.com"
|
|
||||||
assert client.headers == {"Authorization": "Bearer test"}
|
|
||||||
assert client.timeout == 30.0
|
|
||||||
assert client.sse_read_timeout == 60.0
|
|
||||||
assert client.provider_entity == mock_provider_entity
|
|
||||||
assert client.auth_callback == auth_callback
|
|
||||||
assert client.authorization_code == "test-auth-code"
|
|
||||||
assert client.by_server_id is True
|
|
||||||
assert client.mcp_service == mock_mcp_service
|
|
||||||
assert client._has_retried is False
|
|
||||||
# In inheritance design, we don't have _client attribute
|
|
||||||
assert hasattr(client, "_session") # Inherited from MCPClient
|
|
||||||
|
|
||||||
def test_inheritance_structure(self):
|
|
||||||
"""Test that MCPClientWithAuthRetry properly inherits from MCPClient."""
|
|
||||||
client = MCPClientWithAuthRetry(
|
|
||||||
server_url="http://test.example.com",
|
|
||||||
headers={"Authorization": "Bearer test"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify inheritance
|
|
||||||
assert isinstance(client, MCPClient)
|
|
||||||
|
|
||||||
# Verify inherited attributes are accessible
|
|
||||||
assert hasattr(client, "server_url")
|
|
||||||
assert hasattr(client, "headers")
|
|
||||||
assert hasattr(client, "_session")
|
|
||||||
assert hasattr(client, "_exit_stack")
|
|
||||||
assert hasattr(client, "_initialized")
|
|
||||||
|
|
||||||
def test_handle_auth_error_no_retry_components(self):
|
|
||||||
"""Test auth error handling when retry components are missing."""
|
|
||||||
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
|
|
||||||
error = MCPAuthError("Auth failed")
|
|
||||||
|
|
||||||
with pytest.raises(MCPAuthError) as exc_info:
|
|
||||||
client._handle_auth_error(error)
|
|
||||||
|
|
||||||
assert exc_info.value == error
|
|
||||||
|
|
||||||
def test_handle_auth_error_already_retried(self, mock_provider_entity, mock_mcp_service, auth_callback):
|
|
||||||
"""Test auth error handling when already retried."""
|
|
||||||
client = MCPClientWithAuthRetry(
|
|
||||||
server_url="http://test.example.com",
|
|
||||||
provider_entity=mock_provider_entity,
|
|
||||||
auth_callback=auth_callback,
|
|
||||||
mcp_service=mock_mcp_service,
|
|
||||||
)
|
|
||||||
client._has_retried = True
|
|
||||||
error = MCPAuthError("Auth failed")
|
|
||||||
|
|
||||||
with pytest.raises(MCPAuthError) as exc_info:
|
|
||||||
client._handle_auth_error(error)
|
|
||||||
|
|
||||||
assert exc_info.value == error
|
|
||||||
auth_callback.assert_not_called()
|
|
||||||
|
|
||||||
def test_handle_auth_error_successful_refresh(self, mock_provider_entity, mock_mcp_service, auth_callback):
|
|
||||||
"""Test successful auth refresh on error."""
|
|
||||||
client = MCPClientWithAuthRetry(
|
|
||||||
server_url="http://test.example.com",
|
|
||||||
provider_entity=mock_provider_entity,
|
|
||||||
auth_callback=auth_callback,
|
|
||||||
authorization_code="test-code",
|
|
||||||
by_server_id=True,
|
|
||||||
mcp_service=mock_mcp_service,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Configure mocks
|
|
||||||
new_provider = Mock(spec=MCPProviderEntity)
|
|
||||||
new_provider.id = "test-provider-id"
|
|
||||||
new_provider.tenant_id = "test-tenant-id"
|
|
||||||
new_provider.retrieve_tokens.return_value = Mock(
|
|
||||||
access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None
|
|
||||||
)
|
|
||||||
mock_mcp_service.get_provider_entity.return_value = new_provider
|
|
||||||
|
|
||||||
error = MCPAuthError("Auth failed")
|
|
||||||
client._handle_auth_error(error)
|
|
||||||
|
|
||||||
# Verify auth flow
|
|
||||||
auth_callback.assert_called_once_with(mock_provider_entity, mock_mcp_service, "test-code")
|
|
||||||
mock_mcp_service.get_provider_entity.assert_called_once_with(
|
|
||||||
"test-provider-id", "test-tenant-id", by_server_id=True
|
|
||||||
)
|
|
||||||
assert client.headers["Authorization"] == "Bearer new-token"
|
|
||||||
assert client.authorization_code is None # Should be cleared after use
|
|
||||||
assert client._has_retried is True
|
|
||||||
|
|
||||||
def test_handle_auth_error_refresh_fails(self, mock_provider_entity, mock_mcp_service, auth_callback):
|
|
||||||
"""Test auth refresh failure."""
|
|
||||||
client = MCPClientWithAuthRetry(
|
|
||||||
server_url="http://test.example.com",
|
|
||||||
provider_entity=mock_provider_entity,
|
|
||||||
auth_callback=auth_callback,
|
|
||||||
mcp_service=mock_mcp_service,
|
|
||||||
)
|
|
||||||
|
|
||||||
auth_callback.side_effect = Exception("Auth callback failed")
|
|
||||||
|
|
||||||
error = MCPAuthError("Original auth failed")
|
|
||||||
with pytest.raises(MCPAuthError) as exc_info:
|
|
||||||
client._handle_auth_error(error)
|
|
||||||
|
|
||||||
assert "Authentication retry failed" in str(exc_info.value)
|
|
||||||
|
|
||||||
def test_handle_auth_error_no_token_received(self, mock_provider_entity, mock_mcp_service, auth_callback):
|
|
||||||
"""Test auth refresh when no token is received."""
|
|
||||||
client = MCPClientWithAuthRetry(
|
|
||||||
server_url="http://test.example.com",
|
|
||||||
provider_entity=mock_provider_entity,
|
|
||||||
auth_callback=auth_callback,
|
|
||||||
mcp_service=mock_mcp_service,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Configure mock to return no token
|
|
||||||
new_provider = Mock(spec=MCPProviderEntity)
|
|
||||||
new_provider.retrieve_tokens.return_value = None
|
|
||||||
mock_mcp_service.get_provider_entity.return_value = new_provider
|
|
||||||
|
|
||||||
error = MCPAuthError("Auth failed")
|
|
||||||
with pytest.raises(MCPAuthError) as exc_info:
|
|
||||||
client._handle_auth_error(error)
|
|
||||||
|
|
||||||
assert "no token received" in str(exc_info.value)
|
|
||||||
|
|
||||||
def test_execute_with_retry_success(self):
|
|
||||||
"""Test successful execution without retry."""
|
|
||||||
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
|
|
||||||
|
|
||||||
mock_func = Mock(return_value="success")
|
|
||||||
result = client._execute_with_retry(mock_func, "arg1", kwarg1="value1")
|
|
||||||
|
|
||||||
assert result == "success"
|
|
||||||
mock_func.assert_called_once_with("arg1", kwarg1="value1")
|
|
||||||
assert client._has_retried is False
|
|
||||||
|
|
||||||
def test_execute_with_retry_auth_error_then_success(self, mock_provider_entity, mock_mcp_service, auth_callback):
|
|
||||||
"""Test execution with auth error followed by successful retry."""
|
|
||||||
client = MCPClientWithAuthRetry(
|
|
||||||
server_url="http://test.example.com",
|
|
||||||
provider_entity=mock_provider_entity,
|
|
||||||
auth_callback=auth_callback,
|
|
||||||
mcp_service=mock_mcp_service,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Configure new provider with token
|
|
||||||
new_provider = Mock(spec=MCPProviderEntity)
|
|
||||||
new_provider.retrieve_tokens.return_value = Mock(
|
|
||||||
access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None
|
|
||||||
)
|
|
||||||
mock_mcp_service.get_provider_entity.return_value = new_provider
|
|
||||||
|
|
||||||
# Mock function that fails first, then succeeds
|
|
||||||
mock_func = Mock(side_effect=[MCPAuthError("Auth failed"), "success"])
|
|
||||||
|
|
||||||
# Mock the exit stack and session cleanup
|
|
||||||
with (
|
|
||||||
patch.object(client, "_exit_stack") as mock_exit_stack,
|
|
||||||
patch.object(client, "_session") as mock_session,
|
|
||||||
patch.object(client, "_initialize") as mock_initialize,
|
|
||||||
):
|
|
||||||
client._initialized = True
|
|
||||||
result = client._execute_with_retry(mock_func, "arg1", kwarg1="value1")
|
|
||||||
|
|
||||||
assert result == "success"
|
|
||||||
assert mock_func.call_count == 2
|
|
||||||
mock_func.assert_called_with("arg1", kwarg1="value1")
|
|
||||||
auth_callback.assert_called_once()
|
|
||||||
mock_exit_stack.close.assert_called_once()
|
|
||||||
mock_initialize.assert_called_once()
|
|
||||||
assert client._has_retried is False # Reset after completion
|
|
||||||
|
|
||||||
def test_execute_with_retry_non_auth_error(self):
|
|
||||||
"""Test execution with non-auth error (no retry)."""
|
|
||||||
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
|
|
||||||
|
|
||||||
mock_func = Mock(side_effect=ValueError("Some other error"))
|
|
||||||
|
|
||||||
with pytest.raises(ValueError) as exc_info:
|
|
||||||
client._execute_with_retry(mock_func)
|
|
||||||
|
|
||||||
assert str(exc_info.value) == "Some other error"
|
|
||||||
mock_func.assert_called_once()
|
|
||||||
|
|
||||||
def test_context_manager_enter(self):
|
|
||||||
"""Test context manager enter."""
|
|
||||||
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
|
|
||||||
|
|
||||||
with patch.object(client, "_initialize") as mock_initialize:
|
|
||||||
result = client.__enter__()
|
|
||||||
|
|
||||||
assert result == client
|
|
||||||
assert client._initialized is True
|
|
||||||
mock_initialize.assert_called_once()
|
|
||||||
|
|
||||||
def test_context_manager_enter_with_auth_error(self, mock_provider_entity, mock_mcp_service, auth_callback):
|
|
||||||
"""Test context manager enter with auth error and retry."""
|
|
||||||
# Configure new provider with token
|
|
||||||
new_provider = Mock(spec=MCPProviderEntity)
|
|
||||||
new_provider.retrieve_tokens.return_value = Mock(
|
|
||||||
access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None
|
|
||||||
)
|
|
||||||
mock_mcp_service.get_provider_entity.return_value = new_provider
|
|
||||||
|
|
||||||
client = MCPClientWithAuthRetry(
|
|
||||||
server_url="http://test.example.com",
|
|
||||||
provider_entity=mock_provider_entity,
|
|
||||||
auth_callback=auth_callback,
|
|
||||||
mcp_service=mock_mcp_service,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock parent class __enter__ to raise auth error first, then succeed
|
|
||||||
with patch.object(MCPClient, "__enter__") as mock_parent_enter:
|
|
||||||
mock_parent_enter.side_effect = [MCPAuthError("Auth failed"), client]
|
|
||||||
|
|
||||||
result = client.__enter__()
|
|
||||||
|
|
||||||
assert result == client
|
|
||||||
assert mock_parent_enter.call_count == 2
|
|
||||||
auth_callback.assert_called_once()
|
|
||||||
|
|
||||||
def test_context_manager_exit(self):
|
|
||||||
"""Test context manager exit."""
|
|
||||||
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
|
|
||||||
|
|
||||||
with patch.object(client, "cleanup") as mock_cleanup:
|
|
||||||
exc_type: type[BaseException] | None = None
|
|
||||||
exc_val: BaseException | None = None
|
|
||||||
exc_tb: TracebackType | None = None
|
|
||||||
client.__exit__(exc_type, exc_val, exc_tb)
|
|
||||||
|
|
||||||
mock_cleanup.assert_called_once()
|
|
||||||
|
|
||||||
def test_list_tools_not_initialized(self):
|
|
||||||
"""Test list_tools when client not initialized."""
|
|
||||||
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
|
|
||||||
|
|
||||||
with pytest.raises(ValueError) as exc_info:
|
|
||||||
client.list_tools()
|
|
||||||
|
|
||||||
assert "Session not initialized" in str(exc_info.value)
|
|
||||||
|
|
||||||
def test_list_tools_success(self):
|
|
||||||
"""Test successful list_tools call."""
|
|
||||||
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
|
|
||||||
|
|
||||||
expected_tools = [
|
|
||||||
Tool(
|
|
||||||
name="test-tool",
|
|
||||||
description="A test tool",
|
|
||||||
inputSchema={"type": "object", "properties": {}},
|
|
||||||
annotations=ToolAnnotations(title="Test Tool"),
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Mock the parent class list_tools method
|
|
||||||
with patch.object(MCPClient, "list_tools", return_value=expected_tools):
|
|
||||||
result = client.list_tools()
|
|
||||||
assert result == expected_tools
|
|
||||||
|
|
||||||
def test_list_tools_with_auth_retry(self, mock_provider_entity, mock_mcp_service, auth_callback):
|
|
||||||
"""Test list_tools with auth retry."""
|
|
||||||
client = MCPClientWithAuthRetry(
|
|
||||||
server_url="http://test.example.com",
|
|
||||||
provider_entity=mock_provider_entity,
|
|
||||||
auth_callback=auth_callback,
|
|
||||||
mcp_service=mock_mcp_service,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Configure new provider with token
|
|
||||||
new_provider = Mock(spec=MCPProviderEntity)
|
|
||||||
new_provider.retrieve_tokens.return_value = Mock(
|
|
||||||
access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None
|
|
||||||
)
|
|
||||||
mock_mcp_service.get_provider_entity.return_value = new_provider
|
|
||||||
|
|
||||||
expected_tools = [Tool(name="test-tool", description="A test tool", inputSchema={})]
|
|
||||||
|
|
||||||
# Mock parent class list_tools to raise auth error first, then succeed
|
|
||||||
with patch.object(MCPClient, "list_tools") as mock_list_tools:
|
|
||||||
mock_list_tools.side_effect = [MCPAuthError("Auth failed"), expected_tools]
|
|
||||||
|
|
||||||
result = client.list_tools()
|
|
||||||
|
|
||||||
assert result == expected_tools
|
|
||||||
assert mock_list_tools.call_count == 2
|
|
||||||
auth_callback.assert_called_once()
|
|
||||||
|
|
||||||
def test_invoke_tool_not_initialized(self):
|
|
||||||
"""Test invoke_tool when client not initialized."""
|
|
||||||
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
|
|
||||||
|
|
||||||
with pytest.raises(ValueError) as exc_info:
|
|
||||||
client.invoke_tool("test-tool", {"arg": "value"})
|
|
||||||
|
|
||||||
assert "Session not initialized" in str(exc_info.value)
|
|
||||||
|
|
||||||
def test_invoke_tool_success(self):
|
|
||||||
"""Test successful invoke_tool call."""
|
|
||||||
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
|
|
||||||
|
|
||||||
expected_result = CallToolResult(
|
|
||||||
content=[TextContent(type="text", text="Tool executed successfully")], isError=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock the parent class invoke_tool method
|
|
||||||
with patch.object(MCPClient, "invoke_tool", return_value=expected_result) as mock_invoke:
|
|
||||||
result = client.invoke_tool("test-tool", {"arg": "value"})
|
|
||||||
|
|
||||||
assert result == expected_result
|
|
||||||
mock_invoke.assert_called_once_with("test-tool", {"arg": "value"})
|
|
||||||
|
|
||||||
def test_invoke_tool_with_auth_retry(self, mock_provider_entity, mock_mcp_service, auth_callback):
|
|
||||||
"""Test invoke_tool with auth retry."""
|
|
||||||
client = MCPClientWithAuthRetry(
|
|
||||||
server_url="http://test.example.com",
|
|
||||||
provider_entity=mock_provider_entity,
|
|
||||||
auth_callback=auth_callback,
|
|
||||||
mcp_service=mock_mcp_service,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Configure new provider with token
|
|
||||||
new_provider = Mock(spec=MCPProviderEntity)
|
|
||||||
new_provider.retrieve_tokens.return_value = Mock(
|
|
||||||
access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None
|
|
||||||
)
|
|
||||||
mock_mcp_service.get_provider_entity.return_value = new_provider
|
|
||||||
|
|
||||||
expected_result = CallToolResult(content=[TextContent(type="text", text="Success")], isError=False)
|
|
||||||
|
|
||||||
# Mock parent class invoke_tool to raise auth error first, then succeed
|
|
||||||
with patch.object(MCPClient, "invoke_tool") as mock_invoke_tool:
|
|
||||||
mock_invoke_tool.side_effect = [MCPAuthError("Auth failed"), expected_result]
|
|
||||||
|
|
||||||
result = client.invoke_tool("test-tool", {"arg": "value"})
|
|
||||||
|
|
||||||
assert result == expected_result
|
|
||||||
assert mock_invoke_tool.call_count == 2
|
|
||||||
mock_invoke_tool.assert_called_with("test-tool", {"arg": "value"})
|
|
||||||
auth_callback.assert_called_once()
|
|
||||||
|
|
||||||
def test_cleanup(self):
|
|
||||||
"""Test cleanup method."""
|
|
||||||
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
|
|
||||||
|
|
||||||
# Mock the parent class cleanup method
|
|
||||||
with patch.object(MCPClient, "cleanup") as mock_cleanup:
|
|
||||||
client.cleanup()
|
|
||||||
mock_cleanup.assert_called_once()
|
|
||||||
|
|
||||||
def test_cleanup_no_client(self):
|
|
||||||
"""Test cleanup when no client exists."""
|
|
||||||
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
|
|
||||||
|
|
||||||
# Should not raise
|
|
||||||
client.cleanup()
|
|
||||||
|
|
||||||
# Since MCPClientWithAuthRetry inherits from MCPClient,
|
|
||||||
# it doesn't have a _client attribute. The test should just
|
|
||||||
# verify that cleanup can be called without error.
|
|
||||||
assert not hasattr(client, "_client")
|
|
||||||
Loading…
Reference in New Issue
Block a user