diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 96bc288a77..d0db941696 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -867,6 +867,12 @@ class ToolProviderMCPApi(Resource): "sse_read_timeout", type=float, required=False, nullable=False, location="json", default=300 ) parser.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={}) + parser.add_argument("client_id", type=str, required=False, nullable=True, location="json", default="") + parser.add_argument("client_secret", type=str, required=False, nullable=True, location="json", default="") + parser.add_argument( + "grant_type", type=str, required=False, nullable=True, location="json", default="authorization_code" + ) + parser.add_argument("scope", type=str, required=False, nullable=True, location="json", default="") args = parser.parse_args() user = current_user if not is_valid_url(args["server_url"]): @@ -885,6 +891,10 @@ class ToolProviderMCPApi(Resource): timeout=args["timeout"], sse_read_timeout=args["sse_read_timeout"], headers=args["headers"], + client_id=args["client_id"], + client_secret=args["client_secret"], + grant_type=args["grant_type"], + scope=args["scope"], ) session.commit() return jsonable_encoder(result) @@ -904,6 +914,10 @@ class ToolProviderMCPApi(Resource): parser.add_argument("timeout", type=float, required=False, nullable=True, location="json") parser.add_argument("sse_read_timeout", type=float, required=False, nullable=True, location="json") parser.add_argument("headers", type=dict, required=False, nullable=True, location="json") + parser.add_argument("client_id", type=str, required=False, nullable=True, location="json") + parser.add_argument("client_secret", type=str, required=False, nullable=True, location="json") + parser.add_argument("grant_type", type=str, required=False, nullable=True, location="json") + parser.add_argument("scope", type=str, required=False, nullable=True, location="json") args = parser.parse_args() if not is_valid_url(args["server_url"]): if "[__HIDDEN__]" in args["server_url"]: @@ -924,6 +938,10 @@ class ToolProviderMCPApi(Resource): timeout=args.get("timeout"), sse_read_timeout=args.get("sse_read_timeout"), headers=args.get("headers"), + client_id=args.get("client_id"), + client_secret=args.get("client_secret"), + grant_type=args.get("grant_type"), + scope=args.get("scope"), ) session.commit() return {"result": "success"} diff --git a/api/core/entities/mcp_provider.py b/api/core/entities/mcp_provider.py index 213a81f1f5..7ab1117769 100644 --- a/api/core/entities/mcp_provider.py +++ b/api/core/entities/mcp_provider.py @@ -18,6 +18,15 @@ from core.tools.utils.encryption import create_provider_encrypter if TYPE_CHECKING: from models.tools import MCPToolProvider +# Constants +DEFAULT_GRANT_TYPE = "authorization_code" +CLIENT_NAME = "Dify" +CLIENT_URI = "https://github.com/langgenius/dify" +DEFAULT_TOKEN_TYPE = "Bearer" +DEFAULT_EXPIRES_IN = 3600 +MASK_CHAR = "*" +MIN_UNMASK_LENGTH = 6 + class MCPProviderEntity(BaseModel): """MCP Provider domain entity for business logic operations""" @@ -78,13 +87,38 @@ class MCPProviderEntity(BaseModel): @property def client_metadata(self) -> OAuthClientMetadata: """Metadata about this OAuth client.""" + # Get grant type from credentials + credentials = self.decrypt_credentials() + + # Try to get grant_type from different locations + grant_type = credentials.get("grant_type", DEFAULT_GRANT_TYPE) + + # For nested structure, check if client_information has grant_types + if "client_information" in credentials and isinstance(credentials["client_information"], dict): + client_info = credentials["client_information"] + # If grant_types is specified in client_information, use it to determine grant_type + if "grant_types" in client_info and isinstance(client_info["grant_types"], list): + if "client_credentials" in client_info["grant_types"]: + grant_type = "client_credentials" + elif "authorization_code" in client_info["grant_types"]: + grant_type = "authorization_code" + + # Configure based on grant type + is_client_credentials = grant_type == "client_credentials" + + grant_types = ["refresh_token"] + grant_types.append("client_credentials" if is_client_credentials else "authorization_code") + + response_types = [] if is_client_credentials else ["code"] + redirect_uris = [] if is_client_credentials else [self.redirect_url] + return OAuthClientMetadata( - redirect_uris=[self.redirect_url], + redirect_uris=redirect_uris, token_endpoint_auth_method="none", - grant_types=["authorization_code", "refresh_token"], - response_types=["code"], - client_name="Dify", - client_uri="https://github.com/langgenius/dify", + grant_types=grant_types, + response_types=response_types, + client_name=CLIENT_NAME, + client_uri=CLIENT_URI, ) @property @@ -100,7 +134,7 @@ class MCPProviderEntity(BaseModel): def to_api_response(self, user_name: str | None = None) -> dict[str, Any]: """Convert to API response format""" - return { + response = { "id": self.id, "author": user_name or "Anonymous", "name": self.name, @@ -117,11 +151,50 @@ class MCPProviderEntity(BaseModel): "description": I18nObject(en_US="", zh_Hans="").to_dict(), } + # Add masked credentials if they exist + masked_creds = self.masked_credentials() + if masked_creds: + response.update(masked_creds) + + return response + def retrieve_client_information(self) -> OAuthClientInformation | None: """OAuth client information if available""" - client_info = self.decrypt_credentials().get("client_information", {}) - if not client_info: + credentials = self.decrypt_credentials() + if not credentials: return None + + # Check if we have nested client_information structure + if "client_information" in credentials: + # Handle nested structure (Authorization Code flow) + client_info_data = credentials["client_information"] + if isinstance(client_info_data, dict): + return OAuthClientInformation.model_validate(client_info_data) + return None + + # Handle flat structure (Client Credentials flow) + if "client_id" not in credentials: + return None + + # Build client information from flat structure + client_info = { + "client_id": credentials.get("client_id", ""), + "client_secret": credentials.get("client_secret", ""), + "client_name": credentials.get("client_name", CLIENT_NAME), + } + + # Parse JSON fields if they exist + json_fields = ["redirect_uris", "grant_types", "response_types"] + for field in json_fields: + if field in credentials: + try: + client_info[field] = json.loads(credentials[field]) + except: + client_info[field] = [] + + if "scope" in credentials: + client_info["scope"] = credentials["scope"] + return OAuthClientInformation.model_validate(client_info) def retrieve_tokens(self) -> OAuthTokens | None: @@ -131,8 +204,8 @@ class MCPProviderEntity(BaseModel): credentials = self.decrypt_credentials() return OAuthTokens( access_token=credentials.get("access_token", ""), - token_type=credentials.get("token_type", "Bearer"), - expires_in=int(credentials.get("expires_in", "3600") or 3600), + token_type=credentials.get("token_type", DEFAULT_TOKEN_TYPE), + expires_in=int(credentials.get("expires_in", str(DEFAULT_EXPIRES_IN)) or DEFAULT_EXPIRES_IN), refresh_token=credentials.get("refresh_token", ""), ) @@ -144,30 +217,77 @@ class MCPProviderEntity(BaseModel): return f"{base_url}/******" return base_url + def _mask_value(self, value: str) -> str: + """Mask a sensitive value for display""" + if len(value) > MIN_UNMASK_LENGTH: + return value[:2] + MASK_CHAR * (len(value) - 4) + value[-2:] + else: + return MASK_CHAR * len(value) + def masked_headers(self) -> dict[str, str]: """Masked headers for display""" - masked: dict[str, str] = {} - for key, value in self.decrypt_headers().items(): - if len(value) > 6: - masked[key] = value[:2] + "*" * (len(value) - 4) + value[-2:] - else: - masked[key] = "*" * len(value) + return {key: self._mask_value(value) for key, value in self.decrypt_headers().items()} + + def masked_credentials(self) -> dict[str, str]: + """Masked credentials for display""" + credentials = self.decrypt_credentials() + if not credentials: + return {} + + masked = {} + + # Check if we have nested client_information structure + if "client_information" in credentials and isinstance(credentials["client_information"], dict): + client_info = credentials["client_information"] + # Mask sensitive fields from nested structure + if client_info.get("client_id"): + masked["client_id"] = self._mask_value(client_info["client_id"]) + if client_info.get("client_secret"): + masked["client_secret"] = self._mask_value(client_info["client_secret"]) + else: + # Handle flat structure + # Mask sensitive fields + sensitive_fields = ["client_id", "client_secret"] + for field in sensitive_fields: + if credentials.get(field): + masked[field] = self._mask_value(credentials[field]) + + # Include non-sensitive fields (check both flat and nested structures) + if "grant_type" in credentials: + masked["grant_type"] = credentials["grant_type"] + if "scope" in credentials: + masked["scope"] = credentials["scope"] + return masked def decrypt_server_url(self) -> str: """Decrypt server URL""" - return encrypter.decrypt_token(self.tenant_id, self.server_url) - def decrypt_headers(self) -> dict[str, Any]: - """Decrypt headers""" - + def _decrypt_dict(self, data: dict[str, Any]) -> dict[str, Any]: + """Generic method to decrypt dictionary fields""" try: - if not self.headers: + if not data: return {} - # Create dynamic config for all headers as SECRET_INPUT - config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in self.headers] + # Only decrypt fields that are actually encrypted + # For nested structures, client_information is not encrypted as a whole + encrypted_fields = [] + for key, value in data.items(): + # Skip nested objects - they are not encrypted + if isinstance(value, dict): + continue + # Only process string values that might be encrypted + if isinstance(value, str) and value: + encrypted_fields.append(key) + + if not encrypted_fields: + return data + + # Create dynamic config only for encrypted fields + config = [ + BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in encrypted_fields + ] encrypter_instance, _ = create_provider_encrypter( tenant_id=self.tenant_id, @@ -175,28 +295,21 @@ class MCPProviderEntity(BaseModel): cache=NoOpProviderCredentialCache(), ) - result = encrypter_instance.decrypt(self.headers) + # Decrypt only the encrypted fields + decrypted_data = encrypter_instance.decrypt({k: data[k] for k in encrypted_fields}) + + # Merge decrypted data with original data (preserving non-encrypted fields) + result = data.copy() + result.update(decrypted_data) + return result except Exception: return {} - def decrypt_credentials( - self, - ) -> dict[str, Any]: + def decrypt_headers(self) -> dict[str, Any]: + """Decrypt headers""" + return self._decrypt_dict(self.headers) + + def decrypt_credentials(self) -> dict[str, Any]: """Decrypt credentials""" - try: - if not self.credentials: - return {} - - encrypter, _ = create_provider_encrypter( - tenant_id=self.tenant_id, - config=[ - BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) - for key in self.credentials - ], - cache=NoOpProviderCredentialCache(), - ) - - return encrypter.decrypt(self.credentials) - except Exception: - return {} + return self._decrypt_dict(self.credentials) diff --git a/api/core/mcp/auth/auth_flow.py b/api/core/mcp/auth/auth_flow.py index a6bb4fdda9..5f3b20121a 100644 --- a/api/core/mcp/auth/auth_flow.py +++ b/api/core/mcp/auth/auth_flow.py @@ -106,8 +106,8 @@ def handle_callback(state_key: str, authorization_code: str, mcp_service: "MCPTo def check_support_resource_discovery(server_url: str) -> tuple[bool, str]: """Check if the server supports OAuth 2.0 Resource Discovery.""" - b_scheme, b_netloc, b_path, _, b_query, b_fragment = urlparse(server_url, "", True) - url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource{b_path}" + b_scheme, b_netloc, _, _, b_query, b_fragment = urlparse(server_url, "", True) + url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource" if b_query: url_for_resource_discovery += f"?{b_query}" if b_fragment: @@ -117,7 +117,10 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]: response = httpx.get(url_for_resource_discovery, headers=headers) if 200 <= response.status_code < 300: body = response.json() - if "authorization_server_url" in body: + # Support both singular and plural forms + if body.get("authorization_servers"): + return True, body["authorization_servers"][0] + elif body.get("authorization_server_url"): return True, body["authorization_server_url"][0] else: return False, "" @@ -132,27 +135,37 @@ def discover_oauth_metadata(server_url: str, protocol_version: str | None = None # First check if the server supports OAuth 2.0 Resource Discovery support_resource_discovery, oauth_discovery_url = check_support_resource_discovery(server_url) if support_resource_discovery: - url = oauth_discovery_url + # The oauth_discovery_url is the authorization server base URL + # Try OpenID Connect discovery first (more common), then OAuth 2.0 + urls_to_try = [ + urljoin(oauth_discovery_url + "/", ".well-known/openid-configuration"), + urljoin(oauth_discovery_url + "/", ".well-known/oauth-authorization-server"), + ] else: - url = urljoin(server_url, "/.well-known/oauth-authorization-server") + urls_to_try = [urljoin(server_url, "/.well-known/oauth-authorization-server")] - try: - headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION} - response = httpx.get(url, headers=headers) - if response.status_code == 404: - return None - if not response.is_success: - raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata") - return OAuthMetadata.model_validate(response.json()) - except httpx.RequestError as e: - if isinstance(e, httpx.ConnectError): - response = httpx.get(url) + headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION} + + for url in urls_to_try: + try: + response = httpx.get(url, headers=headers) if response.status_code == 404: - return None + continue # Try next URL if not response.is_success: raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata") return OAuthMetadata.model_validate(response.json()) - raise + except httpx.RequestError as e: + if isinstance(e, httpx.ConnectError): + response = httpx.get(url) + if response.status_code == 404: + continue # Try next URL + if not response.is_success: + raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata") + return OAuthMetadata.model_validate(response.json()) + # For other errors, try next URL + continue + + return None # No metadata found def start_authorization( @@ -276,6 +289,49 @@ def refresh_authorization( return OAuthTokens.model_validate(response.json()) +def client_credentials_flow( + server_url: str, + metadata: OAuthMetadata | None, + client_information: OAuthClientInformation, + scope: str | None = None, +) -> OAuthTokens: + """Execute Client Credentials Flow to get access token.""" + grant_type = "client_credentials" + + if metadata: + token_url = metadata.token_endpoint + if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported: + raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}") + else: + token_url = urljoin(server_url, "/token") + + # Support both Basic Auth and body parameters for client authentication + headers = {"Content-Type": "application/x-www-form-urlencoded"} + data = {"grant_type": grant_type} + + if scope: + data["scope"] = scope + + # If client_secret is provided, use Basic Auth (preferred method) + if client_information.client_secret: + credentials = f"{client_information.client_id}:{client_information.client_secret}" + encoded_credentials = base64.b64encode(credentials.encode()).decode() + headers["Authorization"] = f"Basic {encoded_credentials}" + else: + # Fall back to including credentials in the body + data["client_id"] = client_information.client_id + if client_information.client_secret: + data["client_secret"] = client_information.client_secret + + response = httpx.post(token_url, headers=headers, data=data) + if not response.is_success: + raise ValueError( + f"Client credentials token request failed: HTTP {response.status_code}, Response: {response.text}" + ) + + return OAuthTokens.model_validate(response.json()) + + def register_client( server_url: str, metadata: OAuthMetadata | None, @@ -304,6 +360,7 @@ def auth( mcp_service: "MCPToolManageService", authorization_code: str | None = None, state_param: str | None = None, + grant_type: str = "authorization_code", ) -> dict[str, str]: """Orchestrates the full auth flow with a server using secure Redis state storage.""" server_url = provider.decrypt_server_url() @@ -314,9 +371,22 @@ def auth( client_information = provider.retrieve_client_information() redirect_url = provider.redirect_url + # Check if we should use client credentials flow + credentials = provider.decrypt_credentials() + stored_grant_type = credentials.get("grant_type", "authorization_code") + + # Use stored grant type if available, otherwise use parameter + effective_grant_type = stored_grant_type or grant_type + if not client_information: if authorization_code is not None: raise ValueError("Existing OAuth client information is required when exchanging an authorization code") + + # For client credentials flow, we don't need to register client dynamically + if effective_grant_type == "client_credentials": + # Client should provide client_id and client_secret directly + raise ValueError("Client credentials flow requires client_id and client_secret to be provided") + try: full_information = register_client(server_url, server_metadata, client_metadata) except httpx.RequestError as e: @@ -329,7 +399,28 @@ def auth( client_information = full_information - # Exchange authorization code for tokens + # Handle client credentials flow + if effective_grant_type == "client_credentials": + # Direct token request without user interaction + try: + scope = credentials.get("scope") + tokens = client_credentials_flow( + server_url, + server_metadata, + client_information, + scope, + ) + + # Save tokens and grant type + token_data = tokens.model_dump() + token_data["grant_type"] = "client_credentials" + mcp_service.save_oauth_data(provider_id, tenant_id, token_data, "tokens") + + return {"result": "success"} + except Exception as e: + raise ValueError(f"Client credentials flow failed: {e}") + + # Exchange authorization code for tokens (Authorization Code flow) if authorization_code is not None: if not state_param: raise ValueError("State parameter is required when exchanging authorization code") @@ -377,7 +468,7 @@ def auth( except Exception as e: raise ValueError(f"Could not refresh OAuth tokens: {e}") - # Start new authorization flow + # Start new authorization flow (only for authorization code flow) authorization_url, code_verifier = start_authorization( server_url, server_metadata, diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index 1eacd198cb..c6b4368059 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -47,6 +47,11 @@ class ToolProviderApiEntity(BaseModel): sse_read_timeout: float | None = Field(default=300.0, description="The SSE read timeout of the MCP tool") masked_headers: dict[str, str] | None = Field(default=None, description="The masked headers of the MCP tool") original_headers: dict[str, str] | None = Field(default=None, description="The original headers of the MCP tool") + # MCP OAuth credentials + client_id: str | None = Field(default=None, description="The masked client ID for OAuth") + client_secret: str | None = Field(default=None, description="The masked client secret for OAuth") + grant_type: str | None = Field(default=None, description="The OAuth grant type") + scope: str | None = Field(default=None, description="The OAuth scope") @field_validator("tools", mode="before") @classmethod @@ -72,6 +77,10 @@ class ToolProviderApiEntity(BaseModel): optional_fields.update(self.optional_field("timeout", self.timeout)) optional_fields.update(self.optional_field("sse_read_timeout", self.sse_read_timeout)) optional_fields.update(self.optional_field("masked_headers", self.masked_headers)) + optional_fields.update(self.optional_field("client_id", self.client_id)) + optional_fields.update(self.optional_field("client_secret", self.client_secret)) + optional_fields.update(self.optional_field("grant_type", self.grant_type)) + optional_fields.update(self.optional_field("scope", self.scope)) return { "id": self.id, "author": self.author, diff --git a/api/services/tools/mcp_tools_manage_service.py b/api/services/tools/mcp_tools_manage_service.py index a886ba5614..1e491ffd62 100644 --- a/api/services/tools/mcp_tools_manage_service.py +++ b/api/services/tools/mcp_tools_manage_service.py @@ -9,6 +9,7 @@ from sqlalchemy import or_, select from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session +from configs import dify_config from core.entities.mcp_provider import MCPProviderEntity from core.helper import encrypter from core.helper.provider_cache import NoOpProviderCredentialCache @@ -21,7 +22,12 @@ from services.tools.tools_transform_service import ToolTransformService logger = logging.getLogger(__name__) +# Constants UNCHANGED_SERVER_URL_PLACEHOLDER = "[__HIDDEN__]" +DEFAULT_GRANT_TYPE = "authorization_code" +CLIENT_NAME = "Dify" +EMPTY_TOOLS_JSON = "[]" +EMPTY_CREDENTIALS_JSON = "{}" class MCPToolManageService: @@ -85,6 +91,10 @@ class MCPToolManageService: timeout: float, sse_read_timeout: float, headers: dict[str, str] | None = None, + client_id: str | None = None, + client_secret: str | None = None, + grant_type: str = DEFAULT_GRANT_TYPE, + scope: str | None = None, ) -> ToolProviderApiEntity: """Create a new MCP provider.""" server_url_hash = hashlib.sha256(server_url.encode()).hexdigest() @@ -94,8 +104,14 @@ class MCPToolManageService: # Encrypt sensitive data encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url) - encrypted_headers = self._prepare_encrypted_headers(headers, tenant_id) if headers else None - + encrypted_headers = self._prepare_encrypted_dict(headers, tenant_id) if headers else None + if client_id and client_secret: + # Build the full credentials structure with encrypted client_id and client_secret + encrypted_credentials = self._build_and_encrypt_credentials( + client_id, client_secret, grant_type, scope, tenant_id + ) + else: + encrypted_credentials = None # Create provider mcp_tool = MCPToolProvider( tenant_id=tenant_id, @@ -104,12 +120,13 @@ class MCPToolManageService: server_url_hash=server_url_hash, user_id=user_id, authed=False, - tools="[]", + tools=EMPTY_TOOLS_JSON, icon=self._prepare_icon(icon, icon_type, icon_background), server_identifier=server_identifier, timeout=timeout, sse_read_timeout=sse_read_timeout, encrypted_headers=encrypted_headers, + encrypted_credentials=encrypted_credentials, ) self._session.add(mcp_tool) @@ -131,6 +148,10 @@ class MCPToolManageService: timeout: float | None = None, sse_read_timeout: float | None = None, headers: dict[str, str] | None = None, + client_id: str | None = None, + client_secret: str | None = None, + grant_type: str | None = None, + scope: str | None = None, ) -> None: """Update an MCP provider.""" mcp_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) @@ -176,11 +197,31 @@ class MCPToolManageService: if headers: # Build headers preserving unchanged masked values final_headers = self._merge_headers_with_masked(incoming_headers=headers, mcp_provider=mcp_provider) - encrypted_headers_dict = self._prepare_encrypted_headers(final_headers, tenant_id) + encrypted_headers_dict = self._prepare_encrypted_dict(final_headers, tenant_id) mcp_provider.encrypted_headers = encrypted_headers_dict else: # Clear headers if empty dict passed mcp_provider.encrypted_headers = None + + # Update credentials if provided + if client_id is not None and client_secret is not None: + # Merge with existing credentials to handle masked values + ( + final_client_id, + final_client_secret, + final_grant_type, + final_scope, + ) = self._merge_credentials_with_masked(client_id, client_secret, grant_type, scope, mcp_provider) + + # Use default grant_type if none found + final_grant_type = final_grant_type or DEFAULT_GRANT_TYPE + + # Build and encrypt new credentials + encrypted_credentials = self._build_and_encrypt_credentials( + final_client_id, final_client_secret, final_grant_type, final_scope, tenant_id + ) + mcp_provider.encrypted_credentials = encrypted_credentials + self._session.commit() except IntegrityError as e: self._session.rollback() @@ -271,7 +312,7 @@ class MCPToolManageService: if authed is not None: provider.authed = authed if not authed: - provider.tools = "[]" + provider.tools = EMPTY_TOOLS_JSON self._session.commit() @@ -287,28 +328,15 @@ class MCPToolManageService: """ db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) - credentials = {} - authed = None + # Determine if this makes the provider authenticated + authed = data_type == "tokens" or (data_type == "mixed" and "access_token" in data) or None - if data_type == "tokens" or (data_type == "mixed" and "access_token" in data): - # OAuth tokens - credentials = data - authed = True - elif data_type == "client_info" or (data_type == "mixed" and "client_information" in data): - # OAuth client information - credentials = data - elif data_type == "code_verifier" or (data_type == "mixed" and "code_verifier" in data): - # PKCE code verifier - credentials = data - else: - credentials = data - - self.update_provider_credentials(provider=db_provider, credentials=credentials, authed=authed) + self.update_provider_credentials(provider=db_provider, credentials=data, authed=authed) def clear_provider_credentials(self, *, provider: MCPToolProvider) -> None: """Clear all credentials for a provider.""" - provider.tools = "[]" - provider.encrypted_credentials = "{}" + provider.tools = EMPTY_TOOLS_JSON + provider.encrypted_credentials = EMPTY_CREDENTIALS_JSON provider.updated_at = datetime.now() provider.authed = False self._session.commit() @@ -341,13 +369,24 @@ class MCPToolManageService: return json.dumps({"content": icon, "background": icon_background}) return icon - def _prepare_encrypted_headers(self, headers: dict[str, str], tenant_id: str) -> str: - """Encrypt headers and prepare for storage.""" + def _encrypt_dict_fields(self, data: dict[str, Any], secret_fields: list[str], tenant_id: str) -> str: + """Encrypt specified fields in a dictionary. + + Args: + data: Dictionary containing data to encrypt + secret_fields: List of field names to encrypt + tenant_id: Tenant ID for encryption + + Returns: + JSON string of encrypted data + """ from core.entities.provider_entities import BasicProviderConfig from core.tools.utils.encryption import create_provider_encrypter - # Create dynamic config for all headers as SECRET_INPUT - config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers] + # Create config for secret fields + config = [ + BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=field) for field in secret_fields + ] encrypter_instance, _ = create_provider_encrypter( tenant_id=tenant_id, @@ -355,8 +394,13 @@ class MCPToolManageService: cache=NoOpProviderCredentialCache(), ) - encrypted_headers_dict = encrypter_instance.encrypt(headers) - return json.dumps(encrypted_headers_dict) + encrypted_data = encrypter_instance.encrypt(data) + return json.dumps(encrypted_data) + + def _prepare_encrypted_dict(self, headers: dict[str, str], tenant_id: str) -> str: + """Encrypt headers and prepare for storage.""" + # All headers are treated as secret + return self._encrypt_dict_fields(headers, list(headers.keys()), tenant_id) def _prepare_auth_headers(self, provider_entity: MCPProviderEntity) -> dict[str, str]: """Prepare headers with OAuth token if available.""" @@ -391,27 +435,18 @@ class MCPToolManageService: provider_entity = provider.to_entity() headers = provider_entity.headers - timeout = provider_entity.timeout - sse_read_timeout = provider_entity.sse_read_timeout try: - with MCPClientWithAuthRetry( - server_url, - headers=headers, - timeout=timeout, - sse_read_timeout=sse_read_timeout, - provider_entity=provider_entity, - auth_callback=lambda p, s, c: auth(p, self, c), - mcp_service=self, - ) as mcp_client: - tools = mcp_client.list_tools() - return { - "authed": True, - "tools": json.dumps([tool.model_dump() for tool in tools]), - "encrypted_credentials": "{}", - } + tools = self._retrieve_remote_mcp_tools( + server_url, headers, provider_entity, lambda p, s, c: auth(p, self, c) + ) + return { + "authed": True, + "tools": json.dumps([tool.model_dump() for tool in tools]), + "encrypted_credentials": EMPTY_CREDENTIALS_JSON, + } except MCPAuthError: - return {"authed": False, "tools": "[]", "encrypted_credentials": "{}"} + return {"authed": False, "tools": EMPTY_TOOLS_JSON, "encrypted_credentials": EMPTY_CREDENTIALS_JSON} except MCPError as e: raise ValueError(f"Failed to re-connect MCP server: {e}") from e @@ -461,3 +496,76 @@ class MCPToolManageService: for key, value in incoming_headers.items() if key in existing_decrypted or value != existing_masked.get(key) } + + def _merge_credentials_with_masked( + self, + client_id: str, + client_secret: str, + grant_type: str | None, + scope: str | None, + mcp_provider: MCPToolProvider, + ) -> tuple[str, str, str | None, str | None]: + """Merge incoming credentials with existing ones, preserving unchanged masked values. + + Args: + client_id: Client ID from frontend (may be masked) + client_secret: Client secret from frontend (may be masked) + grant_type: Grant type from frontend + scope: OAuth scope from frontend + mcp_provider: The MCP provider instance + + Returns: + Tuple of (final_client_id, final_client_secret, grant_type, scope) + """ + mcp_provider_entity = mcp_provider.to_entity() + existing_decrypted = mcp_provider_entity.decrypt_credentials() + existing_masked = mcp_provider_entity.masked_credentials() + + # Check if client_id is masked and unchanged + final_client_id = client_id + if existing_masked.get("client_id") and client_id == existing_masked["client_id"]: + # Use existing decrypted value + final_client_id = existing_decrypted.get("client_id", client_id) + + # Check if client_secret is masked and unchanged + final_client_secret = client_secret + if existing_masked.get("client_secret") and client_secret == existing_masked["client_secret"]: + # Use existing decrypted value + final_client_secret = existing_decrypted.get("client_secret", client_secret) + + # Grant type and scope are not masked, use as is + final_grant_type = grant_type if grant_type is not None else existing_decrypted.get("grant_type") + final_scope = scope if scope is not None else existing_decrypted.get("scope") + + return final_client_id, final_client_secret, final_grant_type, final_scope + + def _build_and_encrypt_credentials( + self, client_id: str, client_secret: str, grant_type: str, scope: str | None, tenant_id: str + ) -> str: + """Build credentials and encrypt sensitive fields.""" + # Create a flat structure with all credential data + credentials_data = { + "client_id": client_id, + "client_secret": client_secret, + "grant_type": grant_type, + "client_name": CLIENT_NAME, + } + + if scope: + credentials_data["scope"] = scope + + # Add grant types and response types based on grant_type + if grant_type == "client_credentials": + credentials_data["grant_types"] = json.dumps(["client_credentials"]) + credentials_data["response_types"] = json.dumps([]) + credentials_data["redirect_uris"] = json.dumps([]) + else: + credentials_data["grant_types"] = json.dumps(["authorization_code", "refresh_token"]) + credentials_data["response_types"] = json.dumps(["code"]) + credentials_data["redirect_uris"] = json.dumps( + [f"{dify_config.CONSOLE_API_URL}/console/api/mcp/oauth/callback"] + ) + + # Only client_id and client_secret need encryption + secret_fields = ["client_id", "client_secret"] + return self._encrypt_dict_fields(credentials_data, secret_fields, tenant_id) diff --git a/web/app/components/tools/mcp/modal.tsx b/web/app/components/tools/mcp/modal.tsx index 211d594caf..2bbcdc97b9 100644 --- a/web/app/components/tools/mcp/modal.tsx +++ b/web/app/components/tools/mcp/modal.tsx @@ -2,13 +2,14 @@ import React, { useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { getDomain } from 'tldts' -import { RiCloseLine, RiEditLine } from '@remixicon/react' +import { RiArrowDownSLine, RiCloseLine, RiEditLine } from '@remixicon/react' import AppIconPicker from '@/app/components/base/app-icon-picker' import type { AppIconSelection } from '@/app/components/base/app-icon-picker' import AppIcon from '@/app/components/base/app-icon' import Modal from '@/app/components/base/modal' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' +import Select from '@/app/components/base/select' import HeadersInput from './headers-input' import type { AppIconType } from '@/types/app' import type { ToolWithProvider } from '@/app/components/workflow/types' @@ -31,6 +32,10 @@ export type DuplicateAppModalProps = { timeout: number sse_read_timeout: number headers?: Record + client_id?: string + client_secret?: string + grant_type?: string + scope?: string }) => void onHide: () => void } @@ -73,6 +78,12 @@ const MCPModal = ({ const [headers, setHeaders] = React.useState>( data?.masked_headers || {}, ) + const [clientId, setClientId] = React.useState(data?.client_id || '') + const [clientSecret, setClientSecret] = React.useState(data?.client_secret || '') + const [grantType, setGrantType] = React.useState(data?.grant_type || 'authorization_code') + const [scope, setScope] = React.useState(data?.scope || '') + const [authCollapsed, setAuthCollapsed] = React.useState(true) + const [configCollapsed, setConfigCollapsed] = React.useState(true) const [isFetchingIcon, setIsFetchingIcon] = useState(false) const appIconRef = useRef(null) const isHovering = useHover(appIconRef) @@ -86,6 +97,10 @@ const MCPModal = ({ setMcpTimeout(data.timeout || 30) setSseReadTimeout(data.sse_read_timeout || 300) setHeaders(data.masked_headers || {}) + setClientId(data.client_id || '') + setClientSecret(data.client_secret || '') + setGrantType(data.grant_type || 'authorization_code') + setScope(data.scope || '') setAppIcon(getIcon(data)) } else { @@ -96,6 +111,10 @@ const MCPModal = ({ setMcpTimeout(30) setSseReadTimeout(300) setHeaders({}) + setClientId('') + setClientSecret('') + setGrantType('authorization_code') + setScope('') setAppIcon(DEFAULT_ICON as AppIconSelection) } }, [data]) @@ -124,7 +143,8 @@ const MCPModal = ({ setIsFetchingIcon(true) try { const res = await uploadRemoteFileInfo(remoteIcon, undefined, true) - setAppIcon({ type: 'image', url: res.url, fileId: extractFileId(res.url) || '' }) + if ('url' in res) + setAppIcon({ type: 'image', url: res.url, fileId: extractFileId(res.url) || '' }) } catch (e) { let errorMessage = 'Failed to fetch remote icon' @@ -158,6 +178,10 @@ const MCPModal = ({ timeout: timeout || 30, sse_read_timeout: sseReadTimeout || 300, headers: Object.keys(headers).length > 0 ? headers : undefined, + client_id: clientId || undefined, + client_secret: clientSecret || undefined, + grant_type: grantType, + scope: scope || undefined, }) if(isCreate) onHide() @@ -236,41 +260,116 @@ const MCPModal = ({ )} +
-
- {t('tools.mcp.modal.timeout')} +
setAuthCollapsed(!authCollapsed)} + > + {t('tools.mcp.modal.authentication')} +
- setMcpTimeout(Number(e.target.value))} - onBlur={e => handleBlur(e.target.value.trim())} - placeholder={t('tools.mcp.modal.timeoutPlaceholder')} - /> + {!authCollapsed && ( +
+
+
+ {t('tools.mcp.modal.grantType')} +
+ setClientId(e.target.value)} + placeholder={t('tools.mcp.modal.clientIdPlaceholder')} + /> +
+
+
+ {t('tools.mcp.modal.clientSecret')} +
+ setClientSecret(e.target.value)} + placeholder={t('tools.mcp.modal.clientSecretPlaceholder')} + /> +
+ {grantType === 'client_credentials' && ( +
+
+ {t('tools.mcp.modal.scope')} +
+ setScope(e.target.value)} + placeholder={t('tools.mcp.modal.scopePlaceholder')} + /> +
+ )} +
+
+ {t('tools.mcp.modal.headers')} +
+
{t('tools.mcp.modal.headersTip')}
+ 0} + /> +
+
+ )}
+
-
- {t('tools.mcp.modal.sseReadTimeout')} +
setConfigCollapsed(!configCollapsed)} + > + {t('tools.mcp.modal.configuration')} +
- setSseReadTimeout(Number(e.target.value))} - onBlur={e => handleBlur(e.target.value.trim())} - placeholder={t('tools.mcp.modal.timeoutPlaceholder')} - /> -
-
-
- {t('tools.mcp.modal.headers')} -
-
{t('tools.mcp.modal.headersTip')}
- 0} - /> + {!configCollapsed && ( +
+
+
+ {t('tools.mcp.modal.timeout')} +
+ setMcpTimeout(Number(e.target.value))} + onBlur={e => handleBlur(e.target.value.trim())} + placeholder={t('tools.mcp.modal.timeoutPlaceholder')} + /> +
+
+
+ {t('tools.mcp.modal.sseReadTimeout')} +
+ setSseReadTimeout(Number(e.target.value))} + onBlur={e => handleBlur(e.target.value.trim())} + placeholder={t('tools.mcp.modal.timeoutPlaceholder')} + /> +
+
+ )}
diff --git a/web/app/components/tools/types.ts b/web/app/components/tools/types.ts index 5a5c2e0400..7dce0ceeda 100644 --- a/web/app/components/tools/types.ts +++ b/web/app/components/tools/types.ts @@ -61,6 +61,10 @@ export type Collection = { sse_read_timeout?: number headers?: Record masked_headers?: Record + client_id?: string + client_secret?: string + grant_type?: string + scope?: string } export type ToolParameter = { diff --git a/web/i18n/en-US/tools.ts b/web/i18n/en-US/tools.ts index 97c557e62d..468c2bdf6c 100644 --- a/web/i18n/en-US/tools.ts +++ b/web/i18n/en-US/tools.ts @@ -203,6 +203,17 @@ const translation = { timeout: 'Timeout', sseReadTimeout: 'SSE Read Timeout', timeoutPlaceholder: '30', + configuration: 'Configuration', + authentication: 'Authentication', + grantType: 'Grant Type', + grantTypeAuthCode: 'Authorization Code (User Authentication)', + grantTypeClientCredentials: 'Client Credentials (Service-to-Service)', + scope: 'OAuth Scope', + scopePlaceholder: 'Enter OAuth scope (optional)', + clientId: 'Client ID', + clientIdPlaceholder: 'Enter client ID', + clientSecret: 'Client Secret', + clientSecretPlaceholder: 'Enter client secret', }, delete: 'Remove MCP Server', deleteConfirmTitle: 'Would you like to remove {{mcp}}?', diff --git a/web/i18n/zh-Hans/tools.ts b/web/i18n/zh-Hans/tools.ts index 9ade1caaad..b9c3a851ca 100644 --- a/web/i18n/zh-Hans/tools.ts +++ b/web/i18n/zh-Hans/tools.ts @@ -203,6 +203,12 @@ const translation = { timeout: '超时时间', sseReadTimeout: 'SSE 读取超时时间', timeoutPlaceholder: '30', + configuration: '配置', + authentication: '认证', + clientId: '客户端 ID', + clientIdPlaceholder: '请输入客户端 ID', + clientSecret: '客户端密钥', + clientSecretPlaceholder: '请输入客户端密钥', }, delete: '删除 MCP 服务', deleteConfirmTitle: '你想要删除 {{mcp}} 吗?', diff --git a/web/service/use-tools.ts b/web/service/use-tools.ts index 4bd265bf51..b0563cd7ac 100644 --- a/web/service/use-tools.ts +++ b/web/service/use-tools.ts @@ -88,6 +88,10 @@ export const useCreateMCP = () => { timeout?: number sse_read_timeout?: number headers?: Record + client_id?: string + client_secret?: string + grant_type?: string + scope?: string }) => { return post('workspaces/current/tool-provider/mcp', { body: { @@ -115,6 +119,10 @@ export const useUpdateMCP = ({ timeout?: number sse_read_timeout?: number headers?: Record + client_id?: string + client_secret?: string + grant_type?: string + scope?: string }) => { return put('workspaces/current/tool-provider/mcp', { body: {