mirror of
https://github.com/langgenius/dify.git
synced 2026-04-13 22:57:26 +08:00
refactor(mcp): clean the oauth code
This commit is contained in:
parent
aed9955105
commit
f137af4ec5
@ -956,7 +956,7 @@ class ToolMCPAuthApi(Resource):
|
||||
|
||||
with Session(db.engine) as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
db_provider = service.get_provider_by_id(provider_id, tenant_id)
|
||||
db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
if not db_provider:
|
||||
raise ValueError("provider not found")
|
||||
|
||||
@ -984,6 +984,7 @@ class ToolMCPAuthApi(Resource):
|
||||
else None, # Only use auth retry if no custom headers
|
||||
auth_callback=auth if not provider_entity.headers else None,
|
||||
authorization_code=args.get("authorization_code"),
|
||||
mcp_service=service,
|
||||
):
|
||||
service.update_provider_credentials(
|
||||
provider=db_provider,
|
||||
@ -1007,7 +1008,7 @@ class ToolMCPDetailApi(Resource):
|
||||
user = current_user
|
||||
with Session(db.engine) as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
provider = service.get_provider_by_id(provider_id, user.current_tenant_id)
|
||||
provider = service.get_provider(provider_id=provider_id, tenant_id=user.current_tenant_id)
|
||||
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
|
||||
|
||||
|
||||
@ -1049,7 +1050,13 @@ class ToolMCPCallbackApi(Resource):
|
||||
args = parser.parse_args()
|
||||
state_key = args["state"]
|
||||
authorization_code = args["code"]
|
||||
handle_callback(state_key, authorization_code)
|
||||
|
||||
# Create service instance for handle_callback
|
||||
with Session(db.engine) as session:
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
handle_callback(state_key, authorization_code, mcp_service)
|
||||
session.commit()
|
||||
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
|
||||
|
||||
@ -4,12 +4,11 @@ import json
|
||||
import os
|
||||
import secrets
|
||||
import urllib.parse
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.entities.mcp_provider import MCPProviderEntity
|
||||
from core.mcp.types import (
|
||||
@ -20,9 +19,10 @@ from core.mcp.types import (
|
||||
OAuthMetadata,
|
||||
OAuthTokens,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from services.tools.mcp_oauth_service import MCPOAuthService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
|
||||
OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry
|
||||
OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
|
||||
@ -84,7 +84,7 @@ def _retrieve_redis_state(state_key: str) -> OAuthCallbackState:
|
||||
raise ValueError(f"Invalid state parameter: {str(e)}")
|
||||
|
||||
|
||||
def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackState:
|
||||
def handle_callback(state_key: str, authorization_code: str, mcp_service: "MCPToolManageService") -> OAuthCallbackState:
|
||||
"""Handle the callback from the OAuth provider."""
|
||||
# Retrieve state data from Redis (state is automatically deleted after retrieval)
|
||||
full_state_data = _retrieve_redis_state(state_key)
|
||||
@ -99,10 +99,7 @@ def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackSta
|
||||
)
|
||||
|
||||
# Save tokens using the service layer
|
||||
with Session(db.engine) as session:
|
||||
oauth_service = MCPOAuthService(session=session)
|
||||
oauth_service.save_tokens(full_state_data.provider_id, full_state_data.tenant_id, tokens)
|
||||
session.commit()
|
||||
mcp_service.save_oauth_data(full_state_data.provider_id, full_state_data.tenant_id, tokens.model_dump(), "tokens")
|
||||
|
||||
return full_state_data
|
||||
|
||||
@ -304,6 +301,7 @@ def register_client(
|
||||
|
||||
def auth(
|
||||
provider: MCPProviderEntity,
|
||||
mcp_service: "MCPToolManageService",
|
||||
authorization_code: Optional[str] = None,
|
||||
state_param: Optional[str] = None,
|
||||
) -> dict[str, str]:
|
||||
@ -325,10 +323,9 @@ def auth(
|
||||
raise ValueError(f"Could not register OAuth client: {e}")
|
||||
|
||||
# Save client information using service layer
|
||||
with Session(db.engine) as session:
|
||||
oauth_service = MCPOAuthService(session=session)
|
||||
oauth_service.save_client_information(provider_id, tenant_id, full_information)
|
||||
session.commit()
|
||||
mcp_service.save_oauth_data(
|
||||
provider_id, tenant_id, {"client_information": full_information.model_dump()}, "client_info"
|
||||
)
|
||||
|
||||
client_information = full_information
|
||||
|
||||
@ -360,10 +357,7 @@ def auth(
|
||||
)
|
||||
|
||||
# Save tokens using service layer
|
||||
with Session(db.engine) as session:
|
||||
oauth_service = MCPOAuthService(session=session)
|
||||
oauth_service.save_tokens(provider_id, tenant_id, tokens)
|
||||
session.commit()
|
||||
mcp_service.save_oauth_data(provider_id, tenant_id, tokens.model_dump(), "tokens")
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@ -377,10 +371,7 @@ def auth(
|
||||
)
|
||||
|
||||
# Save new tokens using service layer
|
||||
with Session(db.engine) as session:
|
||||
oauth_service = MCPOAuthService(session=session)
|
||||
oauth_service.save_tokens(provider_id, tenant_id, new_tokens)
|
||||
session.commit()
|
||||
mcp_service.save_oauth_data(provider_id, tenant_id, new_tokens.model_dump(), "tokens")
|
||||
|
||||
return {"result": "success"}
|
||||
except Exception as e:
|
||||
@ -397,9 +388,6 @@ def auth(
|
||||
)
|
||||
|
||||
# Save code verifier using service layer
|
||||
with Session(db.engine) as session:
|
||||
oauth_service = MCPOAuthService(session=session)
|
||||
oauth_service.save_code_verifier(provider_id, tenant_id, code_verifier)
|
||||
session.commit()
|
||||
mcp_service.save_oauth_data(provider_id, tenant_id, {"code_verifier": code_verifier}, "code_verifier")
|
||||
|
||||
return {"authorization_url": authorization_url}
|
||||
|
||||
@ -8,15 +8,15 @@ authentication failures and retries operations after refreshing tokens.
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from types import TracebackType
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from core.entities.mcp_provider import MCPProviderEntity
|
||||
from core.mcp.error import MCPAuthError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
from core.mcp.types import CallToolResult, Tool
|
||||
from extensions.ext_database import db
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -36,9 +36,11 @@ class MCPClientWithAuthRetry:
|
||||
timeout: float | None = None,
|
||||
sse_read_timeout: float | None = None,
|
||||
provider_entity: MCPProviderEntity | None = None,
|
||||
auth_callback: Callable[[MCPProviderEntity, Optional[str]], dict[str, str]] | None = None,
|
||||
auth_callback: Callable[[MCPProviderEntity, "MCPToolManageService", Optional[str]], dict[str, str]]
|
||||
| None = None,
|
||||
authorization_code: Optional[str] = None,
|
||||
by_server_id: bool = False,
|
||||
mcp_service: Optional["MCPToolManageService"] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the MCP client with auth retry capability.
|
||||
@ -62,6 +64,7 @@ class MCPClientWithAuthRetry:
|
||||
self._has_retried = False
|
||||
self._client: MCPClient | None = None
|
||||
self.by_server_id = by_server_id
|
||||
self.mcp_service = mcp_service
|
||||
|
||||
def _create_client(self) -> MCPClient:
|
||||
"""Create a new MCPClient instance with current headers."""
|
||||
@ -82,11 +85,8 @@ class MCPClientWithAuthRetry:
|
||||
Raises:
|
||||
MCPAuthError: If authentication fails or max retries reached
|
||||
"""
|
||||
from services.tools.mcp_oauth_service import MCPOAuthService
|
||||
|
||||
if not self.provider_entity or not self.auth_callback:
|
||||
if not self.provider_entity or not self.auth_callback or not self.mcp_service:
|
||||
raise error
|
||||
|
||||
if self._has_retried:
|
||||
raise error
|
||||
|
||||
@ -94,14 +94,12 @@ class MCPClientWithAuthRetry:
|
||||
|
||||
try:
|
||||
# Perform authentication
|
||||
self.auth_callback(self.provider_entity, self.authorization_code)
|
||||
self.auth_callback(self.provider_entity, self.mcp_service, self.authorization_code)
|
||||
|
||||
# Retrieve new tokens
|
||||
with Session(db.engine) as session:
|
||||
oauth_service = MCPOAuthService(session=session)
|
||||
self.provider_entity = oauth_service.get_provider_entity(
|
||||
self.provider_entity.id, self.provider_entity.tenant_id, by_server_id=self.by_server_id
|
||||
)
|
||||
self.provider_entity = self.mcp_service.get_provider_entity(
|
||||
self.provider_entity.id, self.provider_entity.tenant_id, by_server_id=self.by_server_id
|
||||
)
|
||||
token = self.provider_entity.retrieve_tokens()
|
||||
if not token:
|
||||
raise MCPAuthError("Authentication failed - no token received")
|
||||
|
||||
@ -119,23 +119,39 @@ class MCPTool(Tool):
|
||||
tool_parameters = self._handle_none_parameter(tool_parameters)
|
||||
|
||||
# Get provider entity to access tokens
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from extensions.ext_database import db
|
||||
from services.tools.mcp_oauth_service import MCPOAuthService
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
# Get MCP service from invoke parameters or create new one
|
||||
provider_entity = None
|
||||
mcp_service = None
|
||||
|
||||
# Check if mcp_service is passed in tool_parameters
|
||||
if "_mcp_service" in tool_parameters:
|
||||
mcp_service = tool_parameters.pop("_mcp_service")
|
||||
else:
|
||||
# Fallback to creating service with database session
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
oauth_service = MCPOAuthService(session=session)
|
||||
provider_entity = oauth_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True)
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
|
||||
if mcp_service:
|
||||
try:
|
||||
provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True)
|
||||
|
||||
# Try to get existing token and add to headers
|
||||
tokens = provider_entity.retrieve_tokens()
|
||||
if tokens and tokens.access_token:
|
||||
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
|
||||
except Exception:
|
||||
# If provider retrieval or token fails, continue without auth
|
||||
pass
|
||||
except Exception:
|
||||
# If provider retrieval or token fails, continue without auth
|
||||
pass
|
||||
|
||||
# Use MCPClientWithAuthRetry to handle authentication automatically
|
||||
try:
|
||||
@ -145,8 +161,9 @@ class MCPTool(Tool):
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
provider_entity=provider_entity,
|
||||
auth_callback=auth,
|
||||
auth_callback=auth if mcp_service else None,
|
||||
by_server_id=True,
|
||||
mcp_service=mcp_service,
|
||||
) as mcp_client:
|
||||
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
|
||||
except MCPConnectionError as e:
|
||||
|
||||
@ -775,7 +775,7 @@ class ToolManager:
|
||||
with Session(db.engine) as session:
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
try:
|
||||
provider = mcp_service.get_provider_by_server_identifier(provider_id, tenant_id)
|
||||
provider = mcp_service.get_provider(server_identifier=provider_id, tenant_id=tenant_id)
|
||||
except ValueError:
|
||||
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
|
||||
|
||||
@ -918,7 +918,9 @@ class ToolManager:
|
||||
with Session(db.engine) as session:
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
try:
|
||||
mcp_provider = mcp_service.get_provider_by_server_identifier(provider_id, tenant_id)
|
||||
mcp_provider = mcp_service.get_provider_entity(
|
||||
provider_id=provider_id, tenant_id=tenant_id, by_server_id=True
|
||||
)
|
||||
return mcp_provider.provider_icon
|
||||
except ValueError:
|
||||
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
|
||||
|
||||
@ -1,53 +0,0 @@
|
||||
"""
|
||||
MCP OAuth Service - handles OAuth-related database operations
|
||||
"""
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.entities.mcp_provider import MCPProviderEntity
|
||||
from core.mcp.types import OAuthClientInformationFull, OAuthTokens
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
|
||||
|
||||
class MCPOAuthService:
|
||||
"""Service for handling MCP OAuth operations"""
|
||||
|
||||
def __init__(self, session: Session):
|
||||
self._session = session
|
||||
self._mcp_service = MCPToolManageService(session=session)
|
||||
|
||||
def get_provider_entity(self, provider_id: str, tenant_id: str, by_server_id: bool = False) -> MCPProviderEntity:
|
||||
"""Get provider entity by ID"""
|
||||
if by_server_id:
|
||||
db_provider = self._mcp_service.get_provider_by_server_identifier(provider_id, tenant_id)
|
||||
else:
|
||||
db_provider = self._mcp_service.get_provider_by_id(provider_id, tenant_id)
|
||||
return db_provider.to_entity()
|
||||
|
||||
def save_client_information(
|
||||
self, provider_id: str, tenant_id: str, client_information: OAuthClientInformationFull
|
||||
) -> None:
|
||||
"""Save OAuth client information"""
|
||||
db_provider = self._mcp_service.get_provider_by_id(provider_id, tenant_id)
|
||||
self._mcp_service.update_provider_credentials(
|
||||
provider=db_provider,
|
||||
credentials={"client_information": client_information.model_dump()},
|
||||
)
|
||||
|
||||
def save_tokens(self, provider_id: str, tenant_id: str, tokens: OAuthTokens, authed: bool = True) -> None:
|
||||
"""Save OAuth tokens"""
|
||||
db_provider = self._mcp_service.get_provider_by_id(provider_id, tenant_id)
|
||||
token_dict = tokens.model_dump()
|
||||
self._mcp_service.update_provider_credentials(provider=db_provider, credentials=token_dict, authed=authed)
|
||||
|
||||
def save_code_verifier(self, provider_id: str, tenant_id: str, code_verifier: str) -> None:
|
||||
"""Save PKCE code verifier"""
|
||||
db_provider = self._mcp_service.get_provider_by_id(provider_id, tenant_id)
|
||||
self._mcp_service.update_provider_credentials(
|
||||
provider=db_provider, credentials={"code_verifier": code_verifier}
|
||||
)
|
||||
|
||||
def clear_credentials(self, provider_id: str, tenant_id: str) -> None:
|
||||
"""Clear provider credentials"""
|
||||
db_provider = self._mcp_service.get_provider_by_id(provider_id, tenant_id)
|
||||
self._mcp_service.clear_provider_credentials(provider=db_provider)
|
||||
@ -14,7 +14,6 @@ from core.helper import encrypter
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||
from core.mcp.error import MCPAuthError, MCPError
|
||||
from core.mcp.types import OAuthTokens
|
||||
from core.tools.entities.api_entities import ToolProviderApiEntity
|
||||
from core.tools.utils.encryption import ProviderConfigEncrypter
|
||||
from models.tools import MCPToolProvider
|
||||
@ -26,85 +25,51 @@ UNCHANGED_SERVER_URL_PLACEHOLDER = "[__HIDDEN__]"
|
||||
|
||||
|
||||
class MCPToolManageService:
|
||||
"""
|
||||
Service class for managing mcp tools.
|
||||
"""
|
||||
"""Service class for managing MCP tools and providers."""
|
||||
|
||||
def __init__(self, session: Session):
|
||||
self._session = session
|
||||
|
||||
def _encrypt_headers(self, headers: dict[str, str], tenant_id: str) -> dict[str, str]:
|
||||
# ========== Provider CRUD Operations ==========
|
||||
|
||||
def get_provider(
|
||||
self, *, provider_id: Optional[str] = None, server_identifier: Optional[str] = None, tenant_id: str
|
||||
) -> MCPToolProvider:
|
||||
"""
|
||||
Encrypt headers using ProviderConfigEncrypter with all headers as SECRET_INPUT.
|
||||
Get MCP provider by ID or server identifier.
|
||||
|
||||
Args:
|
||||
headers: Dictionary of headers to encrypt
|
||||
tenant_id: Tenant ID for encryption
|
||||
provider_id: Provider ID (UUID)
|
||||
server_identifier: Server identifier
|
||||
tenant_id: Tenant ID
|
||||
|
||||
Returns:
|
||||
Dictionary with all headers encrypted
|
||||
MCPToolProvider instance
|
||||
|
||||
Raises:
|
||||
ValueError: If provider not found
|
||||
"""
|
||||
if not headers:
|
||||
return {}
|
||||
if server_identifier:
|
||||
stmt = select(MCPToolProvider).where(
|
||||
MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier
|
||||
)
|
||||
else:
|
||||
stmt = select(MCPToolProvider).where(
|
||||
MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id
|
||||
)
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.tools.utils.encryption import create_provider_encrypter
|
||||
|
||||
# Create dynamic config for all headers as SECRET_INPUT
|
||||
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers]
|
||||
|
||||
encrypter_instance, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=config,
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
return encrypter_instance.encrypt(headers)
|
||||
|
||||
def _retrieve_remote_mcp_tools(
|
||||
self,
|
||||
server_url: str,
|
||||
headers: dict[str, str],
|
||||
provider_entity: MCPProviderEntity,
|
||||
auth_callback: Callable[[MCPProviderEntity, Optional[str]], dict[str, str]],
|
||||
):
|
||||
"""Retrieve tools from remote MCP server"""
|
||||
with MCPClientWithAuthRetry(
|
||||
server_url,
|
||||
headers=headers,
|
||||
timeout=provider_entity.timeout,
|
||||
sse_read_timeout=provider_entity.sse_read_timeout,
|
||||
provider_entity=provider_entity,
|
||||
auth_callback=auth_callback,
|
||||
) as mcp_client:
|
||||
tools = mcp_client.list_tools()
|
||||
return tools
|
||||
|
||||
def _process_headers(self, headers: dict[str, str], tokens: OAuthTokens | None = None) -> dict[str, str]:
|
||||
"""Process headers and add OAuth token if available"""
|
||||
headers = headers.copy() if headers else {}
|
||||
if tokens:
|
||||
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
|
||||
return headers
|
||||
|
||||
def get_provider_by_id(self, provider_id: str, tenant_id: str) -> MCPToolProvider:
|
||||
"""Get MCP provider by ID"""
|
||||
stmt = select(MCPToolProvider).where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id)
|
||||
provider = self._session.scalar(stmt)
|
||||
if not provider:
|
||||
raise ValueError("MCP tool not found")
|
||||
return provider
|
||||
|
||||
def get_provider_by_server_identifier(self, server_identifier: str, tenant_id: str) -> MCPToolProvider:
|
||||
"""Get MCP provider by server identifier"""
|
||||
stmt = select(MCPToolProvider).where(
|
||||
MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier
|
||||
)
|
||||
provider = self._session.scalar(stmt)
|
||||
if not provider:
|
||||
raise ValueError("MCP tool not found")
|
||||
return provider
|
||||
def get_provider_entity(self, provider_id: str, tenant_id: str, by_server_id: bool = False) -> MCPProviderEntity:
|
||||
"""Get provider entity by ID or server identifier."""
|
||||
if by_server_id:
|
||||
db_provider = self.get_provider(server_identifier=provider_id, tenant_id=tenant_id)
|
||||
else:
|
||||
db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
return db_provider.to_entity()
|
||||
|
||||
def create_provider(
|
||||
self,
|
||||
@ -121,36 +86,15 @@ class MCPToolManageService:
|
||||
sse_read_timeout: float,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> ToolProviderApiEntity:
|
||||
"""Create a new MCP provider"""
|
||||
"""Create a new MCP provider."""
|
||||
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
|
||||
|
||||
# Check for existing provider
|
||||
stmt = select(MCPToolProvider).where(
|
||||
MCPToolProvider.tenant_id == tenant_id,
|
||||
or_(
|
||||
MCPToolProvider.name == name,
|
||||
MCPToolProvider.server_url_hash == server_url_hash,
|
||||
MCPToolProvider.server_identifier == server_identifier,
|
||||
),
|
||||
)
|
||||
existing_provider = self._session.scalar(stmt)
|
||||
self._check_provider_exists(tenant_id, name, server_url_hash, server_identifier)
|
||||
|
||||
if existing_provider:
|
||||
if existing_provider.name == name:
|
||||
raise ValueError(f"MCP tool {name} already exists")
|
||||
if existing_provider.server_url_hash == server_url_hash:
|
||||
raise ValueError(f"MCP tool {server_url} already exists")
|
||||
if existing_provider.server_identifier == server_identifier:
|
||||
raise ValueError(f"MCP tool {server_identifier} already exists")
|
||||
|
||||
# Encrypt server URL
|
||||
# Encrypt sensitive data
|
||||
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
|
||||
|
||||
# Encrypt headers
|
||||
encrypted_headers = None
|
||||
if headers:
|
||||
encrypted_headers_dict = self._encrypt_headers(headers, tenant_id)
|
||||
encrypted_headers = json.dumps(encrypted_headers_dict)
|
||||
encrypted_headers = self._prepare_encrypted_headers(headers, tenant_id) if headers else None
|
||||
|
||||
# Create provider
|
||||
mcp_tool = MCPToolProvider(
|
||||
@ -161,7 +105,7 @@ class MCPToolManageService:
|
||||
user_id=user_id,
|
||||
authed=False,
|
||||
tools="[]",
|
||||
icon=json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon,
|
||||
icon=self._prepare_icon(icon, icon_type, icon_background),
|
||||
server_identifier=server_identifier,
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
@ -173,59 +117,6 @@ class MCPToolManageService:
|
||||
|
||||
return ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
|
||||
|
||||
def list_providers(self, *, tenant_id: str, for_list: bool = False) -> list[ToolProviderApiEntity]:
|
||||
"""List all MCP providers for a tenant"""
|
||||
stmt = select(MCPToolProvider).where(MCPToolProvider.tenant_id == tenant_id).order_by(MCPToolProvider.name)
|
||||
|
||||
mcp_providers = self._session.scalars(stmt).all()
|
||||
|
||||
return [
|
||||
ToolTransformService.mcp_provider_to_user_provider(provider, for_list=for_list)
|
||||
for provider in mcp_providers
|
||||
]
|
||||
|
||||
def list_provider_tools(self, *, tenant_id: str, provider_id: str) -> ToolProviderApiEntity:
|
||||
"""List tools from remote MCP server"""
|
||||
from core.mcp.auth.auth_flow import auth
|
||||
|
||||
# Load provider and convert to entity
|
||||
db_provider = self.get_provider_by_id(provider_id, tenant_id)
|
||||
provider_entity = db_provider.to_entity()
|
||||
|
||||
# Handle authentication headers if authed
|
||||
if not provider_entity.authed:
|
||||
raise ValueError("Please auth the tool first")
|
||||
|
||||
tokens = provider_entity.retrieve_tokens()
|
||||
headers = self._process_headers(provider_entity.headers, tokens)
|
||||
server_url = provider_entity.decrypt_server_url()
|
||||
try:
|
||||
tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity, auth)
|
||||
except MCPError as e:
|
||||
raise ValueError(f"Failed to connect to MCP server: {e}")
|
||||
|
||||
# Update database record with new tools
|
||||
db_provider.tools = json.dumps([tool.model_dump() for tool in tools])
|
||||
db_provider.authed = True
|
||||
db_provider.updated_at = datetime.now()
|
||||
self._session.commit()
|
||||
|
||||
# Create API response using entity
|
||||
user = db_provider.load_user()
|
||||
response = provider_entity.to_api_response(
|
||||
user_name=user.name if user else None,
|
||||
)
|
||||
response["tools"] = ToolTransformService.mcp_tool_to_user_tool(db_provider, tools)
|
||||
response["plugin_unique_identifier"] = provider_entity.provider_id
|
||||
|
||||
return ToolProviderApiEntity(**response)
|
||||
|
||||
def delete_provider(self, *, tenant_id: str, provider_id: str) -> None:
|
||||
"""Delete an MCP provider"""
|
||||
mcp_tool = self.get_provider_by_id(provider_id, tenant_id)
|
||||
self._session.delete(mcp_tool)
|
||||
self._session.commit()
|
||||
|
||||
def update_provider(
|
||||
self,
|
||||
*,
|
||||
@ -241,8 +132,8 @@ class MCPToolManageService:
|
||||
sse_read_timeout: float | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
"""Update an MCP provider"""
|
||||
mcp_provider = self.get_provider_by_id(provider_id, tenant_id)
|
||||
"""Update an MCP provider."""
|
||||
mcp_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
|
||||
reconnect_result = None
|
||||
encrypted_server_url = None
|
||||
@ -263,9 +154,7 @@ class MCPToolManageService:
|
||||
# Update basic fields
|
||||
mcp_provider.updated_at = datetime.now()
|
||||
mcp_provider.name = name
|
||||
mcp_provider.icon = (
|
||||
json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon
|
||||
)
|
||||
mcp_provider.icon = self._prepare_icon(icon, icon_type, icon_background)
|
||||
mcp_provider.server_identifier = server_identifier
|
||||
|
||||
# Update server URL if changed
|
||||
@ -284,35 +173,86 @@ class MCPToolManageService:
|
||||
if sse_read_timeout is not None:
|
||||
mcp_provider.sse_read_timeout = sse_read_timeout
|
||||
if headers is not None:
|
||||
# Encrypt headers
|
||||
if headers:
|
||||
encrypted_headers_dict = self._encrypt_headers(headers, tenant_id)
|
||||
mcp_provider.encrypted_headers = json.dumps(encrypted_headers_dict)
|
||||
else:
|
||||
mcp_provider.encrypted_headers = None
|
||||
mcp_provider.encrypted_headers = (
|
||||
self._prepare_encrypted_headers(headers, tenant_id) if headers else None
|
||||
)
|
||||
|
||||
self._session.commit()
|
||||
|
||||
except IntegrityError as e:
|
||||
self._session.rollback()
|
||||
error_msg = str(e.orig)
|
||||
if "unique_mcp_provider_name" in error_msg:
|
||||
raise ValueError(f"MCP tool {name} already exists")
|
||||
if "unique_mcp_provider_server_url" in error_msg:
|
||||
raise ValueError(f"MCP tool {server_url} already exists")
|
||||
if "unique_mcp_provider_server_identifier" in error_msg:
|
||||
raise ValueError(f"MCP tool {server_identifier} already exists")
|
||||
raise
|
||||
self._handle_integrity_error(e, name, server_url, server_identifier)
|
||||
except Exception:
|
||||
self._session.rollback()
|
||||
raise
|
||||
|
||||
def delete_provider(self, *, tenant_id: str, provider_id: str) -> None:
|
||||
"""Delete an MCP provider."""
|
||||
mcp_tool = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
self._session.delete(mcp_tool)
|
||||
self._session.commit()
|
||||
|
||||
def list_providers(self, *, tenant_id: str, for_list: bool = False) -> list[ToolProviderApiEntity]:
|
||||
"""List all MCP providers for a tenant."""
|
||||
stmt = select(MCPToolProvider).where(MCPToolProvider.tenant_id == tenant_id).order_by(MCPToolProvider.name)
|
||||
mcp_providers = self._session.scalars(stmt).all()
|
||||
|
||||
return [
|
||||
ToolTransformService.mcp_provider_to_user_provider(provider, for_list=for_list)
|
||||
for provider in mcp_providers
|
||||
]
|
||||
|
||||
# ========== Tool Operations ==========
|
||||
|
||||
def list_provider_tools(self, *, tenant_id: str, provider_id: str) -> ToolProviderApiEntity:
|
||||
"""List tools from remote MCP server."""
|
||||
from core.mcp.auth.auth_flow import auth
|
||||
|
||||
# Load provider and convert to entity
|
||||
db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
provider_entity = db_provider.to_entity()
|
||||
|
||||
# Verify authentication
|
||||
if not provider_entity.authed:
|
||||
raise ValueError("Please auth the tool first")
|
||||
|
||||
# Prepare headers with auth token
|
||||
headers = self._prepare_auth_headers(provider_entity)
|
||||
|
||||
# Retrieve tools from remote server
|
||||
server_url = provider_entity.decrypt_server_url()
|
||||
try:
|
||||
tools = self._retrieve_remote_mcp_tools(
|
||||
server_url, headers, provider_entity, lambda p, s, c: auth(p, self, c)
|
||||
)
|
||||
except MCPError as e:
|
||||
raise ValueError(f"Failed to connect to MCP server: {e}")
|
||||
|
||||
# Update database with retrieved tools
|
||||
db_provider.tools = json.dumps([tool.model_dump() for tool in tools])
|
||||
db_provider.authed = True
|
||||
db_provider.updated_at = datetime.now()
|
||||
self._session.commit()
|
||||
|
||||
# Build API response
|
||||
return self._build_tool_provider_response(db_provider, provider_entity, tools)
|
||||
|
||||
# ========== OAuth and Credentials Operations ==========
|
||||
|
||||
def update_provider_credentials(
|
||||
self, *, provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False
|
||||
self, *, provider: MCPToolProvider, credentials: dict[str, Any], authed: bool | None = None
|
||||
) -> None:
|
||||
"""Update provider credentials"""
|
||||
"""
|
||||
Update provider credentials with encryption.
|
||||
|
||||
Args:
|
||||
provider: Provider instance
|
||||
credentials: Credentials to save
|
||||
authed: Whether provider is authenticated (None means keep current state)
|
||||
"""
|
||||
from core.tools.mcp_tool.provider import MCPToolProviderController
|
||||
|
||||
# Encrypt new credentials
|
||||
provider_controller = MCPToolProviderController.from_db(provider)
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=provider.tenant_id,
|
||||
@ -320,24 +260,130 @@ class MCPToolManageService:
|
||||
provider_config_cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
encrypted_credentials = tool_configuration.encrypt(credentials)
|
||||
|
||||
# Update provider
|
||||
provider.updated_at = datetime.now()
|
||||
provider.encrypted_credentials = json.dumps({**provider.credentials, **encrypted_credentials})
|
||||
provider.authed = authed
|
||||
if not authed:
|
||||
provider.tools = "[]"
|
||||
|
||||
if authed is not None:
|
||||
provider.authed = authed
|
||||
if not authed:
|
||||
provider.tools = "[]"
|
||||
|
||||
self._session.commit()
|
||||
|
||||
def save_oauth_data(self, provider_id: str, tenant_id: str, data: dict[str, Any], data_type: str = "mixed") -> None:
|
||||
"""
|
||||
Save OAuth-related data (tokens, client info, code verifier).
|
||||
|
||||
Args:
|
||||
provider_id: Provider ID
|
||||
tenant_id: Tenant ID
|
||||
data: Data to save (tokens, client info, or code verifier)
|
||||
data_type: Type of data ('tokens', 'client_info', 'code_verifier', 'mixed')
|
||||
"""
|
||||
db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
|
||||
credentials = {}
|
||||
authed = None
|
||||
|
||||
if data_type == "tokens" or (data_type == "mixed" and "access_token" in data):
|
||||
# OAuth tokens
|
||||
credentials = data
|
||||
authed = True
|
||||
elif data_type == "client_info" or (data_type == "mixed" and "client_information" in data):
|
||||
# OAuth client information
|
||||
credentials = data
|
||||
elif data_type == "code_verifier" or (data_type == "mixed" and "code_verifier" in data):
|
||||
# PKCE code verifier
|
||||
credentials = data
|
||||
else:
|
||||
credentials = data
|
||||
|
||||
self.update_provider_credentials(provider=db_provider, credentials=credentials, authed=authed)
|
||||
|
||||
def clear_provider_credentials(self, *, provider: MCPToolProvider) -> None:
|
||||
"""Clear provider credentials"""
|
||||
"""Clear all credentials for a provider."""
|
||||
provider.tools = "[]"
|
||||
provider.encrypted_credentials = "{}"
|
||||
provider.updated_at = datetime.now()
|
||||
provider.authed = False
|
||||
self._session.commit()
|
||||
|
||||
# ========== Private Helper Methods ==========
|
||||
|
||||
def _check_provider_exists(self, tenant_id: str, name: str, server_url_hash: str, server_identifier: str) -> None:
|
||||
"""Check if provider with same attributes already exists."""
|
||||
stmt = select(MCPToolProvider).where(
|
||||
MCPToolProvider.tenant_id == tenant_id,
|
||||
or_(
|
||||
MCPToolProvider.name == name,
|
||||
MCPToolProvider.server_url_hash == server_url_hash,
|
||||
MCPToolProvider.server_identifier == server_identifier,
|
||||
),
|
||||
)
|
||||
existing_provider = self._session.scalar(stmt)
|
||||
|
||||
if existing_provider:
|
||||
if existing_provider.name == name:
|
||||
raise ValueError(f"MCP tool {name} already exists")
|
||||
if existing_provider.server_url_hash == server_url_hash:
|
||||
raise ValueError("MCP tool with this server URL already exists")
|
||||
if existing_provider.server_identifier == server_identifier:
|
||||
raise ValueError(f"MCP tool {server_identifier} already exists")
|
||||
|
||||
def _prepare_icon(self, icon: str, icon_type: str, icon_background: str) -> str:
|
||||
"""Prepare icon data for storage."""
|
||||
if icon_type == "emoji":
|
||||
return json.dumps({"content": icon, "background": icon_background})
|
||||
return icon
|
||||
|
||||
def _prepare_encrypted_headers(self, headers: dict[str, str], tenant_id: str) -> str:
|
||||
"""Encrypt headers and prepare for storage."""
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.tools.utils.encryption import create_provider_encrypter
|
||||
|
||||
# Create dynamic config for all headers as SECRET_INPUT
|
||||
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers]
|
||||
|
||||
encrypter_instance, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=config,
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
encrypted_headers_dict = encrypter_instance.encrypt(headers)
|
||||
return json.dumps(encrypted_headers_dict)
|
||||
|
||||
def _prepare_auth_headers(self, provider_entity: MCPProviderEntity) -> dict[str, str]:
|
||||
"""Prepare headers with OAuth token if available."""
|
||||
headers = provider_entity.headers.copy() if provider_entity.headers else {}
|
||||
tokens = provider_entity.retrieve_tokens()
|
||||
if tokens:
|
||||
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
|
||||
return headers
|
||||
|
||||
def _retrieve_remote_mcp_tools(
|
||||
self,
|
||||
server_url: str,
|
||||
headers: dict[str, str],
|
||||
provider_entity: MCPProviderEntity,
|
||||
auth_callback: Callable[[MCPProviderEntity, "MCPToolManageService", Optional[str]], dict[str, str]],
|
||||
):
|
||||
"""Retrieve tools from remote MCP server."""
|
||||
with MCPClientWithAuthRetry(
|
||||
server_url,
|
||||
headers=headers,
|
||||
timeout=provider_entity.timeout,
|
||||
sse_read_timeout=provider_entity.sse_read_timeout,
|
||||
provider_entity=provider_entity,
|
||||
auth_callback=auth_callback,
|
||||
mcp_service=self,
|
||||
) as mcp_client:
|
||||
return mcp_client.list_tools()
|
||||
|
||||
def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> dict[str, Any]:
|
||||
"""Attempt to reconnect to MCP provider with new server URL"""
|
||||
"""Attempt to reconnect to MCP provider with new server URL."""
|
||||
from core.mcp.auth.auth_flow import auth
|
||||
|
||||
provider_entity = provider.to_entity()
|
||||
@ -352,7 +398,8 @@ class MCPToolManageService:
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
provider_entity=provider_entity,
|
||||
auth_callback=auth,
|
||||
auth_callback=lambda p, s, c: auth(p, self, c),
|
||||
mcp_service=self,
|
||||
) as mcp_client:
|
||||
tools = mcp_client.list_tools()
|
||||
return {
|
||||
@ -364,3 +411,28 @@ class MCPToolManageService:
|
||||
return {"authed": False, "tools": "[]", "encrypted_credentials": "{}"}
|
||||
except MCPError as e:
|
||||
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
|
||||
|
||||
def _build_tool_provider_response(
|
||||
self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list
|
||||
) -> ToolProviderApiEntity:
|
||||
"""Build API response for tool provider."""
|
||||
user = db_provider.load_user()
|
||||
response = provider_entity.to_api_response(
|
||||
user_name=user.name if user else None,
|
||||
)
|
||||
response["tools"] = ToolTransformService.mcp_tool_to_user_tool(db_provider, tools)
|
||||
response["plugin_unique_identifier"] = provider_entity.provider_id
|
||||
return ToolProviderApiEntity(**response)
|
||||
|
||||
def _handle_integrity_error(
|
||||
self, error: IntegrityError, name: str, server_url: str, server_identifier: str
|
||||
) -> None:
|
||||
"""Handle database integrity errors with user-friendly messages."""
|
||||
error_msg = str(error.orig)
|
||||
if "unique_mcp_provider_name" in error_msg:
|
||||
raise ValueError(f"MCP tool {name} already exists")
|
||||
if "unique_mcp_provider_server_url" in error_msg:
|
||||
raise ValueError(f"MCP tool {server_url} already exists")
|
||||
if "unique_mcp_provider_server_identifier" in error_msg:
|
||||
raise ValueError(f"MCP tool {server_identifier} already exists")
|
||||
raise
|
||||
|
||||
Loading…
Reference in New Issue
Block a user