diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index bbf48317af..17a63952fd 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -7,6 +7,7 @@ from flask_restx import ( Resource, reqparse, ) +from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden from configs import dify_config @@ -17,13 +18,13 @@ from controllers.console.wraps import ( setup_required, ) from core.mcp.auth.auth_flow import auth, handle_callback -from core.mcp.auth.auth_provider import OAuthClientProvider -from core.mcp.error import MCPAuthError, MCPError -from core.mcp.mcp_client import MCPClient +from core.mcp.auth_client import MCPClientWithAuthRetry +from core.mcp.error import MCPError from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin import ToolProviderID from core.plugin.impl.oauth import OAuthHandler from core.tools.entities.tool_entities import CredentialType +from extensions.ext_database import db from libs.helper import StrLen, alphanumeric, uuid_value from libs.login import login_required from services.plugin.oauth_service import OAuthProxyService @@ -870,8 +871,9 @@ class ToolProviderMCPApi(Resource): user = current_user if not is_valid_url(args["server_url"]): raise ValueError("Server URL is not valid.") - return jsonable_encoder( - MCPToolManageService.create_mcp_provider( + with Session(db.engine) as session: + service = MCPToolManageService(session=session) + result = service.create_provider( tenant_id=user.current_tenant_id, server_url=args["server_url"], name=args["name"], @@ -884,7 +886,8 @@ class ToolProviderMCPApi(Resource): sse_read_timeout=args["sse_read_timeout"], headers=args["headers"], ) - ) + session.commit() + return jsonable_encoder(result) @setup_required @login_required @@ -907,20 +910,23 @@ class ToolProviderMCPApi(Resource): pass else: raise ValueError("Server URL is not valid.") - MCPToolManageService.update_mcp_provider( - tenant_id=current_user.current_tenant_id, - provider_id=args["provider_id"], - server_url=args["server_url"], - name=args["name"], - icon=args["icon"], - icon_type=args["icon_type"], - icon_background=args["icon_background"], - server_identifier=args["server_identifier"], - timeout=args.get("timeout"), - sse_read_timeout=args.get("sse_read_timeout"), - headers=args.get("headers"), - ) - return {"result": "success"} + with Session(db.engine) as session: + service = MCPToolManageService(session=session) + service.update_provider( + tenant_id=current_user.current_tenant_id, + provider_id=args["provider_id"], + server_url=args["server_url"], + name=args["name"], + icon=args["icon"], + icon_type=args["icon_type"], + icon_background=args["icon_background"], + server_identifier=args["server_identifier"], + timeout=args.get("timeout"), + sse_read_timeout=args.get("sse_read_timeout"), + headers=args.get("headers"), + ) + session.commit() + return {"result": "success"} @setup_required @login_required @@ -929,8 +935,11 @@ class ToolProviderMCPApi(Resource): parser = reqparse.RequestParser() parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json") args = parser.parse_args() - MCPToolManageService.delete_mcp_tool(tenant_id=current_user.current_tenant_id, provider_id=args["provider_id"]) - return {"result": "success"} + with Session(db.engine) as session: + service = MCPToolManageService(session=session) + service.delete_provider(tenant_id=current_user.current_tenant_id, provider_id=args["provider_id"]) + session.commit() + return {"result": "success"} class ToolMCPAuthApi(Resource): @@ -944,45 +953,50 @@ class ToolMCPAuthApi(Resource): args = parser.parse_args() provider_id = args["provider_id"] tenant_id = current_user.current_tenant_id - provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id) - if not provider: - raise ValueError("provider not found") - # headers1: if headers is provided, use it and don't need to get token - headers = provider.decrypted_headers or {} - # headers2: Add OAuth token if authed and no headers provided - if not provider.decrypted_headers and provider.authed: - token = OAuthClientProvider(provider_id, tenant_id, for_list=True).tokens() - if token: - headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}" - try: - # try to connect to MCP server with headers - with MCPClient( - provider.decrypted_server_url, - headers=headers, - timeout=provider.timeout, - sse_read_timeout=provider.sse_read_timeout, - ): - MCPToolManageService.update_mcp_provider_credentials( - mcp_provider=provider, - credentials=provider.decrypted_credentials, - authed=True, - ) - return {"result": "success"} + with Session(db.engine) as session: + service = MCPToolManageService(session=session) + db_provider = service.get_provider_by_id(provider_id, tenant_id) + if not db_provider: + raise ValueError("provider not found") - except MCPAuthError as e: + # Convert to entity + provider_entity = db_provider.to_entity() + server_url = provider_entity.decrypt_server_url() + + # Option 1: if headers is provided, use it and don't need to get token + headers = provider_entity.decrypt_headers() + + # Option 2: Add OAuth token if authed and no headers provided + if not provider_entity.headers and provider_entity.authed: + token = provider_entity.retrieve_tokens() + if token: + headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}" try: - if provider.decrypted_headers: - raise ValueError(f"Failed to authenticate, please check your headers: {e}") from e - # if auth failed, try to auth with OAuth or exchange token - auth_provider = OAuthClientProvider(provider_id, tenant_id, for_list=True) - return auth(auth_provider, provider.decrypted_server_url, args["authorization_code"]) - except Exception as e: - MCPToolManageService.clear_mcp_provider_credentials(mcp_provider=provider) - raise ValueError(f"Failed to authenticate, please try again: {e}") from e - except MCPError as e: - MCPToolManageService.clear_mcp_provider_credentials(mcp_provider=provider) - raise ValueError(f"Failed to connect to MCP server: {e}") from e + # Use MCPClientWithAuthRetry to handle authentication automatically + with MCPClientWithAuthRetry( + server_url=server_url, + headers=headers, + timeout=provider_entity.timeout, + sse_read_timeout=provider_entity.sse_read_timeout, + provider_entity=provider_entity + if not provider_entity.headers + else None, # Only use auth retry if no custom headers + auth_callback=auth if not provider_entity.headers else None, + authorization_code=args.get("authorization_code"), + ): + service.update_provider_credentials( + provider=db_provider, + credentials=provider_entity.credentials, + authed=True, + ) + session.commit() + return {"result": "success"} + + except MCPError as e: + service.clear_provider_credentials(provider=db_provider) + session.commit() + raise ValueError(f"Failed to connect to MCP server: {e}") from e class ToolMCPDetailApi(Resource): @@ -991,8 +1005,10 @@ class ToolMCPDetailApi(Resource): @account_initialization_required def get(self, provider_id): user = current_user - provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, user.current_tenant_id) - return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True)) + with Session(db.engine) as session: + service = MCPToolManageService(session=session) + provider = service.get_provider_by_id(provider_id, user.current_tenant_id) + return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True)) class ToolMCPListAllApi(Resource): @@ -1003,9 +1019,11 @@ class ToolMCPListAllApi(Resource): user = current_user tenant_id = user.current_tenant_id - tools = MCPToolManageService.retrieve_mcp_tools(tenant_id=tenant_id) + with Session(db.engine) as session: + service = MCPToolManageService(session=session) + tools = service.list_providers(tenant_id=tenant_id) - return [tool.to_dict() for tool in tools] + return [tool.to_dict() for tool in tools] class ToolMCPUpdateApi(Resource): @@ -1014,11 +1032,13 @@ class ToolMCPUpdateApi(Resource): @account_initialization_required def get(self, provider_id): tenant_id = current_user.current_tenant_id - tools = MCPToolManageService.list_mcp_tool_from_remote_server( - tenant_id=tenant_id, - provider_id=provider_id, - ) - return jsonable_encoder(tools) + with Session(db.engine) as session: + service = MCPToolManageService(session=session) + tools = service.list_provider_tools( + tenant_id=tenant_id, + provider_id=provider_id, + ) + return jsonable_encoder(tools) class ToolMCPCallbackApi(Resource): diff --git a/api/core/entities/mcp_provider.py b/api/core/entities/mcp_provider.py new file mode 100644 index 0000000000..3ce95864f3 --- /dev/null +++ b/api/core/entities/mcp_provider.py @@ -0,0 +1,202 @@ +import json +from datetime import datetime +from typing import TYPE_CHECKING, Any, Optional +from urllib.parse import urlparse + +from pydantic import BaseModel + +from configs import dify_config +from core.entities.provider_entities import BasicProviderConfig +from core.file import helpers as file_helpers +from core.helper import encrypter +from core.helper.provider_cache import NoOpProviderCredentialCache +from core.mcp.types import OAuthClientInformation, OAuthClientMetadata, OAuthTokens +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolProviderType +from core.tools.utils.encryption import create_provider_encrypter + +if TYPE_CHECKING: + from models.tools import MCPToolProvider + + +class MCPProviderEntity(BaseModel): + """MCP Provider domain entity for business logic operations""" + + # Basic identification + id: str + provider_id: str # server_identifier + name: str + tenant_id: str + user_id: str + + # Server connection info + server_url: str # encrypted URL + headers: dict[str, str] # encrypted headers + timeout: float + sse_read_timeout: float + + # Authentication related + authed: bool + credentials: dict[str, Any] # encrypted credentials + code_verifier: Optional[str] = None # for OAuth + + # Tools and display info + tools: list[dict[str, Any]] # parsed tools list + icon: str | dict[str, str] # parsed icon + + # Timestamps + created_at: datetime + updated_at: datetime + + @classmethod + def from_db_model(cls, db_provider: "MCPToolProvider") -> "MCPProviderEntity": + """Create entity from database model with decryption""" + + return cls( + id=db_provider.id, + provider_id=db_provider.server_identifier, + name=db_provider.name, + tenant_id=db_provider.tenant_id, + user_id=db_provider.user_id, + server_url=db_provider.server_url, + headers=db_provider.headers, + timeout=db_provider.timeout, + sse_read_timeout=db_provider.sse_read_timeout, + authed=db_provider.authed, + credentials=db_provider.credentials, + tools=db_provider.tool_dict, + icon=db_provider.icon or "", + created_at=db_provider.created_at, + updated_at=db_provider.updated_at, + ) + + @property + def redirect_url(self) -> str: + """OAuth redirect URL""" + return dify_config.CONSOLE_API_URL + "/console/api/mcp/oauth/callback" + + @property + def client_metadata(self) -> OAuthClientMetadata: + """Metadata about this OAuth client.""" + return OAuthClientMetadata( + redirect_uris=[self.redirect_url], + token_endpoint_auth_method="none", + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + client_name="Dify", + client_uri="https://github.com/langgenius/dify", + ) + + @property + def provider_icon(self) -> dict[str, str] | str: + """Get provider icon, handling both dict and string formats""" + if isinstance(self.icon, dict): + return self.icon + try: + return json.loads(self.icon) + except (json.JSONDecodeError, TypeError): + # If not JSON, assume it's a file path + return file_helpers.get_signed_file_url(self.icon) + + def to_api_response(self, user_name: Optional[str] = None) -> dict[str, Any]: + """Convert to API response format""" + return { + "id": self.id, + "author": user_name or "Anonymous", + "name": self.name, + "icon": self.provider_icon, + "type": ToolProviderType.MCP.value, + "is_team_authorization": self.authed, + "server_url": self.masked_server_url(), + "server_identifier": self.provider_id, + "timeout": self.timeout, + "sse_read_timeout": self.sse_read_timeout, + "masked_headers": self.masked_headers(), + "updated_at": int(self.updated_at.timestamp()), + "label": I18nObject(en_US=self.name, zh_Hans=self.name).to_dict(), + "description": I18nObject(en_US="", zh_Hans="").to_dict(), + } + + def retrieve_client_information(self) -> Optional[OAuthClientInformation]: + """OAuth client information if available""" + client_info = self.decrypt_credentials().get("client_information", {}) + if not client_info: + return None + return OAuthClientInformation.model_validate(client_info) + + def retrieve_tokens(self) -> Optional[OAuthTokens]: + """OAuth tokens if available""" + if not self.credentials: + return None + credentials = self.decrypt_credentials() + return OAuthTokens( + access_token=credentials.get("access_token", ""), + token_type=credentials.get("token_type", "Bearer"), + expires_in=int(credentials.get("expires_in", "3600") or 3600), + refresh_token=credentials.get("refresh_token", ""), + ) + + def masked_server_url(self) -> str: + """Masked server URL for display""" + parsed = urlparse(self.decrypt_server_url()) + base_url = f"{parsed.scheme}://{parsed.netloc}" + if parsed.path and parsed.path != "/": + return f"{base_url}/******" + return base_url + + def masked_headers(self) -> dict[str, str]: + """Masked headers for display""" + masked: dict[str, str] = {} + for key, value in self.decrypt_headers().items(): + if len(value) > 6: + masked[key] = value[:2] + "*" * (len(value) - 4) + value[-2:] + else: + masked[key] = "*" * len(value) + return masked + + def decrypt_server_url(self) -> str: + """Decrypt server URL""" + + return encrypter.decrypt_token(self.tenant_id, self.server_url) + + def decrypt_headers(self) -> dict[str, Any]: + """Decrypt headers""" + + try: + if not self.headers: + return {} + + # Create dynamic config for all headers as SECRET_INPUT + config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in self.headers] + + encrypter_instance, _ = create_provider_encrypter( + tenant_id=self.tenant_id, + config=config, + cache=NoOpProviderCredentialCache(), + ) + + result = encrypter_instance.decrypt(self.headers) + return result + except Exception: + return {} + + def decrypt_credentials( + self, + ) -> dict[str, Any]: + """Decrypt credentials""" + try: + if not self.credentials: + return {} + + encrypter, _ = create_provider_encrypter( + tenant_id=self.tenant_id, + config=[ + BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) + for key in self.credentials + ], + cache=NoOpProviderCredentialCache(), + ) + + return encrypter.decrypt(self.credentials) + except Exception: + return {} diff --git a/api/core/mcp/auth/auth_flow.py b/api/core/mcp/auth/auth_flow.py index 95242fab2c..269f00494e 100644 --- a/api/core/mcp/auth/auth_flow.py +++ b/api/core/mcp/auth/auth_flow.py @@ -9,8 +9,9 @@ from urllib.parse import urljoin, urlparse import httpx from pydantic import BaseModel, ValidationError +from sqlalchemy.orm import Session -from core.mcp.auth.auth_provider import OAuthClientProvider +from core.entities.mcp_provider import MCPProviderEntity from core.mcp.types import ( LATEST_PROTOCOL_VERSION, OAuthClientInformation, @@ -19,7 +20,9 @@ from core.mcp.types import ( OAuthMetadata, OAuthTokens, ) +from extensions.ext_database import db from extensions.ext_redis import redis_client +from services.tools.mcp_oauth_service import MCPOAuthService OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:" @@ -94,8 +97,13 @@ def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackSta full_state_data.code_verifier, full_state_data.redirect_uri, ) - provider = OAuthClientProvider(full_state_data.provider_id, full_state_data.tenant_id, for_list=True) - provider.save_tokens(tokens) + + # Save tokens using the service layer + with Session(db.engine) as session: + oauth_service = MCPOAuthService(session=session) + oauth_service.save_tokens(full_state_data.provider_id, full_state_data.tenant_id, tokens) + session.commit() + return full_state_data @@ -295,24 +303,33 @@ def register_client( def auth( - provider: OAuthClientProvider, - server_url: str, + provider: MCPProviderEntity, authorization_code: Optional[str] = None, state_param: Optional[str] = None, ) -> dict[str, str]: """Orchestrates the full auth flow with a server using secure Redis state storage.""" - metadata = discover_oauth_metadata(server_url) + server_url = provider.decrypt_server_url() + server_metadata = discover_oauth_metadata(server_url) + client_metadata = provider.client_metadata + provider_id = provider.id + tenant_id = provider.tenant_id + client_information = provider.retrieve_client_information() + redirect_url = provider.redirect_url - # Handle client registration if needed - client_information = provider.client_information() if not client_information: if authorization_code is not None: raise ValueError("Existing OAuth client information is required when exchanging an authorization code") try: - full_information = register_client(server_url, metadata, provider.client_metadata) + full_information = register_client(server_url, server_metadata, client_metadata) except httpx.RequestError as e: raise ValueError(f"Could not register OAuth client: {e}") - provider.save_client_information(full_information) + + # Save client information using service layer + with Session(db.engine) as session: + oauth_service = MCPOAuthService(session=session) + oauth_service.save_client_information(provider_id, tenant_id, full_information) + session.commit() + client_information = full_information # Exchange authorization code for tokens @@ -335,22 +352,36 @@ def auth( tokens = exchange_authorization( server_url, - metadata, + server_metadata, client_information, authorization_code, code_verifier, redirect_uri, ) - provider.save_tokens(tokens) + + # Save tokens using service layer + with Session(db.engine) as session: + oauth_service = MCPOAuthService(session=session) + oauth_service.save_tokens(provider_id, tenant_id, tokens) + session.commit() + return {"result": "success"} - provider_tokens = provider.tokens() + provider_tokens = provider.retrieve_tokens() # Handle token refresh or new authorization if provider_tokens and provider_tokens.refresh_token: try: - new_tokens = refresh_authorization(server_url, metadata, client_information, provider_tokens.refresh_token) - provider.save_tokens(new_tokens) + new_tokens = refresh_authorization( + server_url, server_metadata, client_information, provider_tokens.refresh_token + ) + + # Save new tokens using service layer + with Session(db.engine) as session: + oauth_service = MCPOAuthService(session=session) + oauth_service.save_tokens(provider_id, tenant_id, new_tokens) + session.commit() + return {"result": "success"} except Exception as e: raise ValueError(f"Could not refresh OAuth tokens: {e}") @@ -358,12 +389,17 @@ def auth( # Start new authorization flow authorization_url, code_verifier = start_authorization( server_url, - metadata, + server_metadata, client_information, - provider.redirect_url, - provider.mcp_provider.id, - provider.mcp_provider.tenant_id, + redirect_url, + provider_id, + tenant_id, ) - provider.save_code_verifier(code_verifier) + # Save code verifier using service layer + with Session(db.engine) as session: + oauth_service = MCPOAuthService(session=session) + oauth_service.save_code_verifier(provider_id, tenant_id, code_verifier) + session.commit() + return {"authorization_url": authorization_url} diff --git a/api/core/mcp/auth/auth_provider.py b/api/core/mcp/auth/auth_provider.py deleted file mode 100644 index bf1820f744..0000000000 --- a/api/core/mcp/auth/auth_provider.py +++ /dev/null @@ -1,79 +0,0 @@ -from typing import Optional - -from configs import dify_config -from core.mcp.types import ( - OAuthClientInformation, - OAuthClientInformationFull, - OAuthClientMetadata, - OAuthTokens, -) -from models.tools import MCPToolProvider -from services.tools.mcp_tools_manage_service import MCPToolManageService - - -class OAuthClientProvider: - mcp_provider: MCPToolProvider - - def __init__(self, provider_id: str, tenant_id: str, for_list: bool = False): - if for_list: - self.mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id) - else: - self.mcp_provider = MCPToolManageService.get_mcp_provider_by_server_identifier(provider_id, tenant_id) - - @property - def redirect_url(self) -> str: - """The URL to redirect the user agent to after authorization.""" - return dify_config.CONSOLE_API_URL + "/console/api/mcp/oauth/callback" - - @property - def client_metadata(self) -> OAuthClientMetadata: - """Metadata about this OAuth client.""" - return OAuthClientMetadata( - redirect_uris=[self.redirect_url], - token_endpoint_auth_method="none", - grant_types=["authorization_code", "refresh_token"], - response_types=["code"], - client_name="Dify", - client_uri="https://github.com/langgenius/dify", - ) - - def client_information(self) -> Optional[OAuthClientInformation]: - """Loads information about this OAuth client.""" - client_information = self.mcp_provider.decrypted_credentials.get("client_information", {}) - if not client_information: - return None - return OAuthClientInformation.model_validate(client_information) - - def save_client_information(self, client_information: OAuthClientInformationFull): - """Saves client information after dynamic registration.""" - MCPToolManageService.update_mcp_provider_credentials( - self.mcp_provider, - {"client_information": client_information.model_dump()}, - ) - - def tokens(self) -> Optional[OAuthTokens]: - """Loads any existing OAuth tokens for the current session.""" - credentials = self.mcp_provider.decrypted_credentials - if not credentials: - return None - return OAuthTokens( - access_token=credentials.get("access_token", ""), - token_type=credentials.get("token_type", "Bearer"), - expires_in=int(credentials.get("expires_in", "3600") or 3600), - refresh_token=credentials.get("refresh_token", ""), - ) - - def save_tokens(self, tokens: OAuthTokens): - """Stores new OAuth tokens for the current session.""" - # update mcp provider credentials - token_dict = tokens.model_dump() - MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, token_dict, authed=True) - - def save_code_verifier(self, code_verifier: str): - """Saves a PKCE code verifier for the current session.""" - MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, {"code_verifier": code_verifier}) - - def code_verifier(self) -> str: - """Loads the PKCE code verifier for the current session.""" - # get code verifier from mcp provider credentials - return str(self.mcp_provider.decrypted_credentials.get("code_verifier", "")) diff --git a/api/core/mcp/auth_client.py b/api/core/mcp/auth_client.py new file mode 100644 index 0000000000..d09ed7c9fd --- /dev/null +++ b/api/core/mcp/auth_client.py @@ -0,0 +1,204 @@ +""" +MCP Client with Authentication Retry Support + +This module provides a wrapper around MCPClient that automatically handles +authentication failures and retries operations after refreshing tokens. +""" + +import logging +from collections.abc import Callable +from types import TracebackType +from typing import Any, Optional + +from sqlalchemy.orm import Session + +from core.entities.mcp_provider import MCPProviderEntity +from core.mcp.error import MCPAuthError +from core.mcp.mcp_client import MCPClient +from core.mcp.types import CallToolResult, Tool +from extensions.ext_database import db + +logger = logging.getLogger(__name__) + + +class MCPClientWithAuthRetry: + """ + A wrapper around MCPClient that provides automatic authentication retry. + + This class intercepts MCPAuthError exceptions and attempts to refresh + authentication before retrying the failed operation. + """ + + def __init__( + self, + server_url: str, + headers: dict[str, str] | None = None, + timeout: float | None = None, + sse_read_timeout: float | None = None, + provider_entity: MCPProviderEntity | None = None, + auth_callback: Callable[[MCPProviderEntity, Optional[str]], dict[str, str]] | None = None, + authorization_code: Optional[str] = None, + by_server_id: bool = False, + ): + """ + Initialize the MCP client with auth retry capability. + + Args: + server_url: The MCP server URL + headers: Optional headers for requests + timeout: Request timeout + sse_read_timeout: SSE read timeout + provider_entity: Provider entity for authentication + auth_callback: Authentication callback function + authorization_code: Optional authorization code for initial auth + """ + self.server_url = server_url + self.headers = headers or {} + self.timeout = timeout + self.sse_read_timeout = sse_read_timeout + self.provider_entity = provider_entity + self.auth_callback = auth_callback + self.authorization_code = authorization_code + self._has_retried = False + self._client: MCPClient | None = None + self.by_server_id = by_server_id + + def _create_client(self) -> MCPClient: + """Create a new MCPClient instance with current headers.""" + return MCPClient( + server_url=self.server_url, + headers=self.headers, + timeout=self.timeout, + sse_read_timeout=self.sse_read_timeout, + ) + + def _handle_auth_error(self, error: MCPAuthError) -> None: + """ + Handle authentication error by refreshing tokens. + + Args: + error: The authentication error + + Raises: + MCPAuthError: If authentication fails or max retries reached + """ + from services.tools.mcp_oauth_service import MCPOAuthService + + if not self.provider_entity or not self.auth_callback: + raise error + + if self._has_retried: + raise error + + self._has_retried = True + + try: + # Perform authentication + self.auth_callback(self.provider_entity, self.authorization_code) + + # Retrieve new tokens + with Session(db.engine) as session: + oauth_service = MCPOAuthService(session=session) + self.provider_entity = oauth_service.get_provider_entity( + self.provider_entity.id, self.provider_entity.tenant_id, by_server_id=self.by_server_id + ) + token = self.provider_entity.retrieve_tokens() + if not token: + raise MCPAuthError("Authentication failed - no token received") + + # Update headers with new token + self.headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}" + + # Clear authorization code after first use + self.authorization_code = None + + except Exception as e: + logger.exception("Authentication retry failed") + raise MCPAuthError(f"Authentication retry failed: {e}") from e + + def _execute_with_retry(self, func: Callable[..., Any], *args, **kwargs) -> Any: + """ + Execute a function with authentication retry logic. + + Args: + func: The function to execute + *args: Positional arguments for the function + **kwargs: Keyword arguments for the function + + Returns: + The result of the function call + + Raises: + MCPAuthError: If authentication fails after retries + Any other exceptions from the function + """ + try: + return func(*args, **kwargs) + except MCPAuthError as e: + self._handle_auth_error(e) + # Recreate client with new headers + if self._client: + self._client.cleanup() + self._client = self._create_client() + self._client.__enter__() + return func(*args, **kwargs) + finally: + # Reset retry flag after operation completes + self._has_retried = False + + def __enter__(self): + """Enter the context manager.""" + self._client = self._create_client() + + # Try to initialize with retry + def initialize(): + if self._client is None: + raise ValueError("Client not created") + self._client.__enter__() + return self + + return self._execute_with_retry(initialize) + + def __exit__(self, exc_type: type | None, exc_value: BaseException | None, traceback: TracebackType | None): + """Exit the context manager.""" + if self._client: + self._client.__exit__(exc_type, exc_value, traceback) + self._client = None + + def list_tools(self) -> list[Tool]: + """ + List available tools from the MCP server. + + Returns: + List of available tools + + Raises: + MCPAuthError: If authentication fails after retries + """ + if not self._client: + raise ValueError("Client not initialized. Use within a context manager.") + return self._execute_with_retry(self._client.list_tools) + + def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult: + """ + Invoke a tool on the MCP server. + + Args: + tool_name: Name of the tool to invoke + tool_args: Arguments for the tool + + Returns: + Result of the tool invocation + + Raises: + MCPAuthError: If authentication fails after retries + """ + if not self._client: + raise ValueError("Client not initialized. Use within a context manager.") + return self._execute_with_retry(self._client.invoke_tool, tool_name, tool_args) + + def cleanup(self): + """Clean up resources.""" + if self._client: + self._client.cleanup() + self._client = None diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index ca3be26ff9..27fe70bfc4 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -46,7 +46,6 @@ class ToolProviderApiEntity(BaseModel): timeout: Optional[float] = Field(default=30.0, description="The timeout of the MCP tool") sse_read_timeout: Optional[float] = Field(default=300.0, description="The SSE read timeout of the MCP tool") masked_headers: Optional[dict[str, str]] = Field(default=None, description="The masked headers of the MCP tool") - original_headers: Optional[dict[str, str]] = Field(default=None, description="The original headers of the MCP tool") @field_validator("tools", mode="before") @classmethod @@ -72,7 +71,6 @@ class ToolProviderApiEntity(BaseModel): optional_fields.update(self.optional_field("timeout", self.timeout)) optional_fields.update(self.optional_field("sse_read_timeout", self.sse_read_timeout)) optional_fields.update(self.optional_field("masked_headers", self.masked_headers)) - optional_fields.update(self.optional_field("original_headers", self.original_headers)) return { "id": self.id, "author": self.author, diff --git a/api/core/tools/mcp_tool/provider.py b/api/core/tools/mcp_tool/provider.py index 1c1b4fe5e6..bd8bc73e63 100644 --- a/api/core/tools/mcp_tool/provider.py +++ b/api/core/tools/mcp_tool/provider.py @@ -1,6 +1,6 @@ -import json from typing import Any, Optional, Self +from core.entities.mcp_provider import MCPProviderEntity from core.mcp.types import Tool as RemoteMCPTool from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_runtime import ToolRuntime @@ -52,18 +52,28 @@ class MCPToolProviderController(ToolProviderController): """ from db provider """ - tools = [] - tools_data = json.loads(db_provider.tools) - remote_mcp_tools = [RemoteMCPTool(**tool) for tool in tools_data] - user = db_provider.load_user() + # Convert to entity first + provider_entity = db_provider.to_entity() + return cls.from_entity(provider_entity) + + @classmethod + def from_entity(cls, entity: MCPProviderEntity) -> Self: + """ + create a MCPToolProviderController from a MCPProviderEntity + """ + try: + remote_mcp_tools = [RemoteMCPTool(**tool) for tool in entity.tools] + except Exception: + remote_mcp_tools = [] + tools = [ ToolEntity( identity=ToolIdentity( - author=user.name if user else "Anonymous", + author="Anonymous", # Tool level author is not stored name=remote_mcp_tool.name, label=I18nObject(en_US=remote_mcp_tool.name, zh_Hans=remote_mcp_tool.name), - provider=db_provider.server_identifier, - icon=db_provider.icon, + provider=entity.provider_id, + icon=entity.icon if isinstance(entity.icon, str) else "", ), parameters=ToolTransformService.convert_mcp_schema_to_parameter(remote_mcp_tool.inputSchema), description=ToolDescription( @@ -81,22 +91,22 @@ class MCPToolProviderController(ToolProviderController): return cls( entity=ToolProviderEntityWithPlugin( identity=ToolProviderIdentity( - author=user.name if user else "Anonymous", - name=db_provider.name, - label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name), + author="Anonymous", # Provider level author is not stored in entity + name=entity.name, + label=I18nObject(en_US=entity.name, zh_Hans=entity.name), description=I18nObject(en_US="", zh_Hans=""), - icon=db_provider.icon, + icon=entity.icon if isinstance(entity.icon, str) else "", ), plugin_id=None, credentials_schema=[], tools=tools, ), - provider_id=db_provider.server_identifier or "", - tenant_id=db_provider.tenant_id or "", - server_url=db_provider.decrypted_server_url, - headers=db_provider.decrypted_headers or {}, - timeout=db_provider.timeout, - sse_read_timeout=db_provider.sse_read_timeout, + provider_id=entity.provider_id, + tenant_id=entity.tenant_id, + server_url=entity.server_url, + headers=entity.headers, + timeout=entity.timeout, + sse_read_timeout=entity.sse_read_timeout, ) def _validate_credentials(self, user_id: str, credentials: dict[str, Any]): diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index 159ec9b02e..d29c931d5d 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -4,8 +4,8 @@ from collections.abc import Generator from typing import Any, Optional from core.mcp.auth.auth_flow import auth -from core.mcp.error import MCPAuthError, MCPConnectionError -from core.mcp.mcp_client import MCPClient +from core.mcp.auth_client import MCPClientWithAuthRetry +from core.mcp.error import MCPConnectionError from core.mcp.types import CallToolResult, ImageContent, TextContent from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime @@ -118,59 +118,37 @@ class MCPTool(Tool): headers = self.headers.copy() if self.headers else {} tool_parameters = self._handle_none_parameter(tool_parameters) - # Initialize auth provider - from core.mcp.auth.auth_provider import OAuthClientProvider + # Get provider entity to access tokens + from sqlalchemy.orm import Session - provider = None + from extensions.ext_database import db + from services.tools.mcp_oauth_service import MCPOAuthService try: - provider = OAuthClientProvider(self.provider_id, self.tenant_id, for_list=False) - except Exception as e: - # If provider initialization fails, continue without auth + with Session(db.engine) as session: + oauth_service = MCPOAuthService(session=session) + provider_entity = oauth_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True) + + # Try to get existing token and add to headers + tokens = provider_entity.retrieve_tokens() + if tokens and tokens.access_token: + headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}" + except Exception: + # If provider retrieval or token fails, continue without auth pass - # Try to get existing token and add to headers - if provider: - try: - token = provider.tokens() - if token: - headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}" - except Exception: - # If token retrieval fails, continue without auth header - pass - - # Define a helper function to invoke the tool - def _invoke_with_client(client_headers: dict[str, str]) -> CallToolResult: - with MCPClient( - self.server_url, - headers=client_headers, + # Use MCPClientWithAuthRetry to handle authentication automatically + try: + with MCPClientWithAuthRetry( + server_url=provider_entity.decrypt_server_url() if provider_entity else self.server_url, + headers=headers, timeout=self.timeout, sse_read_timeout=self.sse_read_timeout, + provider_entity=provider_entity, + auth_callback=auth, + by_server_id=True, ) as mcp_client: return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters) - - try: - # First attempt with current headers - return _invoke_with_client(headers) - except MCPAuthError as e: - # Authentication required - try to authenticate - if not provider: - raise ToolInvokeError("Authentication required but no auth provider available") from e - - try: - # Perform authentication flow - auth(provider, self.server_url, None, None, False) - token = provider.tokens() - if not token: - raise ToolInvokeError("Authentication failed - no token received") - - # Update headers with new token while preserving other headers - headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}" - - # Retry with authenticated headers - return _invoke_with_client(headers) - except MCPAuthError as auth_error: - raise ToolInvokeError("Authentication failed") from auth_error except MCPConnectionError as e: raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e except Exception as e: diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index b29da3f0ba..afd8c434fd 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -27,6 +27,7 @@ from core.tools.plugin_tool.tool import PluginTool from core.tools.utils.uuid_utils import is_valid_uuid from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.workflow.entities.variable_pool import VariablePool +from extensions.ext_database import db from services.enterprise.plugin_manager_service import PluginCredentialType from services.tools.mcp_tools_manage_service import MCPToolManageService @@ -59,8 +60,7 @@ from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.configuration import ToolParameterConfigurationManager from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter from core.tools.workflow_as_tool.tool import WorkflowTool -from extensions.ext_database import db -from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider +from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider from services.tools.tools_transform_service import ToolTransformService logger = logging.getLogger(__name__) @@ -715,7 +715,9 @@ class ToolManager: ) result_providers[f"workflow_provider.{user_provider.name}"] = user_provider if "mcp" in filters: - mcp_providers = MCPToolManageService.retrieve_mcp_tools(tenant_id, for_list=True) + with Session(db.engine) as session: + mcp_service = MCPToolManageService(session=session) + mcp_providers = mcp_service.list_providers(tenant_id=tenant_id, for_list=True) for mcp_provider in mcp_providers: result_providers[f"mcp_provider.{mcp_provider.name}"] = mcp_provider @@ -770,17 +772,12 @@ class ToolManager: :return: the provider controller, the credentials """ - provider: MCPToolProvider | None = ( - db.session.query(MCPToolProvider) - .where( - MCPToolProvider.server_identifier == provider_id, - MCPToolProvider.tenant_id == tenant_id, - ) - .first() - ) - - if provider is None: - raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found") + with Session(db.engine) as session: + mcp_service = MCPToolManageService(session=session) + try: + provider = mcp_service.get_provider_by_server_identifier(provider_id, tenant_id) + except ValueError: + raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found") controller = MCPToolProviderController.from_db(provider) @@ -918,16 +915,13 @@ class ToolManager: @classmethod def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict[str, str] | str: try: - mcp_provider: MCPToolProvider | None = ( - db.session.query(MCPToolProvider) - .where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == provider_id) - .first() - ) - - if mcp_provider is None: - raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found") - - return mcp_provider.provider_icon + with Session(db.engine) as session: + mcp_service = MCPToolManageService(session=session) + try: + mcp_provider = mcp_service.get_provider_by_server_identifier(provider_id, tenant_id) + return mcp_provider.provider_icon + except ValueError: + raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found") except Exception: return {"background": "#252525", "content": "\ud83d\ude01"} diff --git a/api/models/tools.py b/api/models/tools.py index 96ad76eae5..141393dc8e 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -1,16 +1,12 @@ import json from datetime import datetime -from typing import Any, Optional, cast -from urllib.parse import urlparse +from typing import TYPE_CHECKING, Any, Optional, cast import sqlalchemy as sa from deprecated import deprecated from sqlalchemy import ForeignKey, String, func from sqlalchemy.orm import Mapped, mapped_column -from core.file import helpers as file_helpers -from core.helper import encrypter -from core.mcp.types import Tool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration @@ -20,6 +16,9 @@ from .engine import db from .model import Account, App, Tenant from .types import StringUUID +if TYPE_CHECKING: + from core.entities.mcp_provider import MCPProviderEntity + # system level tool oauth client params (client_id, client_secret, etc.) class ToolOAuthSystemClient(TypeBase): @@ -286,119 +285,34 @@ class MCPToolProvider(Base): def load_user(self) -> Account | None: return db.session.query(Account).where(Account.id == self.user_id).first() - @property - def tenant(self) -> Tenant | None: - return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() - @property def credentials(self) -> dict[str, Any]: try: - return cast(dict[str, Any], json.loads(self.encrypted_credentials)) or {} + return json.loads(self.encrypted_credentials) except Exception: return {} @property - def mcp_tools(self) -> list[Tool]: - return [Tool(**tool) for tool in json.loads(self.tools)] - - @property - def provider_icon(self) -> dict[str, str] | str: + def headers(self) -> dict[str, Any]: + if self.encrypted_headers is None: + return {} try: - return cast(dict[str, str], json.loads(self.icon)) - except json.JSONDecodeError: - return file_helpers.get_signed_file_url(self.icon) - - @property - def decrypted_server_url(self) -> str: - return encrypter.decrypt_token(self.tenant_id, self.server_url) - - @property - def decrypted_headers(self) -> dict[str, Any]: - """Get decrypted headers for MCP server requests.""" - from core.entities.provider_entities import BasicProviderConfig - from core.helper.provider_cache import NoOpProviderCredentialCache - from core.tools.utils.encryption import create_provider_encrypter - - try: - if not self.encrypted_headers: - return {} - - headers_data = json.loads(self.encrypted_headers) - - # Create dynamic config for all headers as SECRET_INPUT - config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers_data] - - encrypter_instance, _ = create_provider_encrypter( - tenant_id=self.tenant_id, - config=config, - cache=NoOpProviderCredentialCache(), - ) - - result = encrypter_instance.decrypt(headers_data) - return result + return json.loads(self.encrypted_headers) except Exception: return {} @property - def masked_headers(self) -> dict[str, Any]: - """Get masked headers for frontend display.""" - from core.entities.provider_entities import BasicProviderConfig - from core.helper.provider_cache import NoOpProviderCredentialCache - from core.tools.utils.encryption import create_provider_encrypter - + def tool_dict(self) -> list[dict[str, Any]]: try: - if not self.encrypted_headers: - return {} + return json.loads(self.tools) if self.tools else [] + except (json.JSONDecodeError, TypeError): + return [] - headers_data = json.loads(self.encrypted_headers) + def to_entity(self) -> "MCPProviderEntity": + """Convert to domain entity""" + from core.entities.mcp_provider import MCPProviderEntity - # Create dynamic config for all headers as SECRET_INPUT - config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers_data] - - encrypter_instance, _ = create_provider_encrypter( - tenant_id=self.tenant_id, - config=config, - cache=NoOpProviderCredentialCache(), - ) - - # First decrypt, then mask - decrypted_headers = encrypter_instance.decrypt(headers_data) - result = encrypter_instance.mask_tool_credentials(decrypted_headers) - return result - except Exception: - return {} - - @property - def masked_server_url(self) -> str: - def mask_url(url: str, mask_char: str = "*") -> str: - """ - mask the url to a simple string - """ - parsed = urlparse(url) - base_url = f"{parsed.scheme}://{parsed.netloc}" - - if parsed.path and parsed.path != "/": - return f"{base_url}/{mask_char * 6}" - else: - return base_url - - return mask_url(self.decrypted_server_url) - - @property - def decrypted_credentials(self) -> dict[str, Any]: - from core.helper.provider_cache import NoOpProviderCredentialCache - from core.tools.mcp_tool.provider import MCPToolProviderController - from core.tools.utils.encryption import create_provider_encrypter - - provider_controller = MCPToolProviderController.from_db(self) - - encrypter, _ = create_provider_encrypter( - tenant_id=self.tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - cache=NoOpProviderCredentialCache(), - ) - - return encrypter.decrypt(self.credentials) + return MCPProviderEntity.from_db_model(self) class ToolModelInvoke(Base): diff --git a/api/services/tools/mcp_oauth_service.py b/api/services/tools/mcp_oauth_service.py new file mode 100644 index 0000000000..5f9904b110 --- /dev/null +++ b/api/services/tools/mcp_oauth_service.py @@ -0,0 +1,53 @@ +""" +MCP OAuth Service - handles OAuth-related database operations +""" + +from sqlalchemy.orm import Session + +from core.entities.mcp_provider import MCPProviderEntity +from core.mcp.types import OAuthClientInformationFull, OAuthTokens +from services.tools.mcp_tools_manage_service import MCPToolManageService + + +class MCPOAuthService: + """Service for handling MCP OAuth operations""" + + def __init__(self, session: Session): + self._session = session + self._mcp_service = MCPToolManageService(session=session) + + def get_provider_entity(self, provider_id: str, tenant_id: str, by_server_id: bool = False) -> MCPProviderEntity: + """Get provider entity by ID""" + if by_server_id: + db_provider = self._mcp_service.get_provider_by_server_identifier(provider_id, tenant_id) + else: + db_provider = self._mcp_service.get_provider_by_id(provider_id, tenant_id) + return db_provider.to_entity() + + def save_client_information( + self, provider_id: str, tenant_id: str, client_information: OAuthClientInformationFull + ) -> None: + """Save OAuth client information""" + db_provider = self._mcp_service.get_provider_by_id(provider_id, tenant_id) + self._mcp_service.update_provider_credentials( + provider=db_provider, + credentials={"client_information": client_information.model_dump()}, + ) + + def save_tokens(self, provider_id: str, tenant_id: str, tokens: OAuthTokens, authed: bool = True) -> None: + """Save OAuth tokens""" + db_provider = self._mcp_service.get_provider_by_id(provider_id, tenant_id) + token_dict = tokens.model_dump() + self._mcp_service.update_provider_credentials(provider=db_provider, credentials=token_dict, authed=authed) + + def save_code_verifier(self, provider_id: str, tenant_id: str, code_verifier: str) -> None: + """Save PKCE code verifier""" + db_provider = self._mcp_service.get_provider_by_id(provider_id, tenant_id) + self._mcp_service.update_provider_credentials( + provider=db_provider, credentials={"code_verifier": code_verifier} + ) + + def clear_credentials(self, provider_id: str, tenant_id: str) -> None: + """Clear provider credentials""" + db_provider = self._mcp_service.get_provider_by_id(provider_id, tenant_id) + self._mcp_service.clear_provider_credentials(provider=db_provider) diff --git a/api/services/tools/mcp_tools_manage_service.py b/api/services/tools/mcp_tools_manage_service.py index 8088cb009c..ff2ffcd22b 100644 --- a/api/services/tools/mcp_tools_manage_service.py +++ b/api/services/tools/mcp_tools_manage_service.py @@ -1,24 +1,27 @@ import hashlib import json +import logging +from collections.abc import Callable from datetime import datetime -from typing import Any +from typing import Any, Optional -from sqlalchemy import or_ +from sqlalchemy import or_, select from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session +from core.entities.mcp_provider import MCPProviderEntity from core.helper import encrypter from core.helper.provider_cache import NoOpProviderCredentialCache +from core.mcp.auth_client import MCPClientWithAuthRetry from core.mcp.error import MCPAuthError, MCPError -from core.mcp.mcp_client import MCPClient from core.mcp.types import OAuthTokens from core.tools.entities.api_entities import ToolProviderApiEntity -from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolProviderType from core.tools.utils.encryption import ProviderConfigEncrypter -from extensions.ext_database import db from models.tools import MCPToolProvider from services.tools.tools_transform_service import ToolTransformService +logger = logging.getLogger(__name__) + UNCHANGED_SERVER_URL_PLACEHOLDER = "[__HIDDEN__]" @@ -27,8 +30,10 @@ class MCPToolManageService: Service class for managing mcp tools. """ - @staticmethod - def _encrypt_headers(headers: dict[str, str], tenant_id: str) -> dict[str, str]: + def __init__(self, session: Session): + self._session = session + + def _encrypt_headers(self, headers: dict[str, str], tenant_id: str) -> dict[str, str]: """ Encrypt headers using ProviderConfigEncrypter with all headers as SECRET_INPUT. @@ -57,48 +62,53 @@ class MCPToolManageService: return encrypter_instance.encrypt(headers) - @staticmethod - def _retrieve_remote_mcp_tools(server_url: str, headers: dict[str, str], timeout: float, sse_read_timeout: float): - with MCPClient( + def _retrieve_remote_mcp_tools( + self, + server_url: str, + headers: dict[str, str], + provider_entity: MCPProviderEntity, + auth_callback: Callable[[MCPProviderEntity, Optional[str]], dict[str, str]], + ): + """Retrieve tools from remote MCP server""" + with MCPClientWithAuthRetry( server_url, headers=headers, - timeout=timeout, - sse_read_timeout=sse_read_timeout, + timeout=provider_entity.timeout, + sse_read_timeout=provider_entity.sse_read_timeout, + provider_entity=provider_entity, + auth_callback=auth_callback, ) as mcp_client: tools = mcp_client.list_tools() return tools - @staticmethod - def _process_headers(headers: dict[str, str], tokens: OAuthTokens | None = None): - headers = headers or {} + def _process_headers(self, headers: dict[str, str], tokens: OAuthTokens | None = None) -> dict[str, str]: + """Process headers and add OAuth token if available""" + headers = headers.copy() if headers else {} if tokens: headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}" return headers - @staticmethod - def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider: - res = ( - db.session.query(MCPToolProvider) - .where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id) - .first() - ) - if not res: + def get_provider_by_id(self, provider_id: str, tenant_id: str) -> MCPToolProvider: + """Get MCP provider by ID""" + stmt = select(MCPToolProvider).where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id) + provider = self._session.scalar(stmt) + if not provider: raise ValueError("MCP tool not found") - return res + return provider - @staticmethod - def get_mcp_provider_by_server_identifier(server_identifier: str, tenant_id: str) -> MCPToolProvider: - res = ( - db.session.query(MCPToolProvider) - .where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier) - .first() + def get_provider_by_server_identifier(self, server_identifier: str, tenant_id: str) -> MCPToolProvider: + """Get MCP provider by server identifier""" + stmt = select(MCPToolProvider).where( + MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier ) - if not res: + provider = self._session.scalar(stmt) + if not provider: raise ValueError("MCP tool not found") - return res + return provider - @staticmethod - def create_mcp_provider( + def create_provider( + self, + *, tenant_id: str, name: str, server_url: str, @@ -111,19 +121,20 @@ class MCPToolManageService: sse_read_timeout: float, headers: dict[str, str] | None = None, ) -> ToolProviderApiEntity: + """Create a new MCP provider""" server_url_hash = hashlib.sha256(server_url.encode()).hexdigest() - existing_provider = ( - db.session.query(MCPToolProvider) - .where( - MCPToolProvider.tenant_id == tenant_id, - or_( - MCPToolProvider.name == name, - MCPToolProvider.server_url_hash == server_url_hash, - MCPToolProvider.server_identifier == server_identifier, - ), - ) - .first() + + # Check for existing provider + stmt = select(MCPToolProvider).where( + MCPToolProvider.tenant_id == tenant_id, + or_( + MCPToolProvider.name == name, + MCPToolProvider.server_url_hash == server_url_hash, + MCPToolProvider.server_identifier == server_identifier, + ), ) + existing_provider = self._session.scalar(stmt) + if existing_provider: if existing_provider.name == name: raise ValueError(f"MCP tool {name} already exists") @@ -131,13 +142,17 @@ class MCPToolManageService: raise ValueError(f"MCP tool {server_url} already exists") if existing_provider.server_identifier == server_identifier: raise ValueError(f"MCP tool {server_identifier} already exists") + + # Encrypt server URL encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url) + # Encrypt headers encrypted_headers = None if headers: - encrypted_headers_dict = MCPToolManageService._encrypt_headers(headers, tenant_id) + encrypted_headers_dict = self._encrypt_headers(headers, tenant_id) encrypted_headers = json.dumps(encrypted_headers_dict) + # Create provider mcp_tool = MCPToolProvider( tenant_id=tenant_id, name=name, @@ -152,91 +167,68 @@ class MCPToolManageService: sse_read_timeout=sse_read_timeout, encrypted_headers=encrypted_headers, ) - db.session.add(mcp_tool) - db.session.commit() + + self._session.add(mcp_tool) + self._session.commit() + return ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True) - @staticmethod - def retrieve_mcp_tools(tenant_id: str, for_list: bool = False) -> list[ToolProviderApiEntity]: - mcp_providers = ( - db.session.query(MCPToolProvider) - .where(MCPToolProvider.tenant_id == tenant_id) - .order_by(MCPToolProvider.name) - .all() - ) + def list_providers(self, *, tenant_id: str, for_list: bool = False) -> list[ToolProviderApiEntity]: + """List all MCP providers for a tenant""" + stmt = select(MCPToolProvider).where(MCPToolProvider.tenant_id == tenant_id).order_by(MCPToolProvider.name) + + mcp_providers = self._session.scalars(stmt).all() + return [ - ToolTransformService.mcp_provider_to_user_provider(mcp_provider, for_list=for_list) - for mcp_provider in mcp_providers + ToolTransformService.mcp_provider_to_user_provider(provider, for_list=for_list) + for provider in mcp_providers ] - @classmethod - def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str) -> ToolProviderApiEntity: + def list_provider_tools(self, *, tenant_id: str, provider_id: str) -> ToolProviderApiEntity: + """List tools from remote MCP server""" from core.mcp.auth.auth_flow import auth - from core.mcp.auth.auth_provider import OAuthClientProvider - mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) - server_url = mcp_provider.decrypted_server_url - authed = mcp_provider.authed - headers = mcp_provider.decrypted_headers - timeout = mcp_provider.timeout - sse_read_timeout = mcp_provider.sse_read_timeout + # Load provider and convert to entity + db_provider = self.get_provider_by_id(provider_id, tenant_id) + provider_entity = db_provider.to_entity() # Handle authentication headers if authed - if not authed: + if not provider_entity.authed: raise ValueError("Please auth the tool first") - provider = OAuthClientProvider(provider_id, tenant_id, for_list=True) - tokens = provider.tokens() - headers = cls._process_headers(headers, tokens) - + tokens = provider_entity.retrieve_tokens() + headers = self._process_headers(provider_entity.headers, tokens) + server_url = provider_entity.decrypt_server_url() try: - tools = cls._retrieve_remote_mcp_tools(server_url, headers, timeout, sse_read_timeout) - except MCPAuthError: - try: - auth(provider, server_url, None, None, False) - tokens = provider.tokens() - re_authed_headers = cls._process_headers(headers, tokens) - tools = cls._retrieve_remote_mcp_tools(server_url, re_authed_headers, timeout, sse_read_timeout) - except Exception: - raise ValueError("Please auth the tool first") + tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity, auth) except MCPError as e: raise ValueError(f"Failed to connect to MCP server: {e}") - try: - mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) - mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools]) - mcp_provider.authed = True - mcp_provider.updated_at = datetime.now() - db.session.commit() - except Exception: - db.session.rollback() - raise + # Update database record with new tools + db_provider.tools = json.dumps([tool.model_dump() for tool in tools]) + db_provider.authed = True + db_provider.updated_at = datetime.now() + self._session.commit() - user = mcp_provider.load_user() - return ToolProviderApiEntity( - id=mcp_provider.id, - name=mcp_provider.name, - tools=ToolTransformService.mcp_tool_to_user_tool(mcp_provider, tools), - type=ToolProviderType.MCP, - icon=mcp_provider.icon, - author=user.name if user else "Anonymous", - server_url=mcp_provider.masked_server_url, - updated_at=int(mcp_provider.updated_at.timestamp()), - description=I18nObject(en_US="", zh_Hans=""), - label=I18nObject(en_US=mcp_provider.name, zh_Hans=mcp_provider.name), - plugin_unique_identifier=mcp_provider.server_identifier, + # Create API response using entity + user = db_provider.load_user() + response = provider_entity.to_api_response( + user_name=user.name if user else None, ) + response["tools"] = ToolTransformService.mcp_tool_to_user_tool(db_provider, tools) + response["plugin_unique_identifier"] = provider_entity.provider_id - @classmethod - def delete_mcp_tool(cls, tenant_id: str, provider_id: str): - mcp_tool = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) + return ToolProviderApiEntity(**response) - db.session.delete(mcp_tool) - db.session.commit() + def delete_provider(self, *, tenant_id: str, provider_id: str) -> None: + """Delete an MCP provider""" + mcp_tool = self.get_provider_by_id(provider_id, tenant_id) + self._session.delete(mcp_tool) + self._session.commit() - @classmethod - def update_mcp_provider( - cls, + def update_provider( + self, + *, tenant_id: str, provider_id: str, name: str, @@ -248,21 +240,27 @@ class MCPToolManageService: timeout: float | None = None, sse_read_timeout: float | None = None, headers: dict[str, str] | None = None, - ): - mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) + ) -> None: + """Update an MCP provider""" + mcp_provider = self.get_provider_by_id(provider_id, tenant_id) reconnect_result = None encrypted_server_url = None server_url_hash = None + # Handle server URL update if UNCHANGED_SERVER_URL_PLACEHOLDER not in server_url: encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url) server_url_hash = hashlib.sha256(server_url.encode()).hexdigest() if server_url_hash != mcp_provider.server_url_hash: - reconnect_result = cls._re_connect_mcp_provider(server_url, provider_id, tenant_id) + reconnect_result = self._reconnect_provider( + server_url=server_url, + provider=mcp_provider, + ) try: + # Update basic fields mcp_provider.updated_at = datetime.now() mcp_provider.name = name mcp_provider.icon = ( @@ -270,6 +268,7 @@ class MCPToolManageService: ) mcp_provider.server_identifier = server_identifier + # Update server URL if changed if encrypted_server_url is not None and server_url_hash is not None: mcp_provider.server_url = encrypted_server_url mcp_provider.server_url_hash = server_url_hash @@ -279,6 +278,7 @@ class MCPToolManageService: mcp_provider.tools = reconnect_result["tools"] mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"] + # Update optional fields if timeout is not None: mcp_provider.timeout = timeout if sse_read_timeout is not None: @@ -286,13 +286,15 @@ class MCPToolManageService: if headers is not None: # Encrypt headers if headers: - encrypted_headers_dict = MCPToolManageService._encrypt_headers(headers, tenant_id) + encrypted_headers_dict = self._encrypt_headers(headers, tenant_id) mcp_provider.encrypted_headers = json.dumps(encrypted_headers_dict) else: mcp_provider.encrypted_headers = None - db.session.commit() + + self._session.commit() + except IntegrityError as e: - db.session.rollback() + self._session.rollback() error_msg = str(e.orig) if "unique_mcp_provider_name" in error_msg: raise ValueError(f"MCP tool {name} already exists") @@ -302,54 +304,55 @@ class MCPToolManageService: raise ValueError(f"MCP tool {server_identifier} already exists") raise except Exception: - db.session.rollback() + self._session.rollback() raise - @classmethod - def update_mcp_provider_credentials( - cls, mcp_provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False - ): + def update_provider_credentials( + self, *, provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False + ) -> None: + """Update provider credentials""" from core.tools.mcp_tool.provider import MCPToolProviderController - provider_controller = MCPToolProviderController.from_db(mcp_provider) + provider_controller = MCPToolProviderController.from_db(provider) tool_configuration = ProviderConfigEncrypter( - tenant_id=mcp_provider.tenant_id, + tenant_id=provider.tenant_id, config=list(provider_controller.get_credentials_schema()), provider_config_cache=NoOpProviderCredentialCache(), ) - credentials = tool_configuration.encrypt(credentials) - mcp_provider.updated_at = datetime.now() - mcp_provider.encrypted_credentials = json.dumps({**mcp_provider.credentials, **credentials}) - mcp_provider.authed = authed + encrypted_credentials = tool_configuration.encrypt(credentials) + provider.updated_at = datetime.now() + provider.encrypted_credentials = json.dumps({**provider.credentials, **encrypted_credentials}) + provider.authed = authed if not authed: - mcp_provider.tools = "[]" - db.session.commit() + provider.tools = "[]" - @classmethod - def clear_mcp_provider_credentials( - cls, - mcp_provider: MCPToolProvider, - ): - mcp_provider.tools = "[]" - mcp_provider.encrypted_credentials = "{}" - mcp_provider.updated_at = datetime.now() - mcp_provider.authed = False - db.session.commit() + self._session.commit() - @classmethod - def _re_connect_mcp_provider(cls, server_url: str, provider_id: str, tenant_id: str) -> dict[str, Any]: - # Get the existing provider to access headers and timeout settings - mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) - headers = mcp_provider.decrypted_headers - timeout = mcp_provider.timeout - sse_read_timeout = mcp_provider.sse_read_timeout + def clear_provider_credentials(self, *, provider: MCPToolProvider) -> None: + """Clear provider credentials""" + provider.tools = "[]" + provider.encrypted_credentials = "{}" + provider.updated_at = datetime.now() + provider.authed = False + self._session.commit() + + def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> dict[str, Any]: + """Attempt to reconnect to MCP provider with new server URL""" + from core.mcp.auth.auth_flow import auth + + provider_entity = provider.to_entity() + headers = provider_entity.headers + timeout = provider_entity.timeout + sse_read_timeout = provider_entity.sse_read_timeout try: - with MCPClient( + with MCPClientWithAuthRetry( server_url, headers=headers, timeout=timeout, sse_read_timeout=sse_read_timeout, + provider_entity=provider_entity, + auth_callback=auth, ) as mcp_client: tools = mcp_client.list_tools() return { diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index f6d89dc262..1692e10889 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -221,27 +221,20 @@ class ToolTransformService: @staticmethod def mcp_provider_to_user_provider(db_provider: MCPToolProvider, for_list: bool = False) -> ToolProviderApiEntity: + # Convert to entity and use its API response method + provider_entity = db_provider.to_entity() user = db_provider.load_user() - return ToolProviderApiEntity( - id=db_provider.server_identifier if not for_list else db_provider.id, - author=user.name if user else "Anonymous", - name=db_provider.name, - icon=db_provider.provider_icon, - type=ToolProviderType.MCP, - is_team_authorization=db_provider.authed, - server_url=db_provider.masked_server_url, - tools=ToolTransformService.mcp_tool_to_user_tool( - db_provider, [MCPTool(**tool) for tool in json.loads(db_provider.tools)] - ), - updated_at=int(db_provider.updated_at.timestamp()), - label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name), - description=I18nObject(en_US="", zh_Hans=""), - server_identifier=db_provider.server_identifier, - timeout=db_provider.timeout, - sse_read_timeout=db_provider.sse_read_timeout, - masked_headers=db_provider.masked_headers, - original_headers=db_provider.decrypted_headers, + + response = provider_entity.to_api_response(user_name=user.name if user else None) + + # Add additional fields specific to the transform + response["id"] = db_provider.server_identifier if not for_list else db_provider.id + response["tools"] = ToolTransformService.mcp_tool_to_user_tool( + db_provider, [MCPTool(**tool) for tool in json.loads(db_provider.tools)] ) + response["server_identifier"] = db_provider.server_identifier + + return ToolProviderApiEntity(**response) @staticmethod def mcp_tool_to_user_tool(mcp_provider: MCPToolProvider, tools: list[MCPTool]) -> list[ToolApiEntity]: @@ -403,7 +396,7 @@ class ToolTransformService: ) @staticmethod - def convert_mcp_schema_to_parameter(schema: dict) -> list["ToolParameter"]: + def convert_mcp_schema_to_parameter(schema: dict[str, Any]) -> list["ToolParameter"]: """ Convert MCP JSON schema to tool parameters @@ -412,7 +405,7 @@ class ToolTransformService: """ def create_parameter( - name: str, description: str, param_type: str, required: bool, input_schema: dict | None = None + name: str, description: str, param_type: str, required: bool, input_schema: dict[str, Any] | None = None ) -> ToolParameter: """Create a ToolParameter instance with given attributes""" input_schema_dict: dict[str, Any] = {"input_schema": input_schema} if input_schema else {} @@ -427,7 +420,9 @@ class ToolTransformService: **input_schema_dict, ) - def process_properties(props: dict, required: list, prefix: str = "") -> list[ToolParameter]: + def process_properties( + props: dict[str, dict[str, Any]], required: list[str], prefix: str = "" + ) -> list[ToolParameter]: """Process properties recursively""" TYPE_MAPPING = {"integer": "number", "float": "number"} COMPLEX_TYPES = ["array", "object"]