chore: fix review issues

This commit is contained in:
Novice 2025-10-14 20:36:13 +08:00
parent d5a7a537e5
commit 5c6a2af448
No known key found for this signature in database
GPG Key ID: EE3F68E3105DAAAB
11 changed files with 296 additions and 257 deletions

View File

@ -16,7 +16,7 @@ from controllers.console.wraps import (
enterprise_license_required,
setup_required,
)
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPSupportGrantType
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
from core.mcp.auth.auth_flow import auth, handle_callback
from core.mcp.error import MCPAuthError, MCPError
from core.mcp.mcp_client import MCPClient
@ -44,7 +44,9 @@ def is_valid_url(url: str) -> bool:
try:
parsed = urlparse(url)
return all([parsed.scheme, parsed.netloc]) and parsed.scheme in ["http", "https"]
except Exception:
except (ValueError, TypeError):
# ValueError: Invalid URL format
# TypeError: url is not a string
return False
@ -886,7 +888,7 @@ class ToolProviderMCPApi(Resource):
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
# Create provider
with Session(db.engine) as session:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
result = service.create_provider(
tenant_id=tenant_id,
@ -897,14 +899,10 @@ class ToolProviderMCPApi(Resource):
icon_type=args["icon_type"],
icon_background=args["icon_background"],
server_identifier=args["server_identifier"],
timeout=configuration.timeout,
sse_read_timeout=configuration.sse_read_timeout,
headers=args["headers"],
client_id=authentication.client_id if authentication else None,
client_secret=authentication.client_secret if authentication else None,
grant_type=authentication.grant_type if authentication else MCPSupportGrantType.AUTHORIZATION_CODE,
configuration=configuration,
authentication=authentication,
)
session.commit()
return jsonable_encoder(result)
@setup_required
@ -932,7 +930,7 @@ class ToolProviderMCPApi(Resource):
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
_, current_tenant_id = current_account_with_tenant()
with Session(db.engine) as session:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.update_provider(
tenant_id=current_tenant_id,
@ -943,14 +941,10 @@ class ToolProviderMCPApi(Resource):
icon_type=args["icon_type"],
icon_background=args["icon_background"],
server_identifier=args["server_identifier"],
timeout=configuration.timeout,
sse_read_timeout=configuration.sse_read_timeout,
headers=args["headers"],
client_id=authentication.client_id if authentication else None,
client_secret=authentication.client_secret if authentication else None,
grant_type=authentication.grant_type if authentication else MCPSupportGrantType.AUTHORIZATION_CODE,
configuration=configuration,
authentication=authentication,
)
session.commit()
return {"result": "success"}
@setup_required
@ -962,10 +956,9 @@ class ToolProviderMCPApi(Resource):
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
with Session(db.engine) as session:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.delete_provider(tenant_id=current_tenant_id.current_tenant_id, provider_id=args["provider_id"])
session.commit()
service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"])
return {"result": "success"}
@ -983,23 +976,18 @@ class ToolMCPAuthApi(Resource):
_, tenant_id = current_account_with_tenant()
with Session(db.engine) as session:
service = MCPToolManageService(session=session)
db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
if not db_provider:
raise ValueError("provider not found")
with session.begin():
service = MCPToolManageService(session=session)
db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
if not db_provider:
raise ValueError("provider not found")
# Convert to entity
provider_entity = db_provider.to_entity()
server_url = provider_entity.decrypt_server_url()
# Convert to entity
provider_entity = db_provider.to_entity()
server_url = provider_entity.decrypt_server_url()
headers = provider_entity.decrypt_authentication()
# Option 1: if headers is provided, use it and don't need to get token
headers = provider_entity.decrypt_headers()
# Option 2: Add OAuth token if authed and no headers provided
if not provider_entity.headers and provider_entity.authed:
token = provider_entity.retrieve_tokens()
if token:
headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
# Try to connect without active transaction
try:
# Use MCPClientWithAuthRetry to handle authentication automatically
with MCPClient(
@ -1008,18 +996,20 @@ class ToolMCPAuthApi(Resource):
timeout=provider_entity.timeout,
sse_read_timeout=provider_entity.sse_read_timeout,
):
service.update_provider_credentials(
provider=db_provider,
credentials=provider_entity.credentials,
authed=True,
)
session.commit()
# Create new transaction for update
with session.begin():
service.update_provider_credentials(
provider=db_provider,
credentials=provider_entity.credentials,
authed=True,
)
return {"result": "success"}
except MCPAuthError as e:
service = MCPToolManageService(session=session)
return auth(provider_entity, service, args.get("authorization_code"))
except MCPError as e:
service.clear_provider_credentials(provider=db_provider)
session.commit()
with session.begin():
service.clear_provider_credentials(provider=db_provider)
raise ValueError(f"Failed to connect to MCP server: {e}") from e
@ -1044,7 +1034,7 @@ class ToolMCPListAllApi(Resource):
def get(self):
_, tenant_id = current_account_with_tenant()
with Session(db.engine) as session:
with Session(db.engine, expire_on_commit=False) as session:
service = MCPToolManageService(session=session)
tools = service.list_providers(tenant_id=tenant_id)
@ -1058,7 +1048,7 @@ class ToolMCPUpdateApi(Resource):
@account_initialization_required
def get(self, provider_id):
_, tenant_id = current_account_with_tenant()
with Session(db.engine) as session:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
tools = service.list_provider_tools(
tenant_id=tenant_id,
@ -1078,9 +1068,8 @@ class ToolMCPCallbackApi(Resource):
authorization_code = args["code"]
# Create service instance for handle_callback
with Session(db.engine) as session:
with Session(db.engine) as session, session.begin():
mcp_service = MCPToolManageService(session=session)
handle_callback(state_key, authorization_code, mcp_service)
session.commit()
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")

View File

@ -4,7 +4,7 @@ from enum import StrEnum
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse
from pydantic import BaseModel, Field
from pydantic import BaseModel
from configs import dify_config
from core.entities.provider_entities import BasicProviderConfig
@ -20,7 +20,6 @@ if TYPE_CHECKING:
from models.tools import MCPToolProvider
# Constants
DEFAULT_GRANT_TYPE = "authorization_code"
CLIENT_NAME = "Dify"
CLIENT_URI = "https://github.com/langgenius/dify"
DEFAULT_TOKEN_TYPE = "Bearer"
@ -34,12 +33,12 @@ class MCPSupportGrantType(StrEnum):
AUTHORIZATION_CODE = "authorization_code"
CLIENT_CREDENTIALS = "client_credentials"
REFRESH_TOKEN = "refresh_token"
class MCPAuthentication(BaseModel):
client_id: str
client_secret: str | None = None
grant_type: MCPSupportGrantType = Field(default=MCPSupportGrantType.AUTHORIZATION_CODE)
class MCPConfiguration(BaseModel):
@ -110,7 +109,7 @@ class MCPProviderEntity(BaseModel):
credentials = self.decrypt_credentials()
# Try to get grant_type from different locations
grant_type = credentials.get("grant_type", DEFAULT_GRANT_TYPE)
grant_type = credentials.get("grant_type", MCPSupportGrantType.AUTHORIZATION_CODE)
# For nested structure, check if client_information has grant_types
if "client_information" in credentials and isinstance(credentials["client_information"], dict):
@ -118,12 +117,12 @@ class MCPProviderEntity(BaseModel):
# If grant_types is specified in client_information, use it to determine grant_type
if "grant_types" in client_info and isinstance(client_info["grant_types"], list):
if "client_credentials" in client_info["grant_types"]:
grant_type = "client_credentials"
grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS
elif "authorization_code" in client_info["grant_types"]:
grant_type = "authorization_code"
grant_type = MCPSupportGrantType.AUTHORIZATION_CODE
# Configure based on grant type
is_client_credentials = grant_type == "client_credentials"
is_client_credentials = grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS
grant_types = ["refresh_token"]
grant_types.append("client_credentials" if is_client_credentials else "authorization_code")
@ -212,10 +211,7 @@ class MCPProviderEntity(BaseModel):
json_fields = ["redirect_uris", "grant_types", "response_types"]
for field in json_fields:
if field in credentials:
try:
client_info[field] = json.loads(credentials[field])
except:
client_info[field] = []
client_info[field] = json.loads(credentials[field])
if "scope" in credentials:
client_info["scope"] = credentials["scope"]
@ -237,10 +233,10 @@ class MCPProviderEntity(BaseModel):
def masked_server_url(self) -> str:
"""Masked server URL for display"""
parsed = urlparse(self.decrypt_server_url())
base_url = f"{parsed.scheme}://{parsed.netloc}"
if parsed.path and parsed.path != "/":
return f"{base_url}/******"
return base_url
masked = parsed._replace(path="/******")
return masked.geturl()
return parsed.geturl()
def _mask_value(self, value: str) -> str:
"""Mask a sensitive value for display"""
@ -289,46 +285,41 @@ class MCPProviderEntity(BaseModel):
def _decrypt_dict(self, data: dict[str, Any]) -> dict[str, Any]:
"""Generic method to decrypt dictionary fields"""
try:
if not data:
return {}
# Only decrypt fields that are actually encrypted
# For nested structures, client_information is not encrypted as a whole
encrypted_fields = []
for key, value in data.items():
# Skip nested objects - they are not encrypted
if isinstance(value, dict):
continue
# Only process string values that might be encrypted
if isinstance(value, str) and value:
encrypted_fields.append(key)
if not encrypted_fields:
return data
# Create dynamic config only for encrypted fields
config = [
BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in encrypted_fields
]
encrypter_instance, _ = create_provider_encrypter(
tenant_id=self.tenant_id,
config=config,
cache=NoOpProviderCredentialCache(),
)
# Decrypt only the encrypted fields
decrypted_data = encrypter_instance.decrypt({k: data[k] for k in encrypted_fields})
# Merge decrypted data with original data (preserving non-encrypted fields)
result = data.copy()
result.update(decrypted_data)
return result
except Exception:
if not data:
return {}
# Only decrypt fields that are actually encrypted
# For nested structures, client_information is not encrypted as a whole
encrypted_fields = []
for key, value in data.items():
# Skip nested objects - they are not encrypted
if isinstance(value, dict):
continue
# Only process string values that might be encrypted
if isinstance(value, str) and value:
encrypted_fields.append(key)
if not encrypted_fields:
return data
# Create dynamic config only for encrypted fields
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in encrypted_fields]
encrypter_instance, _ = create_provider_encrypter(
tenant_id=self.tenant_id,
config=config,
cache=NoOpProviderCredentialCache(),
)
# Decrypt only the encrypted fields
decrypted_data = encrypter_instance.decrypt({k: data[k] for k in encrypted_fields})
# Merge decrypted data with original data (preserving non-encrypted fields)
result = data.copy()
result.update(decrypted_data)
return result
def decrypt_headers(self) -> dict[str, Any]:
"""Decrypt headers"""
return self._decrypt_dict(self.headers)
@ -336,3 +327,15 @@ class MCPProviderEntity(BaseModel):
def decrypt_credentials(self) -> dict[str, Any]:
"""Decrypt credentials"""
return self._decrypt_dict(self.credentials)
def decrypt_authentication(self) -> dict[str, Any]:
"""Decrypt authentication"""
# Option 1: if headers is provided, use it and don't need to get token
headers = self.decrypt_headers()
# Option 2: Add OAuth token if authed and no headers provided
if not self.headers and self.authed:
token = self.retrieve_tokens()
if token:
headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
return headers

View File

@ -7,10 +7,11 @@ import urllib.parse
from typing import TYPE_CHECKING
from urllib.parse import urljoin, urlparse
import httpx
from httpx import ConnectError, HTTPStatusError, RequestError
from pydantic import BaseModel, ValidationError
from core.entities.mcp_provider import MCPProviderEntity
from core.entities.mcp_provider import MCPProviderEntity, MCPSupportGrantType
from core.helper import ssrf_proxy
from core.mcp.types import (
LATEST_PROTOCOL_VERSION,
OAuthClientInformation,
@ -106,15 +107,15 @@ def handle_callback(state_key: str, authorization_code: str, mcp_service: "MCPTo
def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
"""Check if the server supports OAuth 2.0 Resource Discovery."""
b_scheme, b_netloc, b_path, _, b_query, b_fragment = urlparse(server_url, "", True)
url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource{b_path}"
b_scheme, b_netloc, _, _, b_query, b_fragment = urlparse(server_url, "", True)
url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource"
if b_query:
url_for_resource_discovery += f"?{b_query}"
if b_fragment:
url_for_resource_discovery += f"#{b_fragment}"
try:
headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"}
response = httpx.get(url_for_resource_discovery, headers=headers)
response = ssrf_proxy.get(url_for_resource_discovery, headers=headers)
if 200 <= response.status_code < 300:
body = response.json()
# Support both singular and plural forms
@ -125,7 +126,7 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
else:
return False, ""
return False, ""
except httpx.RequestError:
except RequestError:
# Not support resource discovery, fall back to well-known OAuth metadata
return False, ""
@ -138,8 +139,8 @@ def discover_oauth_metadata(server_url: str, protocol_version: str | None = None
# The oauth_discovery_url is the authorization server base URL
# Try OpenID Connect discovery first (more common), then OAuth 2.0
urls_to_try = [
urljoin(oauth_discovery_url + "/", ".well-known/openid-configuration"),
urljoin(oauth_discovery_url + "/", ".well-known/oauth-authorization-server"),
urljoin(oauth_discovery_url + "/", ".well-known/openid-configuration"),
]
else:
urls_to_try = [urljoin(server_url, "/.well-known/oauth-authorization-server")]
@ -148,15 +149,15 @@ def discover_oauth_metadata(server_url: str, protocol_version: str | None = None
for url in urls_to_try:
try:
response = httpx.get(url, headers=headers)
response = ssrf_proxy.get(url, headers=headers)
if response.status_code == 404:
continue # Try next URL
continue
if not response.is_success:
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
response.raise_for_status()
return OAuthMetadata.model_validate(response.json())
except httpx.RequestError as e:
if isinstance(e, httpx.ConnectError):
response = httpx.get(url)
except (RequestError, HTTPStatusError) as e:
if isinstance(e, ConnectError):
response = ssrf_proxy.get(url)
if response.status_code == 404:
continue # Try next URL
if not response.is_success:
@ -232,7 +233,7 @@ def exchange_authorization(
redirect_uri: str,
) -> OAuthTokens:
"""Exchanges an authorization code for an access token."""
grant_type = "authorization_code"
grant_type = MCPSupportGrantType.AUTHORIZATION_CODE.value
if metadata:
token_url = metadata.token_endpoint
@ -252,7 +253,7 @@ def exchange_authorization(
if client_information.client_secret:
params["client_secret"] = client_information.client_secret
response = httpx.post(token_url, data=params)
response = ssrf_proxy.post(token_url, data=params)
if not response.is_success:
raise ValueError(f"Token exchange failed: HTTP {response.status_code}")
return OAuthTokens.model_validate(response.json())
@ -265,7 +266,7 @@ def refresh_authorization(
refresh_token: str,
) -> OAuthTokens:
"""Exchange a refresh token for an updated access token."""
grant_type = "refresh_token"
grant_type = MCPSupportGrantType.REFRESH_TOKEN.value
if metadata:
token_url = metadata.token_endpoint
@ -283,7 +284,7 @@ def refresh_authorization(
if client_information.client_secret:
params["client_secret"] = client_information.client_secret
response = httpx.post(token_url, data=params)
response = ssrf_proxy.post(token_url, data=params)
if not response.is_success:
raise ValueError(f"Token refresh failed: HTTP {response.status_code}")
return OAuthTokens.model_validate(response.json())
@ -296,7 +297,7 @@ def client_credentials_flow(
scope: str | None = None,
) -> OAuthTokens:
"""Execute Client Credentials Flow to get access token."""
grant_type = "client_credentials"
grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS.value
if metadata:
token_url = metadata.token_endpoint
@ -323,7 +324,7 @@ def client_credentials_flow(
if client_information.client_secret:
data["client_secret"] = client_information.client_secret
response = httpx.post(token_url, headers=headers, data=data)
response = ssrf_proxy.post(token_url, headers=headers, data=data)
if not response.is_success:
raise ValueError(
f"Client credentials token request failed: HTTP {response.status_code}, Response: {response.text}"
@ -345,7 +346,7 @@ def register_client(
else:
registration_url = urljoin(server_url, "/register")
response = httpx.post(
response = ssrf_proxy.post(
registration_url,
json=client_metadata.model_dump(),
headers={"Content-Type": "application/json"},
@ -360,7 +361,6 @@ def auth(
mcp_service: "MCPToolManageService",
authorization_code: str | None = None,
state_param: str | None = None,
grant_type: str = "authorization_code",
) -> dict[str, str]:
"""Orchestrates the full auth flow with a server using secure Redis state storage."""
server_url = provider.decrypt_server_url()
@ -371,25 +371,37 @@ def auth(
client_information = provider.retrieve_client_information()
redirect_url = provider.redirect_url
# Check if we should use client credentials flow
credentials = provider.decrypt_credentials()
stored_grant_type = credentials.get("grant_type", "authorization_code")
# Determine grant type based on server metadata
if not server_metadata:
raise ValueError("Failed to discover OAuth metadata from server")
# Use stored grant type if available, otherwise use parameter
effective_grant_type = stored_grant_type or grant_type
supported_grant_types = server_metadata.grant_types_supported or []
# Convert to lowercase for comparison
supported_grant_types_lower = [gt.lower() for gt in supported_grant_types]
# Determine which grant type to use
effective_grant_type = None
if MCPSupportGrantType.AUTHORIZATION_CODE.value in supported_grant_types_lower:
effective_grant_type = MCPSupportGrantType.AUTHORIZATION_CODE.value
else:
effective_grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS.value
# Get stored credentials
credentials = provider.decrypt_credentials()
if not client_information:
if authorization_code is not None:
raise ValueError("Existing OAuth client information is required when exchanging an authorization code")
# For client credentials flow, we don't need to register client dynamically
if effective_grant_type == "client_credentials":
if effective_grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS.value:
# Client should provide client_id and client_secret directly
raise ValueError("Client credentials flow requires client_id and client_secret to be provided")
try:
full_information = register_client(server_url, server_metadata, client_metadata)
except httpx.RequestError as e:
except RequestError as e:
raise ValueError(f"Could not register OAuth client: {e}")
# Save client information using service layer
@ -400,7 +412,7 @@ def auth(
client_information = full_information
# Handle client credentials flow
if effective_grant_type == "client_credentials":
if effective_grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS.value:
# Direct token request without user interaction
try:
scope = credentials.get("scope")
@ -413,11 +425,14 @@ def auth(
# Save tokens and grant type
token_data = tokens.model_dump()
token_data["grant_type"] = "client_credentials"
token_data["grant_type"] = MCPSupportGrantType.CLIENT_CREDENTIALS.value
mcp_service.save_oauth_data(provider_id, tenant_id, token_data, "tokens")
return {"result": "success"}
except Exception as e:
except (RequestError, ValueError, KeyError) as e:
# RequestError: HTTP request failed
# ValueError: Invalid response data
# KeyError: Missing required fields in response
raise ValueError(f"Client credentials flow failed: {e}")
# Exchange authorization code for tokens (Authorization Code flow)
@ -465,7 +480,10 @@ def auth(
mcp_service.save_oauth_data(provider_id, tenant_id, new_tokens.model_dump(), "tokens")
return {"result": "success"}
except Exception as e:
except (RequestError, ValueError, KeyError) as e:
# RequestError: HTTP request failed
# ValueError: Invalid response data
# KeyError: Missing required fields in response
raise ValueError(f"Could not refresh OAuth tokens: {e}")
# Start new authorization flow (only for authorization code flow)

View File

@ -99,7 +99,11 @@ class MCPClientWithAuthRetry(MCPClient):
# Clear authorization code after first use
self.authorization_code = None
except MCPAuthError:
# Re-raise MCPAuthError as is
raise
except Exception as e:
# Catch all exceptions during auth retry
logger.exception("Authentication retry failed")
raise MCPAuthError(f"Authentication retry failed: {e}") from e

View File

@ -4,6 +4,7 @@ from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool import ToolParameter
from core.tools.entities.common_entities import I18nObject
@ -47,9 +48,9 @@ class ToolProviderApiEntity(BaseModel):
masked_headers: dict[str, str] | None = Field(default=None, description="The masked headers of the MCP tool")
original_headers: dict[str, str] | None = Field(default=None, description="The original headers of the MCP tool")
authentication: dict[str, str] | None = Field(default=None, description="The OAuth config of the MCP tool")
authentication: MCPAuthentication | None = Field(default=None, description="The OAuth config of the MCP tool")
is_dynamic_registration: bool = Field(default=True, description="Whether the MCP tool is dynamically registered")
configuration: dict[str, str] | None = Field(
configuration: MCPConfiguration | None = Field(
default=None, description="The timeout and sse_read_timeout of the MCP tool"
)
@ -74,8 +75,14 @@ class ToolProviderApiEntity(BaseModel):
if self.type == ToolProviderType.MCP:
optional_fields.update(self.optional_field("updated_at", self.updated_at))
optional_fields.update(self.optional_field("server_identifier", self.server_identifier))
optional_fields.update(self.optional_field("configuration", self.configuration))
optional_fields.update(self.optional_field("authentication", self.authentication))
optional_fields.update(
self.optional_field(
"configuration", self.configuration.model_dump() if self.configuration else MCPConfiguration()
)
)
optional_fields.update(
self.optional_field("authentication", self.authentication.model_dump() if self.authentication else None)
)
optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration))
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
optional_fields.update(self.optional_field("original_headers", self.original_headers))

View File

@ -61,10 +61,7 @@ class MCPToolProviderController(ToolProviderController):
"""
create a MCPToolProviderController from a MCPProviderEntity
"""
try:
remote_mcp_tools = [RemoteMCPTool(**tool) for tool in entity.tools]
except Exception:
remote_mcp_tools = []
remote_mcp_tools = [RemoteMCPTool(**tool) for tool in entity.tools]
tools = [
ToolEntity(
@ -87,7 +84,7 @@ class MCPToolProviderController(ToolProviderController):
)
for remote_mcp_tool in remote_mcp_tools
]
if not db_provider.icon:
if not entity.icon:
raise ValueError("Database provider icon is required")
return cls(
entity=ToolProviderEntityWithPlugin(

View File

@ -60,11 +60,18 @@ class MCPTool(Tool):
def _process_text_content(self, content: TextContent) -> Generator[ToolInvokeMessage, None, None]:
"""Process text content and yield appropriate messages."""
try:
content_json = json.loads(content.text)
yield from self._process_json_content(content_json)
except json.JSONDecodeError:
yield self.create_text_message(content.text)
# Check if content looks like JSON before attempting to parse
text = content.text.strip()
if text and text[0] in ("{", "[") and text[-1] in ("}", "]"):
try:
content_json = json.loads(text)
yield from self._process_json_content(content_json)
return
except json.JSONDecodeError:
pass
# If not JSON or parsing failed, treat as plain text
yield self.create_text_message(content.text)
def _process_json_content(self, content_json: Any) -> Generator[ToolInvokeMessage, None, None]:
"""Process JSON content based on its type."""
@ -119,10 +126,6 @@ class MCPTool(Tool):
tool_parameters = self._handle_none_parameter(tool_parameters)
# Get provider entity to access tokens
from typing import TYPE_CHECKING
if TYPE_CHECKING:
pass
# Get MCP service from invoke parameters or create new one
provider_entity = None
@ -131,18 +134,7 @@ class MCPTool(Tool):
# Check if mcp_service is passed in tool_parameters
if "_mcp_service" in tool_parameters:
mcp_service = tool_parameters.pop("_mcp_service")
else:
# Fallback to creating service with database session
from sqlalchemy.orm import Session
from extensions.ext_database import db
from services.tools.mcp_tools_manage_service import MCPToolManageService
with Session(db.engine) as session:
mcp_service = MCPToolManageService(session=session)
if mcp_service:
try:
if mcp_service:
provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True)
headers = provider_entity.decrypt_headers()
# Try to get existing token and add to headers
@ -150,23 +142,54 @@ class MCPTool(Tool):
tokens = provider_entity.retrieve_tokens()
if tokens and tokens.access_token:
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
except Exception:
# If provider retrieval or token fails, continue without auth
pass
# Use MCPClientWithAuthRetry to handle authentication automatically
try:
with MCPClientWithAuthRetry(
server_url=provider_entity.decrypt_server_url() if provider_entity else self.server_url,
headers=headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
provider_entity=provider_entity,
auth_callback=auth if mcp_service else None,
mcp_service=mcp_service,
) as mcp_client:
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
except MCPConnectionError as e:
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
except Exception as e:
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
# Use MCPClientWithAuthRetry to handle authentication automatically
try:
with MCPClientWithAuthRetry(
server_url=provider_entity.decrypt_server_url() if provider_entity else self.server_url,
headers=headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
provider_entity=provider_entity,
auth_callback=auth if mcp_service else None,
mcp_service=mcp_service,
) as mcp_client:
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
except MCPConnectionError as e:
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
except (ValueError, TypeError, KeyError) as e:
# Catch specific exceptions that might occur during tool invocation
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
else:
# Fallback to creating service with database session
from sqlalchemy.orm import Session
from extensions.ext_database import db
from services.tools.mcp_tools_manage_service import MCPToolManageService
with Session(db.engine, expire_on_commit=False) as session:
mcp_service = MCPToolManageService(session=session)
provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True)
headers = provider_entity.decrypt_headers()
# Try to get existing token and add to headers
if not headers:
tokens = provider_entity.retrieve_tokens()
if tokens and tokens.access_token:
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
# Use MCPClientWithAuthRetry to handle authentication automatically
try:
with MCPClientWithAuthRetry(
server_url=provider_entity.decrypt_server_url() if provider_entity else self.server_url,
headers=headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
provider_entity=provider_entity,
auth_callback=auth if mcp_service else None,
mcp_service=mcp_service,
) as mcp_client:
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
except MCPConnectionError as e:
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
except Exception as e:
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e

View File

@ -9,8 +9,7 @@ from sqlalchemy import or_, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from configs import dify_config
from core.entities.mcp_provider import MCPProviderEntity
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity
from core.helper import encrypter
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.mcp.auth_client import MCPClientWithAuthRetry
@ -24,7 +23,6 @@ logger = logging.getLogger(__name__)
# Constants
UNCHANGED_SERVER_URL_PLACEHOLDER = "[__HIDDEN__]"
DEFAULT_GRANT_TYPE = "authorization_code"
CLIENT_NAME = "Dify"
EMPTY_TOOLS_JSON = "[]"
EMPTY_CREDENTIALS_JSON = "{}"
@ -88,12 +86,9 @@ class MCPToolManageService:
icon_type: str,
icon_background: str,
server_identifier: str,
timeout: float,
sse_read_timeout: float,
configuration: MCPConfiguration,
authentication: MCPAuthentication | None = None,
headers: dict[str, str] | None = None,
client_id: str | None = None,
client_secret: str | None = None,
grant_type: str = DEFAULT_GRANT_TYPE,
) -> ToolProviderApiEntity:
"""Create a new MCP provider."""
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
@ -104,9 +99,11 @@ class MCPToolManageService:
# Encrypt sensitive data
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
encrypted_headers = self._prepare_encrypted_dict(headers, tenant_id) if headers else None
if client_id and client_secret:
if authentication is not None:
# Build the full credentials structure with encrypted client_id and client_secret
encrypted_credentials = self._build_and_encrypt_credentials(client_id, client_secret, grant_type, tenant_id)
encrypted_credentials = self._build_and_encrypt_credentials(
authentication.client_id, authentication.client_secret, tenant_id
)
else:
encrypted_credentials = None
# Create provider
@ -120,16 +117,16 @@ class MCPToolManageService:
tools=EMPTY_TOOLS_JSON,
icon=self._prepare_icon(icon, icon_type, icon_background),
server_identifier=server_identifier,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
timeout=configuration.timeout,
sse_read_timeout=configuration.sse_read_timeout,
encrypted_headers=encrypted_headers,
encrypted_credentials=encrypted_credentials,
)
self._session.add(mcp_tool)
self._session.commit()
return ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
self._session.flush()
mcp_providers = ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
return mcp_providers
def update_provider(
self,
@ -142,12 +139,9 @@ class MCPToolManageService:
icon_type: str,
icon_background: str,
server_identifier: str,
timeout: float | None = None,
sse_read_timeout: float | None = None,
headers: dict[str, str] | None = None,
client_id: str | None = None,
client_secret: str | None = None,
grant_type: str | None = None,
configuration: MCPConfiguration,
authentication: MCPAuthentication | None = None,
) -> None:
"""Update an MCP provider."""
mcp_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
@ -185,10 +179,10 @@ class MCPToolManageService:
mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"]
# Update optional fields
if timeout is not None:
mcp_provider.timeout = timeout
if sse_read_timeout is not None:
mcp_provider.sse_read_timeout = sse_read_timeout
if configuration.timeout is not None:
mcp_provider.timeout = configuration.timeout
if configuration.sse_read_timeout is not None:
mcp_provider.sse_read_timeout = configuration.sse_read_timeout
if headers is not None:
if headers:
# Build headers preserving unchanged masked values
@ -200,20 +194,18 @@ class MCPToolManageService:
mcp_provider.encrypted_headers = None
# Update credentials if provided
if client_id is not None and client_secret is not None:
if authentication is not None:
# Merge with existing credentials to handle masked values
(
final_client_id,
final_client_secret,
final_grant_type,
) = self._merge_credentials_with_masked(client_id, client_secret, grant_type, mcp_provider)
# Use default grant_type if none found
final_grant_type = final_grant_type or DEFAULT_GRANT_TYPE
) = self._merge_credentials_with_masked(
authentication.client_id, authentication.client_secret, mcp_provider
)
# Build and encrypt new credentials
encrypted_credentials = self._build_and_encrypt_credentials(
final_client_id, final_client_secret, final_grant_type, tenant_id
final_client_id, final_client_secret, tenant_id
)
mcp_provider.encrypted_credentials = encrypted_credentials
@ -221,7 +213,11 @@ class MCPToolManageService:
except IntegrityError as e:
self._session.rollback()
self._handle_integrity_error(e, name, server_url, server_identifier)
except Exception:
except (ValueError, AttributeError, TypeError) as e:
# Catch specific exceptions that might occur during update
# ValueError: invalid data provided
# AttributeError: missing required attributes
# TypeError: type conversion errors
self._session.rollback()
raise
@ -271,7 +267,7 @@ class MCPToolManageService:
db_provider.tools = json.dumps([tool.model_dump() for tool in tools])
db_provider.authed = True
db_provider.updated_at = datetime.now()
self._session.commit()
self._session.flush()
# Build API response
return self._build_tool_provider_response(db_provider, provider_entity, tools)
@ -309,7 +305,7 @@ class MCPToolManageService:
if not authed:
provider.tools = EMPTY_TOOLS_JSON
self._session.commit()
self._session.flush()
def save_oauth_data(self, provider_id: str, tenant_id: str, data: dict[str, Any], data_type: str = "mixed") -> None:
"""
@ -495,20 +491,21 @@ class MCPToolManageService:
def _merge_credentials_with_masked(
self,
client_id: str,
client_secret: str,
grant_type: str | None,
client_secret: str | None,
mcp_provider: MCPToolProvider,
) -> tuple[str, str, str | None]:
) -> tuple[
str,
str | None,
]:
"""Merge incoming credentials with existing ones, preserving unchanged masked values.
Args:
client_id: Client ID from frontend (may be masked)
client_secret: Client secret from frontend (may be masked)
grant_type: Grant type from frontend
mcp_provider: The MCP provider instance
Returns:
Tuple of (final_client_id, final_client_secret, grant_type)
Tuple of (final_client_id, final_client_secret)
"""
mcp_provider_entity = mcp_provider.to_entity()
existing_decrypted = mcp_provider_entity.decrypt_credentials()
@ -526,35 +523,18 @@ class MCPToolManageService:
# Use existing decrypted value
final_client_secret = existing_decrypted.get("client_secret", client_secret)
final_grant_type = grant_type if grant_type is not None else existing_decrypted.get("grant_type")
return final_client_id, final_client_secret
return final_client_id, final_client_secret, final_grant_type
def _build_and_encrypt_credentials(
self, client_id: str, client_secret: str, grant_type: str, tenant_id: str
) -> str:
def _build_and_encrypt_credentials(self, client_id: str, client_secret: str | None, tenant_id: str) -> str:
"""Build credentials and encrypt sensitive fields."""
# Create a flat structure with all credential data
credentials_data = {
"client_id": client_id,
"client_secret": client_secret,
"grant_type": grant_type,
"client_name": CLIENT_NAME,
"is_dynamic_registration": False,
}
# Add grant types and response types based on grant_type
if grant_type == "client_credentials":
credentials_data["grant_types"] = json.dumps(["client_credentials"])
credentials_data["response_types"] = json.dumps([])
credentials_data["redirect_uris"] = json.dumps([])
else:
credentials_data["grant_types"] = json.dumps(["authorization_code", "refresh_token"])
credentials_data["response_types"] = json.dumps(["code"])
credentials_data["redirect_uris"] = json.dumps(
[f"{dify_config.CONSOLE_API_URL}/console/api/mcp/oauth/callback"]
)
# Only client_id and client_secret need encryption
secret_fields = ["client_id", "client_secret"]
secret_fields = ["client_id", "client_secret"] if client_secret else ["client_id"]
return self._encrypt_dict_fields(credentials_data, secret_fields, tenant_id)

View File

@ -6,6 +6,7 @@ from typing import Any, Union
from yarl import URL
from configs import dify_config
from core.entities.mcp_provider import MCPConfiguration
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.mcp.types import Tool as MCPTool
from core.plugin.entities.plugin_daemon import PluginDatasourceProviderEntity
@ -246,6 +247,13 @@ class ToolTransformService:
)
response["server_identifier"] = db_provider.server_identifier
# Convert configuration dict to MCPConfiguration object
if "configuration" in response and isinstance(response["configuration"], dict):
response["configuration"] = MCPConfiguration(
timeout=float(response["configuration"]["timeout"]),
sse_read_timeout=float(response["configuration"]["sse_read_timeout"]),
)
return ToolProviderApiEntity(**response)
@staticmethod

View File

@ -139,7 +139,7 @@ class TestRedisStateManagement:
class TestOAuthDiscovery:
"""Test OAuth discovery functions."""
@patch("httpx.get")
@patch("core.helper.ssrf_proxy.get")
def test_check_support_resource_discovery_success(self, mock_get):
"""Test successful resource discovery check."""
mock_response = Mock()
@ -152,11 +152,11 @@ class TestOAuthDiscovery:
assert supported is True
assert auth_url == "https://auth.example.com"
mock_get.assert_called_once_with(
"https://api.example.com/.well-known/oauth-protected-resource/endpoint",
"https://api.example.com/.well-known/oauth-protected-resource",
headers={"MCP-Protocol-Version": "2025-03-26", "User-Agent": "Dify"},
)
@patch("httpx.get")
@patch("core.helper.ssrf_proxy.get")
def test_check_support_resource_discovery_not_supported(self, mock_get):
"""Test resource discovery not supported."""
mock_response = Mock()
@ -168,7 +168,7 @@ class TestOAuthDiscovery:
assert supported is False
assert auth_url == ""
@patch("httpx.get")
@patch("core.helper.ssrf_proxy.get")
def test_check_support_resource_discovery_with_query_fragment(self, mock_get):
"""Test resource discovery with query and fragment."""
mock_response = Mock()
@ -181,11 +181,11 @@ class TestOAuthDiscovery:
assert supported is True
assert auth_url == "https://auth.example.com"
mock_get.assert_called_once_with(
"https://api.example.com/.well-known/oauth-protected-resource/path?query=1#fragment",
"https://api.example.com/.well-known/oauth-protected-resource?query=1#fragment",
headers={"MCP-Protocol-Version": "2025-03-26", "User-Agent": "Dify"},
)
@patch("httpx.get")
@patch("core.helper.ssrf_proxy.get")
def test_discover_oauth_metadata_with_resource_discovery(self, mock_get):
"""Test OAuth metadata discovery with resource discovery support."""
with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check:
@ -207,11 +207,11 @@ class TestOAuthDiscovery:
assert metadata.authorization_endpoint == "https://auth.example.com/authorize"
assert metadata.token_endpoint == "https://auth.example.com/token"
mock_get.assert_called_once_with(
"https://auth.example.com/.well-known/openid-configuration",
"https://auth.example.com/.well-known/oauth-authorization-server",
headers={"MCP-Protocol-Version": "2025-03-26"},
)
@patch("httpx.get")
@patch("core.helper.ssrf_proxy.get")
def test_discover_oauth_metadata_without_resource_discovery(self, mock_get):
"""Test OAuth metadata discovery without resource discovery."""
with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check:
@ -236,7 +236,7 @@ class TestOAuthDiscovery:
headers={"MCP-Protocol-Version": "2025-03-26"},
)
@patch("httpx.get")
@patch("core.helper.ssrf_proxy.get")
def test_discover_oauth_metadata_not_found(self, mock_get):
"""Test OAuth metadata discovery when not found."""
with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check:
@ -336,7 +336,7 @@ class TestAuthorizationFlow:
assert "does not support response type code" in str(exc_info.value)
@patch("httpx.post")
@patch("core.helper.ssrf_proxy.post")
def test_exchange_authorization_success(self, mock_post):
"""Test successful authorization code exchange."""
mock_response = Mock()
@ -384,7 +384,7 @@ class TestAuthorizationFlow:
},
)
@patch("httpx.post")
@patch("core.helper.ssrf_proxy.post")
def test_exchange_authorization_failure(self, mock_post):
"""Test failed authorization code exchange."""
mock_response = Mock()
@ -406,7 +406,7 @@ class TestAuthorizationFlow:
assert "Token exchange failed: HTTP 400" in str(exc_info.value)
@patch("httpx.post")
@patch("core.helper.ssrf_proxy.post")
def test_refresh_authorization_success(self, mock_post):
"""Test successful token refresh."""
mock_response = Mock()
@ -442,7 +442,7 @@ class TestAuthorizationFlow:
},
)
@patch("httpx.post")
@patch("core.helper.ssrf_proxy.post")
def test_register_client_success(self, mock_post):
"""Test successful client registration."""
mock_response = Mock()
@ -576,7 +576,12 @@ class TestAuthOrchestration:
def test_auth_new_registration(self, mock_start_auth, mock_register, mock_discover, mock_provider, mock_service):
"""Test auth flow for new client registration."""
# Setup
mock_discover.return_value = None
mock_discover.return_value = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
)
mock_register.return_value = OAuthClientInformationFull(
client_id="new-client-id",
client_name="Dify",
@ -679,7 +684,12 @@ class TestAuthOrchestration:
mock_refresh.return_value = new_tokens
with patch("core.mcp.auth.auth_flow.discover_oauth_metadata") as mock_discover:
mock_discover.return_value = None
mock_discover.return_value = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
)
result = auth(mock_provider, mock_service)

View File

@ -228,8 +228,7 @@ class TestMCPToolTransform:
"masked_headers": {"Authorization": "Bearer *****"},
"updated_at": 1234567890,
"labels": [],
"timeout": 30,
"sse_read_timeout": 300,
"configuration": {"timeout": "30", "sse_read_timeout": "300"},
"original_headers": {"Authorization": "Bearer secret-token"},
"author": "Test User",
"description": I18nObject(en_US="Test MCP Provider Description", zh_Hans="Test MCP Provider Description"),
@ -246,8 +245,9 @@ class TestMCPToolTransform:
assert isinstance(result, ToolProviderApiEntity)
assert result.id == "server-identifier-456" # Should use server_identifier when for_list=False
assert result.server_identifier == "server-identifier-456"
assert result.timeout == 30
assert result.sse_read_timeout == 300
assert result.configuration is not None
assert result.configuration.timeout == 30
assert result.configuration.sse_read_timeout == 300
assert result.original_headers == {"Authorization": "Bearer secret-token"}
assert len(result.tools) == 1
assert result.tools[0].description.en_US == "Tool description"