refactor(mcp): clean the oauth code

This commit is contained in:
Novice 2025-09-16 14:16:38 +08:00
parent aed9955105
commit f137af4ec5
7 changed files with 309 additions and 278 deletions

View File

@ -956,7 +956,7 @@ class ToolMCPAuthApi(Resource):
with Session(db.engine) as session:
service = MCPToolManageService(session=session)
db_provider = service.get_provider_by_id(provider_id, tenant_id)
db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
if not db_provider:
raise ValueError("provider not found")
@ -984,6 +984,7 @@ class ToolMCPAuthApi(Resource):
else None, # Only use auth retry if no custom headers
auth_callback=auth if not provider_entity.headers else None,
authorization_code=args.get("authorization_code"),
mcp_service=service,
):
service.update_provider_credentials(
provider=db_provider,
@ -1007,7 +1008,7 @@ class ToolMCPDetailApi(Resource):
user = current_user
with Session(db.engine) as session:
service = MCPToolManageService(session=session)
provider = service.get_provider_by_id(provider_id, user.current_tenant_id)
provider = service.get_provider(provider_id=provider_id, tenant_id=user.current_tenant_id)
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
@ -1049,7 +1050,13 @@ class ToolMCPCallbackApi(Resource):
args = parser.parse_args()
state_key = args["state"]
authorization_code = args["code"]
handle_callback(state_key, authorization_code)
# Create service instance for handle_callback
with Session(db.engine) as session:
mcp_service = MCPToolManageService(session=session)
handle_callback(state_key, authorization_code, mcp_service)
session.commit()
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")

View File

@ -4,12 +4,11 @@ import json
import os
import secrets
import urllib.parse
from typing import Optional
from typing import TYPE_CHECKING, Optional
from urllib.parse import urljoin, urlparse
import httpx
from pydantic import BaseModel, ValidationError
from sqlalchemy.orm import Session
from core.entities.mcp_provider import MCPProviderEntity
from core.mcp.types import (
@ -20,9 +19,10 @@ from core.mcp.types import (
OAuthMetadata,
OAuthTokens,
)
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from services.tools.mcp_oauth_service import MCPOAuthService
if TYPE_CHECKING:
from services.tools.mcp_tools_manage_service import MCPToolManageService
OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry
OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
@ -84,7 +84,7 @@ def _retrieve_redis_state(state_key: str) -> OAuthCallbackState:
raise ValueError(f"Invalid state parameter: {str(e)}")
def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackState:
def handle_callback(state_key: str, authorization_code: str, mcp_service: "MCPToolManageService") -> OAuthCallbackState:
"""Handle the callback from the OAuth provider."""
# Retrieve state data from Redis (state is automatically deleted after retrieval)
full_state_data = _retrieve_redis_state(state_key)
@ -99,10 +99,7 @@ def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackSta
)
# Save tokens using the service layer
with Session(db.engine) as session:
oauth_service = MCPOAuthService(session=session)
oauth_service.save_tokens(full_state_data.provider_id, full_state_data.tenant_id, tokens)
session.commit()
mcp_service.save_oauth_data(full_state_data.provider_id, full_state_data.tenant_id, tokens.model_dump(), "tokens")
return full_state_data
@ -304,6 +301,7 @@ def register_client(
def auth(
provider: MCPProviderEntity,
mcp_service: "MCPToolManageService",
authorization_code: Optional[str] = None,
state_param: Optional[str] = None,
) -> dict[str, str]:
@ -325,10 +323,9 @@ def auth(
raise ValueError(f"Could not register OAuth client: {e}")
# Save client information using service layer
with Session(db.engine) as session:
oauth_service = MCPOAuthService(session=session)
oauth_service.save_client_information(provider_id, tenant_id, full_information)
session.commit()
mcp_service.save_oauth_data(
provider_id, tenant_id, {"client_information": full_information.model_dump()}, "client_info"
)
client_information = full_information
@ -360,10 +357,7 @@ def auth(
)
# Save tokens using service layer
with Session(db.engine) as session:
oauth_service = MCPOAuthService(session=session)
oauth_service.save_tokens(provider_id, tenant_id, tokens)
session.commit()
mcp_service.save_oauth_data(provider_id, tenant_id, tokens.model_dump(), "tokens")
return {"result": "success"}
@ -377,10 +371,7 @@ def auth(
)
# Save new tokens using service layer
with Session(db.engine) as session:
oauth_service = MCPOAuthService(session=session)
oauth_service.save_tokens(provider_id, tenant_id, new_tokens)
session.commit()
mcp_service.save_oauth_data(provider_id, tenant_id, new_tokens.model_dump(), "tokens")
return {"result": "success"}
except Exception as e:
@ -397,9 +388,6 @@ def auth(
)
# Save code verifier using service layer
with Session(db.engine) as session:
oauth_service = MCPOAuthService(session=session)
oauth_service.save_code_verifier(provider_id, tenant_id, code_verifier)
session.commit()
mcp_service.save_oauth_data(provider_id, tenant_id, {"code_verifier": code_verifier}, "code_verifier")
return {"authorization_url": authorization_url}

View File

@ -8,15 +8,15 @@ authentication failures and retries operations after refreshing tokens.
import logging
from collections.abc import Callable
from types import TracebackType
from typing import Any, Optional
from sqlalchemy.orm import Session
from typing import TYPE_CHECKING, Any, Optional
from core.entities.mcp_provider import MCPProviderEntity
from core.mcp.error import MCPAuthError
from core.mcp.mcp_client import MCPClient
from core.mcp.types import CallToolResult, Tool
from extensions.ext_database import db
if TYPE_CHECKING:
from services.tools.mcp_tools_manage_service import MCPToolManageService
logger = logging.getLogger(__name__)
@ -36,9 +36,11 @@ class MCPClientWithAuthRetry:
timeout: float | None = None,
sse_read_timeout: float | None = None,
provider_entity: MCPProviderEntity | None = None,
auth_callback: Callable[[MCPProviderEntity, Optional[str]], dict[str, str]] | None = None,
auth_callback: Callable[[MCPProviderEntity, "MCPToolManageService", Optional[str]], dict[str, str]]
| None = None,
authorization_code: Optional[str] = None,
by_server_id: bool = False,
mcp_service: Optional["MCPToolManageService"] = None,
):
"""
Initialize the MCP client with auth retry capability.
@ -62,6 +64,7 @@ class MCPClientWithAuthRetry:
self._has_retried = False
self._client: MCPClient | None = None
self.by_server_id = by_server_id
self.mcp_service = mcp_service
def _create_client(self) -> MCPClient:
"""Create a new MCPClient instance with current headers."""
@ -82,11 +85,8 @@ class MCPClientWithAuthRetry:
Raises:
MCPAuthError: If authentication fails or max retries reached
"""
from services.tools.mcp_oauth_service import MCPOAuthService
if not self.provider_entity or not self.auth_callback:
if not self.provider_entity or not self.auth_callback or not self.mcp_service:
raise error
if self._has_retried:
raise error
@ -94,14 +94,12 @@ class MCPClientWithAuthRetry:
try:
# Perform authentication
self.auth_callback(self.provider_entity, self.authorization_code)
self.auth_callback(self.provider_entity, self.mcp_service, self.authorization_code)
# Retrieve new tokens
with Session(db.engine) as session:
oauth_service = MCPOAuthService(session=session)
self.provider_entity = oauth_service.get_provider_entity(
self.provider_entity.id, self.provider_entity.tenant_id, by_server_id=self.by_server_id
)
self.provider_entity = self.mcp_service.get_provider_entity(
self.provider_entity.id, self.provider_entity.tenant_id, by_server_id=self.by_server_id
)
token = self.provider_entity.retrieve_tokens()
if not token:
raise MCPAuthError("Authentication failed - no token received")

View File

@ -119,23 +119,39 @@ class MCPTool(Tool):
tool_parameters = self._handle_none_parameter(tool_parameters)
# Get provider entity to access tokens
from sqlalchemy.orm import Session
from typing import TYPE_CHECKING
from extensions.ext_database import db
from services.tools.mcp_oauth_service import MCPOAuthService
if TYPE_CHECKING:
pass
# Get MCP service from invoke parameters or create new one
provider_entity = None
mcp_service = None
# Check if mcp_service is passed in tool_parameters
if "_mcp_service" in tool_parameters:
mcp_service = tool_parameters.pop("_mcp_service")
else:
# Fallback to creating service with database session
from sqlalchemy.orm import Session
from extensions.ext_database import db
from services.tools.mcp_tools_manage_service import MCPToolManageService
try:
with Session(db.engine) as session:
oauth_service = MCPOAuthService(session=session)
provider_entity = oauth_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True)
mcp_service = MCPToolManageService(session=session)
if mcp_service:
try:
provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True)
# Try to get existing token and add to headers
tokens = provider_entity.retrieve_tokens()
if tokens and tokens.access_token:
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
except Exception:
# If provider retrieval or token fails, continue without auth
pass
except Exception:
# If provider retrieval or token fails, continue without auth
pass
# Use MCPClientWithAuthRetry to handle authentication automatically
try:
@ -145,8 +161,9 @@ class MCPTool(Tool):
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
provider_entity=provider_entity,
auth_callback=auth,
auth_callback=auth if mcp_service else None,
by_server_id=True,
mcp_service=mcp_service,
) as mcp_client:
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
except MCPConnectionError as e:

View File

@ -775,7 +775,7 @@ class ToolManager:
with Session(db.engine) as session:
mcp_service = MCPToolManageService(session=session)
try:
provider = mcp_service.get_provider_by_server_identifier(provider_id, tenant_id)
provider = mcp_service.get_provider(server_identifier=provider_id, tenant_id=tenant_id)
except ValueError:
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
@ -918,7 +918,9 @@ class ToolManager:
with Session(db.engine) as session:
mcp_service = MCPToolManageService(session=session)
try:
mcp_provider = mcp_service.get_provider_by_server_identifier(provider_id, tenant_id)
mcp_provider = mcp_service.get_provider_entity(
provider_id=provider_id, tenant_id=tenant_id, by_server_id=True
)
return mcp_provider.provider_icon
except ValueError:
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")

View File

@ -1,53 +0,0 @@
"""
MCP OAuth Service - handles OAuth-related database operations
"""
from sqlalchemy.orm import Session
from core.entities.mcp_provider import MCPProviderEntity
from core.mcp.types import OAuthClientInformationFull, OAuthTokens
from services.tools.mcp_tools_manage_service import MCPToolManageService
class MCPOAuthService:
"""Service for handling MCP OAuth operations"""
def __init__(self, session: Session):
self._session = session
self._mcp_service = MCPToolManageService(session=session)
def get_provider_entity(self, provider_id: str, tenant_id: str, by_server_id: bool = False) -> MCPProviderEntity:
"""Get provider entity by ID"""
if by_server_id:
db_provider = self._mcp_service.get_provider_by_server_identifier(provider_id, tenant_id)
else:
db_provider = self._mcp_service.get_provider_by_id(provider_id, tenant_id)
return db_provider.to_entity()
def save_client_information(
self, provider_id: str, tenant_id: str, client_information: OAuthClientInformationFull
) -> None:
"""Save OAuth client information"""
db_provider = self._mcp_service.get_provider_by_id(provider_id, tenant_id)
self._mcp_service.update_provider_credentials(
provider=db_provider,
credentials={"client_information": client_information.model_dump()},
)
def save_tokens(self, provider_id: str, tenant_id: str, tokens: OAuthTokens, authed: bool = True) -> None:
"""Save OAuth tokens"""
db_provider = self._mcp_service.get_provider_by_id(provider_id, tenant_id)
token_dict = tokens.model_dump()
self._mcp_service.update_provider_credentials(provider=db_provider, credentials=token_dict, authed=authed)
def save_code_verifier(self, provider_id: str, tenant_id: str, code_verifier: str) -> None:
"""Save PKCE code verifier"""
db_provider = self._mcp_service.get_provider_by_id(provider_id, tenant_id)
self._mcp_service.update_provider_credentials(
provider=db_provider, credentials={"code_verifier": code_verifier}
)
def clear_credentials(self, provider_id: str, tenant_id: str) -> None:
"""Clear provider credentials"""
db_provider = self._mcp_service.get_provider_by_id(provider_id, tenant_id)
self._mcp_service.clear_provider_credentials(provider=db_provider)

View File

@ -14,7 +14,6 @@ from core.helper import encrypter
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPAuthError, MCPError
from core.mcp.types import OAuthTokens
from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.utils.encryption import ProviderConfigEncrypter
from models.tools import MCPToolProvider
@ -26,85 +25,51 @@ UNCHANGED_SERVER_URL_PLACEHOLDER = "[__HIDDEN__]"
class MCPToolManageService:
"""
Service class for managing mcp tools.
"""
"""Service class for managing MCP tools and providers."""
def __init__(self, session: Session):
self._session = session
def _encrypt_headers(self, headers: dict[str, str], tenant_id: str) -> dict[str, str]:
# ========== Provider CRUD Operations ==========
def get_provider(
self, *, provider_id: Optional[str] = None, server_identifier: Optional[str] = None, tenant_id: str
) -> MCPToolProvider:
"""
Encrypt headers using ProviderConfigEncrypter with all headers as SECRET_INPUT.
Get MCP provider by ID or server identifier.
Args:
headers: Dictionary of headers to encrypt
tenant_id: Tenant ID for encryption
provider_id: Provider ID (UUID)
server_identifier: Server identifier
tenant_id: Tenant ID
Returns:
Dictionary with all headers encrypted
MCPToolProvider instance
Raises:
ValueError: If provider not found
"""
if not headers:
return {}
if server_identifier:
stmt = select(MCPToolProvider).where(
MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier
)
else:
stmt = select(MCPToolProvider).where(
MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id
)
from core.entities.provider_entities import BasicProviderConfig
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.tools.utils.encryption import create_provider_encrypter
# Create dynamic config for all headers as SECRET_INPUT
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers]
encrypter_instance, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=config,
cache=NoOpProviderCredentialCache(),
)
return encrypter_instance.encrypt(headers)
def _retrieve_remote_mcp_tools(
self,
server_url: str,
headers: dict[str, str],
provider_entity: MCPProviderEntity,
auth_callback: Callable[[MCPProviderEntity, Optional[str]], dict[str, str]],
):
"""Retrieve tools from remote MCP server"""
with MCPClientWithAuthRetry(
server_url,
headers=headers,
timeout=provider_entity.timeout,
sse_read_timeout=provider_entity.sse_read_timeout,
provider_entity=provider_entity,
auth_callback=auth_callback,
) as mcp_client:
tools = mcp_client.list_tools()
return tools
def _process_headers(self, headers: dict[str, str], tokens: OAuthTokens | None = None) -> dict[str, str]:
"""Process headers and add OAuth token if available"""
headers = headers.copy() if headers else {}
if tokens:
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
return headers
def get_provider_by_id(self, provider_id: str, tenant_id: str) -> MCPToolProvider:
"""Get MCP provider by ID"""
stmt = select(MCPToolProvider).where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id)
provider = self._session.scalar(stmt)
if not provider:
raise ValueError("MCP tool not found")
return provider
def get_provider_by_server_identifier(self, server_identifier: str, tenant_id: str) -> MCPToolProvider:
"""Get MCP provider by server identifier"""
stmt = select(MCPToolProvider).where(
MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier
)
provider = self._session.scalar(stmt)
if not provider:
raise ValueError("MCP tool not found")
return provider
def get_provider_entity(self, provider_id: str, tenant_id: str, by_server_id: bool = False) -> MCPProviderEntity:
"""Get provider entity by ID or server identifier."""
if by_server_id:
db_provider = self.get_provider(server_identifier=provider_id, tenant_id=tenant_id)
else:
db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
return db_provider.to_entity()
def create_provider(
self,
@ -121,36 +86,15 @@ class MCPToolManageService:
sse_read_timeout: float,
headers: dict[str, str] | None = None,
) -> ToolProviderApiEntity:
"""Create a new MCP provider"""
"""Create a new MCP provider."""
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
# Check for existing provider
stmt = select(MCPToolProvider).where(
MCPToolProvider.tenant_id == tenant_id,
or_(
MCPToolProvider.name == name,
MCPToolProvider.server_url_hash == server_url_hash,
MCPToolProvider.server_identifier == server_identifier,
),
)
existing_provider = self._session.scalar(stmt)
self._check_provider_exists(tenant_id, name, server_url_hash, server_identifier)
if existing_provider:
if existing_provider.name == name:
raise ValueError(f"MCP tool {name} already exists")
if existing_provider.server_url_hash == server_url_hash:
raise ValueError(f"MCP tool {server_url} already exists")
if existing_provider.server_identifier == server_identifier:
raise ValueError(f"MCP tool {server_identifier} already exists")
# Encrypt server URL
# Encrypt sensitive data
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
# Encrypt headers
encrypted_headers = None
if headers:
encrypted_headers_dict = self._encrypt_headers(headers, tenant_id)
encrypted_headers = json.dumps(encrypted_headers_dict)
encrypted_headers = self._prepare_encrypted_headers(headers, tenant_id) if headers else None
# Create provider
mcp_tool = MCPToolProvider(
@ -161,7 +105,7 @@ class MCPToolManageService:
user_id=user_id,
authed=False,
tools="[]",
icon=json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon,
icon=self._prepare_icon(icon, icon_type, icon_background),
server_identifier=server_identifier,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
@ -173,59 +117,6 @@ class MCPToolManageService:
return ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
def list_providers(self, *, tenant_id: str, for_list: bool = False) -> list[ToolProviderApiEntity]:
"""List all MCP providers for a tenant"""
stmt = select(MCPToolProvider).where(MCPToolProvider.tenant_id == tenant_id).order_by(MCPToolProvider.name)
mcp_providers = self._session.scalars(stmt).all()
return [
ToolTransformService.mcp_provider_to_user_provider(provider, for_list=for_list)
for provider in mcp_providers
]
def list_provider_tools(self, *, tenant_id: str, provider_id: str) -> ToolProviderApiEntity:
"""List tools from remote MCP server"""
from core.mcp.auth.auth_flow import auth
# Load provider and convert to entity
db_provider = self.get_provider_by_id(provider_id, tenant_id)
provider_entity = db_provider.to_entity()
# Handle authentication headers if authed
if not provider_entity.authed:
raise ValueError("Please auth the tool first")
tokens = provider_entity.retrieve_tokens()
headers = self._process_headers(provider_entity.headers, tokens)
server_url = provider_entity.decrypt_server_url()
try:
tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity, auth)
except MCPError as e:
raise ValueError(f"Failed to connect to MCP server: {e}")
# Update database record with new tools
db_provider.tools = json.dumps([tool.model_dump() for tool in tools])
db_provider.authed = True
db_provider.updated_at = datetime.now()
self._session.commit()
# Create API response using entity
user = db_provider.load_user()
response = provider_entity.to_api_response(
user_name=user.name if user else None,
)
response["tools"] = ToolTransformService.mcp_tool_to_user_tool(db_provider, tools)
response["plugin_unique_identifier"] = provider_entity.provider_id
return ToolProviderApiEntity(**response)
def delete_provider(self, *, tenant_id: str, provider_id: str) -> None:
"""Delete an MCP provider"""
mcp_tool = self.get_provider_by_id(provider_id, tenant_id)
self._session.delete(mcp_tool)
self._session.commit()
def update_provider(
self,
*,
@ -241,8 +132,8 @@ class MCPToolManageService:
sse_read_timeout: float | None = None,
headers: dict[str, str] | None = None,
) -> None:
"""Update an MCP provider"""
mcp_provider = self.get_provider_by_id(provider_id, tenant_id)
"""Update an MCP provider."""
mcp_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
reconnect_result = None
encrypted_server_url = None
@ -263,9 +154,7 @@ class MCPToolManageService:
# Update basic fields
mcp_provider.updated_at = datetime.now()
mcp_provider.name = name
mcp_provider.icon = (
json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon
)
mcp_provider.icon = self._prepare_icon(icon, icon_type, icon_background)
mcp_provider.server_identifier = server_identifier
# Update server URL if changed
@ -284,35 +173,86 @@ class MCPToolManageService:
if sse_read_timeout is not None:
mcp_provider.sse_read_timeout = sse_read_timeout
if headers is not None:
# Encrypt headers
if headers:
encrypted_headers_dict = self._encrypt_headers(headers, tenant_id)
mcp_provider.encrypted_headers = json.dumps(encrypted_headers_dict)
else:
mcp_provider.encrypted_headers = None
mcp_provider.encrypted_headers = (
self._prepare_encrypted_headers(headers, tenant_id) if headers else None
)
self._session.commit()
except IntegrityError as e:
self._session.rollback()
error_msg = str(e.orig)
if "unique_mcp_provider_name" in error_msg:
raise ValueError(f"MCP tool {name} already exists")
if "unique_mcp_provider_server_url" in error_msg:
raise ValueError(f"MCP tool {server_url} already exists")
if "unique_mcp_provider_server_identifier" in error_msg:
raise ValueError(f"MCP tool {server_identifier} already exists")
raise
self._handle_integrity_error(e, name, server_url, server_identifier)
except Exception:
self._session.rollback()
raise
def delete_provider(self, *, tenant_id: str, provider_id: str) -> None:
"""Delete an MCP provider."""
mcp_tool = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
self._session.delete(mcp_tool)
self._session.commit()
def list_providers(self, *, tenant_id: str, for_list: bool = False) -> list[ToolProviderApiEntity]:
"""List all MCP providers for a tenant."""
stmt = select(MCPToolProvider).where(MCPToolProvider.tenant_id == tenant_id).order_by(MCPToolProvider.name)
mcp_providers = self._session.scalars(stmt).all()
return [
ToolTransformService.mcp_provider_to_user_provider(provider, for_list=for_list)
for provider in mcp_providers
]
# ========== Tool Operations ==========
def list_provider_tools(self, *, tenant_id: str, provider_id: str) -> ToolProviderApiEntity:
"""List tools from remote MCP server."""
from core.mcp.auth.auth_flow import auth
# Load provider and convert to entity
db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
provider_entity = db_provider.to_entity()
# Verify authentication
if not provider_entity.authed:
raise ValueError("Please auth the tool first")
# Prepare headers with auth token
headers = self._prepare_auth_headers(provider_entity)
# Retrieve tools from remote server
server_url = provider_entity.decrypt_server_url()
try:
tools = self._retrieve_remote_mcp_tools(
server_url, headers, provider_entity, lambda p, s, c: auth(p, self, c)
)
except MCPError as e:
raise ValueError(f"Failed to connect to MCP server: {e}")
# Update database with retrieved tools
db_provider.tools = json.dumps([tool.model_dump() for tool in tools])
db_provider.authed = True
db_provider.updated_at = datetime.now()
self._session.commit()
# Build API response
return self._build_tool_provider_response(db_provider, provider_entity, tools)
# ========== OAuth and Credentials Operations ==========
def update_provider_credentials(
self, *, provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False
self, *, provider: MCPToolProvider, credentials: dict[str, Any], authed: bool | None = None
) -> None:
"""Update provider credentials"""
"""
Update provider credentials with encryption.
Args:
provider: Provider instance
credentials: Credentials to save
authed: Whether provider is authenticated (None means keep current state)
"""
from core.tools.mcp_tool.provider import MCPToolProviderController
# Encrypt new credentials
provider_controller = MCPToolProviderController.from_db(provider)
tool_configuration = ProviderConfigEncrypter(
tenant_id=provider.tenant_id,
@ -320,24 +260,130 @@ class MCPToolManageService:
provider_config_cache=NoOpProviderCredentialCache(),
)
encrypted_credentials = tool_configuration.encrypt(credentials)
# Update provider
provider.updated_at = datetime.now()
provider.encrypted_credentials = json.dumps({**provider.credentials, **encrypted_credentials})
provider.authed = authed
if not authed:
provider.tools = "[]"
if authed is not None:
provider.authed = authed
if not authed:
provider.tools = "[]"
self._session.commit()
def save_oauth_data(self, provider_id: str, tenant_id: str, data: dict[str, Any], data_type: str = "mixed") -> None:
"""
Save OAuth-related data (tokens, client info, code verifier).
Args:
provider_id: Provider ID
tenant_id: Tenant ID
data: Data to save (tokens, client info, or code verifier)
data_type: Type of data ('tokens', 'client_info', 'code_verifier', 'mixed')
"""
db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
credentials = {}
authed = None
if data_type == "tokens" or (data_type == "mixed" and "access_token" in data):
# OAuth tokens
credentials = data
authed = True
elif data_type == "client_info" or (data_type == "mixed" and "client_information" in data):
# OAuth client information
credentials = data
elif data_type == "code_verifier" or (data_type == "mixed" and "code_verifier" in data):
# PKCE code verifier
credentials = data
else:
credentials = data
self.update_provider_credentials(provider=db_provider, credentials=credentials, authed=authed)
def clear_provider_credentials(self, *, provider: MCPToolProvider) -> None:
"""Clear provider credentials"""
"""Clear all credentials for a provider."""
provider.tools = "[]"
provider.encrypted_credentials = "{}"
provider.updated_at = datetime.now()
provider.authed = False
self._session.commit()
# ========== Private Helper Methods ==========
def _check_provider_exists(self, tenant_id: str, name: str, server_url_hash: str, server_identifier: str) -> None:
"""Check if provider with same attributes already exists."""
stmt = select(MCPToolProvider).where(
MCPToolProvider.tenant_id == tenant_id,
or_(
MCPToolProvider.name == name,
MCPToolProvider.server_url_hash == server_url_hash,
MCPToolProvider.server_identifier == server_identifier,
),
)
existing_provider = self._session.scalar(stmt)
if existing_provider:
if existing_provider.name == name:
raise ValueError(f"MCP tool {name} already exists")
if existing_provider.server_url_hash == server_url_hash:
raise ValueError("MCP tool with this server URL already exists")
if existing_provider.server_identifier == server_identifier:
raise ValueError(f"MCP tool {server_identifier} already exists")
def _prepare_icon(self, icon: str, icon_type: str, icon_background: str) -> str:
"""Prepare icon data for storage."""
if icon_type == "emoji":
return json.dumps({"content": icon, "background": icon_background})
return icon
def _prepare_encrypted_headers(self, headers: dict[str, str], tenant_id: str) -> str:
"""Encrypt headers and prepare for storage."""
from core.entities.provider_entities import BasicProviderConfig
from core.tools.utils.encryption import create_provider_encrypter
# Create dynamic config for all headers as SECRET_INPUT
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers]
encrypter_instance, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=config,
cache=NoOpProviderCredentialCache(),
)
encrypted_headers_dict = encrypter_instance.encrypt(headers)
return json.dumps(encrypted_headers_dict)
def _prepare_auth_headers(self, provider_entity: MCPProviderEntity) -> dict[str, str]:
"""Prepare headers with OAuth token if available."""
headers = provider_entity.headers.copy() if provider_entity.headers else {}
tokens = provider_entity.retrieve_tokens()
if tokens:
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
return headers
def _retrieve_remote_mcp_tools(
self,
server_url: str,
headers: dict[str, str],
provider_entity: MCPProviderEntity,
auth_callback: Callable[[MCPProviderEntity, "MCPToolManageService", Optional[str]], dict[str, str]],
):
"""Retrieve tools from remote MCP server."""
with MCPClientWithAuthRetry(
server_url,
headers=headers,
timeout=provider_entity.timeout,
sse_read_timeout=provider_entity.sse_read_timeout,
provider_entity=provider_entity,
auth_callback=auth_callback,
mcp_service=self,
) as mcp_client:
return mcp_client.list_tools()
def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> dict[str, Any]:
"""Attempt to reconnect to MCP provider with new server URL"""
"""Attempt to reconnect to MCP provider with new server URL."""
from core.mcp.auth.auth_flow import auth
provider_entity = provider.to_entity()
@ -352,7 +398,8 @@ class MCPToolManageService:
timeout=timeout,
sse_read_timeout=sse_read_timeout,
provider_entity=provider_entity,
auth_callback=auth,
auth_callback=lambda p, s, c: auth(p, self, c),
mcp_service=self,
) as mcp_client:
tools = mcp_client.list_tools()
return {
@ -364,3 +411,28 @@ class MCPToolManageService:
return {"authed": False, "tools": "[]", "encrypted_credentials": "{}"}
except MCPError as e:
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
def _build_tool_provider_response(
self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list
) -> ToolProviderApiEntity:
"""Build API response for tool provider."""
user = db_provider.load_user()
response = provider_entity.to_api_response(
user_name=user.name if user else None,
)
response["tools"] = ToolTransformService.mcp_tool_to_user_tool(db_provider, tools)
response["plugin_unique_identifier"] = provider_entity.provider_id
return ToolProviderApiEntity(**response)
def _handle_integrity_error(
self, error: IntegrityError, name: str, server_url: str, server_identifier: str
) -> None:
"""Handle database integrity errors with user-friendly messages."""
error_msg = str(error.orig)
if "unique_mcp_provider_name" in error_msg:
raise ValueError(f"MCP tool {name} already exists")
if "unique_mcp_provider_server_url" in error_msg:
raise ValueError(f"MCP tool {server_url} already exists")
if "unique_mcp_provider_server_identifier" in error_msg:
raise ValueError(f"MCP tool {server_identifier} already exists")
raise