mirror of
https://github.com/langgenius/dify.git
synced 2026-04-15 01:38:19 +08:00
chore: fix review issues
This commit is contained in:
parent
d5a7a537e5
commit
5c6a2af448
@ -16,7 +16,7 @@ from controllers.console.wraps import (
|
||||
enterprise_license_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPSupportGrantType
|
||||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
|
||||
from core.mcp.auth.auth_flow import auth, handle_callback
|
||||
from core.mcp.error import MCPAuthError, MCPError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
@ -44,7 +44,9 @@ def is_valid_url(url: str) -> bool:
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
return all([parsed.scheme, parsed.netloc]) and parsed.scheme in ["http", "https"]
|
||||
except Exception:
|
||||
except (ValueError, TypeError):
|
||||
# ValueError: Invalid URL format
|
||||
# TypeError: url is not a string
|
||||
return False
|
||||
|
||||
|
||||
@ -886,7 +888,7 @@ class ToolProviderMCPApi(Resource):
|
||||
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
||||
|
||||
# Create provider
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
result = service.create_provider(
|
||||
tenant_id=tenant_id,
|
||||
@ -897,14 +899,10 @@ class ToolProviderMCPApi(Resource):
|
||||
icon_type=args["icon_type"],
|
||||
icon_background=args["icon_background"],
|
||||
server_identifier=args["server_identifier"],
|
||||
timeout=configuration.timeout,
|
||||
sse_read_timeout=configuration.sse_read_timeout,
|
||||
headers=args["headers"],
|
||||
client_id=authentication.client_id if authentication else None,
|
||||
client_secret=authentication.client_secret if authentication else None,
|
||||
grant_type=authentication.grant_type if authentication else MCPSupportGrantType.AUTHORIZATION_CODE,
|
||||
configuration=configuration,
|
||||
authentication=authentication,
|
||||
)
|
||||
session.commit()
|
||||
return jsonable_encoder(result)
|
||||
|
||||
@setup_required
|
||||
@ -932,7 +930,7 @@ class ToolProviderMCPApi(Resource):
|
||||
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.update_provider(
|
||||
tenant_id=current_tenant_id,
|
||||
@ -943,14 +941,10 @@ class ToolProviderMCPApi(Resource):
|
||||
icon_type=args["icon_type"],
|
||||
icon_background=args["icon_background"],
|
||||
server_identifier=args["server_identifier"],
|
||||
timeout=configuration.timeout,
|
||||
sse_read_timeout=configuration.sse_read_timeout,
|
||||
headers=args["headers"],
|
||||
client_id=authentication.client_id if authentication else None,
|
||||
client_secret=authentication.client_secret if authentication else None,
|
||||
grant_type=authentication.grant_type if authentication else MCPSupportGrantType.AUTHORIZATION_CODE,
|
||||
configuration=configuration,
|
||||
authentication=authentication,
|
||||
)
|
||||
session.commit()
|
||||
return {"result": "success"}
|
||||
|
||||
@setup_required
|
||||
@ -962,10 +956,9 @@ class ToolProviderMCPApi(Resource):
|
||||
args = parser.parse_args()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.delete_provider(tenant_id=current_tenant_id.current_tenant_id, provider_id=args["provider_id"])
|
||||
session.commit()
|
||||
service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"])
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@ -983,23 +976,18 @@ class ToolMCPAuthApi(Resource):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
if not db_provider:
|
||||
raise ValueError("provider not found")
|
||||
with session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
if not db_provider:
|
||||
raise ValueError("provider not found")
|
||||
|
||||
# Convert to entity
|
||||
provider_entity = db_provider.to_entity()
|
||||
server_url = provider_entity.decrypt_server_url()
|
||||
# Convert to entity
|
||||
provider_entity = db_provider.to_entity()
|
||||
server_url = provider_entity.decrypt_server_url()
|
||||
headers = provider_entity.decrypt_authentication()
|
||||
|
||||
# Option 1: if headers is provided, use it and don't need to get token
|
||||
headers = provider_entity.decrypt_headers()
|
||||
|
||||
# Option 2: Add OAuth token if authed and no headers provided
|
||||
if not provider_entity.headers and provider_entity.authed:
|
||||
token = provider_entity.retrieve_tokens()
|
||||
if token:
|
||||
headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
|
||||
# Try to connect without active transaction
|
||||
try:
|
||||
# Use MCPClientWithAuthRetry to handle authentication automatically
|
||||
with MCPClient(
|
||||
@ -1008,18 +996,20 @@ class ToolMCPAuthApi(Resource):
|
||||
timeout=provider_entity.timeout,
|
||||
sse_read_timeout=provider_entity.sse_read_timeout,
|
||||
):
|
||||
service.update_provider_credentials(
|
||||
provider=db_provider,
|
||||
credentials=provider_entity.credentials,
|
||||
authed=True,
|
||||
)
|
||||
session.commit()
|
||||
# Create new transaction for update
|
||||
with session.begin():
|
||||
service.update_provider_credentials(
|
||||
provider=db_provider,
|
||||
credentials=provider_entity.credentials,
|
||||
authed=True,
|
||||
)
|
||||
return {"result": "success"}
|
||||
except MCPAuthError as e:
|
||||
service = MCPToolManageService(session=session)
|
||||
return auth(provider_entity, service, args.get("authorization_code"))
|
||||
except MCPError as e:
|
||||
service.clear_provider_credentials(provider=db_provider)
|
||||
session.commit()
|
||||
with session.begin():
|
||||
service.clear_provider_credentials(provider=db_provider)
|
||||
raise ValueError(f"Failed to connect to MCP server: {e}") from e
|
||||
|
||||
|
||||
@ -1044,7 +1034,7 @@ class ToolMCPListAllApi(Resource):
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
tools = service.list_providers(tenant_id=tenant_id)
|
||||
|
||||
@ -1058,7 +1048,7 @@ class ToolMCPUpdateApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, provider_id):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
tools = service.list_provider_tools(
|
||||
tenant_id=tenant_id,
|
||||
@ -1078,9 +1068,8 @@ class ToolMCPCallbackApi(Resource):
|
||||
authorization_code = args["code"]
|
||||
|
||||
# Create service instance for handle_callback
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine) as session, session.begin():
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
handle_callback(state_key, authorization_code, mcp_service)
|
||||
session.commit()
|
||||
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
@ -4,7 +4,7 @@ from enum import StrEnum
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
@ -20,7 +20,6 @@ if TYPE_CHECKING:
|
||||
from models.tools import MCPToolProvider
|
||||
|
||||
# Constants
|
||||
DEFAULT_GRANT_TYPE = "authorization_code"
|
||||
CLIENT_NAME = "Dify"
|
||||
CLIENT_URI = "https://github.com/langgenius/dify"
|
||||
DEFAULT_TOKEN_TYPE = "Bearer"
|
||||
@ -34,12 +33,12 @@ class MCPSupportGrantType(StrEnum):
|
||||
|
||||
AUTHORIZATION_CODE = "authorization_code"
|
||||
CLIENT_CREDENTIALS = "client_credentials"
|
||||
REFRESH_TOKEN = "refresh_token"
|
||||
|
||||
|
||||
class MCPAuthentication(BaseModel):
|
||||
client_id: str
|
||||
client_secret: str | None = None
|
||||
grant_type: MCPSupportGrantType = Field(default=MCPSupportGrantType.AUTHORIZATION_CODE)
|
||||
|
||||
|
||||
class MCPConfiguration(BaseModel):
|
||||
@ -110,7 +109,7 @@ class MCPProviderEntity(BaseModel):
|
||||
credentials = self.decrypt_credentials()
|
||||
|
||||
# Try to get grant_type from different locations
|
||||
grant_type = credentials.get("grant_type", DEFAULT_GRANT_TYPE)
|
||||
grant_type = credentials.get("grant_type", MCPSupportGrantType.AUTHORIZATION_CODE)
|
||||
|
||||
# For nested structure, check if client_information has grant_types
|
||||
if "client_information" in credentials and isinstance(credentials["client_information"], dict):
|
||||
@ -118,12 +117,12 @@ class MCPProviderEntity(BaseModel):
|
||||
# If grant_types is specified in client_information, use it to determine grant_type
|
||||
if "grant_types" in client_info and isinstance(client_info["grant_types"], list):
|
||||
if "client_credentials" in client_info["grant_types"]:
|
||||
grant_type = "client_credentials"
|
||||
grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS
|
||||
elif "authorization_code" in client_info["grant_types"]:
|
||||
grant_type = "authorization_code"
|
||||
grant_type = MCPSupportGrantType.AUTHORIZATION_CODE
|
||||
|
||||
# Configure based on grant type
|
||||
is_client_credentials = grant_type == "client_credentials"
|
||||
is_client_credentials = grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS
|
||||
|
||||
grant_types = ["refresh_token"]
|
||||
grant_types.append("client_credentials" if is_client_credentials else "authorization_code")
|
||||
@ -212,10 +211,7 @@ class MCPProviderEntity(BaseModel):
|
||||
json_fields = ["redirect_uris", "grant_types", "response_types"]
|
||||
for field in json_fields:
|
||||
if field in credentials:
|
||||
try:
|
||||
client_info[field] = json.loads(credentials[field])
|
||||
except:
|
||||
client_info[field] = []
|
||||
client_info[field] = json.loads(credentials[field])
|
||||
|
||||
if "scope" in credentials:
|
||||
client_info["scope"] = credentials["scope"]
|
||||
@ -237,10 +233,10 @@ class MCPProviderEntity(BaseModel):
|
||||
def masked_server_url(self) -> str:
|
||||
"""Masked server URL for display"""
|
||||
parsed = urlparse(self.decrypt_server_url())
|
||||
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
||||
if parsed.path and parsed.path != "/":
|
||||
return f"{base_url}/******"
|
||||
return base_url
|
||||
masked = parsed._replace(path="/******")
|
||||
return masked.geturl()
|
||||
return parsed.geturl()
|
||||
|
||||
def _mask_value(self, value: str) -> str:
|
||||
"""Mask a sensitive value for display"""
|
||||
@ -289,46 +285,41 @@ class MCPProviderEntity(BaseModel):
|
||||
|
||||
def _decrypt_dict(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Generic method to decrypt dictionary fields"""
|
||||
try:
|
||||
if not data:
|
||||
return {}
|
||||
|
||||
# Only decrypt fields that are actually encrypted
|
||||
# For nested structures, client_information is not encrypted as a whole
|
||||
encrypted_fields = []
|
||||
for key, value in data.items():
|
||||
# Skip nested objects - they are not encrypted
|
||||
if isinstance(value, dict):
|
||||
continue
|
||||
# Only process string values that might be encrypted
|
||||
if isinstance(value, str) and value:
|
||||
encrypted_fields.append(key)
|
||||
|
||||
if not encrypted_fields:
|
||||
return data
|
||||
|
||||
# Create dynamic config only for encrypted fields
|
||||
config = [
|
||||
BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in encrypted_fields
|
||||
]
|
||||
|
||||
encrypter_instance, _ = create_provider_encrypter(
|
||||
tenant_id=self.tenant_id,
|
||||
config=config,
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
# Decrypt only the encrypted fields
|
||||
decrypted_data = encrypter_instance.decrypt({k: data[k] for k in encrypted_fields})
|
||||
|
||||
# Merge decrypted data with original data (preserving non-encrypted fields)
|
||||
result = data.copy()
|
||||
result.update(decrypted_data)
|
||||
|
||||
return result
|
||||
except Exception:
|
||||
if not data:
|
||||
return {}
|
||||
|
||||
# Only decrypt fields that are actually encrypted
|
||||
# For nested structures, client_information is not encrypted as a whole
|
||||
encrypted_fields = []
|
||||
for key, value in data.items():
|
||||
# Skip nested objects - they are not encrypted
|
||||
if isinstance(value, dict):
|
||||
continue
|
||||
# Only process string values that might be encrypted
|
||||
if isinstance(value, str) and value:
|
||||
encrypted_fields.append(key)
|
||||
|
||||
if not encrypted_fields:
|
||||
return data
|
||||
|
||||
# Create dynamic config only for encrypted fields
|
||||
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in encrypted_fields]
|
||||
|
||||
encrypter_instance, _ = create_provider_encrypter(
|
||||
tenant_id=self.tenant_id,
|
||||
config=config,
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
# Decrypt only the encrypted fields
|
||||
decrypted_data = encrypter_instance.decrypt({k: data[k] for k in encrypted_fields})
|
||||
|
||||
# Merge decrypted data with original data (preserving non-encrypted fields)
|
||||
result = data.copy()
|
||||
result.update(decrypted_data)
|
||||
|
||||
return result
|
||||
|
||||
def decrypt_headers(self) -> dict[str, Any]:
|
||||
"""Decrypt headers"""
|
||||
return self._decrypt_dict(self.headers)
|
||||
@ -336,3 +327,15 @@ class MCPProviderEntity(BaseModel):
|
||||
def decrypt_credentials(self) -> dict[str, Any]:
|
||||
"""Decrypt credentials"""
|
||||
return self._decrypt_dict(self.credentials)
|
||||
|
||||
def decrypt_authentication(self) -> dict[str, Any]:
|
||||
"""Decrypt authentication"""
|
||||
# Option 1: if headers is provided, use it and don't need to get token
|
||||
headers = self.decrypt_headers()
|
||||
|
||||
# Option 2: Add OAuth token if authed and no headers provided
|
||||
if not self.headers and self.authed:
|
||||
token = self.retrieve_tokens()
|
||||
if token:
|
||||
headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
|
||||
return headers
|
||||
|
||||
@ -7,10 +7,11 @@ import urllib.parse
|
||||
from typing import TYPE_CHECKING
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import httpx
|
||||
from httpx import ConnectError, HTTPStatusError, RequestError
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from core.entities.mcp_provider import MCPProviderEntity
|
||||
from core.entities.mcp_provider import MCPProviderEntity, MCPSupportGrantType
|
||||
from core.helper import ssrf_proxy
|
||||
from core.mcp.types import (
|
||||
LATEST_PROTOCOL_VERSION,
|
||||
OAuthClientInformation,
|
||||
@ -106,15 +107,15 @@ def handle_callback(state_key: str, authorization_code: str, mcp_service: "MCPTo
|
||||
|
||||
def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
|
||||
"""Check if the server supports OAuth 2.0 Resource Discovery."""
|
||||
b_scheme, b_netloc, b_path, _, b_query, b_fragment = urlparse(server_url, "", True)
|
||||
url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource{b_path}"
|
||||
b_scheme, b_netloc, _, _, b_query, b_fragment = urlparse(server_url, "", True)
|
||||
url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource"
|
||||
if b_query:
|
||||
url_for_resource_discovery += f"?{b_query}"
|
||||
if b_fragment:
|
||||
url_for_resource_discovery += f"#{b_fragment}"
|
||||
try:
|
||||
headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"}
|
||||
response = httpx.get(url_for_resource_discovery, headers=headers)
|
||||
response = ssrf_proxy.get(url_for_resource_discovery, headers=headers)
|
||||
if 200 <= response.status_code < 300:
|
||||
body = response.json()
|
||||
# Support both singular and plural forms
|
||||
@ -125,7 +126,7 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
|
||||
else:
|
||||
return False, ""
|
||||
return False, ""
|
||||
except httpx.RequestError:
|
||||
except RequestError:
|
||||
# Not support resource discovery, fall back to well-known OAuth metadata
|
||||
return False, ""
|
||||
|
||||
@ -138,8 +139,8 @@ def discover_oauth_metadata(server_url: str, protocol_version: str | None = None
|
||||
# The oauth_discovery_url is the authorization server base URL
|
||||
# Try OpenID Connect discovery first (more common), then OAuth 2.0
|
||||
urls_to_try = [
|
||||
urljoin(oauth_discovery_url + "/", ".well-known/openid-configuration"),
|
||||
urljoin(oauth_discovery_url + "/", ".well-known/oauth-authorization-server"),
|
||||
urljoin(oauth_discovery_url + "/", ".well-known/openid-configuration"),
|
||||
]
|
||||
else:
|
||||
urls_to_try = [urljoin(server_url, "/.well-known/oauth-authorization-server")]
|
||||
@ -148,15 +149,15 @@ def discover_oauth_metadata(server_url: str, protocol_version: str | None = None
|
||||
|
||||
for url in urls_to_try:
|
||||
try:
|
||||
response = httpx.get(url, headers=headers)
|
||||
response = ssrf_proxy.get(url, headers=headers)
|
||||
if response.status_code == 404:
|
||||
continue # Try next URL
|
||||
continue
|
||||
if not response.is_success:
|
||||
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
|
||||
response.raise_for_status()
|
||||
return OAuthMetadata.model_validate(response.json())
|
||||
except httpx.RequestError as e:
|
||||
if isinstance(e, httpx.ConnectError):
|
||||
response = httpx.get(url)
|
||||
except (RequestError, HTTPStatusError) as e:
|
||||
if isinstance(e, ConnectError):
|
||||
response = ssrf_proxy.get(url)
|
||||
if response.status_code == 404:
|
||||
continue # Try next URL
|
||||
if not response.is_success:
|
||||
@ -232,7 +233,7 @@ def exchange_authorization(
|
||||
redirect_uri: str,
|
||||
) -> OAuthTokens:
|
||||
"""Exchanges an authorization code for an access token."""
|
||||
grant_type = "authorization_code"
|
||||
grant_type = MCPSupportGrantType.AUTHORIZATION_CODE.value
|
||||
|
||||
if metadata:
|
||||
token_url = metadata.token_endpoint
|
||||
@ -252,7 +253,7 @@ def exchange_authorization(
|
||||
if client_information.client_secret:
|
||||
params["client_secret"] = client_information.client_secret
|
||||
|
||||
response = httpx.post(token_url, data=params)
|
||||
response = ssrf_proxy.post(token_url, data=params)
|
||||
if not response.is_success:
|
||||
raise ValueError(f"Token exchange failed: HTTP {response.status_code}")
|
||||
return OAuthTokens.model_validate(response.json())
|
||||
@ -265,7 +266,7 @@ def refresh_authorization(
|
||||
refresh_token: str,
|
||||
) -> OAuthTokens:
|
||||
"""Exchange a refresh token for an updated access token."""
|
||||
grant_type = "refresh_token"
|
||||
grant_type = MCPSupportGrantType.REFRESH_TOKEN.value
|
||||
|
||||
if metadata:
|
||||
token_url = metadata.token_endpoint
|
||||
@ -283,7 +284,7 @@ def refresh_authorization(
|
||||
if client_information.client_secret:
|
||||
params["client_secret"] = client_information.client_secret
|
||||
|
||||
response = httpx.post(token_url, data=params)
|
||||
response = ssrf_proxy.post(token_url, data=params)
|
||||
if not response.is_success:
|
||||
raise ValueError(f"Token refresh failed: HTTP {response.status_code}")
|
||||
return OAuthTokens.model_validate(response.json())
|
||||
@ -296,7 +297,7 @@ def client_credentials_flow(
|
||||
scope: str | None = None,
|
||||
) -> OAuthTokens:
|
||||
"""Execute Client Credentials Flow to get access token."""
|
||||
grant_type = "client_credentials"
|
||||
grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS.value
|
||||
|
||||
if metadata:
|
||||
token_url = metadata.token_endpoint
|
||||
@ -323,7 +324,7 @@ def client_credentials_flow(
|
||||
if client_information.client_secret:
|
||||
data["client_secret"] = client_information.client_secret
|
||||
|
||||
response = httpx.post(token_url, headers=headers, data=data)
|
||||
response = ssrf_proxy.post(token_url, headers=headers, data=data)
|
||||
if not response.is_success:
|
||||
raise ValueError(
|
||||
f"Client credentials token request failed: HTTP {response.status_code}, Response: {response.text}"
|
||||
@ -345,7 +346,7 @@ def register_client(
|
||||
else:
|
||||
registration_url = urljoin(server_url, "/register")
|
||||
|
||||
response = httpx.post(
|
||||
response = ssrf_proxy.post(
|
||||
registration_url,
|
||||
json=client_metadata.model_dump(),
|
||||
headers={"Content-Type": "application/json"},
|
||||
@ -360,7 +361,6 @@ def auth(
|
||||
mcp_service: "MCPToolManageService",
|
||||
authorization_code: str | None = None,
|
||||
state_param: str | None = None,
|
||||
grant_type: str = "authorization_code",
|
||||
) -> dict[str, str]:
|
||||
"""Orchestrates the full auth flow with a server using secure Redis state storage."""
|
||||
server_url = provider.decrypt_server_url()
|
||||
@ -371,25 +371,37 @@ def auth(
|
||||
client_information = provider.retrieve_client_information()
|
||||
redirect_url = provider.redirect_url
|
||||
|
||||
# Check if we should use client credentials flow
|
||||
credentials = provider.decrypt_credentials()
|
||||
stored_grant_type = credentials.get("grant_type", "authorization_code")
|
||||
# Determine grant type based on server metadata
|
||||
if not server_metadata:
|
||||
raise ValueError("Failed to discover OAuth metadata from server")
|
||||
|
||||
# Use stored grant type if available, otherwise use parameter
|
||||
effective_grant_type = stored_grant_type or grant_type
|
||||
supported_grant_types = server_metadata.grant_types_supported or []
|
||||
|
||||
# Convert to lowercase for comparison
|
||||
supported_grant_types_lower = [gt.lower() for gt in supported_grant_types]
|
||||
|
||||
# Determine which grant type to use
|
||||
effective_grant_type = None
|
||||
if MCPSupportGrantType.AUTHORIZATION_CODE.value in supported_grant_types_lower:
|
||||
effective_grant_type = MCPSupportGrantType.AUTHORIZATION_CODE.value
|
||||
else:
|
||||
effective_grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS.value
|
||||
|
||||
# Get stored credentials
|
||||
credentials = provider.decrypt_credentials()
|
||||
|
||||
if not client_information:
|
||||
if authorization_code is not None:
|
||||
raise ValueError("Existing OAuth client information is required when exchanging an authorization code")
|
||||
|
||||
# For client credentials flow, we don't need to register client dynamically
|
||||
if effective_grant_type == "client_credentials":
|
||||
if effective_grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS.value:
|
||||
# Client should provide client_id and client_secret directly
|
||||
raise ValueError("Client credentials flow requires client_id and client_secret to be provided")
|
||||
|
||||
try:
|
||||
full_information = register_client(server_url, server_metadata, client_metadata)
|
||||
except httpx.RequestError as e:
|
||||
except RequestError as e:
|
||||
raise ValueError(f"Could not register OAuth client: {e}")
|
||||
|
||||
# Save client information using service layer
|
||||
@ -400,7 +412,7 @@ def auth(
|
||||
client_information = full_information
|
||||
|
||||
# Handle client credentials flow
|
||||
if effective_grant_type == "client_credentials":
|
||||
if effective_grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS.value:
|
||||
# Direct token request without user interaction
|
||||
try:
|
||||
scope = credentials.get("scope")
|
||||
@ -413,11 +425,14 @@ def auth(
|
||||
|
||||
# Save tokens and grant type
|
||||
token_data = tokens.model_dump()
|
||||
token_data["grant_type"] = "client_credentials"
|
||||
token_data["grant_type"] = MCPSupportGrantType.CLIENT_CREDENTIALS.value
|
||||
mcp_service.save_oauth_data(provider_id, tenant_id, token_data, "tokens")
|
||||
|
||||
return {"result": "success"}
|
||||
except Exception as e:
|
||||
except (RequestError, ValueError, KeyError) as e:
|
||||
# RequestError: HTTP request failed
|
||||
# ValueError: Invalid response data
|
||||
# KeyError: Missing required fields in response
|
||||
raise ValueError(f"Client credentials flow failed: {e}")
|
||||
|
||||
# Exchange authorization code for tokens (Authorization Code flow)
|
||||
@ -465,7 +480,10 @@ def auth(
|
||||
mcp_service.save_oauth_data(provider_id, tenant_id, new_tokens.model_dump(), "tokens")
|
||||
|
||||
return {"result": "success"}
|
||||
except Exception as e:
|
||||
except (RequestError, ValueError, KeyError) as e:
|
||||
# RequestError: HTTP request failed
|
||||
# ValueError: Invalid response data
|
||||
# KeyError: Missing required fields in response
|
||||
raise ValueError(f"Could not refresh OAuth tokens: {e}")
|
||||
|
||||
# Start new authorization flow (only for authorization code flow)
|
||||
|
||||
@ -99,7 +99,11 @@ class MCPClientWithAuthRetry(MCPClient):
|
||||
# Clear authorization code after first use
|
||||
self.authorization_code = None
|
||||
|
||||
except MCPAuthError:
|
||||
# Re-raise MCPAuthError as is
|
||||
raise
|
||||
except Exception as e:
|
||||
# Catch all exceptions during auth retry
|
||||
logger.exception("Authentication retry failed")
|
||||
raise MCPAuthError(f"Authentication retry failed: {e}") from e
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.__base.tool import ToolParameter
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
@ -47,9 +48,9 @@ class ToolProviderApiEntity(BaseModel):
|
||||
|
||||
masked_headers: dict[str, str] | None = Field(default=None, description="The masked headers of the MCP tool")
|
||||
original_headers: dict[str, str] | None = Field(default=None, description="The original headers of the MCP tool")
|
||||
authentication: dict[str, str] | None = Field(default=None, description="The OAuth config of the MCP tool")
|
||||
authentication: MCPAuthentication | None = Field(default=None, description="The OAuth config of the MCP tool")
|
||||
is_dynamic_registration: bool = Field(default=True, description="Whether the MCP tool is dynamically registered")
|
||||
configuration: dict[str, str] | None = Field(
|
||||
configuration: MCPConfiguration | None = Field(
|
||||
default=None, description="The timeout and sse_read_timeout of the MCP tool"
|
||||
)
|
||||
|
||||
@ -74,8 +75,14 @@ class ToolProviderApiEntity(BaseModel):
|
||||
if self.type == ToolProviderType.MCP:
|
||||
optional_fields.update(self.optional_field("updated_at", self.updated_at))
|
||||
optional_fields.update(self.optional_field("server_identifier", self.server_identifier))
|
||||
optional_fields.update(self.optional_field("configuration", self.configuration))
|
||||
optional_fields.update(self.optional_field("authentication", self.authentication))
|
||||
optional_fields.update(
|
||||
self.optional_field(
|
||||
"configuration", self.configuration.model_dump() if self.configuration else MCPConfiguration()
|
||||
)
|
||||
)
|
||||
optional_fields.update(
|
||||
self.optional_field("authentication", self.authentication.model_dump() if self.authentication else None)
|
||||
)
|
||||
optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration))
|
||||
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
|
||||
optional_fields.update(self.optional_field("original_headers", self.original_headers))
|
||||
|
||||
@ -61,10 +61,7 @@ class MCPToolProviderController(ToolProviderController):
|
||||
"""
|
||||
create a MCPToolProviderController from a MCPProviderEntity
|
||||
"""
|
||||
try:
|
||||
remote_mcp_tools = [RemoteMCPTool(**tool) for tool in entity.tools]
|
||||
except Exception:
|
||||
remote_mcp_tools = []
|
||||
remote_mcp_tools = [RemoteMCPTool(**tool) for tool in entity.tools]
|
||||
|
||||
tools = [
|
||||
ToolEntity(
|
||||
@ -87,7 +84,7 @@ class MCPToolProviderController(ToolProviderController):
|
||||
)
|
||||
for remote_mcp_tool in remote_mcp_tools
|
||||
]
|
||||
if not db_provider.icon:
|
||||
if not entity.icon:
|
||||
raise ValueError("Database provider icon is required")
|
||||
return cls(
|
||||
entity=ToolProviderEntityWithPlugin(
|
||||
|
||||
@ -60,11 +60,18 @@ class MCPTool(Tool):
|
||||
|
||||
def _process_text_content(self, content: TextContent) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""Process text content and yield appropriate messages."""
|
||||
try:
|
||||
content_json = json.loads(content.text)
|
||||
yield from self._process_json_content(content_json)
|
||||
except json.JSONDecodeError:
|
||||
yield self.create_text_message(content.text)
|
||||
# Check if content looks like JSON before attempting to parse
|
||||
text = content.text.strip()
|
||||
if text and text[0] in ("{", "[") and text[-1] in ("}", "]"):
|
||||
try:
|
||||
content_json = json.loads(text)
|
||||
yield from self._process_json_content(content_json)
|
||||
return
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# If not JSON or parsing failed, treat as plain text
|
||||
yield self.create_text_message(content.text)
|
||||
|
||||
def _process_json_content(self, content_json: Any) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""Process JSON content based on its type."""
|
||||
@ -119,10 +126,6 @@ class MCPTool(Tool):
|
||||
tool_parameters = self._handle_none_parameter(tool_parameters)
|
||||
|
||||
# Get provider entity to access tokens
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
# Get MCP service from invoke parameters or create new one
|
||||
provider_entity = None
|
||||
@ -131,18 +134,7 @@ class MCPTool(Tool):
|
||||
# Check if mcp_service is passed in tool_parameters
|
||||
if "_mcp_service" in tool_parameters:
|
||||
mcp_service = tool_parameters.pop("_mcp_service")
|
||||
else:
|
||||
# Fallback to creating service with database session
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
|
||||
with Session(db.engine) as session:
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
|
||||
if mcp_service:
|
||||
try:
|
||||
if mcp_service:
|
||||
provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True)
|
||||
headers = provider_entity.decrypt_headers()
|
||||
# Try to get existing token and add to headers
|
||||
@ -150,23 +142,54 @@ class MCPTool(Tool):
|
||||
tokens = provider_entity.retrieve_tokens()
|
||||
if tokens and tokens.access_token:
|
||||
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
|
||||
except Exception:
|
||||
# If provider retrieval or token fails, continue without auth
|
||||
pass
|
||||
|
||||
# Use MCPClientWithAuthRetry to handle authentication automatically
|
||||
try:
|
||||
with MCPClientWithAuthRetry(
|
||||
server_url=provider_entity.decrypt_server_url() if provider_entity else self.server_url,
|
||||
headers=headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
provider_entity=provider_entity,
|
||||
auth_callback=auth if mcp_service else None,
|
||||
mcp_service=mcp_service,
|
||||
) as mcp_client:
|
||||
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
|
||||
except MCPConnectionError as e:
|
||||
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
|
||||
# Use MCPClientWithAuthRetry to handle authentication automatically
|
||||
try:
|
||||
with MCPClientWithAuthRetry(
|
||||
server_url=provider_entity.decrypt_server_url() if provider_entity else self.server_url,
|
||||
headers=headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
provider_entity=provider_entity,
|
||||
auth_callback=auth if mcp_service else None,
|
||||
mcp_service=mcp_service,
|
||||
) as mcp_client:
|
||||
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
|
||||
except MCPConnectionError as e:
|
||||
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
|
||||
except (ValueError, TypeError, KeyError) as e:
|
||||
# Catch specific exceptions that might occur during tool invocation
|
||||
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
|
||||
else:
|
||||
# Fallback to creating service with database session
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True)
|
||||
headers = provider_entity.decrypt_headers()
|
||||
# Try to get existing token and add to headers
|
||||
if not headers:
|
||||
tokens = provider_entity.retrieve_tokens()
|
||||
if tokens and tokens.access_token:
|
||||
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
|
||||
|
||||
# Use MCPClientWithAuthRetry to handle authentication automatically
|
||||
try:
|
||||
with MCPClientWithAuthRetry(
|
||||
server_url=provider_entity.decrypt_server_url() if provider_entity else self.server_url,
|
||||
headers=headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
provider_entity=provider_entity,
|
||||
auth_callback=auth if mcp_service else None,
|
||||
mcp_service=mcp_service,
|
||||
) as mcp_client:
|
||||
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
|
||||
except MCPConnectionError as e:
|
||||
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
|
||||
|
||||
@ -9,8 +9,7 @@ from sqlalchemy import or_, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.mcp_provider import MCPProviderEntity
|
||||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity
|
||||
from core.helper import encrypter
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||
@ -24,7 +23,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants
|
||||
UNCHANGED_SERVER_URL_PLACEHOLDER = "[__HIDDEN__]"
|
||||
DEFAULT_GRANT_TYPE = "authorization_code"
|
||||
CLIENT_NAME = "Dify"
|
||||
EMPTY_TOOLS_JSON = "[]"
|
||||
EMPTY_CREDENTIALS_JSON = "{}"
|
||||
@ -88,12 +86,9 @@ class MCPToolManageService:
|
||||
icon_type: str,
|
||||
icon_background: str,
|
||||
server_identifier: str,
|
||||
timeout: float,
|
||||
sse_read_timeout: float,
|
||||
configuration: MCPConfiguration,
|
||||
authentication: MCPAuthentication | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
client_id: str | None = None,
|
||||
client_secret: str | None = None,
|
||||
grant_type: str = DEFAULT_GRANT_TYPE,
|
||||
) -> ToolProviderApiEntity:
|
||||
"""Create a new MCP provider."""
|
||||
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
|
||||
@ -104,9 +99,11 @@ class MCPToolManageService:
|
||||
# Encrypt sensitive data
|
||||
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
|
||||
encrypted_headers = self._prepare_encrypted_dict(headers, tenant_id) if headers else None
|
||||
if client_id and client_secret:
|
||||
if authentication is not None:
|
||||
# Build the full credentials structure with encrypted client_id and client_secret
|
||||
encrypted_credentials = self._build_and_encrypt_credentials(client_id, client_secret, grant_type, tenant_id)
|
||||
encrypted_credentials = self._build_and_encrypt_credentials(
|
||||
authentication.client_id, authentication.client_secret, tenant_id
|
||||
)
|
||||
else:
|
||||
encrypted_credentials = None
|
||||
# Create provider
|
||||
@ -120,16 +117,16 @@ class MCPToolManageService:
|
||||
tools=EMPTY_TOOLS_JSON,
|
||||
icon=self._prepare_icon(icon, icon_type, icon_background),
|
||||
server_identifier=server_identifier,
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
timeout=configuration.timeout,
|
||||
sse_read_timeout=configuration.sse_read_timeout,
|
||||
encrypted_headers=encrypted_headers,
|
||||
encrypted_credentials=encrypted_credentials,
|
||||
)
|
||||
|
||||
self._session.add(mcp_tool)
|
||||
self._session.commit()
|
||||
|
||||
return ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
|
||||
self._session.flush()
|
||||
mcp_providers = ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
|
||||
return mcp_providers
|
||||
|
||||
def update_provider(
|
||||
self,
|
||||
@ -142,12 +139,9 @@ class MCPToolManageService:
|
||||
icon_type: str,
|
||||
icon_background: str,
|
||||
server_identifier: str,
|
||||
timeout: float | None = None,
|
||||
sse_read_timeout: float | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
client_id: str | None = None,
|
||||
client_secret: str | None = None,
|
||||
grant_type: str | None = None,
|
||||
configuration: MCPConfiguration,
|
||||
authentication: MCPAuthentication | None = None,
|
||||
) -> None:
|
||||
"""Update an MCP provider."""
|
||||
mcp_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
@ -185,10 +179,10 @@ class MCPToolManageService:
|
||||
mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"]
|
||||
|
||||
# Update optional fields
|
||||
if timeout is not None:
|
||||
mcp_provider.timeout = timeout
|
||||
if sse_read_timeout is not None:
|
||||
mcp_provider.sse_read_timeout = sse_read_timeout
|
||||
if configuration.timeout is not None:
|
||||
mcp_provider.timeout = configuration.timeout
|
||||
if configuration.sse_read_timeout is not None:
|
||||
mcp_provider.sse_read_timeout = configuration.sse_read_timeout
|
||||
if headers is not None:
|
||||
if headers:
|
||||
# Build headers preserving unchanged masked values
|
||||
@ -200,20 +194,18 @@ class MCPToolManageService:
|
||||
mcp_provider.encrypted_headers = None
|
||||
|
||||
# Update credentials if provided
|
||||
if client_id is not None and client_secret is not None:
|
||||
if authentication is not None:
|
||||
# Merge with existing credentials to handle masked values
|
||||
(
|
||||
final_client_id,
|
||||
final_client_secret,
|
||||
final_grant_type,
|
||||
) = self._merge_credentials_with_masked(client_id, client_secret, grant_type, mcp_provider)
|
||||
|
||||
# Use default grant_type if none found
|
||||
final_grant_type = final_grant_type or DEFAULT_GRANT_TYPE
|
||||
) = self._merge_credentials_with_masked(
|
||||
authentication.client_id, authentication.client_secret, mcp_provider
|
||||
)
|
||||
|
||||
# Build and encrypt new credentials
|
||||
encrypted_credentials = self._build_and_encrypt_credentials(
|
||||
final_client_id, final_client_secret, final_grant_type, tenant_id
|
||||
final_client_id, final_client_secret, tenant_id
|
||||
)
|
||||
mcp_provider.encrypted_credentials = encrypted_credentials
|
||||
|
||||
@ -221,7 +213,11 @@ class MCPToolManageService:
|
||||
except IntegrityError as e:
|
||||
self._session.rollback()
|
||||
self._handle_integrity_error(e, name, server_url, server_identifier)
|
||||
except Exception:
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
# Catch specific exceptions that might occur during update
|
||||
# ValueError: invalid data provided
|
||||
# AttributeError: missing required attributes
|
||||
# TypeError: type conversion errors
|
||||
self._session.rollback()
|
||||
raise
|
||||
|
||||
@ -271,7 +267,7 @@ class MCPToolManageService:
|
||||
db_provider.tools = json.dumps([tool.model_dump() for tool in tools])
|
||||
db_provider.authed = True
|
||||
db_provider.updated_at = datetime.now()
|
||||
self._session.commit()
|
||||
self._session.flush()
|
||||
|
||||
# Build API response
|
||||
return self._build_tool_provider_response(db_provider, provider_entity, tools)
|
||||
@ -309,7 +305,7 @@ class MCPToolManageService:
|
||||
if not authed:
|
||||
provider.tools = EMPTY_TOOLS_JSON
|
||||
|
||||
self._session.commit()
|
||||
self._session.flush()
|
||||
|
||||
def save_oauth_data(self, provider_id: str, tenant_id: str, data: dict[str, Any], data_type: str = "mixed") -> None:
|
||||
"""
|
||||
@ -495,20 +491,21 @@ class MCPToolManageService:
|
||||
def _merge_credentials_with_masked(
|
||||
self,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
grant_type: str | None,
|
||||
client_secret: str | None,
|
||||
mcp_provider: MCPToolProvider,
|
||||
) -> tuple[str, str, str | None]:
|
||||
) -> tuple[
|
||||
str,
|
||||
str | None,
|
||||
]:
|
||||
"""Merge incoming credentials with existing ones, preserving unchanged masked values.
|
||||
|
||||
Args:
|
||||
client_id: Client ID from frontend (may be masked)
|
||||
client_secret: Client secret from frontend (may be masked)
|
||||
grant_type: Grant type from frontend
|
||||
mcp_provider: The MCP provider instance
|
||||
|
||||
Returns:
|
||||
Tuple of (final_client_id, final_client_secret, grant_type)
|
||||
Tuple of (final_client_id, final_client_secret)
|
||||
"""
|
||||
mcp_provider_entity = mcp_provider.to_entity()
|
||||
existing_decrypted = mcp_provider_entity.decrypt_credentials()
|
||||
@ -526,35 +523,18 @@ class MCPToolManageService:
|
||||
# Use existing decrypted value
|
||||
final_client_secret = existing_decrypted.get("client_secret", client_secret)
|
||||
|
||||
final_grant_type = grant_type if grant_type is not None else existing_decrypted.get("grant_type")
|
||||
return final_client_id, final_client_secret
|
||||
|
||||
return final_client_id, final_client_secret, final_grant_type
|
||||
|
||||
def _build_and_encrypt_credentials(
|
||||
self, client_id: str, client_secret: str, grant_type: str, tenant_id: str
|
||||
) -> str:
|
||||
def _build_and_encrypt_credentials(self, client_id: str, client_secret: str | None, tenant_id: str) -> str:
|
||||
"""Build credentials and encrypt sensitive fields."""
|
||||
# Create a flat structure with all credential data
|
||||
credentials_data = {
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
"grant_type": grant_type,
|
||||
"client_name": CLIENT_NAME,
|
||||
"is_dynamic_registration": False,
|
||||
}
|
||||
|
||||
# Add grant types and response types based on grant_type
|
||||
if grant_type == "client_credentials":
|
||||
credentials_data["grant_types"] = json.dumps(["client_credentials"])
|
||||
credentials_data["response_types"] = json.dumps([])
|
||||
credentials_data["redirect_uris"] = json.dumps([])
|
||||
else:
|
||||
credentials_data["grant_types"] = json.dumps(["authorization_code", "refresh_token"])
|
||||
credentials_data["response_types"] = json.dumps(["code"])
|
||||
credentials_data["redirect_uris"] = json.dumps(
|
||||
[f"{dify_config.CONSOLE_API_URL}/console/api/mcp/oauth/callback"]
|
||||
)
|
||||
|
||||
# Only client_id and client_secret need encryption
|
||||
secret_fields = ["client_id", "client_secret"]
|
||||
secret_fields = ["client_id", "client_secret"] if client_secret else ["client_id"]
|
||||
return self._encrypt_dict_fields(credentials_data, secret_fields, tenant_id)
|
||||
|
||||
@ -6,6 +6,7 @@ from typing import Any, Union
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.mcp_provider import MCPConfiguration
|
||||
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||
from core.mcp.types import Tool as MCPTool
|
||||
from core.plugin.entities.plugin_daemon import PluginDatasourceProviderEntity
|
||||
@ -246,6 +247,13 @@ class ToolTransformService:
|
||||
)
|
||||
response["server_identifier"] = db_provider.server_identifier
|
||||
|
||||
# Convert configuration dict to MCPConfiguration object
|
||||
if "configuration" in response and isinstance(response["configuration"], dict):
|
||||
response["configuration"] = MCPConfiguration(
|
||||
timeout=float(response["configuration"]["timeout"]),
|
||||
sse_read_timeout=float(response["configuration"]["sse_read_timeout"]),
|
||||
)
|
||||
|
||||
return ToolProviderApiEntity(**response)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -139,7 +139,7 @@ class TestRedisStateManagement:
|
||||
class TestOAuthDiscovery:
|
||||
"""Test OAuth discovery functions."""
|
||||
|
||||
@patch("httpx.get")
|
||||
@patch("core.helper.ssrf_proxy.get")
|
||||
def test_check_support_resource_discovery_success(self, mock_get):
|
||||
"""Test successful resource discovery check."""
|
||||
mock_response = Mock()
|
||||
@ -152,11 +152,11 @@ class TestOAuthDiscovery:
|
||||
assert supported is True
|
||||
assert auth_url == "https://auth.example.com"
|
||||
mock_get.assert_called_once_with(
|
||||
"https://api.example.com/.well-known/oauth-protected-resource/endpoint",
|
||||
"https://api.example.com/.well-known/oauth-protected-resource",
|
||||
headers={"MCP-Protocol-Version": "2025-03-26", "User-Agent": "Dify"},
|
||||
)
|
||||
|
||||
@patch("httpx.get")
|
||||
@patch("core.helper.ssrf_proxy.get")
|
||||
def test_check_support_resource_discovery_not_supported(self, mock_get):
|
||||
"""Test resource discovery not supported."""
|
||||
mock_response = Mock()
|
||||
@ -168,7 +168,7 @@ class TestOAuthDiscovery:
|
||||
assert supported is False
|
||||
assert auth_url == ""
|
||||
|
||||
@patch("httpx.get")
|
||||
@patch("core.helper.ssrf_proxy.get")
|
||||
def test_check_support_resource_discovery_with_query_fragment(self, mock_get):
|
||||
"""Test resource discovery with query and fragment."""
|
||||
mock_response = Mock()
|
||||
@ -181,11 +181,11 @@ class TestOAuthDiscovery:
|
||||
assert supported is True
|
||||
assert auth_url == "https://auth.example.com"
|
||||
mock_get.assert_called_once_with(
|
||||
"https://api.example.com/.well-known/oauth-protected-resource/path?query=1#fragment",
|
||||
"https://api.example.com/.well-known/oauth-protected-resource?query=1#fragment",
|
||||
headers={"MCP-Protocol-Version": "2025-03-26", "User-Agent": "Dify"},
|
||||
)
|
||||
|
||||
@patch("httpx.get")
|
||||
@patch("core.helper.ssrf_proxy.get")
|
||||
def test_discover_oauth_metadata_with_resource_discovery(self, mock_get):
|
||||
"""Test OAuth metadata discovery with resource discovery support."""
|
||||
with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check:
|
||||
@ -207,11 +207,11 @@ class TestOAuthDiscovery:
|
||||
assert metadata.authorization_endpoint == "https://auth.example.com/authorize"
|
||||
assert metadata.token_endpoint == "https://auth.example.com/token"
|
||||
mock_get.assert_called_once_with(
|
||||
"https://auth.example.com/.well-known/openid-configuration",
|
||||
"https://auth.example.com/.well-known/oauth-authorization-server",
|
||||
headers={"MCP-Protocol-Version": "2025-03-26"},
|
||||
)
|
||||
|
||||
@patch("httpx.get")
|
||||
@patch("core.helper.ssrf_proxy.get")
|
||||
def test_discover_oauth_metadata_without_resource_discovery(self, mock_get):
|
||||
"""Test OAuth metadata discovery without resource discovery."""
|
||||
with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check:
|
||||
@ -236,7 +236,7 @@ class TestOAuthDiscovery:
|
||||
headers={"MCP-Protocol-Version": "2025-03-26"},
|
||||
)
|
||||
|
||||
@patch("httpx.get")
|
||||
@patch("core.helper.ssrf_proxy.get")
|
||||
def test_discover_oauth_metadata_not_found(self, mock_get):
|
||||
"""Test OAuth metadata discovery when not found."""
|
||||
with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check:
|
||||
@ -336,7 +336,7 @@ class TestAuthorizationFlow:
|
||||
|
||||
assert "does not support response type code" in str(exc_info.value)
|
||||
|
||||
@patch("httpx.post")
|
||||
@patch("core.helper.ssrf_proxy.post")
|
||||
def test_exchange_authorization_success(self, mock_post):
|
||||
"""Test successful authorization code exchange."""
|
||||
mock_response = Mock()
|
||||
@ -384,7 +384,7 @@ class TestAuthorizationFlow:
|
||||
},
|
||||
)
|
||||
|
||||
@patch("httpx.post")
|
||||
@patch("core.helper.ssrf_proxy.post")
|
||||
def test_exchange_authorization_failure(self, mock_post):
|
||||
"""Test failed authorization code exchange."""
|
||||
mock_response = Mock()
|
||||
@ -406,7 +406,7 @@ class TestAuthorizationFlow:
|
||||
|
||||
assert "Token exchange failed: HTTP 400" in str(exc_info.value)
|
||||
|
||||
@patch("httpx.post")
|
||||
@patch("core.helper.ssrf_proxy.post")
|
||||
def test_refresh_authorization_success(self, mock_post):
|
||||
"""Test successful token refresh."""
|
||||
mock_response = Mock()
|
||||
@ -442,7 +442,7 @@ class TestAuthorizationFlow:
|
||||
},
|
||||
)
|
||||
|
||||
@patch("httpx.post")
|
||||
@patch("core.helper.ssrf_proxy.post")
|
||||
def test_register_client_success(self, mock_post):
|
||||
"""Test successful client registration."""
|
||||
mock_response = Mock()
|
||||
@ -576,7 +576,12 @@ class TestAuthOrchestration:
|
||||
def test_auth_new_registration(self, mock_start_auth, mock_register, mock_discover, mock_provider, mock_service):
|
||||
"""Test auth flow for new client registration."""
|
||||
# Setup
|
||||
mock_discover.return_value = None
|
||||
mock_discover.return_value = OAuthMetadata(
|
||||
authorization_endpoint="https://auth.example.com/authorize",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code"],
|
||||
)
|
||||
mock_register.return_value = OAuthClientInformationFull(
|
||||
client_id="new-client-id",
|
||||
client_name="Dify",
|
||||
@ -679,7 +684,12 @@ class TestAuthOrchestration:
|
||||
mock_refresh.return_value = new_tokens
|
||||
|
||||
with patch("core.mcp.auth.auth_flow.discover_oauth_metadata") as mock_discover:
|
||||
mock_discover.return_value = None
|
||||
mock_discover.return_value = OAuthMetadata(
|
||||
authorization_endpoint="https://auth.example.com/authorize",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code"],
|
||||
)
|
||||
|
||||
result = auth(mock_provider, mock_service)
|
||||
|
||||
|
||||
@ -228,8 +228,7 @@ class TestMCPToolTransform:
|
||||
"masked_headers": {"Authorization": "Bearer *****"},
|
||||
"updated_at": 1234567890,
|
||||
"labels": [],
|
||||
"timeout": 30,
|
||||
"sse_read_timeout": 300,
|
||||
"configuration": {"timeout": "30", "sse_read_timeout": "300"},
|
||||
"original_headers": {"Authorization": "Bearer secret-token"},
|
||||
"author": "Test User",
|
||||
"description": I18nObject(en_US="Test MCP Provider Description", zh_Hans="Test MCP Provider Description"),
|
||||
@ -246,8 +245,9 @@ class TestMCPToolTransform:
|
||||
assert isinstance(result, ToolProviderApiEntity)
|
||||
assert result.id == "server-identifier-456" # Should use server_identifier when for_list=False
|
||||
assert result.server_identifier == "server-identifier-456"
|
||||
assert result.timeout == 30
|
||||
assert result.sse_read_timeout == 300
|
||||
assert result.configuration is not None
|
||||
assert result.configuration.timeout == 30
|
||||
assert result.configuration.sse_read_timeout == 300
|
||||
assert result.original_headers == {"Authorization": "Bearer secret-token"}
|
||||
assert len(result.tools) == 1
|
||||
assert result.tools[0].description.en_US == "Tool description"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user