feat: add client credentials auth

This commit is contained in:
Novice 2025-10-09 17:54:46 +08:00
parent 3592240d14
commit 740f970041
No known key found for this signature in database
GPG Key ID: EE3F68E3105DAAAB
10 changed files with 609 additions and 142 deletions

View File

@ -867,6 +867,12 @@ class ToolProviderMCPApi(Resource):
"sse_read_timeout", type=float, required=False, nullable=False, location="json", default=300
)
parser.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
parser.add_argument("client_id", type=str, required=False, nullable=True, location="json", default="")
parser.add_argument("client_secret", type=str, required=False, nullable=True, location="json", default="")
parser.add_argument(
"grant_type", type=str, required=False, nullable=True, location="json", default="authorization_code"
)
parser.add_argument("scope", type=str, required=False, nullable=True, location="json", default="")
args = parser.parse_args()
user = current_user
if not is_valid_url(args["server_url"]):
@ -885,6 +891,10 @@ class ToolProviderMCPApi(Resource):
timeout=args["timeout"],
sse_read_timeout=args["sse_read_timeout"],
headers=args["headers"],
client_id=args["client_id"],
client_secret=args["client_secret"],
grant_type=args["grant_type"],
scope=args["scope"],
)
session.commit()
return jsonable_encoder(result)
@ -904,6 +914,10 @@ class ToolProviderMCPApi(Resource):
parser.add_argument("timeout", type=float, required=False, nullable=True, location="json")
parser.add_argument("sse_read_timeout", type=float, required=False, nullable=True, location="json")
parser.add_argument("headers", type=dict, required=False, nullable=True, location="json")
parser.add_argument("client_id", type=str, required=False, nullable=True, location="json")
parser.add_argument("client_secret", type=str, required=False, nullable=True, location="json")
parser.add_argument("grant_type", type=str, required=False, nullable=True, location="json")
parser.add_argument("scope", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
if not is_valid_url(args["server_url"]):
if "[__HIDDEN__]" in args["server_url"]:
@ -924,6 +938,10 @@ class ToolProviderMCPApi(Resource):
timeout=args.get("timeout"),
sse_read_timeout=args.get("sse_read_timeout"),
headers=args.get("headers"),
client_id=args.get("client_id"),
client_secret=args.get("client_secret"),
grant_type=args.get("grant_type"),
scope=args.get("scope"),
)
session.commit()
return {"result": "success"}

View File

@ -18,6 +18,15 @@ from core.tools.utils.encryption import create_provider_encrypter
if TYPE_CHECKING:
from models.tools import MCPToolProvider
# Constants
DEFAULT_GRANT_TYPE = "authorization_code"
CLIENT_NAME = "Dify"
CLIENT_URI = "https://github.com/langgenius/dify"
DEFAULT_TOKEN_TYPE = "Bearer"
DEFAULT_EXPIRES_IN = 3600
MASK_CHAR = "*"
MIN_UNMASK_LENGTH = 6
class MCPProviderEntity(BaseModel):
"""MCP Provider domain entity for business logic operations"""
@ -78,13 +87,38 @@ class MCPProviderEntity(BaseModel):
@property
def client_metadata(self) -> OAuthClientMetadata:
"""Metadata about this OAuth client."""
# Get grant type from credentials
credentials = self.decrypt_credentials()
# Try to get grant_type from different locations
grant_type = credentials.get("grant_type", DEFAULT_GRANT_TYPE)
# For nested structure, check if client_information has grant_types
if "client_information" in credentials and isinstance(credentials["client_information"], dict):
client_info = credentials["client_information"]
# If grant_types is specified in client_information, use it to determine grant_type
if "grant_types" in client_info and isinstance(client_info["grant_types"], list):
if "client_credentials" in client_info["grant_types"]:
grant_type = "client_credentials"
elif "authorization_code" in client_info["grant_types"]:
grant_type = "authorization_code"
# Configure based on grant type
is_client_credentials = grant_type == "client_credentials"
grant_types = ["refresh_token"]
grant_types.append("client_credentials" if is_client_credentials else "authorization_code")
response_types = [] if is_client_credentials else ["code"]
redirect_uris = [] if is_client_credentials else [self.redirect_url]
return OAuthClientMetadata(
redirect_uris=[self.redirect_url],
redirect_uris=redirect_uris,
token_endpoint_auth_method="none",
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
client_name="Dify",
client_uri="https://github.com/langgenius/dify",
grant_types=grant_types,
response_types=response_types,
client_name=CLIENT_NAME,
client_uri=CLIENT_URI,
)
@property
@ -100,7 +134,7 @@ class MCPProviderEntity(BaseModel):
def to_api_response(self, user_name: str | None = None) -> dict[str, Any]:
"""Convert to API response format"""
return {
response = {
"id": self.id,
"author": user_name or "Anonymous",
"name": self.name,
@ -117,11 +151,50 @@ class MCPProviderEntity(BaseModel):
"description": I18nObject(en_US="", zh_Hans="").to_dict(),
}
# Add masked credentials if they exist
masked_creds = self.masked_credentials()
if masked_creds:
response.update(masked_creds)
return response
def retrieve_client_information(self) -> OAuthClientInformation | None:
"""OAuth client information if available"""
client_info = self.decrypt_credentials().get("client_information", {})
if not client_info:
credentials = self.decrypt_credentials()
if not credentials:
return None
# Check if we have nested client_information structure
if "client_information" in credentials:
# Handle nested structure (Authorization Code flow)
client_info_data = credentials["client_information"]
if isinstance(client_info_data, dict):
return OAuthClientInformation.model_validate(client_info_data)
return None
# Handle flat structure (Client Credentials flow)
if "client_id" not in credentials:
return None
# Build client information from flat structure
client_info = {
"client_id": credentials.get("client_id", ""),
"client_secret": credentials.get("client_secret", ""),
"client_name": credentials.get("client_name", CLIENT_NAME),
}
# Parse JSON fields if they exist
json_fields = ["redirect_uris", "grant_types", "response_types"]
for field in json_fields:
if field in credentials:
try:
client_info[field] = json.loads(credentials[field])
except:
client_info[field] = []
if "scope" in credentials:
client_info["scope"] = credentials["scope"]
return OAuthClientInformation.model_validate(client_info)
def retrieve_tokens(self) -> OAuthTokens | None:
@ -131,8 +204,8 @@ class MCPProviderEntity(BaseModel):
credentials = self.decrypt_credentials()
return OAuthTokens(
access_token=credentials.get("access_token", ""),
token_type=credentials.get("token_type", "Bearer"),
expires_in=int(credentials.get("expires_in", "3600") or 3600),
token_type=credentials.get("token_type", DEFAULT_TOKEN_TYPE),
expires_in=int(credentials.get("expires_in", str(DEFAULT_EXPIRES_IN)) or DEFAULT_EXPIRES_IN),
refresh_token=credentials.get("refresh_token", ""),
)
@ -144,30 +217,77 @@ class MCPProviderEntity(BaseModel):
return f"{base_url}/******"
return base_url
def _mask_value(self, value: str) -> str:
"""Mask a sensitive value for display"""
if len(value) > MIN_UNMASK_LENGTH:
return value[:2] + MASK_CHAR * (len(value) - 4) + value[-2:]
else:
return MASK_CHAR * len(value)
def masked_headers(self) -> dict[str, str]:
"""Masked headers for display"""
masked: dict[str, str] = {}
for key, value in self.decrypt_headers().items():
if len(value) > 6:
masked[key] = value[:2] + "*" * (len(value) - 4) + value[-2:]
else:
masked[key] = "*" * len(value)
return {key: self._mask_value(value) for key, value in self.decrypt_headers().items()}
def masked_credentials(self) -> dict[str, str]:
"""Masked credentials for display"""
credentials = self.decrypt_credentials()
if not credentials:
return {}
masked = {}
# Check if we have nested client_information structure
if "client_information" in credentials and isinstance(credentials["client_information"], dict):
client_info = credentials["client_information"]
# Mask sensitive fields from nested structure
if client_info.get("client_id"):
masked["client_id"] = self._mask_value(client_info["client_id"])
if client_info.get("client_secret"):
masked["client_secret"] = self._mask_value(client_info["client_secret"])
else:
# Handle flat structure
# Mask sensitive fields
sensitive_fields = ["client_id", "client_secret"]
for field in sensitive_fields:
if credentials.get(field):
masked[field] = self._mask_value(credentials[field])
# Include non-sensitive fields (check both flat and nested structures)
if "grant_type" in credentials:
masked["grant_type"] = credentials["grant_type"]
if "scope" in credentials:
masked["scope"] = credentials["scope"]
return masked
def decrypt_server_url(self) -> str:
"""Decrypt server URL"""
return encrypter.decrypt_token(self.tenant_id, self.server_url)
def decrypt_headers(self) -> dict[str, Any]:
"""Decrypt headers"""
def _decrypt_dict(self, data: dict[str, Any]) -> dict[str, Any]:
"""Generic method to decrypt dictionary fields"""
try:
if not self.headers:
if not data:
return {}
# Create dynamic config for all headers as SECRET_INPUT
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in self.headers]
# Only decrypt fields that are actually encrypted
# For nested structures, client_information is not encrypted as a whole
encrypted_fields = []
for key, value in data.items():
# Skip nested objects - they are not encrypted
if isinstance(value, dict):
continue
# Only process string values that might be encrypted
if isinstance(value, str) and value:
encrypted_fields.append(key)
if not encrypted_fields:
return data
# Create dynamic config only for encrypted fields
config = [
BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in encrypted_fields
]
encrypter_instance, _ = create_provider_encrypter(
tenant_id=self.tenant_id,
@ -175,28 +295,21 @@ class MCPProviderEntity(BaseModel):
cache=NoOpProviderCredentialCache(),
)
result = encrypter_instance.decrypt(self.headers)
# Decrypt only the encrypted fields
decrypted_data = encrypter_instance.decrypt({k: data[k] for k in encrypted_fields})
# Merge decrypted data with original data (preserving non-encrypted fields)
result = data.copy()
result.update(decrypted_data)
return result
except Exception:
return {}
def decrypt_credentials(
self,
) -> dict[str, Any]:
def decrypt_headers(self) -> dict[str, Any]:
"""Decrypt headers"""
return self._decrypt_dict(self.headers)
def decrypt_credentials(self) -> dict[str, Any]:
"""Decrypt credentials"""
try:
if not self.credentials:
return {}
encrypter, _ = create_provider_encrypter(
tenant_id=self.tenant_id,
config=[
BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key)
for key in self.credentials
],
cache=NoOpProviderCredentialCache(),
)
return encrypter.decrypt(self.credentials)
except Exception:
return {}
return self._decrypt_dict(self.credentials)

View File

@ -106,8 +106,8 @@ def handle_callback(state_key: str, authorization_code: str, mcp_service: "MCPTo
def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
"""Check if the server supports OAuth 2.0 Resource Discovery."""
b_scheme, b_netloc, b_path, _, b_query, b_fragment = urlparse(server_url, "", True)
url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource{b_path}"
b_scheme, b_netloc, _, _, b_query, b_fragment = urlparse(server_url, "", True)
url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource"
if b_query:
url_for_resource_discovery += f"?{b_query}"
if b_fragment:
@ -117,7 +117,10 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
response = httpx.get(url_for_resource_discovery, headers=headers)
if 200 <= response.status_code < 300:
body = response.json()
if "authorization_server_url" in body:
# Support both singular and plural forms
if body.get("authorization_servers"):
return True, body["authorization_servers"][0]
elif body.get("authorization_server_url"):
return True, body["authorization_server_url"][0]
else:
return False, ""
@ -132,27 +135,37 @@ def discover_oauth_metadata(server_url: str, protocol_version: str | None = None
# First check if the server supports OAuth 2.0 Resource Discovery
support_resource_discovery, oauth_discovery_url = check_support_resource_discovery(server_url)
if support_resource_discovery:
url = oauth_discovery_url
# The oauth_discovery_url is the authorization server base URL
# Try OpenID Connect discovery first (more common), then OAuth 2.0
urls_to_try = [
urljoin(oauth_discovery_url + "/", ".well-known/openid-configuration"),
urljoin(oauth_discovery_url + "/", ".well-known/oauth-authorization-server"),
]
else:
url = urljoin(server_url, "/.well-known/oauth-authorization-server")
urls_to_try = [urljoin(server_url, "/.well-known/oauth-authorization-server")]
try:
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION}
response = httpx.get(url, headers=headers)
if response.status_code == 404:
return None
if not response.is_success:
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
return OAuthMetadata.model_validate(response.json())
except httpx.RequestError as e:
if isinstance(e, httpx.ConnectError):
response = httpx.get(url)
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION}
for url in urls_to_try:
try:
response = httpx.get(url, headers=headers)
if response.status_code == 404:
return None
continue # Try next URL
if not response.is_success:
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
return OAuthMetadata.model_validate(response.json())
raise
except httpx.RequestError as e:
if isinstance(e, httpx.ConnectError):
response = httpx.get(url)
if response.status_code == 404:
continue # Try next URL
if not response.is_success:
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
return OAuthMetadata.model_validate(response.json())
# For other errors, try next URL
continue
return None # No metadata found
def start_authorization(
@ -276,6 +289,49 @@ def refresh_authorization(
return OAuthTokens.model_validate(response.json())
def client_credentials_flow(
server_url: str,
metadata: OAuthMetadata | None,
client_information: OAuthClientInformation,
scope: str | None = None,
) -> OAuthTokens:
"""Execute Client Credentials Flow to get access token."""
grant_type = "client_credentials"
if metadata:
token_url = metadata.token_endpoint
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
else:
token_url = urljoin(server_url, "/token")
# Support both Basic Auth and body parameters for client authentication
headers = {"Content-Type": "application/x-www-form-urlencoded"}
data = {"grant_type": grant_type}
if scope:
data["scope"] = scope
# If client_secret is provided, use Basic Auth (preferred method)
if client_information.client_secret:
credentials = f"{client_information.client_id}:{client_information.client_secret}"
encoded_credentials = base64.b64encode(credentials.encode()).decode()
headers["Authorization"] = f"Basic {encoded_credentials}"
else:
# Fall back to including credentials in the body
data["client_id"] = client_information.client_id
if client_information.client_secret:
data["client_secret"] = client_information.client_secret
response = httpx.post(token_url, headers=headers, data=data)
if not response.is_success:
raise ValueError(
f"Client credentials token request failed: HTTP {response.status_code}, Response: {response.text}"
)
return OAuthTokens.model_validate(response.json())
def register_client(
server_url: str,
metadata: OAuthMetadata | None,
@ -304,6 +360,7 @@ def auth(
mcp_service: "MCPToolManageService",
authorization_code: str | None = None,
state_param: str | None = None,
grant_type: str = "authorization_code",
) -> dict[str, str]:
"""Orchestrates the full auth flow with a server using secure Redis state storage."""
server_url = provider.decrypt_server_url()
@ -314,9 +371,22 @@ def auth(
client_information = provider.retrieve_client_information()
redirect_url = provider.redirect_url
# Check if we should use client credentials flow
credentials = provider.decrypt_credentials()
stored_grant_type = credentials.get("grant_type", "authorization_code")
# Use stored grant type if available, otherwise use parameter
effective_grant_type = stored_grant_type or grant_type
if not client_information:
if authorization_code is not None:
raise ValueError("Existing OAuth client information is required when exchanging an authorization code")
# For client credentials flow, we don't need to register client dynamically
if effective_grant_type == "client_credentials":
# Client should provide client_id and client_secret directly
raise ValueError("Client credentials flow requires client_id and client_secret to be provided")
try:
full_information = register_client(server_url, server_metadata, client_metadata)
except httpx.RequestError as e:
@ -329,7 +399,28 @@ def auth(
client_information = full_information
# Exchange authorization code for tokens
# Handle client credentials flow
if effective_grant_type == "client_credentials":
# Direct token request without user interaction
try:
scope = credentials.get("scope")
tokens = client_credentials_flow(
server_url,
server_metadata,
client_information,
scope,
)
# Save tokens and grant type
token_data = tokens.model_dump()
token_data["grant_type"] = "client_credentials"
mcp_service.save_oauth_data(provider_id, tenant_id, token_data, "tokens")
return {"result": "success"}
except Exception as e:
raise ValueError(f"Client credentials flow failed: {e}")
# Exchange authorization code for tokens (Authorization Code flow)
if authorization_code is not None:
if not state_param:
raise ValueError("State parameter is required when exchanging authorization code")
@ -377,7 +468,7 @@ def auth(
except Exception as e:
raise ValueError(f"Could not refresh OAuth tokens: {e}")
# Start new authorization flow
# Start new authorization flow (only for authorization code flow)
authorization_url, code_verifier = start_authorization(
server_url,
server_metadata,

View File

@ -47,6 +47,11 @@ class ToolProviderApiEntity(BaseModel):
sse_read_timeout: float | None = Field(default=300.0, description="The SSE read timeout of the MCP tool")
masked_headers: dict[str, str] | None = Field(default=None, description="The masked headers of the MCP tool")
original_headers: dict[str, str] | None = Field(default=None, description="The original headers of the MCP tool")
# MCP OAuth credentials
client_id: str | None = Field(default=None, description="The masked client ID for OAuth")
client_secret: str | None = Field(default=None, description="The masked client secret for OAuth")
grant_type: str | None = Field(default=None, description="The OAuth grant type")
scope: str | None = Field(default=None, description="The OAuth scope")
@field_validator("tools", mode="before")
@classmethod
@ -72,6 +77,10 @@ class ToolProviderApiEntity(BaseModel):
optional_fields.update(self.optional_field("timeout", self.timeout))
optional_fields.update(self.optional_field("sse_read_timeout", self.sse_read_timeout))
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
optional_fields.update(self.optional_field("client_id", self.client_id))
optional_fields.update(self.optional_field("client_secret", self.client_secret))
optional_fields.update(self.optional_field("grant_type", self.grant_type))
optional_fields.update(self.optional_field("scope", self.scope))
return {
"id": self.id,
"author": self.author,

View File

@ -9,6 +9,7 @@ from sqlalchemy import or_, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from configs import dify_config
from core.entities.mcp_provider import MCPProviderEntity
from core.helper import encrypter
from core.helper.provider_cache import NoOpProviderCredentialCache
@ -21,7 +22,12 @@ from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__)
# Constants
UNCHANGED_SERVER_URL_PLACEHOLDER = "[__HIDDEN__]"
DEFAULT_GRANT_TYPE = "authorization_code"
CLIENT_NAME = "Dify"
EMPTY_TOOLS_JSON = "[]"
EMPTY_CREDENTIALS_JSON = "{}"
class MCPToolManageService:
@ -85,6 +91,10 @@ class MCPToolManageService:
timeout: float,
sse_read_timeout: float,
headers: dict[str, str] | None = None,
client_id: str | None = None,
client_secret: str | None = None,
grant_type: str = DEFAULT_GRANT_TYPE,
scope: str | None = None,
) -> ToolProviderApiEntity:
"""Create a new MCP provider."""
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
@ -94,8 +104,14 @@ class MCPToolManageService:
# Encrypt sensitive data
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
encrypted_headers = self._prepare_encrypted_headers(headers, tenant_id) if headers else None
encrypted_headers = self._prepare_encrypted_dict(headers, tenant_id) if headers else None
if client_id and client_secret:
# Build the full credentials structure with encrypted client_id and client_secret
encrypted_credentials = self._build_and_encrypt_credentials(
client_id, client_secret, grant_type, scope, tenant_id
)
else:
encrypted_credentials = None
# Create provider
mcp_tool = MCPToolProvider(
tenant_id=tenant_id,
@ -104,12 +120,13 @@ class MCPToolManageService:
server_url_hash=server_url_hash,
user_id=user_id,
authed=False,
tools="[]",
tools=EMPTY_TOOLS_JSON,
icon=self._prepare_icon(icon, icon_type, icon_background),
server_identifier=server_identifier,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
encrypted_headers=encrypted_headers,
encrypted_credentials=encrypted_credentials,
)
self._session.add(mcp_tool)
@ -131,6 +148,10 @@ class MCPToolManageService:
timeout: float | None = None,
sse_read_timeout: float | None = None,
headers: dict[str, str] | None = None,
client_id: str | None = None,
client_secret: str | None = None,
grant_type: str | None = None,
scope: str | None = None,
) -> None:
"""Update an MCP provider."""
mcp_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
@ -176,11 +197,31 @@ class MCPToolManageService:
if headers:
# Build headers preserving unchanged masked values
final_headers = self._merge_headers_with_masked(incoming_headers=headers, mcp_provider=mcp_provider)
encrypted_headers_dict = self._prepare_encrypted_headers(final_headers, tenant_id)
encrypted_headers_dict = self._prepare_encrypted_dict(final_headers, tenant_id)
mcp_provider.encrypted_headers = encrypted_headers_dict
else:
# Clear headers if empty dict passed
mcp_provider.encrypted_headers = None
# Update credentials if provided
if client_id is not None and client_secret is not None:
# Merge with existing credentials to handle masked values
(
final_client_id,
final_client_secret,
final_grant_type,
final_scope,
) = self._merge_credentials_with_masked(client_id, client_secret, grant_type, scope, mcp_provider)
# Use default grant_type if none found
final_grant_type = final_grant_type or DEFAULT_GRANT_TYPE
# Build and encrypt new credentials
encrypted_credentials = self._build_and_encrypt_credentials(
final_client_id, final_client_secret, final_grant_type, final_scope, tenant_id
)
mcp_provider.encrypted_credentials = encrypted_credentials
self._session.commit()
except IntegrityError as e:
self._session.rollback()
@ -271,7 +312,7 @@ class MCPToolManageService:
if authed is not None:
provider.authed = authed
if not authed:
provider.tools = "[]"
provider.tools = EMPTY_TOOLS_JSON
self._session.commit()
@ -287,28 +328,15 @@ class MCPToolManageService:
"""
db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
credentials = {}
authed = None
# Determine if this makes the provider authenticated
authed = data_type == "tokens" or (data_type == "mixed" and "access_token" in data) or None
if data_type == "tokens" or (data_type == "mixed" and "access_token" in data):
# OAuth tokens
credentials = data
authed = True
elif data_type == "client_info" or (data_type == "mixed" and "client_information" in data):
# OAuth client information
credentials = data
elif data_type == "code_verifier" or (data_type == "mixed" and "code_verifier" in data):
# PKCE code verifier
credentials = data
else:
credentials = data
self.update_provider_credentials(provider=db_provider, credentials=credentials, authed=authed)
self.update_provider_credentials(provider=db_provider, credentials=data, authed=authed)
def clear_provider_credentials(self, *, provider: MCPToolProvider) -> None:
"""Clear all credentials for a provider."""
provider.tools = "[]"
provider.encrypted_credentials = "{}"
provider.tools = EMPTY_TOOLS_JSON
provider.encrypted_credentials = EMPTY_CREDENTIALS_JSON
provider.updated_at = datetime.now()
provider.authed = False
self._session.commit()
@ -341,13 +369,24 @@ class MCPToolManageService:
return json.dumps({"content": icon, "background": icon_background})
return icon
def _prepare_encrypted_headers(self, headers: dict[str, str], tenant_id: str) -> str:
"""Encrypt headers and prepare for storage."""
def _encrypt_dict_fields(self, data: dict[str, Any], secret_fields: list[str], tenant_id: str) -> str:
"""Encrypt specified fields in a dictionary.
Args:
data: Dictionary containing data to encrypt
secret_fields: List of field names to encrypt
tenant_id: Tenant ID for encryption
Returns:
JSON string of encrypted data
"""
from core.entities.provider_entities import BasicProviderConfig
from core.tools.utils.encryption import create_provider_encrypter
# Create dynamic config for all headers as SECRET_INPUT
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers]
# Create config for secret fields
config = [
BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=field) for field in secret_fields
]
encrypter_instance, _ = create_provider_encrypter(
tenant_id=tenant_id,
@ -355,8 +394,13 @@ class MCPToolManageService:
cache=NoOpProviderCredentialCache(),
)
encrypted_headers_dict = encrypter_instance.encrypt(headers)
return json.dumps(encrypted_headers_dict)
encrypted_data = encrypter_instance.encrypt(data)
return json.dumps(encrypted_data)
def _prepare_encrypted_dict(self, headers: dict[str, str], tenant_id: str) -> str:
"""Encrypt headers and prepare for storage."""
# All headers are treated as secret
return self._encrypt_dict_fields(headers, list(headers.keys()), tenant_id)
def _prepare_auth_headers(self, provider_entity: MCPProviderEntity) -> dict[str, str]:
"""Prepare headers with OAuth token if available."""
@ -391,27 +435,18 @@ class MCPToolManageService:
provider_entity = provider.to_entity()
headers = provider_entity.headers
timeout = provider_entity.timeout
sse_read_timeout = provider_entity.sse_read_timeout
try:
with MCPClientWithAuthRetry(
server_url,
headers=headers,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
provider_entity=provider_entity,
auth_callback=lambda p, s, c: auth(p, self, c),
mcp_service=self,
) as mcp_client:
tools = mcp_client.list_tools()
return {
"authed": True,
"tools": json.dumps([tool.model_dump() for tool in tools]),
"encrypted_credentials": "{}",
}
tools = self._retrieve_remote_mcp_tools(
server_url, headers, provider_entity, lambda p, s, c: auth(p, self, c)
)
return {
"authed": True,
"tools": json.dumps([tool.model_dump() for tool in tools]),
"encrypted_credentials": EMPTY_CREDENTIALS_JSON,
}
except MCPAuthError:
return {"authed": False, "tools": "[]", "encrypted_credentials": "{}"}
return {"authed": False, "tools": EMPTY_TOOLS_JSON, "encrypted_credentials": EMPTY_CREDENTIALS_JSON}
except MCPError as e:
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
@ -461,3 +496,76 @@ class MCPToolManageService:
for key, value in incoming_headers.items()
if key in existing_decrypted or value != existing_masked.get(key)
}
def _merge_credentials_with_masked(
self,
client_id: str,
client_secret: str,
grant_type: str | None,
scope: str | None,
mcp_provider: MCPToolProvider,
) -> tuple[str, str, str | None, str | None]:
"""Merge incoming credentials with existing ones, preserving unchanged masked values.
Args:
client_id: Client ID from frontend (may be masked)
client_secret: Client secret from frontend (may be masked)
grant_type: Grant type from frontend
scope: OAuth scope from frontend
mcp_provider: The MCP provider instance
Returns:
Tuple of (final_client_id, final_client_secret, grant_type, scope)
"""
mcp_provider_entity = mcp_provider.to_entity()
existing_decrypted = mcp_provider_entity.decrypt_credentials()
existing_masked = mcp_provider_entity.masked_credentials()
# Check if client_id is masked and unchanged
final_client_id = client_id
if existing_masked.get("client_id") and client_id == existing_masked["client_id"]:
# Use existing decrypted value
final_client_id = existing_decrypted.get("client_id", client_id)
# Check if client_secret is masked and unchanged
final_client_secret = client_secret
if existing_masked.get("client_secret") and client_secret == existing_masked["client_secret"]:
# Use existing decrypted value
final_client_secret = existing_decrypted.get("client_secret", client_secret)
# Grant type and scope are not masked, use as is
final_grant_type = grant_type if grant_type is not None else existing_decrypted.get("grant_type")
final_scope = scope if scope is not None else existing_decrypted.get("scope")
return final_client_id, final_client_secret, final_grant_type, final_scope
def _build_and_encrypt_credentials(
self, client_id: str, client_secret: str, grant_type: str, scope: str | None, tenant_id: str
) -> str:
"""Build credentials and encrypt sensitive fields."""
# Create a flat structure with all credential data
credentials_data = {
"client_id": client_id,
"client_secret": client_secret,
"grant_type": grant_type,
"client_name": CLIENT_NAME,
}
if scope:
credentials_data["scope"] = scope
# Add grant types and response types based on grant_type
if grant_type == "client_credentials":
credentials_data["grant_types"] = json.dumps(["client_credentials"])
credentials_data["response_types"] = json.dumps([])
credentials_data["redirect_uris"] = json.dumps([])
else:
credentials_data["grant_types"] = json.dumps(["authorization_code", "refresh_token"])
credentials_data["response_types"] = json.dumps(["code"])
credentials_data["redirect_uris"] = json.dumps(
[f"{dify_config.CONSOLE_API_URL}/console/api/mcp/oauth/callback"]
)
# Only client_id and client_secret need encryption
secret_fields = ["client_id", "client_secret"]
return self._encrypt_dict_fields(credentials_data, secret_fields, tenant_id)

View File

@ -2,13 +2,14 @@
import React, { useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { getDomain } from 'tldts'
import { RiCloseLine, RiEditLine } from '@remixicon/react'
import { RiArrowDownSLine, RiCloseLine, RiEditLine } from '@remixicon/react'
import AppIconPicker from '@/app/components/base/app-icon-picker'
import type { AppIconSelection } from '@/app/components/base/app-icon-picker'
import AppIcon from '@/app/components/base/app-icon'
import Modal from '@/app/components/base/modal'
import Button from '@/app/components/base/button'
import Input from '@/app/components/base/input'
import Select from '@/app/components/base/select'
import HeadersInput from './headers-input'
import type { AppIconType } from '@/types/app'
import type { ToolWithProvider } from '@/app/components/workflow/types'
@ -31,6 +32,10 @@ export type DuplicateAppModalProps = {
timeout: number
sse_read_timeout: number
headers?: Record<string, string>
client_id?: string
client_secret?: string
grant_type?: string
scope?: string
}) => void
onHide: () => void
}
@ -73,6 +78,12 @@ const MCPModal = ({
const [headers, setHeaders] = React.useState<Record<string, string>>(
data?.masked_headers || {},
)
const [clientId, setClientId] = React.useState(data?.client_id || '')
const [clientSecret, setClientSecret] = React.useState(data?.client_secret || '')
const [grantType, setGrantType] = React.useState(data?.grant_type || 'authorization_code')
const [scope, setScope] = React.useState(data?.scope || '')
const [authCollapsed, setAuthCollapsed] = React.useState(true)
const [configCollapsed, setConfigCollapsed] = React.useState(true)
const [isFetchingIcon, setIsFetchingIcon] = useState(false)
const appIconRef = useRef<HTMLDivElement>(null)
const isHovering = useHover(appIconRef)
@ -86,6 +97,10 @@ const MCPModal = ({
setMcpTimeout(data.timeout || 30)
setSseReadTimeout(data.sse_read_timeout || 300)
setHeaders(data.masked_headers || {})
setClientId(data.client_id || '')
setClientSecret(data.client_secret || '')
setGrantType(data.grant_type || 'authorization_code')
setScope(data.scope || '')
setAppIcon(getIcon(data))
}
else {
@ -96,6 +111,10 @@ const MCPModal = ({
setMcpTimeout(30)
setSseReadTimeout(300)
setHeaders({})
setClientId('')
setClientSecret('')
setGrantType('authorization_code')
setScope('')
setAppIcon(DEFAULT_ICON as AppIconSelection)
}
}, [data])
@ -124,7 +143,8 @@ const MCPModal = ({
setIsFetchingIcon(true)
try {
const res = await uploadRemoteFileInfo(remoteIcon, undefined, true)
setAppIcon({ type: 'image', url: res.url, fileId: extractFileId(res.url) || '' })
if ('url' in res)
setAppIcon({ type: 'image', url: res.url, fileId: extractFileId(res.url) || '' })
}
catch (e) {
let errorMessage = 'Failed to fetch remote icon'
@ -158,6 +178,10 @@ const MCPModal = ({
timeout: timeout || 30,
sse_read_timeout: sseReadTimeout || 300,
headers: Object.keys(headers).length > 0 ? headers : undefined,
client_id: clientId || undefined,
client_secret: clientSecret || undefined,
grant_type: grantType,
scope: scope || undefined,
})
if(isCreate)
onHide()
@ -236,41 +260,116 @@ const MCPModal = ({
</div>
)}
</div>
<div>
<div className='mb-1 flex h-6 items-center'>
<span className='system-sm-medium text-text-secondary'>{t('tools.mcp.modal.timeout')}</span>
<div
className='mb-1 flex h-6 cursor-pointer items-center justify-between'
onClick={() => setAuthCollapsed(!authCollapsed)}
>
<span className='system-sm-semibold-uppercase text-text-secondary'>{t('tools.mcp.modal.authentication')}</span>
<RiArrowDownSLine className={cn('h-4 w-4 text-text-tertiary transition-transform', authCollapsed && '-rotate-90')} />
</div>
<Input
type='number'
value={timeout}
onChange={e => setMcpTimeout(Number(e.target.value))}
onBlur={e => handleBlur(e.target.value.trim())}
placeholder={t('tools.mcp.modal.timeoutPlaceholder')}
/>
{!authCollapsed && (
<div className='mt-3 space-y-5'>
<div>
<div className='mb-1 flex h-6 items-center'>
<span className='system-sm-medium text-text-secondary'>{t('tools.mcp.modal.grantType')}</span>
</div>
<Select
items={[
{ value: 'authorization_code', name: t('tools.mcp.modal.grantTypeAuthCode') },
{ value: 'client_credentials', name: t('tools.mcp.modal.grantTypeClientCredentials') },
]}
defaultValue={grantType}
onSelect={item => setGrantType(item.value as string)}
placeholder={t('tools.mcp.modal.grantType')}
/>
</div>
<div>
<div className='mb-1 flex h-6 items-center'>
<span className='system-sm-medium text-text-secondary'>{t('tools.mcp.modal.clientId')}</span>
</div>
<Input
value={clientId}
onChange={e => setClientId(e.target.value)}
placeholder={t('tools.mcp.modal.clientIdPlaceholder')}
/>
</div>
<div>
<div className='mb-1 flex h-6 items-center'>
<span className='system-sm-medium text-text-secondary'>{t('tools.mcp.modal.clientSecret')}</span>
</div>
<Input
type='password'
value={clientSecret}
onChange={e => setClientSecret(e.target.value)}
placeholder={t('tools.mcp.modal.clientSecretPlaceholder')}
/>
</div>
{grantType === 'client_credentials' && (
<div>
<div className='mb-1 flex h-6 items-center'>
<span className='system-sm-medium text-text-secondary'>{t('tools.mcp.modal.scope')}</span>
</div>
<Input
value={scope}
onChange={e => setScope(e.target.value)}
placeholder={t('tools.mcp.modal.scopePlaceholder')}
/>
</div>
)}
<div>
<div className='mb-1 flex h-6 items-center'>
<span className='system-sm-medium text-text-secondary'>{t('tools.mcp.modal.headers')}</span>
</div>
<div className='body-xs-regular mb-2 text-text-tertiary'>{t('tools.mcp.modal.headersTip')}</div>
<HeadersInput
headers={headers}
onChange={setHeaders}
readonly={false}
isMasked={!isCreate && Object.keys(headers).length > 0}
/>
</div>
</div>
)}
</div>
<div>
<div className='mb-1 flex h-6 items-center'>
<span className='system-sm-medium text-text-secondary'>{t('tools.mcp.modal.sseReadTimeout')}</span>
<div
className='mb-1 flex h-6 cursor-pointer items-center justify-between'
onClick={() => setConfigCollapsed(!configCollapsed)}
>
<span className='system-sm-semibold-uppercase text-text-secondary'>{t('tools.mcp.modal.configuration')}</span>
<RiArrowDownSLine className={cn('h-4 w-4 text-text-tertiary transition-transform', configCollapsed && '-rotate-90')} />
</div>
<Input
type='number'
value={sseReadTimeout}
onChange={e => setSseReadTimeout(Number(e.target.value))}
onBlur={e => handleBlur(e.target.value.trim())}
placeholder={t('tools.mcp.modal.timeoutPlaceholder')}
/>
</div>
<div>
<div className='mb-1 flex h-6 items-center'>
<span className='system-sm-medium text-text-secondary'>{t('tools.mcp.modal.headers')}</span>
</div>
<div className='body-xs-regular mb-2 text-text-tertiary'>{t('tools.mcp.modal.headersTip')}</div>
<HeadersInput
headers={headers}
onChange={setHeaders}
readonly={false}
isMasked={!isCreate && Object.keys(headers).length > 0}
/>
{!configCollapsed && (
<div className='mt-3 space-y-5'>
<div>
<div className='mb-1 flex h-6 items-center'>
<span className='system-sm-medium text-text-secondary'>{t('tools.mcp.modal.timeout')}</span>
</div>
<Input
type='number'
value={timeout}
onChange={e => setMcpTimeout(Number(e.target.value))}
onBlur={e => handleBlur(e.target.value.trim())}
placeholder={t('tools.mcp.modal.timeoutPlaceholder')}
/>
</div>
<div>
<div className='mb-1 flex h-6 items-center'>
<span className='system-sm-medium text-text-secondary'>{t('tools.mcp.modal.sseReadTimeout')}</span>
</div>
<Input
type='number'
value={sseReadTimeout}
onChange={e => setSseReadTimeout(Number(e.target.value))}
onBlur={e => handleBlur(e.target.value.trim())}
placeholder={t('tools.mcp.modal.timeoutPlaceholder')}
/>
</div>
</div>
)}
</div>
</div>
<div className='flex flex-row-reverse pt-5'>

View File

@ -61,6 +61,10 @@ export type Collection = {
sse_read_timeout?: number
headers?: Record<string, string>
masked_headers?: Record<string, string>
client_id?: string
client_secret?: string
grant_type?: string
scope?: string
}
export type ToolParameter = {

View File

@ -203,6 +203,17 @@ const translation = {
timeout: 'Timeout',
sseReadTimeout: 'SSE Read Timeout',
timeoutPlaceholder: '30',
configuration: 'Configuration',
authentication: 'Authentication',
grantType: 'Grant Type',
grantTypeAuthCode: 'Authorization Code (User Authentication)',
grantTypeClientCredentials: 'Client Credentials (Service-to-Service)',
scope: 'OAuth Scope',
scopePlaceholder: 'Enter OAuth scope (optional)',
clientId: 'Client ID',
clientIdPlaceholder: 'Enter client ID',
clientSecret: 'Client Secret',
clientSecretPlaceholder: 'Enter client secret',
},
delete: 'Remove MCP Server',
deleteConfirmTitle: 'Would you like to remove {{mcp}}?',

View File

@ -203,6 +203,12 @@ const translation = {
timeout: '超时时间',
sseReadTimeout: 'SSE 读取超时时间',
timeoutPlaceholder: '30',
configuration: '配置',
authentication: '认证',
clientId: '客户端 ID',
clientIdPlaceholder: '请输入客户端 ID',
clientSecret: '客户端密钥',
clientSecretPlaceholder: '请输入客户端密钥',
},
delete: '删除 MCP 服务',
deleteConfirmTitle: '你想要删除 {{mcp}} 吗?',

View File

@ -88,6 +88,10 @@ export const useCreateMCP = () => {
timeout?: number
sse_read_timeout?: number
headers?: Record<string, string>
client_id?: string
client_secret?: string
grant_type?: string
scope?: string
}) => {
return post<ToolWithProvider>('workspaces/current/tool-provider/mcp', {
body: {
@ -115,6 +119,10 @@ export const useUpdateMCP = ({
timeout?: number
sse_read_timeout?: number
headers?: Record<string, string>
client_id?: string
client_secret?: string
grant_type?: string
scope?: string
}) => {
return put('workspaces/current/tool-provider/mcp', {
body: {