refactor(mcp): clean the client code

This commit is contained in:
Novice 2025-09-12 15:23:53 +08:00
parent aa44c38b58
commit f16151ea29
5 changed files with 160 additions and 131 deletions

View File

@ -947,15 +947,19 @@ class ToolMCPAuthApi(Resource):
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id) provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if not provider: if not provider:
raise ValueError("provider not found") raise ValueError("provider not found")
# headers1: if headers is provided, use it and don't need to get token
headers = provider.decrypted_headers or {}
# headers2: Add OAuth token if authed and no headers provided
if not provider.decrypted_headers and provider.authed:
token = OAuthClientProvider(provider_id, tenant_id, for_list=True).tokens()
if token:
headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
try: try:
# try to connect to MCP server with headers
with MCPClient( with MCPClient(
provider.decrypted_server_url, provider.decrypted_server_url,
provider_id, headers=headers,
tenant_id,
authed=False,
authorization_code=args["authorization_code"],
for_list=True,
headers=provider.decrypted_headers,
timeout=provider.timeout, timeout=provider.timeout,
sse_read_timeout=provider.sse_read_timeout, sse_read_timeout=provider.sse_read_timeout,
): ):
@ -966,8 +970,11 @@ class ToolMCPAuthApi(Resource):
) )
return {"result": "success"} return {"result": "success"}
except MCPAuthError: except MCPAuthError as e:
try: try:
if provider.decrypted_headers:
raise ValueError(f"Failed to authenticate, please check your headers: {e}") from e
# if auth failed, try to auth with OAuth or exchange token
auth_provider = OAuthClientProvider(provider_id, tenant_id, for_list=True) auth_provider = OAuthClientProvider(provider_id, tenant_id, for_list=True)
return auth(auth_provider, provider.decrypted_server_url, args["authorization_code"]) return auth(auth_provider, provider.decrypted_server_url, args["authorization_code"])
except Exception as e: except Exception as e:

View File

@ -299,7 +299,6 @@ def auth(
server_url: str, server_url: str,
authorization_code: Optional[str] = None, authorization_code: Optional[str] = None,
state_param: Optional[str] = None, state_param: Optional[str] = None,
for_list: bool = False,
) -> dict[str, str]: ) -> dict[str, str]:
"""Orchestrates the full auth flow with a server using secure Redis state storage.""" """Orchestrates the full auth flow with a server using secure Redis state storage."""
metadata = discover_oauth_metadata(server_url) metadata = discover_oauth_metadata(server_url)

View File

@ -2,12 +2,12 @@ import logging
from collections.abc import Callable from collections.abc import Callable
from contextlib import AbstractContextManager, ExitStack from contextlib import AbstractContextManager, ExitStack
from types import TracebackType from types import TracebackType
from typing import Any, Optional from typing import Any
from urllib.parse import urlparse from urllib.parse import urlparse
from core.mcp.client.sse_client import sse_client from core.mcp.client.sse_client import sse_client
from core.mcp.client.streamable_client import streamablehttp_client from core.mcp.client.streamable_client import streamablehttp_client
from core.mcp.error import MCPAuthError, MCPConnectionError from core.mcp.error import MCPConnectionError
from core.mcp.session.client_session import ClientSession from core.mcp.session.client_session import ClientSession
from core.mcp.types import CallToolResult, Tool from core.mcp.types import CallToolResult, Tool
@ -18,40 +18,18 @@ class MCPClient:
def __init__( def __init__(
self, self,
server_url: str, server_url: str,
provider_id: str, headers: dict[str, str] | None = None,
tenant_id: str, timeout: float | None = None,
authed: bool = True, sse_read_timeout: float | None = None,
authorization_code: Optional[str] = None,
for_list: bool = False,
headers: Optional[dict[str, str]] = None,
timeout: Optional[float] = None,
sse_read_timeout: Optional[float] = None,
): ):
# Initialize info
self.provider_id = provider_id
self.tenant_id = tenant_id
self.client_type = "streamable"
self.server_url = server_url self.server_url = server_url
self.headers = headers or {} self.headers = headers or {}
self.timeout = timeout self.timeout = timeout
self.sse_read_timeout = sse_read_timeout self.sse_read_timeout = sse_read_timeout
# Authentication info
self.authed = authed
self.authorization_code = authorization_code
if authed:
from core.mcp.auth.auth_provider import OAuthClientProvider
self.provider = OAuthClientProvider(self.provider_id, self.tenant_id, for_list=for_list)
self.token = self.provider.tokens()
# Initialize session and client objects # Initialize session and client objects
self._session: Optional[ClientSession] = None self._session: ClientSession | None = None
self._streams_context: Optional[AbstractContextManager[Any]] = None
self._session_context: Optional[ClientSession] = None
self._exit_stack = ExitStack() self._exit_stack = ExitStack()
# Whether the client has been initialized
self._initialized = False self._initialized = False
def __enter__(self): def __enter__(self):
@ -59,9 +37,7 @@ class MCPClient:
self._initialized = True self._initialized = True
return self return self
def __exit__( def __exit__(self, exc_type: type | None, exc_value: BaseException | None, traceback: TracebackType | None):
self, exc_type: Optional[type], exc_value: Optional[BaseException], traceback: Optional[TracebackType]
):
self.cleanup() self.cleanup()
def _initialize( def _initialize(
@ -87,61 +63,42 @@ class MCPClient:
logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.") logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.")
self.connect_server(streamablehttp_client, "mcp") self.connect_server(streamablehttp_client, "mcp")
def connect_server( def connect_server(self, client_factory: Callable[..., AbstractContextManager[Any]], method_name: str) -> None:
self, client_factory: Callable[..., AbstractContextManager[Any]], method_name: str, first_try: bool = True """
): Connect to the MCP server using streamable http or sse.
from core.mcp.auth.auth_flow import auth Default to streamable http.
Args:
client_factory: The client factory to use(streamablehttp_client or sse_client).
method_name: The method name to use(mcp or sse).
"""
streams_context = client_factory(
url=self.server_url,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
)
try: # Use exit_stack to manage context managers properly
headers = ( if method_name == "mcp":
{"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"} read_stream, write_stream, _ = self._exit_stack.enter_context(streams_context)
if self.authed and self.token streams = (read_stream, write_stream)
else self.headers else: # sse_client
) streams = self._exit_stack.enter_context(streams_context)
self._streams_context = client_factory(
url=self.server_url,
headers=headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
)
if not self._streams_context:
raise MCPConnectionError("Failed to create connection context")
# Use exit_stack to manage context managers properly session_context = ClientSession(*streams)
if method_name == "mcp": self._session = self._exit_stack.enter_context(session_context)
read_stream, write_stream, _ = self._exit_stack.enter_context(self._streams_context) self._session.initialize()
streams = (read_stream, write_stream)
else: # sse_client
streams = self._exit_stack.enter_context(self._streams_context)
self._session_context = ClientSession(*streams)
self._session = self._exit_stack.enter_context(self._session_context)
self._session.initialize()
return
except MCPAuthError:
if not self.authed:
raise
try:
auth(self.provider, self.server_url, self.authorization_code)
except Exception as e:
raise ValueError(f"Failed to authenticate: {e}")
self.token = self.provider.tokens()
if first_try:
return self.connect_server(client_factory, method_name, first_try=False)
def list_tools(self) -> list[Tool]: def list_tools(self) -> list[Tool]:
"""Connect to an MCP server running with SSE transport""" """List available tools from the MCP server"""
# List available tools to verify connection if not self._session:
if not self._initialized or not self._session:
raise ValueError("Session not initialized.") raise ValueError("Session not initialized.")
response = self._session.list_tools() response = self._session.list_tools()
tools = response.tools return response.tools
return tools
def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult: def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult:
"""Call a tool""" """Call a tool"""
if not self._initialized or not self._session: if not self._session:
raise ValueError("Session not initialized.") raise ValueError("Session not initialized.")
return self._session.call_tool(tool_name, tool_args) return self._session.call_tool(tool_name, tool_args)
@ -155,6 +112,4 @@ class MCPClient:
raise ValueError(f"Error during cleanup: {e}") raise ValueError(f"Error during cleanup: {e}")
finally: finally:
self._session = None self._session = None
self._session_context = None
self._streams_context = None
self._initialized = False self._initialized = False

View File

@ -3,12 +3,14 @@ import json
from collections.abc import Generator from collections.abc import Generator
from typing import Any, Optional from typing import Any, Optional
from core.mcp.auth.auth_flow import auth
from core.mcp.error import MCPAuthError, MCPConnectionError from core.mcp.error import MCPAuthError, MCPConnectionError
from core.mcp.mcp_client import MCPClient from core.mcp.mcp_client import MCPClient
from core.mcp.types import ImageContent, TextContent from core.mcp.types import CallToolResult, ImageContent, TextContent
from core.tools.__base.tool import Tool from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
from core.tools.errors import ToolInvokeError
class MCPTool(Tool): class MCPTool(Tool):
@ -44,26 +46,7 @@ class MCPTool(Tool):
app_id: Optional[str] = None, app_id: Optional[str] = None,
message_id: Optional[str] = None, message_id: Optional[str] = None,
) -> Generator[ToolInvokeMessage, None, None]: ) -> Generator[ToolInvokeMessage, None, None]:
from core.tools.errors import ToolInvokeError result = self.invoke_remote_mcp_tool(tool_parameters)
try:
with MCPClient(
self.server_url,
self.provider_id,
self.tenant_id,
authed=True,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
) as mcp_client:
tool_parameters = self._handle_none_parameter(tool_parameters)
result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
except MCPAuthError as e:
raise ToolInvokeError("Please auth the tool first") from e
except MCPConnectionError as e:
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
except Exception as e:
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
# handle dify tool output # handle dify tool output
for content in result.content: for content in result.content:
if isinstance(content, TextContent): if isinstance(content, TextContent):
@ -95,7 +78,7 @@ class MCPTool(Tool):
def _process_json_list(self, json_list: list) -> Generator[ToolInvokeMessage, None, None]: def _process_json_list(self, json_list: list) -> Generator[ToolInvokeMessage, None, None]:
"""Process a list of JSON items.""" """Process a list of JSON items."""
if any(not isinstance(item, dict[str, Any]) for item in json_list): if any(not isinstance(item, dict) for item in json_list):
# If the list contains any non-dict item, treat the entire list as a text message. # If the list contains any non-dict item, treat the entire list as a text message.
yield self.create_text_message(str(json_list)) yield self.create_text_message(str(json_list))
return return
@ -130,3 +113,65 @@ class MCPTool(Tool):
for key, value in parameter.items() for key, value in parameter.items()
if value is not None and not (isinstance(value, str) and value.strip() == "") if value is not None and not (isinstance(value, str) and value.strip() == "")
} }
def invoke_remote_mcp_tool(self, tool_parameters: dict[str, Any]) -> CallToolResult:
headers = self.headers.copy() if self.headers else {}
tool_parameters = self._handle_none_parameter(tool_parameters)
# Initialize auth provider
from core.mcp.auth.auth_provider import OAuthClientProvider
provider = None
try:
provider = OAuthClientProvider(self.provider_id, self.tenant_id, for_list=False)
except Exception as e:
# If provider initialization fails, continue without auth
pass
# Try to get existing token and add to headers
if provider:
try:
token = provider.tokens()
if token:
headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
except Exception:
# If token retrieval fails, continue without auth header
pass
# Define a helper function to invoke the tool
def _invoke_with_client(client_headers: dict[str, str]) -> CallToolResult:
with MCPClient(
self.server_url,
headers=client_headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
) as mcp_client:
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
try:
# First attempt with current headers
return _invoke_with_client(headers)
except MCPAuthError as e:
# Authentication required - try to authenticate
if not provider:
raise ToolInvokeError("Authentication required but no auth provider available") from e
try:
# Perform authentication flow
auth(provider, self.server_url, None, None, False)
token = provider.tokens()
if not token:
raise ToolInvokeError("Authentication failed - no token received")
# Update headers with new token while preserving other headers
headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
# Retry with authenticated headers
return _invoke_with_client(headers)
except MCPAuthError as auth_error:
raise ToolInvokeError("Authentication failed") from auth_error
except MCPConnectionError as e:
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
except Exception as e:
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e

View File

@ -1,7 +1,7 @@
import hashlib import hashlib
import json import json
from datetime import datetime from datetime import datetime
from typing import Any, cast from typing import Any
from sqlalchemy import or_ from sqlalchemy import or_
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
@ -10,10 +10,10 @@ from core.helper import encrypter
from core.helper.provider_cache import NoOpProviderCredentialCache from core.helper.provider_cache import NoOpProviderCredentialCache
from core.mcp.error import MCPAuthError, MCPError from core.mcp.error import MCPAuthError, MCPError
from core.mcp.mcp_client import MCPClient from core.mcp.mcp_client import MCPClient
from core.mcp.types import OAuthTokens
from core.tools.entities.api_entities import ToolProviderApiEntity from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType from core.tools.entities.tool_entities import ToolProviderType
from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.utils.encryption import ProviderConfigEncrypter from core.tools.utils.encryption import ProviderConfigEncrypter
from extensions.ext_database import db from extensions.ext_database import db
from models.tools import MCPToolProvider from models.tools import MCPToolProvider
@ -55,7 +55,25 @@ class MCPToolManageService:
cache=NoOpProviderCredentialCache(), cache=NoOpProviderCredentialCache(),
) )
return cast(dict[str, str], encrypter_instance.encrypt(headers)) return encrypter_instance.encrypt(headers)
@staticmethod
def _retrieve_remote_mcp_tools(server_url: str, headers: dict[str, str], timeout: float, sse_read_timeout: float):
with MCPClient(
server_url,
headers=headers,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
) as mcp_client:
tools = mcp_client.list_tools()
return tools
@staticmethod
def _process_headers(headers: dict[str, str], tokens: OAuthTokens | None = None):
headers = headers or {}
if tokens:
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
return headers
@staticmethod @staticmethod
def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider: def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider:
@ -153,6 +171,9 @@ class MCPToolManageService:
@classmethod @classmethod
def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str) -> ToolProviderApiEntity: def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str) -> ToolProviderApiEntity:
from core.mcp.auth.auth_flow import auth
from core.mcp.auth.auth_provider import OAuthClientProvider
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
server_url = mcp_provider.decrypted_server_url server_url = mcp_provider.decrypted_server_url
authed = mcp_provider.authed authed = mcp_provider.authed
@ -160,20 +181,24 @@ class MCPToolManageService:
timeout = mcp_provider.timeout timeout = mcp_provider.timeout
sse_read_timeout = mcp_provider.sse_read_timeout sse_read_timeout = mcp_provider.sse_read_timeout
try: # Handle authentication headers if authed
with MCPClient( if not authed:
server_url,
provider_id,
tenant_id,
authed=authed,
for_list=True,
headers=headers,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
) as mcp_client:
tools = mcp_client.list_tools()
except MCPAuthError:
raise ValueError("Please auth the tool first") raise ValueError("Please auth the tool first")
provider = OAuthClientProvider(provider_id, tenant_id, for_list=True)
tokens = provider.tokens()
headers = cls._process_headers(headers, tokens)
try:
tools = cls._retrieve_remote_mcp_tools(server_url, headers, timeout, sse_read_timeout)
except MCPAuthError:
try:
auth(provider, server_url, None, None, False)
tokens = provider.tokens()
re_authed_headers = cls._process_headers(headers, tokens)
tools = cls._retrieve_remote_mcp_tools(server_url, re_authed_headers, timeout, sse_read_timeout)
except Exception:
raise ValueError("Please auth the tool first")
except MCPError as e: except MCPError as e:
raise ValueError(f"Failed to connect to MCP server: {e}") raise ValueError(f"Failed to connect to MCP server: {e}")
@ -284,10 +309,12 @@ class MCPToolManageService:
def update_mcp_provider_credentials( def update_mcp_provider_credentials(
cls, mcp_provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False cls, mcp_provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False
): ):
from core.tools.mcp_tool.provider import MCPToolProviderController
provider_controller = MCPToolProviderController.from_db(mcp_provider) provider_controller = MCPToolProviderController.from_db(mcp_provider)
tool_configuration = ProviderConfigEncrypter( tool_configuration = ProviderConfigEncrypter(
tenant_id=mcp_provider.tenant_id, tenant_id=mcp_provider.tenant_id,
config=list(provider_controller.get_credentials_schema()), # ty: ignore [invalid-argument-type] config=list(provider_controller.get_credentials_schema()),
provider_config_cache=NoOpProviderCredentialCache(), provider_config_cache=NoOpProviderCredentialCache(),
) )
credentials = tool_configuration.encrypt(credentials) credentials = tool_configuration.encrypt(credentials)
@ -310,7 +337,7 @@ class MCPToolManageService:
db.session.commit() db.session.commit()
@classmethod @classmethod
def _re_connect_mcp_provider(cls, server_url: str, provider_id: str, tenant_id: str): def _re_connect_mcp_provider(cls, server_url: str, provider_id: str, tenant_id: str) -> dict[str, Any]:
# Get the existing provider to access headers and timeout settings # Get the existing provider to access headers and timeout settings
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
headers = mcp_provider.decrypted_headers headers = mcp_provider.decrypted_headers
@ -320,10 +347,6 @@ class MCPToolManageService:
try: try:
with MCPClient( with MCPClient(
server_url, server_url,
provider_id,
tenant_id,
authed=False,
for_list=True,
headers=headers, headers=headers,
timeout=timeout, timeout=timeout,
sse_read_timeout=sse_read_timeout, sse_read_timeout=sse_read_timeout,