diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 17a63952fd..094370f1cc 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -956,7 +956,7 @@ class ToolMCPAuthApi(Resource): with Session(db.engine) as session: service = MCPToolManageService(session=session) - db_provider = service.get_provider_by_id(provider_id, tenant_id) + db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id) if not db_provider: raise ValueError("provider not found") @@ -984,6 +984,7 @@ class ToolMCPAuthApi(Resource): else None, # Only use auth retry if no custom headers auth_callback=auth if not provider_entity.headers else None, authorization_code=args.get("authorization_code"), + mcp_service=service, ): service.update_provider_credentials( provider=db_provider, @@ -1007,7 +1008,7 @@ class ToolMCPDetailApi(Resource): user = current_user with Session(db.engine) as session: service = MCPToolManageService(session=session) - provider = service.get_provider_by_id(provider_id, user.current_tenant_id) + provider = service.get_provider(provider_id=provider_id, tenant_id=user.current_tenant_id) return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True)) @@ -1049,7 +1050,13 @@ class ToolMCPCallbackApi(Resource): args = parser.parse_args() state_key = args["state"] authorization_code = args["code"] - handle_callback(state_key, authorization_code) + + # Create service instance for handle_callback + with Session(db.engine) as session: + mcp_service = MCPToolManageService(session=session) + handle_callback(state_key, authorization_code, mcp_service) + session.commit() + return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") diff --git a/api/core/mcp/auth/auth_flow.py b/api/core/mcp/auth/auth_flow.py index 269f00494e..a56c6ef86e 100644 --- a/api/core/mcp/auth/auth_flow.py +++ b/api/core/mcp/auth/auth_flow.py @@ -4,12 +4,11 @@ import json import os import secrets import urllib.parse -from typing import Optional +from typing import TYPE_CHECKING, Optional from urllib.parse import urljoin, urlparse import httpx from pydantic import BaseModel, ValidationError -from sqlalchemy.orm import Session from core.entities.mcp_provider import MCPProviderEntity from core.mcp.types import ( @@ -20,9 +19,10 @@ from core.mcp.types import ( OAuthMetadata, OAuthTokens, ) -from extensions.ext_database import db from extensions.ext_redis import redis_client -from services.tools.mcp_oauth_service import MCPOAuthService + +if TYPE_CHECKING: + from services.tools.mcp_tools_manage_service import MCPToolManageService OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:" @@ -84,7 +84,7 @@ def _retrieve_redis_state(state_key: str) -> OAuthCallbackState: raise ValueError(f"Invalid state parameter: {str(e)}") -def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackState: +def handle_callback(state_key: str, authorization_code: str, mcp_service: "MCPToolManageService") -> OAuthCallbackState: """Handle the callback from the OAuth provider.""" # Retrieve state data from Redis (state is automatically deleted after retrieval) full_state_data = _retrieve_redis_state(state_key) @@ -99,10 +99,7 @@ def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackSta ) # Save tokens using the service layer - with Session(db.engine) as session: - oauth_service = MCPOAuthService(session=session) - oauth_service.save_tokens(full_state_data.provider_id, full_state_data.tenant_id, tokens) - session.commit() + mcp_service.save_oauth_data(full_state_data.provider_id, full_state_data.tenant_id, tokens.model_dump(), "tokens") return full_state_data @@ -304,6 +301,7 @@ def register_client( def auth( provider: MCPProviderEntity, + mcp_service: "MCPToolManageService", authorization_code: Optional[str] = None, state_param: Optional[str] = None, ) -> dict[str, str]: @@ -325,10 +323,9 @@ def auth( raise ValueError(f"Could not register OAuth client: {e}") # Save client information using service layer - with Session(db.engine) as session: - oauth_service = MCPOAuthService(session=session) - oauth_service.save_client_information(provider_id, tenant_id, full_information) - session.commit() + mcp_service.save_oauth_data( + provider_id, tenant_id, {"client_information": full_information.model_dump()}, "client_info" + ) client_information = full_information @@ -360,10 +357,7 @@ def auth( ) # Save tokens using service layer - with Session(db.engine) as session: - oauth_service = MCPOAuthService(session=session) - oauth_service.save_tokens(provider_id, tenant_id, tokens) - session.commit() + mcp_service.save_oauth_data(provider_id, tenant_id, tokens.model_dump(), "tokens") return {"result": "success"} @@ -377,10 +371,7 @@ def auth( ) # Save new tokens using service layer - with Session(db.engine) as session: - oauth_service = MCPOAuthService(session=session) - oauth_service.save_tokens(provider_id, tenant_id, new_tokens) - session.commit() + mcp_service.save_oauth_data(provider_id, tenant_id, new_tokens.model_dump(), "tokens") return {"result": "success"} except Exception as e: @@ -397,9 +388,6 @@ def auth( ) # Save code verifier using service layer - with Session(db.engine) as session: - oauth_service = MCPOAuthService(session=session) - oauth_service.save_code_verifier(provider_id, tenant_id, code_verifier) - session.commit() + mcp_service.save_oauth_data(provider_id, tenant_id, {"code_verifier": code_verifier}, "code_verifier") return {"authorization_url": authorization_url} diff --git a/api/core/mcp/auth_client.py b/api/core/mcp/auth_client.py index d09ed7c9fd..ede2518dde 100644 --- a/api/core/mcp/auth_client.py +++ b/api/core/mcp/auth_client.py @@ -8,15 +8,15 @@ authentication failures and retries operations after refreshing tokens. import logging from collections.abc import Callable from types import TracebackType -from typing import Any, Optional - -from sqlalchemy.orm import Session +from typing import TYPE_CHECKING, Any, Optional from core.entities.mcp_provider import MCPProviderEntity from core.mcp.error import MCPAuthError from core.mcp.mcp_client import MCPClient from core.mcp.types import CallToolResult, Tool -from extensions.ext_database import db + +if TYPE_CHECKING: + from services.tools.mcp_tools_manage_service import MCPToolManageService logger = logging.getLogger(__name__) @@ -36,9 +36,11 @@ class MCPClientWithAuthRetry: timeout: float | None = None, sse_read_timeout: float | None = None, provider_entity: MCPProviderEntity | None = None, - auth_callback: Callable[[MCPProviderEntity, Optional[str]], dict[str, str]] | None = None, + auth_callback: Callable[[MCPProviderEntity, "MCPToolManageService", Optional[str]], dict[str, str]] + | None = None, authorization_code: Optional[str] = None, by_server_id: bool = False, + mcp_service: Optional["MCPToolManageService"] = None, ): """ Initialize the MCP client with auth retry capability. @@ -62,6 +64,7 @@ class MCPClientWithAuthRetry: self._has_retried = False self._client: MCPClient | None = None self.by_server_id = by_server_id + self.mcp_service = mcp_service def _create_client(self) -> MCPClient: """Create a new MCPClient instance with current headers.""" @@ -82,11 +85,8 @@ class MCPClientWithAuthRetry: Raises: MCPAuthError: If authentication fails or max retries reached """ - from services.tools.mcp_oauth_service import MCPOAuthService - - if not self.provider_entity or not self.auth_callback: + if not self.provider_entity or not self.auth_callback or not self.mcp_service: raise error - if self._has_retried: raise error @@ -94,14 +94,12 @@ class MCPClientWithAuthRetry: try: # Perform authentication - self.auth_callback(self.provider_entity, self.authorization_code) + self.auth_callback(self.provider_entity, self.mcp_service, self.authorization_code) # Retrieve new tokens - with Session(db.engine) as session: - oauth_service = MCPOAuthService(session=session) - self.provider_entity = oauth_service.get_provider_entity( - self.provider_entity.id, self.provider_entity.tenant_id, by_server_id=self.by_server_id - ) + self.provider_entity = self.mcp_service.get_provider_entity( + self.provider_entity.id, self.provider_entity.tenant_id, by_server_id=self.by_server_id + ) token = self.provider_entity.retrieve_tokens() if not token: raise MCPAuthError("Authentication failed - no token received") diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index d29c931d5d..27274c859b 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -119,23 +119,39 @@ class MCPTool(Tool): tool_parameters = self._handle_none_parameter(tool_parameters) # Get provider entity to access tokens - from sqlalchemy.orm import Session + from typing import TYPE_CHECKING - from extensions.ext_database import db - from services.tools.mcp_oauth_service import MCPOAuthService + if TYPE_CHECKING: + pass + + # Get MCP service from invoke parameters or create new one + provider_entity = None + mcp_service = None + + # Check if mcp_service is passed in tool_parameters + if "_mcp_service" in tool_parameters: + mcp_service = tool_parameters.pop("_mcp_service") + else: + # Fallback to creating service with database session + from sqlalchemy.orm import Session + + from extensions.ext_database import db + from services.tools.mcp_tools_manage_service import MCPToolManageService - try: with Session(db.engine) as session: - oauth_service = MCPOAuthService(session=session) - provider_entity = oauth_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True) + mcp_service = MCPToolManageService(session=session) + + if mcp_service: + try: + provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True) # Try to get existing token and add to headers tokens = provider_entity.retrieve_tokens() if tokens and tokens.access_token: headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}" - except Exception: - # If provider retrieval or token fails, continue without auth - pass + except Exception: + # If provider retrieval or token fails, continue without auth + pass # Use MCPClientWithAuthRetry to handle authentication automatically try: @@ -145,8 +161,9 @@ class MCPTool(Tool): timeout=self.timeout, sse_read_timeout=self.sse_read_timeout, provider_entity=provider_entity, - auth_callback=auth, + auth_callback=auth if mcp_service else None, by_server_id=True, + mcp_service=mcp_service, ) as mcp_client: return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters) except MCPConnectionError as e: diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index afd8c434fd..a8f7267d35 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -775,7 +775,7 @@ class ToolManager: with Session(db.engine) as session: mcp_service = MCPToolManageService(session=session) try: - provider = mcp_service.get_provider_by_server_identifier(provider_id, tenant_id) + provider = mcp_service.get_provider(server_identifier=provider_id, tenant_id=tenant_id) except ValueError: raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found") @@ -918,7 +918,9 @@ class ToolManager: with Session(db.engine) as session: mcp_service = MCPToolManageService(session=session) try: - mcp_provider = mcp_service.get_provider_by_server_identifier(provider_id, tenant_id) + mcp_provider = mcp_service.get_provider_entity( + provider_id=provider_id, tenant_id=tenant_id, by_server_id=True + ) return mcp_provider.provider_icon except ValueError: raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found") diff --git a/api/services/tools/mcp_oauth_service.py b/api/services/tools/mcp_oauth_service.py deleted file mode 100644 index 5f9904b110..0000000000 --- a/api/services/tools/mcp_oauth_service.py +++ /dev/null @@ -1,53 +0,0 @@ -""" -MCP OAuth Service - handles OAuth-related database operations -""" - -from sqlalchemy.orm import Session - -from core.entities.mcp_provider import MCPProviderEntity -from core.mcp.types import OAuthClientInformationFull, OAuthTokens -from services.tools.mcp_tools_manage_service import MCPToolManageService - - -class MCPOAuthService: - """Service for handling MCP OAuth operations""" - - def __init__(self, session: Session): - self._session = session - self._mcp_service = MCPToolManageService(session=session) - - def get_provider_entity(self, provider_id: str, tenant_id: str, by_server_id: bool = False) -> MCPProviderEntity: - """Get provider entity by ID""" - if by_server_id: - db_provider = self._mcp_service.get_provider_by_server_identifier(provider_id, tenant_id) - else: - db_provider = self._mcp_service.get_provider_by_id(provider_id, tenant_id) - return db_provider.to_entity() - - def save_client_information( - self, provider_id: str, tenant_id: str, client_information: OAuthClientInformationFull - ) -> None: - """Save OAuth client information""" - db_provider = self._mcp_service.get_provider_by_id(provider_id, tenant_id) - self._mcp_service.update_provider_credentials( - provider=db_provider, - credentials={"client_information": client_information.model_dump()}, - ) - - def save_tokens(self, provider_id: str, tenant_id: str, tokens: OAuthTokens, authed: bool = True) -> None: - """Save OAuth tokens""" - db_provider = self._mcp_service.get_provider_by_id(provider_id, tenant_id) - token_dict = tokens.model_dump() - self._mcp_service.update_provider_credentials(provider=db_provider, credentials=token_dict, authed=authed) - - def save_code_verifier(self, provider_id: str, tenant_id: str, code_verifier: str) -> None: - """Save PKCE code verifier""" - db_provider = self._mcp_service.get_provider_by_id(provider_id, tenant_id) - self._mcp_service.update_provider_credentials( - provider=db_provider, credentials={"code_verifier": code_verifier} - ) - - def clear_credentials(self, provider_id: str, tenant_id: str) -> None: - """Clear provider credentials""" - db_provider = self._mcp_service.get_provider_by_id(provider_id, tenant_id) - self._mcp_service.clear_provider_credentials(provider=db_provider) diff --git a/api/services/tools/mcp_tools_manage_service.py b/api/services/tools/mcp_tools_manage_service.py index ff2ffcd22b..2a9da94c52 100644 --- a/api/services/tools/mcp_tools_manage_service.py +++ b/api/services/tools/mcp_tools_manage_service.py @@ -14,7 +14,6 @@ from core.helper import encrypter from core.helper.provider_cache import NoOpProviderCredentialCache from core.mcp.auth_client import MCPClientWithAuthRetry from core.mcp.error import MCPAuthError, MCPError -from core.mcp.types import OAuthTokens from core.tools.entities.api_entities import ToolProviderApiEntity from core.tools.utils.encryption import ProviderConfigEncrypter from models.tools import MCPToolProvider @@ -26,85 +25,51 @@ UNCHANGED_SERVER_URL_PLACEHOLDER = "[__HIDDEN__]" class MCPToolManageService: - """ - Service class for managing mcp tools. - """ + """Service class for managing MCP tools and providers.""" def __init__(self, session: Session): self._session = session - def _encrypt_headers(self, headers: dict[str, str], tenant_id: str) -> dict[str, str]: + # ========== Provider CRUD Operations ========== + + def get_provider( + self, *, provider_id: Optional[str] = None, server_identifier: Optional[str] = None, tenant_id: str + ) -> MCPToolProvider: """ - Encrypt headers using ProviderConfigEncrypter with all headers as SECRET_INPUT. + Get MCP provider by ID or server identifier. Args: - headers: Dictionary of headers to encrypt - tenant_id: Tenant ID for encryption + provider_id: Provider ID (UUID) + server_identifier: Server identifier + tenant_id: Tenant ID Returns: - Dictionary with all headers encrypted + MCPToolProvider instance + + Raises: + ValueError: If provider not found """ - if not headers: - return {} + if server_identifier: + stmt = select(MCPToolProvider).where( + MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier + ) + else: + stmt = select(MCPToolProvider).where( + MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id + ) - from core.entities.provider_entities import BasicProviderConfig - from core.helper.provider_cache import NoOpProviderCredentialCache - from core.tools.utils.encryption import create_provider_encrypter - - # Create dynamic config for all headers as SECRET_INPUT - config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers] - - encrypter_instance, _ = create_provider_encrypter( - tenant_id=tenant_id, - config=config, - cache=NoOpProviderCredentialCache(), - ) - - return encrypter_instance.encrypt(headers) - - def _retrieve_remote_mcp_tools( - self, - server_url: str, - headers: dict[str, str], - provider_entity: MCPProviderEntity, - auth_callback: Callable[[MCPProviderEntity, Optional[str]], dict[str, str]], - ): - """Retrieve tools from remote MCP server""" - with MCPClientWithAuthRetry( - server_url, - headers=headers, - timeout=provider_entity.timeout, - sse_read_timeout=provider_entity.sse_read_timeout, - provider_entity=provider_entity, - auth_callback=auth_callback, - ) as mcp_client: - tools = mcp_client.list_tools() - return tools - - def _process_headers(self, headers: dict[str, str], tokens: OAuthTokens | None = None) -> dict[str, str]: - """Process headers and add OAuth token if available""" - headers = headers.copy() if headers else {} - if tokens: - headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}" - return headers - - def get_provider_by_id(self, provider_id: str, tenant_id: str) -> MCPToolProvider: - """Get MCP provider by ID""" - stmt = select(MCPToolProvider).where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id) provider = self._session.scalar(stmt) if not provider: raise ValueError("MCP tool not found") return provider - def get_provider_by_server_identifier(self, server_identifier: str, tenant_id: str) -> MCPToolProvider: - """Get MCP provider by server identifier""" - stmt = select(MCPToolProvider).where( - MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier - ) - provider = self._session.scalar(stmt) - if not provider: - raise ValueError("MCP tool not found") - return provider + def get_provider_entity(self, provider_id: str, tenant_id: str, by_server_id: bool = False) -> MCPProviderEntity: + """Get provider entity by ID or server identifier.""" + if by_server_id: + db_provider = self.get_provider(server_identifier=provider_id, tenant_id=tenant_id) + else: + db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) + return db_provider.to_entity() def create_provider( self, @@ -121,36 +86,15 @@ class MCPToolManageService: sse_read_timeout: float, headers: dict[str, str] | None = None, ) -> ToolProviderApiEntity: - """Create a new MCP provider""" + """Create a new MCP provider.""" server_url_hash = hashlib.sha256(server_url.encode()).hexdigest() # Check for existing provider - stmt = select(MCPToolProvider).where( - MCPToolProvider.tenant_id == tenant_id, - or_( - MCPToolProvider.name == name, - MCPToolProvider.server_url_hash == server_url_hash, - MCPToolProvider.server_identifier == server_identifier, - ), - ) - existing_provider = self._session.scalar(stmt) + self._check_provider_exists(tenant_id, name, server_url_hash, server_identifier) - if existing_provider: - if existing_provider.name == name: - raise ValueError(f"MCP tool {name} already exists") - if existing_provider.server_url_hash == server_url_hash: - raise ValueError(f"MCP tool {server_url} already exists") - if existing_provider.server_identifier == server_identifier: - raise ValueError(f"MCP tool {server_identifier} already exists") - - # Encrypt server URL + # Encrypt sensitive data encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url) - - # Encrypt headers - encrypted_headers = None - if headers: - encrypted_headers_dict = self._encrypt_headers(headers, tenant_id) - encrypted_headers = json.dumps(encrypted_headers_dict) + encrypted_headers = self._prepare_encrypted_headers(headers, tenant_id) if headers else None # Create provider mcp_tool = MCPToolProvider( @@ -161,7 +105,7 @@ class MCPToolManageService: user_id=user_id, authed=False, tools="[]", - icon=json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon, + icon=self._prepare_icon(icon, icon_type, icon_background), server_identifier=server_identifier, timeout=timeout, sse_read_timeout=sse_read_timeout, @@ -173,59 +117,6 @@ class MCPToolManageService: return ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True) - def list_providers(self, *, tenant_id: str, for_list: bool = False) -> list[ToolProviderApiEntity]: - """List all MCP providers for a tenant""" - stmt = select(MCPToolProvider).where(MCPToolProvider.tenant_id == tenant_id).order_by(MCPToolProvider.name) - - mcp_providers = self._session.scalars(stmt).all() - - return [ - ToolTransformService.mcp_provider_to_user_provider(provider, for_list=for_list) - for provider in mcp_providers - ] - - def list_provider_tools(self, *, tenant_id: str, provider_id: str) -> ToolProviderApiEntity: - """List tools from remote MCP server""" - from core.mcp.auth.auth_flow import auth - - # Load provider and convert to entity - db_provider = self.get_provider_by_id(provider_id, tenant_id) - provider_entity = db_provider.to_entity() - - # Handle authentication headers if authed - if not provider_entity.authed: - raise ValueError("Please auth the tool first") - - tokens = provider_entity.retrieve_tokens() - headers = self._process_headers(provider_entity.headers, tokens) - server_url = provider_entity.decrypt_server_url() - try: - tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity, auth) - except MCPError as e: - raise ValueError(f"Failed to connect to MCP server: {e}") - - # Update database record with new tools - db_provider.tools = json.dumps([tool.model_dump() for tool in tools]) - db_provider.authed = True - db_provider.updated_at = datetime.now() - self._session.commit() - - # Create API response using entity - user = db_provider.load_user() - response = provider_entity.to_api_response( - user_name=user.name if user else None, - ) - response["tools"] = ToolTransformService.mcp_tool_to_user_tool(db_provider, tools) - response["plugin_unique_identifier"] = provider_entity.provider_id - - return ToolProviderApiEntity(**response) - - def delete_provider(self, *, tenant_id: str, provider_id: str) -> None: - """Delete an MCP provider""" - mcp_tool = self.get_provider_by_id(provider_id, tenant_id) - self._session.delete(mcp_tool) - self._session.commit() - def update_provider( self, *, @@ -241,8 +132,8 @@ class MCPToolManageService: sse_read_timeout: float | None = None, headers: dict[str, str] | None = None, ) -> None: - """Update an MCP provider""" - mcp_provider = self.get_provider_by_id(provider_id, tenant_id) + """Update an MCP provider.""" + mcp_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) reconnect_result = None encrypted_server_url = None @@ -263,9 +154,7 @@ class MCPToolManageService: # Update basic fields mcp_provider.updated_at = datetime.now() mcp_provider.name = name - mcp_provider.icon = ( - json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon - ) + mcp_provider.icon = self._prepare_icon(icon, icon_type, icon_background) mcp_provider.server_identifier = server_identifier # Update server URL if changed @@ -284,35 +173,86 @@ class MCPToolManageService: if sse_read_timeout is not None: mcp_provider.sse_read_timeout = sse_read_timeout if headers is not None: - # Encrypt headers - if headers: - encrypted_headers_dict = self._encrypt_headers(headers, tenant_id) - mcp_provider.encrypted_headers = json.dumps(encrypted_headers_dict) - else: - mcp_provider.encrypted_headers = None + mcp_provider.encrypted_headers = ( + self._prepare_encrypted_headers(headers, tenant_id) if headers else None + ) self._session.commit() except IntegrityError as e: self._session.rollback() - error_msg = str(e.orig) - if "unique_mcp_provider_name" in error_msg: - raise ValueError(f"MCP tool {name} already exists") - if "unique_mcp_provider_server_url" in error_msg: - raise ValueError(f"MCP tool {server_url} already exists") - if "unique_mcp_provider_server_identifier" in error_msg: - raise ValueError(f"MCP tool {server_identifier} already exists") - raise + self._handle_integrity_error(e, name, server_url, server_identifier) except Exception: self._session.rollback() raise + def delete_provider(self, *, tenant_id: str, provider_id: str) -> None: + """Delete an MCP provider.""" + mcp_tool = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) + self._session.delete(mcp_tool) + self._session.commit() + + def list_providers(self, *, tenant_id: str, for_list: bool = False) -> list[ToolProviderApiEntity]: + """List all MCP providers for a tenant.""" + stmt = select(MCPToolProvider).where(MCPToolProvider.tenant_id == tenant_id).order_by(MCPToolProvider.name) + mcp_providers = self._session.scalars(stmt).all() + + return [ + ToolTransformService.mcp_provider_to_user_provider(provider, for_list=for_list) + for provider in mcp_providers + ] + + # ========== Tool Operations ========== + + def list_provider_tools(self, *, tenant_id: str, provider_id: str) -> ToolProviderApiEntity: + """List tools from remote MCP server.""" + from core.mcp.auth.auth_flow import auth + + # Load provider and convert to entity + db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) + provider_entity = db_provider.to_entity() + + # Verify authentication + if not provider_entity.authed: + raise ValueError("Please auth the tool first") + + # Prepare headers with auth token + headers = self._prepare_auth_headers(provider_entity) + + # Retrieve tools from remote server + server_url = provider_entity.decrypt_server_url() + try: + tools = self._retrieve_remote_mcp_tools( + server_url, headers, provider_entity, lambda p, s, c: auth(p, self, c) + ) + except MCPError as e: + raise ValueError(f"Failed to connect to MCP server: {e}") + + # Update database with retrieved tools + db_provider.tools = json.dumps([tool.model_dump() for tool in tools]) + db_provider.authed = True + db_provider.updated_at = datetime.now() + self._session.commit() + + # Build API response + return self._build_tool_provider_response(db_provider, provider_entity, tools) + + # ========== OAuth and Credentials Operations ========== + def update_provider_credentials( - self, *, provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False + self, *, provider: MCPToolProvider, credentials: dict[str, Any], authed: bool | None = None ) -> None: - """Update provider credentials""" + """ + Update provider credentials with encryption. + + Args: + provider: Provider instance + credentials: Credentials to save + authed: Whether provider is authenticated (None means keep current state) + """ from core.tools.mcp_tool.provider import MCPToolProviderController + # Encrypt new credentials provider_controller = MCPToolProviderController.from_db(provider) tool_configuration = ProviderConfigEncrypter( tenant_id=provider.tenant_id, @@ -320,24 +260,130 @@ class MCPToolManageService: provider_config_cache=NoOpProviderCredentialCache(), ) encrypted_credentials = tool_configuration.encrypt(credentials) + + # Update provider provider.updated_at = datetime.now() provider.encrypted_credentials = json.dumps({**provider.credentials, **encrypted_credentials}) - provider.authed = authed - if not authed: - provider.tools = "[]" + + if authed is not None: + provider.authed = authed + if not authed: + provider.tools = "[]" self._session.commit() + def save_oauth_data(self, provider_id: str, tenant_id: str, data: dict[str, Any], data_type: str = "mixed") -> None: + """ + Save OAuth-related data (tokens, client info, code verifier). + + Args: + provider_id: Provider ID + tenant_id: Tenant ID + data: Data to save (tokens, client info, or code verifier) + data_type: Type of data ('tokens', 'client_info', 'code_verifier', 'mixed') + """ + db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) + + credentials = {} + authed = None + + if data_type == "tokens" or (data_type == "mixed" and "access_token" in data): + # OAuth tokens + credentials = data + authed = True + elif data_type == "client_info" or (data_type == "mixed" and "client_information" in data): + # OAuth client information + credentials = data + elif data_type == "code_verifier" or (data_type == "mixed" and "code_verifier" in data): + # PKCE code verifier + credentials = data + else: + credentials = data + + self.update_provider_credentials(provider=db_provider, credentials=credentials, authed=authed) + def clear_provider_credentials(self, *, provider: MCPToolProvider) -> None: - """Clear provider credentials""" + """Clear all credentials for a provider.""" provider.tools = "[]" provider.encrypted_credentials = "{}" provider.updated_at = datetime.now() provider.authed = False self._session.commit() + # ========== Private Helper Methods ========== + + def _check_provider_exists(self, tenant_id: str, name: str, server_url_hash: str, server_identifier: str) -> None: + """Check if provider with same attributes already exists.""" + stmt = select(MCPToolProvider).where( + MCPToolProvider.tenant_id == tenant_id, + or_( + MCPToolProvider.name == name, + MCPToolProvider.server_url_hash == server_url_hash, + MCPToolProvider.server_identifier == server_identifier, + ), + ) + existing_provider = self._session.scalar(stmt) + + if existing_provider: + if existing_provider.name == name: + raise ValueError(f"MCP tool {name} already exists") + if existing_provider.server_url_hash == server_url_hash: + raise ValueError("MCP tool with this server URL already exists") + if existing_provider.server_identifier == server_identifier: + raise ValueError(f"MCP tool {server_identifier} already exists") + + def _prepare_icon(self, icon: str, icon_type: str, icon_background: str) -> str: + """Prepare icon data for storage.""" + if icon_type == "emoji": + return json.dumps({"content": icon, "background": icon_background}) + return icon + + def _prepare_encrypted_headers(self, headers: dict[str, str], tenant_id: str) -> str: + """Encrypt headers and prepare for storage.""" + from core.entities.provider_entities import BasicProviderConfig + from core.tools.utils.encryption import create_provider_encrypter + + # Create dynamic config for all headers as SECRET_INPUT + config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers] + + encrypter_instance, _ = create_provider_encrypter( + tenant_id=tenant_id, + config=config, + cache=NoOpProviderCredentialCache(), + ) + + encrypted_headers_dict = encrypter_instance.encrypt(headers) + return json.dumps(encrypted_headers_dict) + + def _prepare_auth_headers(self, provider_entity: MCPProviderEntity) -> dict[str, str]: + """Prepare headers with OAuth token if available.""" + headers = provider_entity.headers.copy() if provider_entity.headers else {} + tokens = provider_entity.retrieve_tokens() + if tokens: + headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}" + return headers + + def _retrieve_remote_mcp_tools( + self, + server_url: str, + headers: dict[str, str], + provider_entity: MCPProviderEntity, + auth_callback: Callable[[MCPProviderEntity, "MCPToolManageService", Optional[str]], dict[str, str]], + ): + """Retrieve tools from remote MCP server.""" + with MCPClientWithAuthRetry( + server_url, + headers=headers, + timeout=provider_entity.timeout, + sse_read_timeout=provider_entity.sse_read_timeout, + provider_entity=provider_entity, + auth_callback=auth_callback, + mcp_service=self, + ) as mcp_client: + return mcp_client.list_tools() + def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> dict[str, Any]: - """Attempt to reconnect to MCP provider with new server URL""" + """Attempt to reconnect to MCP provider with new server URL.""" from core.mcp.auth.auth_flow import auth provider_entity = provider.to_entity() @@ -352,7 +398,8 @@ class MCPToolManageService: timeout=timeout, sse_read_timeout=sse_read_timeout, provider_entity=provider_entity, - auth_callback=auth, + auth_callback=lambda p, s, c: auth(p, self, c), + mcp_service=self, ) as mcp_client: tools = mcp_client.list_tools() return { @@ -364,3 +411,28 @@ class MCPToolManageService: return {"authed": False, "tools": "[]", "encrypted_credentials": "{}"} except MCPError as e: raise ValueError(f"Failed to re-connect MCP server: {e}") from e + + def _build_tool_provider_response( + self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list + ) -> ToolProviderApiEntity: + """Build API response for tool provider.""" + user = db_provider.load_user() + response = provider_entity.to_api_response( + user_name=user.name if user else None, + ) + response["tools"] = ToolTransformService.mcp_tool_to_user_tool(db_provider, tools) + response["plugin_unique_identifier"] = provider_entity.provider_id + return ToolProviderApiEntity(**response) + + def _handle_integrity_error( + self, error: IntegrityError, name: str, server_url: str, server_identifier: str + ) -> None: + """Handle database integrity errors with user-friendly messages.""" + error_msg = str(error.orig) + if "unique_mcp_provider_name" in error_msg: + raise ValueError(f"MCP tool {name} already exists") + if "unique_mcp_provider_server_url" in error_msg: + raise ValueError(f"MCP tool {server_url} already exists") + if "unique_mcp_provider_server_identifier" in error_msg: + raise ValueError(f"MCP tool {server_identifier} already exists") + raise