mirror of https://github.com/langgenius/dify.git
Merge remote-tracking branch 'origin/deploy/dev' into deploy/dev
This commit is contained in:
commit
946f0b00e4
|
|
@ -18,7 +18,7 @@ from controllers.console.wraps import (
|
|||
)
|
||||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
|
||||
from core.mcp.auth.auth_flow import auth, handle_callback
|
||||
from core.mcp.error import MCPAuthError, MCPError
|
||||
from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
|
|
@ -1007,7 +1007,12 @@ class ToolMCPAuthApi(Resource):
|
|||
return {"result": "success"}
|
||||
except MCPAuthError as e:
|
||||
service = MCPToolManageService(session=session)
|
||||
return auth(provider_entity, service, args.get("authorization_code"))
|
||||
try:
|
||||
return auth(provider_entity, service, args.get("authorization_code"))
|
||||
except MCPRefreshTokenError as e:
|
||||
with session.begin():
|
||||
service.clear_provider_credentials(provider=db_provider)
|
||||
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
|
||||
except MCPError as e:
|
||||
with session.begin():
|
||||
service.clear_provider_credentials(provider=db_provider)
|
||||
|
|
|
|||
|
|
@ -189,34 +189,16 @@ class MCPProviderEntity(BaseModel):
|
|||
return None
|
||||
|
||||
# Check if we have nested client_information structure
|
||||
if "client_information" in credentials:
|
||||
# Handle nested structure (Authorization Code flow)
|
||||
client_info_data = credentials["client_information"]
|
||||
if isinstance(client_info_data, dict):
|
||||
return OAuthClientInformation.model_validate(client_info_data)
|
||||
if "client_information" not in credentials:
|
||||
return None
|
||||
|
||||
# Handle flat structure (Client Credentials flow)
|
||||
if "client_id" not in credentials:
|
||||
return None
|
||||
|
||||
# Build client information from flat structure
|
||||
client_info = {
|
||||
"client_id": credentials.get("client_id", ""),
|
||||
"client_secret": credentials.get("client_secret", ""),
|
||||
"client_name": credentials.get("client_name", CLIENT_NAME),
|
||||
}
|
||||
|
||||
# Parse JSON fields if they exist
|
||||
json_fields = ["redirect_uris", "grant_types", "response_types"]
|
||||
for field in json_fields:
|
||||
if field in credentials:
|
||||
client_info[field] = json.loads(credentials[field])
|
||||
|
||||
if "scope" in credentials:
|
||||
client_info["scope"] = credentials["scope"]
|
||||
|
||||
return OAuthClientInformation.model_validate(client_info)
|
||||
client_info_data = credentials["client_information"]
|
||||
if isinstance(client_info_data, dict):
|
||||
if "encrypted_client_secret" in client_info_data:
|
||||
client_info_data["client_secret"] = encrypter.decrypt_token(
|
||||
self.tenant_id, client_info_data["encrypted_client_secret"]
|
||||
)
|
||||
return OAuthClientInformation.model_validate(client_info_data)
|
||||
return None
|
||||
|
||||
def retrieve_tokens(self) -> OAuthTokens | None:
|
||||
"""OAuth tokens if available"""
|
||||
|
|
@ -257,26 +239,18 @@ class MCPProviderEntity(BaseModel):
|
|||
|
||||
masked = {}
|
||||
|
||||
# Check if we have nested client_information structure
|
||||
if "client_information" in credentials and isinstance(credentials["client_information"], dict):
|
||||
client_info = credentials["client_information"]
|
||||
# Mask sensitive fields from nested structure
|
||||
if client_info.get("client_id"):
|
||||
masked["client_id"] = self._mask_value(client_info["client_id"])
|
||||
if client_info.get("client_secret"):
|
||||
masked["client_secret"] = self._mask_value(client_info["client_secret"])
|
||||
else:
|
||||
# Handle flat structure
|
||||
# Mask sensitive fields
|
||||
sensitive_fields = ["client_id", "client_secret"]
|
||||
for field in sensitive_fields:
|
||||
if credentials.get(field):
|
||||
masked[field] = self._mask_value(credentials[field])
|
||||
|
||||
# Include non-sensitive fields (check both flat and nested structures)
|
||||
if "grant_type" in credentials:
|
||||
masked["grant_type"] = credentials["grant_type"]
|
||||
|
||||
if "client_information" not in credentials or not isinstance(credentials["client_information"], dict):
|
||||
return {}
|
||||
client_info = credentials["client_information"]
|
||||
# Mask sensitive fields from nested structure
|
||||
if client_info.get("client_id"):
|
||||
masked["client_id"] = self._mask_value(client_info["client_id"])
|
||||
if client_info.get("encrypted_client_secret"):
|
||||
masked["client_secret"] = self._mask_value(
|
||||
encrypter.decrypt_token(self.tenant_id, client_info["encrypted_client_secret"])
|
||||
)
|
||||
if client_info.get("client_secret"):
|
||||
masked["client_secret"] = self._mask_value(client_info["client_secret"])
|
||||
return masked
|
||||
|
||||
def decrypt_server_url(self) -> str:
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from pydantic import BaseModel, ValidationError
|
|||
|
||||
from core.entities.mcp_provider import MCPProviderEntity, MCPSupportGrantType
|
||||
from core.helper import ssrf_proxy
|
||||
from core.mcp.error import MCPRefreshTokenError
|
||||
from core.mcp.types import (
|
||||
LATEST_PROTOCOL_VERSION,
|
||||
OAuthClientInformation,
|
||||
|
|
@ -283,10 +284,12 @@ def refresh_authorization(
|
|||
|
||||
if client_information.client_secret:
|
||||
params["client_secret"] = client_information.client_secret
|
||||
|
||||
response = ssrf_proxy.post(token_url, data=params)
|
||||
try:
|
||||
response = ssrf_proxy.post(token_url, data=params)
|
||||
except ssrf_proxy.MaxRetriesExceededError as e:
|
||||
raise MCPRefreshTokenError(e) from e
|
||||
if not response.is_success:
|
||||
raise ValueError(f"Token refresh failed: HTTP {response.status_code}")
|
||||
raise MCPRefreshTokenError(response.text)
|
||||
return OAuthTokens.model_validate(response.json())
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -8,3 +8,7 @@ class MCPConnectionError(MCPError):
|
|||
|
||||
class MCPAuthError(MCPConnectionError):
|
||||
pass
|
||||
|
||||
|
||||
class MCPRefreshTokenError(MCPError):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -99,7 +99,7 @@ class MCPToolManageService:
|
|||
# Encrypt sensitive data
|
||||
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
|
||||
encrypted_headers = self._prepare_encrypted_dict(headers, tenant_id) if headers else None
|
||||
if authentication is not None:
|
||||
if authentication is not None and authentication.client_id and authentication.client_secret:
|
||||
# Build the full credentials structure with encrypted client_id and client_secret
|
||||
encrypted_credentials = self._build_and_encrypt_credentials(
|
||||
authentication.client_id, authentication.client_secret, tenant_id
|
||||
|
|
@ -194,7 +194,7 @@ class MCPToolManageService:
|
|||
mcp_provider.encrypted_headers = None
|
||||
|
||||
# Update credentials if provided
|
||||
if authentication is not None:
|
||||
if authentication is not None and authentication.client_id and authentication.client_secret:
|
||||
# Merge with existing credentials to handle masked values
|
||||
(
|
||||
final_client_id,
|
||||
|
|
@ -305,7 +305,7 @@ class MCPToolManageService:
|
|||
if not authed:
|
||||
provider.tools = EMPTY_TOOLS_JSON
|
||||
|
||||
self._session.flush()
|
||||
self._session.commit()
|
||||
|
||||
def save_oauth_data(self, provider_id: str, tenant_id: str, data: dict[str, Any], data_type: str = "mixed") -> None:
|
||||
"""
|
||||
|
|
@ -360,7 +360,7 @@ class MCPToolManageService:
|
|||
return json.dumps({"content": icon, "background": icon_background})
|
||||
return icon
|
||||
|
||||
def _encrypt_dict_fields(self, data: dict[str, Any], secret_fields: list[str], tenant_id: str) -> str:
|
||||
def _encrypt_dict_fields(self, data: dict[str, Any], secret_fields: list[str], tenant_id: str) -> dict[str, str]:
|
||||
"""Encrypt specified fields in a dictionary.
|
||||
|
||||
Args:
|
||||
|
|
@ -386,12 +386,12 @@ class MCPToolManageService:
|
|||
)
|
||||
|
||||
encrypted_data = encrypter_instance.encrypt(data)
|
||||
return json.dumps(encrypted_data)
|
||||
return encrypted_data
|
||||
|
||||
def _prepare_encrypted_dict(self, headers: dict[str, str], tenant_id: str) -> str:
|
||||
"""Encrypt headers and prepare for storage."""
|
||||
# All headers are treated as secret
|
||||
return self._encrypt_dict_fields(headers, list(headers.keys()), tenant_id)
|
||||
return json.dumps(self._encrypt_dict_fields(headers, list(headers.keys()), tenant_id))
|
||||
|
||||
def _prepare_auth_headers(self, provider_entity: MCPProviderEntity) -> dict[str, str]:
|
||||
"""Prepare headers with OAuth token if available."""
|
||||
|
|
@ -530,11 +530,12 @@ class MCPToolManageService:
|
|||
# Create a flat structure with all credential data
|
||||
credentials_data = {
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
"encrypted_client_secret": client_secret,
|
||||
"client_name": CLIENT_NAME,
|
||||
"is_dynamic_registration": False,
|
||||
}
|
||||
|
||||
# Only client_id and client_secret need encryption
|
||||
secret_fields = ["client_id", "client_secret"] if client_secret else ["client_id"]
|
||||
return self._encrypt_dict_fields(credentials_data, secret_fields, tenant_id)
|
||||
secret_fields = ["encrypted_client_secret"] if client_secret else []
|
||||
client_info = self._encrypt_dict_fields(credentials_data, secret_fields, tenant_id)
|
||||
return json.dumps({"client_information": client_info})
|
||||
|
|
|
|||
Loading…
Reference in New Issue