refactor(mcp): clean the client service code

This commit is contained in:
Novice 2025-09-16 10:54:31 +08:00
parent f16151ea29
commit aed9955105
13 changed files with 858 additions and 530 deletions

View File

@ -7,6 +7,7 @@ from flask_restx import (
Resource,
reqparse,
)
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from configs import dify_config
@ -17,13 +18,13 @@ from controllers.console.wraps import (
setup_required,
)
from core.mcp.auth.auth_flow import auth, handle_callback
from core.mcp.auth.auth_provider import OAuthClientProvider
from core.mcp.error import MCPAuthError, MCPError
from core.mcp.mcp_client import MCPClient
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPError
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.oauth import OAuthHandler
from core.tools.entities.tool_entities import CredentialType
from extensions.ext_database import db
from libs.helper import StrLen, alphanumeric, uuid_value
from libs.login import login_required
from services.plugin.oauth_service import OAuthProxyService
@ -870,8 +871,9 @@ class ToolProviderMCPApi(Resource):
user = current_user
if not is_valid_url(args["server_url"]):
raise ValueError("Server URL is not valid.")
return jsonable_encoder(
MCPToolManageService.create_mcp_provider(
with Session(db.engine) as session:
service = MCPToolManageService(session=session)
result = service.create_provider(
tenant_id=user.current_tenant_id,
server_url=args["server_url"],
name=args["name"],
@ -884,7 +886,8 @@ class ToolProviderMCPApi(Resource):
sse_read_timeout=args["sse_read_timeout"],
headers=args["headers"],
)
)
session.commit()
return jsonable_encoder(result)
@setup_required
@login_required
@ -907,20 +910,23 @@ class ToolProviderMCPApi(Resource):
pass
else:
raise ValueError("Server URL is not valid.")
MCPToolManageService.update_mcp_provider(
tenant_id=current_user.current_tenant_id,
provider_id=args["provider_id"],
server_url=args["server_url"],
name=args["name"],
icon=args["icon"],
icon_type=args["icon_type"],
icon_background=args["icon_background"],
server_identifier=args["server_identifier"],
timeout=args.get("timeout"),
sse_read_timeout=args.get("sse_read_timeout"),
headers=args.get("headers"),
)
return {"result": "success"}
with Session(db.engine) as session:
service = MCPToolManageService(session=session)
service.update_provider(
tenant_id=current_user.current_tenant_id,
provider_id=args["provider_id"],
server_url=args["server_url"],
name=args["name"],
icon=args["icon"],
icon_type=args["icon_type"],
icon_background=args["icon_background"],
server_identifier=args["server_identifier"],
timeout=args.get("timeout"),
sse_read_timeout=args.get("sse_read_timeout"),
headers=args.get("headers"),
)
session.commit()
return {"result": "success"}
@setup_required
@login_required
@ -929,8 +935,11 @@ class ToolProviderMCPApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
MCPToolManageService.delete_mcp_tool(tenant_id=current_user.current_tenant_id, provider_id=args["provider_id"])
return {"result": "success"}
with Session(db.engine) as session:
service = MCPToolManageService(session=session)
service.delete_provider(tenant_id=current_user.current_tenant_id, provider_id=args["provider_id"])
session.commit()
return {"result": "success"}
class ToolMCPAuthApi(Resource):
@ -944,45 +953,50 @@ class ToolMCPAuthApi(Resource):
args = parser.parse_args()
provider_id = args["provider_id"]
tenant_id = current_user.current_tenant_id
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if not provider:
raise ValueError("provider not found")
# headers1: if headers is provided, use it and don't need to get token
headers = provider.decrypted_headers or {}
# headers2: Add OAuth token if authed and no headers provided
if not provider.decrypted_headers and provider.authed:
token = OAuthClientProvider(provider_id, tenant_id, for_list=True).tokens()
if token:
headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
try:
# try to connect to MCP server with headers
with MCPClient(
provider.decrypted_server_url,
headers=headers,
timeout=provider.timeout,
sse_read_timeout=provider.sse_read_timeout,
):
MCPToolManageService.update_mcp_provider_credentials(
mcp_provider=provider,
credentials=provider.decrypted_credentials,
authed=True,
)
return {"result": "success"}
with Session(db.engine) as session:
service = MCPToolManageService(session=session)
db_provider = service.get_provider_by_id(provider_id, tenant_id)
if not db_provider:
raise ValueError("provider not found")
except MCPAuthError as e:
# Convert to entity
provider_entity = db_provider.to_entity()
server_url = provider_entity.decrypt_server_url()
# Option 1: if headers is provided, use it and don't need to get token
headers = provider_entity.decrypt_headers()
# Option 2: Add OAuth token if authed and no headers provided
if not provider_entity.headers and provider_entity.authed:
token = provider_entity.retrieve_tokens()
if token:
headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
try:
if provider.decrypted_headers:
raise ValueError(f"Failed to authenticate, please check your headers: {e}") from e
# if auth failed, try to auth with OAuth or exchange token
auth_provider = OAuthClientProvider(provider_id, tenant_id, for_list=True)
return auth(auth_provider, provider.decrypted_server_url, args["authorization_code"])
except Exception as e:
MCPToolManageService.clear_mcp_provider_credentials(mcp_provider=provider)
raise ValueError(f"Failed to authenticate, please try again: {e}") from e
except MCPError as e:
MCPToolManageService.clear_mcp_provider_credentials(mcp_provider=provider)
raise ValueError(f"Failed to connect to MCP server: {e}") from e
# Use MCPClientWithAuthRetry to handle authentication automatically
with MCPClientWithAuthRetry(
server_url=server_url,
headers=headers,
timeout=provider_entity.timeout,
sse_read_timeout=provider_entity.sse_read_timeout,
provider_entity=provider_entity
if not provider_entity.headers
else None, # Only use auth retry if no custom headers
auth_callback=auth if not provider_entity.headers else None,
authorization_code=args.get("authorization_code"),
):
service.update_provider_credentials(
provider=db_provider,
credentials=provider_entity.credentials,
authed=True,
)
session.commit()
return {"result": "success"}
except MCPError as e:
service.clear_provider_credentials(provider=db_provider)
session.commit()
raise ValueError(f"Failed to connect to MCP server: {e}") from e
class ToolMCPDetailApi(Resource):
@ -991,8 +1005,10 @@ class ToolMCPDetailApi(Resource):
@account_initialization_required
def get(self, provider_id):
user = current_user
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, user.current_tenant_id)
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
with Session(db.engine) as session:
service = MCPToolManageService(session=session)
provider = service.get_provider_by_id(provider_id, user.current_tenant_id)
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
class ToolMCPListAllApi(Resource):
@ -1003,9 +1019,11 @@ class ToolMCPListAllApi(Resource):
user = current_user
tenant_id = user.current_tenant_id
tools = MCPToolManageService.retrieve_mcp_tools(tenant_id=tenant_id)
with Session(db.engine) as session:
service = MCPToolManageService(session=session)
tools = service.list_providers(tenant_id=tenant_id)
return [tool.to_dict() for tool in tools]
return [tool.to_dict() for tool in tools]
class ToolMCPUpdateApi(Resource):
@ -1014,11 +1032,13 @@ class ToolMCPUpdateApi(Resource):
@account_initialization_required
def get(self, provider_id):
tenant_id = current_user.current_tenant_id
tools = MCPToolManageService.list_mcp_tool_from_remote_server(
tenant_id=tenant_id,
provider_id=provider_id,
)
return jsonable_encoder(tools)
with Session(db.engine) as session:
service = MCPToolManageService(session=session)
tools = service.list_provider_tools(
tenant_id=tenant_id,
provider_id=provider_id,
)
return jsonable_encoder(tools)
class ToolMCPCallbackApi(Resource):

View File

@ -0,0 +1,202 @@
import json
from datetime import datetime
from typing import TYPE_CHECKING, Any, Optional
from urllib.parse import urlparse
from pydantic import BaseModel
from configs import dify_config
from core.entities.provider_entities import BasicProviderConfig
from core.file import helpers as file_helpers
from core.helper import encrypter
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.mcp.types import OAuthClientInformation, OAuthClientMetadata, OAuthTokens
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.utils.encryption import create_provider_encrypter
if TYPE_CHECKING:
from models.tools import MCPToolProvider
class MCPProviderEntity(BaseModel):
"""MCP Provider domain entity for business logic operations"""
# Basic identification
id: str
provider_id: str # server_identifier
name: str
tenant_id: str
user_id: str
# Server connection info
server_url: str # encrypted URL
headers: dict[str, str] # encrypted headers
timeout: float
sse_read_timeout: float
# Authentication related
authed: bool
credentials: dict[str, Any] # encrypted credentials
code_verifier: Optional[str] = None # for OAuth
# Tools and display info
tools: list[dict[str, Any]] # parsed tools list
icon: str | dict[str, str] # parsed icon
# Timestamps
created_at: datetime
updated_at: datetime
@classmethod
def from_db_model(cls, db_provider: "MCPToolProvider") -> "MCPProviderEntity":
"""Create entity from database model with decryption"""
return cls(
id=db_provider.id,
provider_id=db_provider.server_identifier,
name=db_provider.name,
tenant_id=db_provider.tenant_id,
user_id=db_provider.user_id,
server_url=db_provider.server_url,
headers=db_provider.headers,
timeout=db_provider.timeout,
sse_read_timeout=db_provider.sse_read_timeout,
authed=db_provider.authed,
credentials=db_provider.credentials,
tools=db_provider.tool_dict,
icon=db_provider.icon or "",
created_at=db_provider.created_at,
updated_at=db_provider.updated_at,
)
@property
def redirect_url(self) -> str:
"""OAuth redirect URL"""
return dify_config.CONSOLE_API_URL + "/console/api/mcp/oauth/callback"
@property
def client_metadata(self) -> OAuthClientMetadata:
"""Metadata about this OAuth client."""
return OAuthClientMetadata(
redirect_uris=[self.redirect_url],
token_endpoint_auth_method="none",
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
client_name="Dify",
client_uri="https://github.com/langgenius/dify",
)
@property
def provider_icon(self) -> dict[str, str] | str:
"""Get provider icon, handling both dict and string formats"""
if isinstance(self.icon, dict):
return self.icon
try:
return json.loads(self.icon)
except (json.JSONDecodeError, TypeError):
# If not JSON, assume it's a file path
return file_helpers.get_signed_file_url(self.icon)
def to_api_response(self, user_name: Optional[str] = None) -> dict[str, Any]:
"""Convert to API response format"""
return {
"id": self.id,
"author": user_name or "Anonymous",
"name": self.name,
"icon": self.provider_icon,
"type": ToolProviderType.MCP.value,
"is_team_authorization": self.authed,
"server_url": self.masked_server_url(),
"server_identifier": self.provider_id,
"timeout": self.timeout,
"sse_read_timeout": self.sse_read_timeout,
"masked_headers": self.masked_headers(),
"updated_at": int(self.updated_at.timestamp()),
"label": I18nObject(en_US=self.name, zh_Hans=self.name).to_dict(),
"description": I18nObject(en_US="", zh_Hans="").to_dict(),
}
def retrieve_client_information(self) -> Optional[OAuthClientInformation]:
"""OAuth client information if available"""
client_info = self.decrypt_credentials().get("client_information", {})
if not client_info:
return None
return OAuthClientInformation.model_validate(client_info)
def retrieve_tokens(self) -> Optional[OAuthTokens]:
"""OAuth tokens if available"""
if not self.credentials:
return None
credentials = self.decrypt_credentials()
return OAuthTokens(
access_token=credentials.get("access_token", ""),
token_type=credentials.get("token_type", "Bearer"),
expires_in=int(credentials.get("expires_in", "3600") or 3600),
refresh_token=credentials.get("refresh_token", ""),
)
def masked_server_url(self) -> str:
"""Masked server URL for display"""
parsed = urlparse(self.decrypt_server_url())
base_url = f"{parsed.scheme}://{parsed.netloc}"
if parsed.path and parsed.path != "/":
return f"{base_url}/******"
return base_url
def masked_headers(self) -> dict[str, str]:
"""Masked headers for display"""
masked: dict[str, str] = {}
for key, value in self.decrypt_headers().items():
if len(value) > 6:
masked[key] = value[:2] + "*" * (len(value) - 4) + value[-2:]
else:
masked[key] = "*" * len(value)
return masked
def decrypt_server_url(self) -> str:
"""Decrypt server URL"""
return encrypter.decrypt_token(self.tenant_id, self.server_url)
def decrypt_headers(self) -> dict[str, Any]:
"""Decrypt headers"""
try:
if not self.headers:
return {}
# Create dynamic config for all headers as SECRET_INPUT
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in self.headers]
encrypter_instance, _ = create_provider_encrypter(
tenant_id=self.tenant_id,
config=config,
cache=NoOpProviderCredentialCache(),
)
result = encrypter_instance.decrypt(self.headers)
return result
except Exception:
return {}
def decrypt_credentials(
self,
) -> dict[str, Any]:
"""Decrypt credentials"""
try:
if not self.credentials:
return {}
encrypter, _ = create_provider_encrypter(
tenant_id=self.tenant_id,
config=[
BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key)
for key in self.credentials
],
cache=NoOpProviderCredentialCache(),
)
return encrypter.decrypt(self.credentials)
except Exception:
return {}

View File

@ -9,8 +9,9 @@ from urllib.parse import urljoin, urlparse
import httpx
from pydantic import BaseModel, ValidationError
from sqlalchemy.orm import Session
from core.mcp.auth.auth_provider import OAuthClientProvider
from core.entities.mcp_provider import MCPProviderEntity
from core.mcp.types import (
LATEST_PROTOCOL_VERSION,
OAuthClientInformation,
@ -19,7 +20,9 @@ from core.mcp.types import (
OAuthMetadata,
OAuthTokens,
)
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from services.tools.mcp_oauth_service import MCPOAuthService
OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry
OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
@ -94,8 +97,13 @@ def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackSta
full_state_data.code_verifier,
full_state_data.redirect_uri,
)
provider = OAuthClientProvider(full_state_data.provider_id, full_state_data.tenant_id, for_list=True)
provider.save_tokens(tokens)
# Save tokens using the service layer
with Session(db.engine) as session:
oauth_service = MCPOAuthService(session=session)
oauth_service.save_tokens(full_state_data.provider_id, full_state_data.tenant_id, tokens)
session.commit()
return full_state_data
@ -295,24 +303,33 @@ def register_client(
def auth(
provider: OAuthClientProvider,
server_url: str,
provider: MCPProviderEntity,
authorization_code: Optional[str] = None,
state_param: Optional[str] = None,
) -> dict[str, str]:
"""Orchestrates the full auth flow with a server using secure Redis state storage."""
metadata = discover_oauth_metadata(server_url)
server_url = provider.decrypt_server_url()
server_metadata = discover_oauth_metadata(server_url)
client_metadata = provider.client_metadata
provider_id = provider.id
tenant_id = provider.tenant_id
client_information = provider.retrieve_client_information()
redirect_url = provider.redirect_url
# Handle client registration if needed
client_information = provider.client_information()
if not client_information:
if authorization_code is not None:
raise ValueError("Existing OAuth client information is required when exchanging an authorization code")
try:
full_information = register_client(server_url, metadata, provider.client_metadata)
full_information = register_client(server_url, server_metadata, client_metadata)
except httpx.RequestError as e:
raise ValueError(f"Could not register OAuth client: {e}")
provider.save_client_information(full_information)
# Save client information using service layer
with Session(db.engine) as session:
oauth_service = MCPOAuthService(session=session)
oauth_service.save_client_information(provider_id, tenant_id, full_information)
session.commit()
client_information = full_information
# Exchange authorization code for tokens
@ -335,22 +352,36 @@ def auth(
tokens = exchange_authorization(
server_url,
metadata,
server_metadata,
client_information,
authorization_code,
code_verifier,
redirect_uri,
)
provider.save_tokens(tokens)
# Save tokens using service layer
with Session(db.engine) as session:
oauth_service = MCPOAuthService(session=session)
oauth_service.save_tokens(provider_id, tenant_id, tokens)
session.commit()
return {"result": "success"}
provider_tokens = provider.tokens()
provider_tokens = provider.retrieve_tokens()
# Handle token refresh or new authorization
if provider_tokens and provider_tokens.refresh_token:
try:
new_tokens = refresh_authorization(server_url, metadata, client_information, provider_tokens.refresh_token)
provider.save_tokens(new_tokens)
new_tokens = refresh_authorization(
server_url, server_metadata, client_information, provider_tokens.refresh_token
)
# Save new tokens using service layer
with Session(db.engine) as session:
oauth_service = MCPOAuthService(session=session)
oauth_service.save_tokens(provider_id, tenant_id, new_tokens)
session.commit()
return {"result": "success"}
except Exception as e:
raise ValueError(f"Could not refresh OAuth tokens: {e}")
@ -358,12 +389,17 @@ def auth(
# Start new authorization flow
authorization_url, code_verifier = start_authorization(
server_url,
metadata,
server_metadata,
client_information,
provider.redirect_url,
provider.mcp_provider.id,
provider.mcp_provider.tenant_id,
redirect_url,
provider_id,
tenant_id,
)
provider.save_code_verifier(code_verifier)
# Save code verifier using service layer
with Session(db.engine) as session:
oauth_service = MCPOAuthService(session=session)
oauth_service.save_code_verifier(provider_id, tenant_id, code_verifier)
session.commit()
return {"authorization_url": authorization_url}

View File

@ -1,79 +0,0 @@
from typing import Optional
from configs import dify_config
from core.mcp.types import (
OAuthClientInformation,
OAuthClientInformationFull,
OAuthClientMetadata,
OAuthTokens,
)
from models.tools import MCPToolProvider
from services.tools.mcp_tools_manage_service import MCPToolManageService
class OAuthClientProvider:
mcp_provider: MCPToolProvider
def __init__(self, provider_id: str, tenant_id: str, for_list: bool = False):
if for_list:
self.mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
else:
self.mcp_provider = MCPToolManageService.get_mcp_provider_by_server_identifier(provider_id, tenant_id)
@property
def redirect_url(self) -> str:
"""The URL to redirect the user agent to after authorization."""
return dify_config.CONSOLE_API_URL + "/console/api/mcp/oauth/callback"
@property
def client_metadata(self) -> OAuthClientMetadata:
"""Metadata about this OAuth client."""
return OAuthClientMetadata(
redirect_uris=[self.redirect_url],
token_endpoint_auth_method="none",
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
client_name="Dify",
client_uri="https://github.com/langgenius/dify",
)
def client_information(self) -> Optional[OAuthClientInformation]:
"""Loads information about this OAuth client."""
client_information = self.mcp_provider.decrypted_credentials.get("client_information", {})
if not client_information:
return None
return OAuthClientInformation.model_validate(client_information)
def save_client_information(self, client_information: OAuthClientInformationFull):
"""Saves client information after dynamic registration."""
MCPToolManageService.update_mcp_provider_credentials(
self.mcp_provider,
{"client_information": client_information.model_dump()},
)
def tokens(self) -> Optional[OAuthTokens]:
"""Loads any existing OAuth tokens for the current session."""
credentials = self.mcp_provider.decrypted_credentials
if not credentials:
return None
return OAuthTokens(
access_token=credentials.get("access_token", ""),
token_type=credentials.get("token_type", "Bearer"),
expires_in=int(credentials.get("expires_in", "3600") or 3600),
refresh_token=credentials.get("refresh_token", ""),
)
def save_tokens(self, tokens: OAuthTokens):
"""Stores new OAuth tokens for the current session."""
# update mcp provider credentials
token_dict = tokens.model_dump()
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, token_dict, authed=True)
def save_code_verifier(self, code_verifier: str):
"""Saves a PKCE code verifier for the current session."""
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, {"code_verifier": code_verifier})
def code_verifier(self) -> str:
"""Loads the PKCE code verifier for the current session."""
# get code verifier from mcp provider credentials
return str(self.mcp_provider.decrypted_credentials.get("code_verifier", ""))

204
api/core/mcp/auth_client.py Normal file
View File

@ -0,0 +1,204 @@
"""
MCP Client with Authentication Retry Support
This module provides a wrapper around MCPClient that automatically handles
authentication failures and retries operations after refreshing tokens.
"""
import logging
from collections.abc import Callable
from types import TracebackType
from typing import Any, Optional
from sqlalchemy.orm import Session
from core.entities.mcp_provider import MCPProviderEntity
from core.mcp.error import MCPAuthError
from core.mcp.mcp_client import MCPClient
from core.mcp.types import CallToolResult, Tool
from extensions.ext_database import db
logger = logging.getLogger(__name__)
class MCPClientWithAuthRetry:
"""
A wrapper around MCPClient that provides automatic authentication retry.
This class intercepts MCPAuthError exceptions and attempts to refresh
authentication before retrying the failed operation.
"""
def __init__(
self,
server_url: str,
headers: dict[str, str] | None = None,
timeout: float | None = None,
sse_read_timeout: float | None = None,
provider_entity: MCPProviderEntity | None = None,
auth_callback: Callable[[MCPProviderEntity, Optional[str]], dict[str, str]] | None = None,
authorization_code: Optional[str] = None,
by_server_id: bool = False,
):
"""
Initialize the MCP client with auth retry capability.
Args:
server_url: The MCP server URL
headers: Optional headers for requests
timeout: Request timeout
sse_read_timeout: SSE read timeout
provider_entity: Provider entity for authentication
auth_callback: Authentication callback function
authorization_code: Optional authorization code for initial auth
"""
self.server_url = server_url
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
self.provider_entity = provider_entity
self.auth_callback = auth_callback
self.authorization_code = authorization_code
self._has_retried = False
self._client: MCPClient | None = None
self.by_server_id = by_server_id
def _create_client(self) -> MCPClient:
"""Create a new MCPClient instance with current headers."""
return MCPClient(
server_url=self.server_url,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
)
def _handle_auth_error(self, error: MCPAuthError) -> None:
"""
Handle authentication error by refreshing tokens.
Args:
error: The authentication error
Raises:
MCPAuthError: If authentication fails or max retries reached
"""
from services.tools.mcp_oauth_service import MCPOAuthService
if not self.provider_entity or not self.auth_callback:
raise error
if self._has_retried:
raise error
self._has_retried = True
try:
# Perform authentication
self.auth_callback(self.provider_entity, self.authorization_code)
# Retrieve new tokens
with Session(db.engine) as session:
oauth_service = MCPOAuthService(session=session)
self.provider_entity = oauth_service.get_provider_entity(
self.provider_entity.id, self.provider_entity.tenant_id, by_server_id=self.by_server_id
)
token = self.provider_entity.retrieve_tokens()
if not token:
raise MCPAuthError("Authentication failed - no token received")
# Update headers with new token
self.headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
# Clear authorization code after first use
self.authorization_code = None
except Exception as e:
logger.exception("Authentication retry failed")
raise MCPAuthError(f"Authentication retry failed: {e}") from e
def _execute_with_retry(self, func: Callable[..., Any], *args, **kwargs) -> Any:
"""
Execute a function with authentication retry logic.
Args:
func: The function to execute
*args: Positional arguments for the function
**kwargs: Keyword arguments for the function
Returns:
The result of the function call
Raises:
MCPAuthError: If authentication fails after retries
Any other exceptions from the function
"""
try:
return func(*args, **kwargs)
except MCPAuthError as e:
self._handle_auth_error(e)
# Recreate client with new headers
if self._client:
self._client.cleanup()
self._client = self._create_client()
self._client.__enter__()
return func(*args, **kwargs)
finally:
# Reset retry flag after operation completes
self._has_retried = False
def __enter__(self):
"""Enter the context manager."""
self._client = self._create_client()
# Try to initialize with retry
def initialize():
if self._client is None:
raise ValueError("Client not created")
self._client.__enter__()
return self
return self._execute_with_retry(initialize)
def __exit__(self, exc_type: type | None, exc_value: BaseException | None, traceback: TracebackType | None):
"""Exit the context manager."""
if self._client:
self._client.__exit__(exc_type, exc_value, traceback)
self._client = None
def list_tools(self) -> list[Tool]:
"""
List available tools from the MCP server.
Returns:
List of available tools
Raises:
MCPAuthError: If authentication fails after retries
"""
if not self._client:
raise ValueError("Client not initialized. Use within a context manager.")
return self._execute_with_retry(self._client.list_tools)
def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult:
"""
Invoke a tool on the MCP server.
Args:
tool_name: Name of the tool to invoke
tool_args: Arguments for the tool
Returns:
Result of the tool invocation
Raises:
MCPAuthError: If authentication fails after retries
"""
if not self._client:
raise ValueError("Client not initialized. Use within a context manager.")
return self._execute_with_retry(self._client.invoke_tool, tool_name, tool_args)
def cleanup(self):
"""Clean up resources."""
if self._client:
self._client.cleanup()
self._client = None

View File

@ -46,7 +46,6 @@ class ToolProviderApiEntity(BaseModel):
timeout: Optional[float] = Field(default=30.0, description="The timeout of the MCP tool")
sse_read_timeout: Optional[float] = Field(default=300.0, description="The SSE read timeout of the MCP tool")
masked_headers: Optional[dict[str, str]] = Field(default=None, description="The masked headers of the MCP tool")
original_headers: Optional[dict[str, str]] = Field(default=None, description="The original headers of the MCP tool")
@field_validator("tools", mode="before")
@classmethod
@ -72,7 +71,6 @@ class ToolProviderApiEntity(BaseModel):
optional_fields.update(self.optional_field("timeout", self.timeout))
optional_fields.update(self.optional_field("sse_read_timeout", self.sse_read_timeout))
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
optional_fields.update(self.optional_field("original_headers", self.original_headers))
return {
"id": self.id,
"author": self.author,

View File

@ -1,6 +1,6 @@
import json
from typing import Any, Optional, Self
from core.entities.mcp_provider import MCPProviderEntity
from core.mcp.types import Tool as RemoteMCPTool
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
@ -52,18 +52,28 @@ class MCPToolProviderController(ToolProviderController):
"""
from db provider
"""
tools = []
tools_data = json.loads(db_provider.tools)
remote_mcp_tools = [RemoteMCPTool(**tool) for tool in tools_data]
user = db_provider.load_user()
# Convert to entity first
provider_entity = db_provider.to_entity()
return cls.from_entity(provider_entity)
@classmethod
def from_entity(cls, entity: MCPProviderEntity) -> Self:
"""
create a MCPToolProviderController from a MCPProviderEntity
"""
try:
remote_mcp_tools = [RemoteMCPTool(**tool) for tool in entity.tools]
except Exception:
remote_mcp_tools = []
tools = [
ToolEntity(
identity=ToolIdentity(
author=user.name if user else "Anonymous",
author="Anonymous", # Tool level author is not stored
name=remote_mcp_tool.name,
label=I18nObject(en_US=remote_mcp_tool.name, zh_Hans=remote_mcp_tool.name),
provider=db_provider.server_identifier,
icon=db_provider.icon,
provider=entity.provider_id,
icon=entity.icon if isinstance(entity.icon, str) else "",
),
parameters=ToolTransformService.convert_mcp_schema_to_parameter(remote_mcp_tool.inputSchema),
description=ToolDescription(
@ -81,22 +91,22 @@ class MCPToolProviderController(ToolProviderController):
return cls(
entity=ToolProviderEntityWithPlugin(
identity=ToolProviderIdentity(
author=user.name if user else "Anonymous",
name=db_provider.name,
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
author="Anonymous", # Provider level author is not stored in entity
name=entity.name,
label=I18nObject(en_US=entity.name, zh_Hans=entity.name),
description=I18nObject(en_US="", zh_Hans=""),
icon=db_provider.icon,
icon=entity.icon if isinstance(entity.icon, str) else "",
),
plugin_id=None,
credentials_schema=[],
tools=tools,
),
provider_id=db_provider.server_identifier or "",
tenant_id=db_provider.tenant_id or "",
server_url=db_provider.decrypted_server_url,
headers=db_provider.decrypted_headers or {},
timeout=db_provider.timeout,
sse_read_timeout=db_provider.sse_read_timeout,
provider_id=entity.provider_id,
tenant_id=entity.tenant_id,
server_url=entity.server_url,
headers=entity.headers,
timeout=entity.timeout,
sse_read_timeout=entity.sse_read_timeout,
)
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):

View File

@ -4,8 +4,8 @@ from collections.abc import Generator
from typing import Any, Optional
from core.mcp.auth.auth_flow import auth
from core.mcp.error import MCPAuthError, MCPConnectionError
from core.mcp.mcp_client import MCPClient
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPConnectionError
from core.mcp.types import CallToolResult, ImageContent, TextContent
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
@ -118,59 +118,37 @@ class MCPTool(Tool):
headers = self.headers.copy() if self.headers else {}
tool_parameters = self._handle_none_parameter(tool_parameters)
# Initialize auth provider
from core.mcp.auth.auth_provider import OAuthClientProvider
# Get provider entity to access tokens
from sqlalchemy.orm import Session
provider = None
from extensions.ext_database import db
from services.tools.mcp_oauth_service import MCPOAuthService
try:
provider = OAuthClientProvider(self.provider_id, self.tenant_id, for_list=False)
except Exception as e:
# If provider initialization fails, continue without auth
with Session(db.engine) as session:
oauth_service = MCPOAuthService(session=session)
provider_entity = oauth_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True)
# Try to get existing token and add to headers
tokens = provider_entity.retrieve_tokens()
if tokens and tokens.access_token:
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
except Exception:
# If provider retrieval or token fails, continue without auth
pass
# Try to get existing token and add to headers
if provider:
try:
token = provider.tokens()
if token:
headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
except Exception:
# If token retrieval fails, continue without auth header
pass
# Define a helper function to invoke the tool
def _invoke_with_client(client_headers: dict[str, str]) -> CallToolResult:
with MCPClient(
self.server_url,
headers=client_headers,
# Use MCPClientWithAuthRetry to handle authentication automatically
try:
with MCPClientWithAuthRetry(
server_url=provider_entity.decrypt_server_url() if provider_entity else self.server_url,
headers=headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
provider_entity=provider_entity,
auth_callback=auth,
by_server_id=True,
) as mcp_client:
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
try:
# First attempt with current headers
return _invoke_with_client(headers)
except MCPAuthError as e:
# Authentication required - try to authenticate
if not provider:
raise ToolInvokeError("Authentication required but no auth provider available") from e
try:
# Perform authentication flow
auth(provider, self.server_url, None, None, False)
token = provider.tokens()
if not token:
raise ToolInvokeError("Authentication failed - no token received")
# Update headers with new token while preserving other headers
headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
# Retry with authenticated headers
return _invoke_with_client(headers)
except MCPAuthError as auth_error:
raise ToolInvokeError("Authentication failed") from auth_error
except MCPConnectionError as e:
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
except Exception as e:

View File

@ -27,6 +27,7 @@ from core.tools.plugin_tool.tool import PluginTool
from core.tools.utils.uuid_utils import is_valid_uuid
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.workflow.entities.variable_pool import VariablePool
from extensions.ext_database import db
from services.enterprise.plugin_manager_service import PluginCredentialType
from services.tools.mcp_tools_manage_service import MCPToolManageService
@ -59,8 +60,7 @@ from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__)
@ -715,7 +715,9 @@ class ToolManager:
)
result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
if "mcp" in filters:
mcp_providers = MCPToolManageService.retrieve_mcp_tools(tenant_id, for_list=True)
with Session(db.engine) as session:
mcp_service = MCPToolManageService(session=session)
mcp_providers = mcp_service.list_providers(tenant_id=tenant_id, for_list=True)
for mcp_provider in mcp_providers:
result_providers[f"mcp_provider.{mcp_provider.name}"] = mcp_provider
@ -770,17 +772,12 @@ class ToolManager:
:return: the provider controller, the credentials
"""
provider: MCPToolProvider | None = (
db.session.query(MCPToolProvider)
.where(
MCPToolProvider.server_identifier == provider_id,
MCPToolProvider.tenant_id == tenant_id,
)
.first()
)
if provider is None:
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
with Session(db.engine) as session:
mcp_service = MCPToolManageService(session=session)
try:
provider = mcp_service.get_provider_by_server_identifier(provider_id, tenant_id)
except ValueError:
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
controller = MCPToolProviderController.from_db(provider)
@ -918,16 +915,13 @@ class ToolManager:
@classmethod
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict[str, str] | str:
try:
mcp_provider: MCPToolProvider | None = (
db.session.query(MCPToolProvider)
.where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == provider_id)
.first()
)
if mcp_provider is None:
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
return mcp_provider.provider_icon
with Session(db.engine) as session:
mcp_service = MCPToolManageService(session=session)
try:
mcp_provider = mcp_service.get_provider_by_server_identifier(provider_id, tenant_id)
return mcp_provider.provider_icon
except ValueError:
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"}

View File

@ -1,16 +1,12 @@
import json
from datetime import datetime
from typing import Any, Optional, cast
from urllib.parse import urlparse
from typing import TYPE_CHECKING, Any, Optional, cast
import sqlalchemy as sa
from deprecated import deprecated
from sqlalchemy import ForeignKey, String, func
from sqlalchemy.orm import Mapped, mapped_column
from core.file import helpers as file_helpers
from core.helper import encrypter
from core.mcp.types import Tool
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
@ -20,6 +16,9 @@ from .engine import db
from .model import Account, App, Tenant
from .types import StringUUID
if TYPE_CHECKING:
from core.entities.mcp_provider import MCPProviderEntity
# system level tool oauth client params (client_id, client_secret, etc.)
class ToolOAuthSystemClient(TypeBase):
@ -286,119 +285,34 @@ class MCPToolProvider(Base):
def load_user(self) -> Account | None:
return db.session.query(Account).where(Account.id == self.user_id).first()
@property
def tenant(self) -> Tenant | None:
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
@property
def credentials(self) -> dict[str, Any]:
try:
return cast(dict[str, Any], json.loads(self.encrypted_credentials)) or {}
return json.loads(self.encrypted_credentials)
except Exception:
return {}
@property
def mcp_tools(self) -> list[Tool]:
return [Tool(**tool) for tool in json.loads(self.tools)]
@property
def provider_icon(self) -> dict[str, str] | str:
def headers(self) -> dict[str, Any]:
if self.encrypted_headers is None:
return {}
try:
return cast(dict[str, str], json.loads(self.icon))
except json.JSONDecodeError:
return file_helpers.get_signed_file_url(self.icon)
@property
def decrypted_server_url(self) -> str:
return encrypter.decrypt_token(self.tenant_id, self.server_url)
@property
def decrypted_headers(self) -> dict[str, Any]:
"""Get decrypted headers for MCP server requests."""
from core.entities.provider_entities import BasicProviderConfig
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.tools.utils.encryption import create_provider_encrypter
try:
if not self.encrypted_headers:
return {}
headers_data = json.loads(self.encrypted_headers)
# Create dynamic config for all headers as SECRET_INPUT
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers_data]
encrypter_instance, _ = create_provider_encrypter(
tenant_id=self.tenant_id,
config=config,
cache=NoOpProviderCredentialCache(),
)
result = encrypter_instance.decrypt(headers_data)
return result
return json.loads(self.encrypted_headers)
except Exception:
return {}
@property
def masked_headers(self) -> dict[str, Any]:
"""Get masked headers for frontend display."""
from core.entities.provider_entities import BasicProviderConfig
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.tools.utils.encryption import create_provider_encrypter
def tool_dict(self) -> list[dict[str, Any]]:
try:
if not self.encrypted_headers:
return {}
return json.loads(self.tools) if self.tools else []
except (json.JSONDecodeError, TypeError):
return []
headers_data = json.loads(self.encrypted_headers)
def to_entity(self) -> "MCPProviderEntity":
"""Convert to domain entity"""
from core.entities.mcp_provider import MCPProviderEntity
# Create dynamic config for all headers as SECRET_INPUT
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers_data]
encrypter_instance, _ = create_provider_encrypter(
tenant_id=self.tenant_id,
config=config,
cache=NoOpProviderCredentialCache(),
)
# First decrypt, then mask
decrypted_headers = encrypter_instance.decrypt(headers_data)
result = encrypter_instance.mask_tool_credentials(decrypted_headers)
return result
except Exception:
return {}
@property
def masked_server_url(self) -> str:
def mask_url(url: str, mask_char: str = "*") -> str:
"""
mask the url to a simple string
"""
parsed = urlparse(url)
base_url = f"{parsed.scheme}://{parsed.netloc}"
if parsed.path and parsed.path != "/":
return f"{base_url}/{mask_char * 6}"
else:
return base_url
return mask_url(self.decrypted_server_url)
@property
def decrypted_credentials(self) -> dict[str, Any]:
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.utils.encryption import create_provider_encrypter
provider_controller = MCPToolProviderController.from_db(self)
encrypter, _ = create_provider_encrypter(
tenant_id=self.tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
cache=NoOpProviderCredentialCache(),
)
return encrypter.decrypt(self.credentials)
return MCPProviderEntity.from_db_model(self)
class ToolModelInvoke(Base):

View File

@ -0,0 +1,53 @@
"""
MCP OAuth Service - handles OAuth-related database operations
"""
from sqlalchemy.orm import Session
from core.entities.mcp_provider import MCPProviderEntity
from core.mcp.types import OAuthClientInformationFull, OAuthTokens
from services.tools.mcp_tools_manage_service import MCPToolManageService
class MCPOAuthService:
"""Service for handling MCP OAuth operations"""
def __init__(self, session: Session):
self._session = session
self._mcp_service = MCPToolManageService(session=session)
def get_provider_entity(self, provider_id: str, tenant_id: str, by_server_id: bool = False) -> MCPProviderEntity:
"""Get provider entity by ID"""
if by_server_id:
db_provider = self._mcp_service.get_provider_by_server_identifier(provider_id, tenant_id)
else:
db_provider = self._mcp_service.get_provider_by_id(provider_id, tenant_id)
return db_provider.to_entity()
def save_client_information(
self, provider_id: str, tenant_id: str, client_information: OAuthClientInformationFull
) -> None:
"""Save OAuth client information"""
db_provider = self._mcp_service.get_provider_by_id(provider_id, tenant_id)
self._mcp_service.update_provider_credentials(
provider=db_provider,
credentials={"client_information": client_information.model_dump()},
)
def save_tokens(self, provider_id: str, tenant_id: str, tokens: OAuthTokens, authed: bool = True) -> None:
"""Save OAuth tokens"""
db_provider = self._mcp_service.get_provider_by_id(provider_id, tenant_id)
token_dict = tokens.model_dump()
self._mcp_service.update_provider_credentials(provider=db_provider, credentials=token_dict, authed=authed)
def save_code_verifier(self, provider_id: str, tenant_id: str, code_verifier: str) -> None:
"""Save PKCE code verifier"""
db_provider = self._mcp_service.get_provider_by_id(provider_id, tenant_id)
self._mcp_service.update_provider_credentials(
provider=db_provider, credentials={"code_verifier": code_verifier}
)
def clear_credentials(self, provider_id: str, tenant_id: str) -> None:
"""Clear provider credentials"""
db_provider = self._mcp_service.get_provider_by_id(provider_id, tenant_id)
self._mcp_service.clear_provider_credentials(provider=db_provider)

View File

@ -1,24 +1,27 @@
import hashlib
import json
import logging
from collections.abc import Callable
from datetime import datetime
from typing import Any
from typing import Any, Optional
from sqlalchemy import or_
from sqlalchemy import or_, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from core.entities.mcp_provider import MCPProviderEntity
from core.helper import encrypter
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPAuthError, MCPError
from core.mcp.mcp_client import MCPClient
from core.mcp.types import OAuthTokens
from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.utils.encryption import ProviderConfigEncrypter
from extensions.ext_database import db
from models.tools import MCPToolProvider
from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__)
UNCHANGED_SERVER_URL_PLACEHOLDER = "[__HIDDEN__]"
@ -27,8 +30,10 @@ class MCPToolManageService:
Service class for managing mcp tools.
"""
@staticmethod
def _encrypt_headers(headers: dict[str, str], tenant_id: str) -> dict[str, str]:
def __init__(self, session: Session):
self._session = session
def _encrypt_headers(self, headers: dict[str, str], tenant_id: str) -> dict[str, str]:
"""
Encrypt headers using ProviderConfigEncrypter with all headers as SECRET_INPUT.
@ -57,48 +62,53 @@ class MCPToolManageService:
return encrypter_instance.encrypt(headers)
@staticmethod
def _retrieve_remote_mcp_tools(server_url: str, headers: dict[str, str], timeout: float, sse_read_timeout: float):
with MCPClient(
def _retrieve_remote_mcp_tools(
self,
server_url: str,
headers: dict[str, str],
provider_entity: MCPProviderEntity,
auth_callback: Callable[[MCPProviderEntity, Optional[str]], dict[str, str]],
):
"""Retrieve tools from remote MCP server"""
with MCPClientWithAuthRetry(
server_url,
headers=headers,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
timeout=provider_entity.timeout,
sse_read_timeout=provider_entity.sse_read_timeout,
provider_entity=provider_entity,
auth_callback=auth_callback,
) as mcp_client:
tools = mcp_client.list_tools()
return tools
@staticmethod
def _process_headers(headers: dict[str, str], tokens: OAuthTokens | None = None):
headers = headers or {}
def _process_headers(self, headers: dict[str, str], tokens: OAuthTokens | None = None) -> dict[str, str]:
"""Process headers and add OAuth token if available"""
headers = headers.copy() if headers else {}
if tokens:
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
return headers
@staticmethod
def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider:
res = (
db.session.query(MCPToolProvider)
.where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id)
.first()
)
if not res:
def get_provider_by_id(self, provider_id: str, tenant_id: str) -> MCPToolProvider:
"""Get MCP provider by ID"""
stmt = select(MCPToolProvider).where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id)
provider = self._session.scalar(stmt)
if not provider:
raise ValueError("MCP tool not found")
return res
return provider
@staticmethod
def get_mcp_provider_by_server_identifier(server_identifier: str, tenant_id: str) -> MCPToolProvider:
res = (
db.session.query(MCPToolProvider)
.where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier)
.first()
def get_provider_by_server_identifier(self, server_identifier: str, tenant_id: str) -> MCPToolProvider:
"""Get MCP provider by server identifier"""
stmt = select(MCPToolProvider).where(
MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier
)
if not res:
provider = self._session.scalar(stmt)
if not provider:
raise ValueError("MCP tool not found")
return res
return provider
@staticmethod
def create_mcp_provider(
def create_provider(
self,
*,
tenant_id: str,
name: str,
server_url: str,
@ -111,19 +121,20 @@ class MCPToolManageService:
sse_read_timeout: float,
headers: dict[str, str] | None = None,
) -> ToolProviderApiEntity:
"""Create a new MCP provider"""
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
existing_provider = (
db.session.query(MCPToolProvider)
.where(
MCPToolProvider.tenant_id == tenant_id,
or_(
MCPToolProvider.name == name,
MCPToolProvider.server_url_hash == server_url_hash,
MCPToolProvider.server_identifier == server_identifier,
),
)
.first()
# Check for existing provider
stmt = select(MCPToolProvider).where(
MCPToolProvider.tenant_id == tenant_id,
or_(
MCPToolProvider.name == name,
MCPToolProvider.server_url_hash == server_url_hash,
MCPToolProvider.server_identifier == server_identifier,
),
)
existing_provider = self._session.scalar(stmt)
if existing_provider:
if existing_provider.name == name:
raise ValueError(f"MCP tool {name} already exists")
@ -131,13 +142,17 @@ class MCPToolManageService:
raise ValueError(f"MCP tool {server_url} already exists")
if existing_provider.server_identifier == server_identifier:
raise ValueError(f"MCP tool {server_identifier} already exists")
# Encrypt server URL
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
# Encrypt headers
encrypted_headers = None
if headers:
encrypted_headers_dict = MCPToolManageService._encrypt_headers(headers, tenant_id)
encrypted_headers_dict = self._encrypt_headers(headers, tenant_id)
encrypted_headers = json.dumps(encrypted_headers_dict)
# Create provider
mcp_tool = MCPToolProvider(
tenant_id=tenant_id,
name=name,
@ -152,91 +167,68 @@ class MCPToolManageService:
sse_read_timeout=sse_read_timeout,
encrypted_headers=encrypted_headers,
)
db.session.add(mcp_tool)
db.session.commit()
self._session.add(mcp_tool)
self._session.commit()
return ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
@staticmethod
def retrieve_mcp_tools(tenant_id: str, for_list: bool = False) -> list[ToolProviderApiEntity]:
mcp_providers = (
db.session.query(MCPToolProvider)
.where(MCPToolProvider.tenant_id == tenant_id)
.order_by(MCPToolProvider.name)
.all()
)
def list_providers(self, *, tenant_id: str, for_list: bool = False) -> list[ToolProviderApiEntity]:
"""List all MCP providers for a tenant"""
stmt = select(MCPToolProvider).where(MCPToolProvider.tenant_id == tenant_id).order_by(MCPToolProvider.name)
mcp_providers = self._session.scalars(stmt).all()
return [
ToolTransformService.mcp_provider_to_user_provider(mcp_provider, for_list=for_list)
for mcp_provider in mcp_providers
ToolTransformService.mcp_provider_to_user_provider(provider, for_list=for_list)
for provider in mcp_providers
]
@classmethod
def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str) -> ToolProviderApiEntity:
def list_provider_tools(self, *, tenant_id: str, provider_id: str) -> ToolProviderApiEntity:
"""List tools from remote MCP server"""
from core.mcp.auth.auth_flow import auth
from core.mcp.auth.auth_provider import OAuthClientProvider
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
server_url = mcp_provider.decrypted_server_url
authed = mcp_provider.authed
headers = mcp_provider.decrypted_headers
timeout = mcp_provider.timeout
sse_read_timeout = mcp_provider.sse_read_timeout
# Load provider and convert to entity
db_provider = self.get_provider_by_id(provider_id, tenant_id)
provider_entity = db_provider.to_entity()
# Handle authentication headers if authed
if not authed:
if not provider_entity.authed:
raise ValueError("Please auth the tool first")
provider = OAuthClientProvider(provider_id, tenant_id, for_list=True)
tokens = provider.tokens()
headers = cls._process_headers(headers, tokens)
tokens = provider_entity.retrieve_tokens()
headers = self._process_headers(provider_entity.headers, tokens)
server_url = provider_entity.decrypt_server_url()
try:
tools = cls._retrieve_remote_mcp_tools(server_url, headers, timeout, sse_read_timeout)
except MCPAuthError:
try:
auth(provider, server_url, None, None, False)
tokens = provider.tokens()
re_authed_headers = cls._process_headers(headers, tokens)
tools = cls._retrieve_remote_mcp_tools(server_url, re_authed_headers, timeout, sse_read_timeout)
except Exception:
raise ValueError("Please auth the tool first")
tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity, auth)
except MCPError as e:
raise ValueError(f"Failed to connect to MCP server: {e}")
try:
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools])
mcp_provider.authed = True
mcp_provider.updated_at = datetime.now()
db.session.commit()
except Exception:
db.session.rollback()
raise
# Update database record with new tools
db_provider.tools = json.dumps([tool.model_dump() for tool in tools])
db_provider.authed = True
db_provider.updated_at = datetime.now()
self._session.commit()
user = mcp_provider.load_user()
return ToolProviderApiEntity(
id=mcp_provider.id,
name=mcp_provider.name,
tools=ToolTransformService.mcp_tool_to_user_tool(mcp_provider, tools),
type=ToolProviderType.MCP,
icon=mcp_provider.icon,
author=user.name if user else "Anonymous",
server_url=mcp_provider.masked_server_url,
updated_at=int(mcp_provider.updated_at.timestamp()),
description=I18nObject(en_US="", zh_Hans=""),
label=I18nObject(en_US=mcp_provider.name, zh_Hans=mcp_provider.name),
plugin_unique_identifier=mcp_provider.server_identifier,
# Create API response using entity
user = db_provider.load_user()
response = provider_entity.to_api_response(
user_name=user.name if user else None,
)
response["tools"] = ToolTransformService.mcp_tool_to_user_tool(db_provider, tools)
response["plugin_unique_identifier"] = provider_entity.provider_id
@classmethod
def delete_mcp_tool(cls, tenant_id: str, provider_id: str):
mcp_tool = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
return ToolProviderApiEntity(**response)
db.session.delete(mcp_tool)
db.session.commit()
def delete_provider(self, *, tenant_id: str, provider_id: str) -> None:
"""Delete an MCP provider"""
mcp_tool = self.get_provider_by_id(provider_id, tenant_id)
self._session.delete(mcp_tool)
self._session.commit()
@classmethod
def update_mcp_provider(
cls,
def update_provider(
self,
*,
tenant_id: str,
provider_id: str,
name: str,
@ -248,21 +240,27 @@ class MCPToolManageService:
timeout: float | None = None,
sse_read_timeout: float | None = None,
headers: dict[str, str] | None = None,
):
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
) -> None:
"""Update an MCP provider"""
mcp_provider = self.get_provider_by_id(provider_id, tenant_id)
reconnect_result = None
encrypted_server_url = None
server_url_hash = None
# Handle server URL update
if UNCHANGED_SERVER_URL_PLACEHOLDER not in server_url:
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
if server_url_hash != mcp_provider.server_url_hash:
reconnect_result = cls._re_connect_mcp_provider(server_url, provider_id, tenant_id)
reconnect_result = self._reconnect_provider(
server_url=server_url,
provider=mcp_provider,
)
try:
# Update basic fields
mcp_provider.updated_at = datetime.now()
mcp_provider.name = name
mcp_provider.icon = (
@ -270,6 +268,7 @@ class MCPToolManageService:
)
mcp_provider.server_identifier = server_identifier
# Update server URL if changed
if encrypted_server_url is not None and server_url_hash is not None:
mcp_provider.server_url = encrypted_server_url
mcp_provider.server_url_hash = server_url_hash
@ -279,6 +278,7 @@ class MCPToolManageService:
mcp_provider.tools = reconnect_result["tools"]
mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"]
# Update optional fields
if timeout is not None:
mcp_provider.timeout = timeout
if sse_read_timeout is not None:
@ -286,13 +286,15 @@ class MCPToolManageService:
if headers is not None:
# Encrypt headers
if headers:
encrypted_headers_dict = MCPToolManageService._encrypt_headers(headers, tenant_id)
encrypted_headers_dict = self._encrypt_headers(headers, tenant_id)
mcp_provider.encrypted_headers = json.dumps(encrypted_headers_dict)
else:
mcp_provider.encrypted_headers = None
db.session.commit()
self._session.commit()
except IntegrityError as e:
db.session.rollback()
self._session.rollback()
error_msg = str(e.orig)
if "unique_mcp_provider_name" in error_msg:
raise ValueError(f"MCP tool {name} already exists")
@ -302,54 +304,55 @@ class MCPToolManageService:
raise ValueError(f"MCP tool {server_identifier} already exists")
raise
except Exception:
db.session.rollback()
self._session.rollback()
raise
@classmethod
def update_mcp_provider_credentials(
cls, mcp_provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False
):
def update_provider_credentials(
self, *, provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False
) -> None:
"""Update provider credentials"""
from core.tools.mcp_tool.provider import MCPToolProviderController
provider_controller = MCPToolProviderController.from_db(mcp_provider)
provider_controller = MCPToolProviderController.from_db(provider)
tool_configuration = ProviderConfigEncrypter(
tenant_id=mcp_provider.tenant_id,
tenant_id=provider.tenant_id,
config=list(provider_controller.get_credentials_schema()),
provider_config_cache=NoOpProviderCredentialCache(),
)
credentials = tool_configuration.encrypt(credentials)
mcp_provider.updated_at = datetime.now()
mcp_provider.encrypted_credentials = json.dumps({**mcp_provider.credentials, **credentials})
mcp_provider.authed = authed
encrypted_credentials = tool_configuration.encrypt(credentials)
provider.updated_at = datetime.now()
provider.encrypted_credentials = json.dumps({**provider.credentials, **encrypted_credentials})
provider.authed = authed
if not authed:
mcp_provider.tools = "[]"
db.session.commit()
provider.tools = "[]"
@classmethod
def clear_mcp_provider_credentials(
cls,
mcp_provider: MCPToolProvider,
):
mcp_provider.tools = "[]"
mcp_provider.encrypted_credentials = "{}"
mcp_provider.updated_at = datetime.now()
mcp_provider.authed = False
db.session.commit()
self._session.commit()
@classmethod
def _re_connect_mcp_provider(cls, server_url: str, provider_id: str, tenant_id: str) -> dict[str, Any]:
# Get the existing provider to access headers and timeout settings
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
headers = mcp_provider.decrypted_headers
timeout = mcp_provider.timeout
sse_read_timeout = mcp_provider.sse_read_timeout
def clear_provider_credentials(self, *, provider: MCPToolProvider) -> None:
"""Clear provider credentials"""
provider.tools = "[]"
provider.encrypted_credentials = "{}"
provider.updated_at = datetime.now()
provider.authed = False
self._session.commit()
def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> dict[str, Any]:
"""Attempt to reconnect to MCP provider with new server URL"""
from core.mcp.auth.auth_flow import auth
provider_entity = provider.to_entity()
headers = provider_entity.headers
timeout = provider_entity.timeout
sse_read_timeout = provider_entity.sse_read_timeout
try:
with MCPClient(
with MCPClientWithAuthRetry(
server_url,
headers=headers,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
provider_entity=provider_entity,
auth_callback=auth,
) as mcp_client:
tools = mcp_client.list_tools()
return {

View File

@ -221,27 +221,20 @@ class ToolTransformService:
@staticmethod
def mcp_provider_to_user_provider(db_provider: MCPToolProvider, for_list: bool = False) -> ToolProviderApiEntity:
# Convert to entity and use its API response method
provider_entity = db_provider.to_entity()
user = db_provider.load_user()
return ToolProviderApiEntity(
id=db_provider.server_identifier if not for_list else db_provider.id,
author=user.name if user else "Anonymous",
name=db_provider.name,
icon=db_provider.provider_icon,
type=ToolProviderType.MCP,
is_team_authorization=db_provider.authed,
server_url=db_provider.masked_server_url,
tools=ToolTransformService.mcp_tool_to_user_tool(
db_provider, [MCPTool(**tool) for tool in json.loads(db_provider.tools)]
),
updated_at=int(db_provider.updated_at.timestamp()),
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
description=I18nObject(en_US="", zh_Hans=""),
server_identifier=db_provider.server_identifier,
timeout=db_provider.timeout,
sse_read_timeout=db_provider.sse_read_timeout,
masked_headers=db_provider.masked_headers,
original_headers=db_provider.decrypted_headers,
response = provider_entity.to_api_response(user_name=user.name if user else None)
# Add additional fields specific to the transform
response["id"] = db_provider.server_identifier if not for_list else db_provider.id
response["tools"] = ToolTransformService.mcp_tool_to_user_tool(
db_provider, [MCPTool(**tool) for tool in json.loads(db_provider.tools)]
)
response["server_identifier"] = db_provider.server_identifier
return ToolProviderApiEntity(**response)
@staticmethod
def mcp_tool_to_user_tool(mcp_provider: MCPToolProvider, tools: list[MCPTool]) -> list[ToolApiEntity]:
@ -403,7 +396,7 @@ class ToolTransformService:
)
@staticmethod
def convert_mcp_schema_to_parameter(schema: dict) -> list["ToolParameter"]:
def convert_mcp_schema_to_parameter(schema: dict[str, Any]) -> list["ToolParameter"]:
"""
Convert MCP JSON schema to tool parameters
@ -412,7 +405,7 @@ class ToolTransformService:
"""
def create_parameter(
name: str, description: str, param_type: str, required: bool, input_schema: dict | None = None
name: str, description: str, param_type: str, required: bool, input_schema: dict[str, Any] | None = None
) -> ToolParameter:
"""Create a ToolParameter instance with given attributes"""
input_schema_dict: dict[str, Any] = {"input_schema": input_schema} if input_schema else {}
@ -427,7 +420,9 @@ class ToolTransformService:
**input_schema_dict,
)
def process_properties(props: dict, required: list, prefix: str = "") -> list[ToolParameter]:
def process_properties(
props: dict[str, dict[str, Any]], required: list[str], prefix: str = ""
) -> list[ToolParameter]:
"""Process properties recursively"""
TYPE_MAPPING = {"integer": "number", "float": "number"}
COMPLEX_TYPES = ["array", "object"]