diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 68d17b5ada..34474ef506 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -17,6 +17,7 @@ from controllers.console.wraps import ( enterprise_license_required, setup_required, ) +from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPSupportGrantType from core.mcp.auth.auth_flow import auth, handle_callback from core.mcp.error import MCPAuthError, MCPError from core.mcp.mcp_client import MCPClient @@ -895,39 +896,37 @@ class ToolProviderMCPApi(Resource): parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json") parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="") parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json") - parser.add_argument("timeout", type=float, required=False, nullable=False, location="json", default=30) - parser.add_argument( - "sse_read_timeout", type=float, required=False, nullable=False, location="json", default=300 - ) + parser.add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={}) parser.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={}) - parser.add_argument("client_id", type=str, required=False, nullable=True, location="json", default="") - parser.add_argument("client_secret", type=str, required=False, nullable=True, location="json", default="") - parser.add_argument( - "grant_type", type=str, required=False, nullable=True, location="json", default="authorization_code" - ) - parser.add_argument("scope", type=str, required=False, nullable=True, location="json", default="") + parser.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={}) args = parser.parse_args() - user = current_user + + # Validate server URL if not is_valid_url(args["server_url"]): raise ValueError("Server URL is not valid.") + + # Parse and validate models + configuration = MCPConfiguration.model_validate(args["configuration"]) + authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None + + # Create provider with Session(db.engine) as session: service = MCPToolManageService(session=session) result = service.create_provider( - tenant_id=user.current_tenant_id, + tenant_id=current_user.current_tenant_id, + user_id=current_user.id, server_url=args["server_url"], name=args["name"], icon=args["icon"], icon_type=args["icon_type"], icon_background=args["icon_background"], - user_id=user.id, server_identifier=args["server_identifier"], - timeout=args["timeout"], - sse_read_timeout=args["sse_read_timeout"], + timeout=configuration.timeout, + sse_read_timeout=configuration.sse_read_timeout, headers=args["headers"], - client_id=args["client_id"], - client_secret=args["client_secret"], - grant_type=args["grant_type"], - scope=args["scope"], + client_id=authentication.client_id if authentication else None, + client_secret=authentication.client_secret if authentication else None, + grant_type=authentication.grant_type if authentication else MCPSupportGrantType.AUTHORIZATION_CODE, ) session.commit() return jsonable_encoder(result) @@ -944,19 +943,17 @@ class ToolProviderMCPApi(Resource): parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json") parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json") parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json") - parser.add_argument("timeout", type=float, required=False, nullable=True, location="json") - parser.add_argument("sse_read_timeout", type=float, required=False, nullable=True, location="json") - parser.add_argument("headers", type=dict, required=False, nullable=True, location="json") - parser.add_argument("client_id", type=str, required=False, nullable=True, location="json") - parser.add_argument("client_secret", type=str, required=False, nullable=True, location="json") - parser.add_argument("grant_type", type=str, required=False, nullable=True, location="json") - parser.add_argument("scope", type=str, required=False, nullable=True, location="json") + parser.add_argument("configuration", type=dict, required=False, nullable=True, location="json") + parser.add_argument("authentication", type=dict, required=False, nullable=True, location="json") + args = parser.parse_args() if not is_valid_url(args["server_url"]): if "[__HIDDEN__]" in args["server_url"]: pass else: raise ValueError("Server URL is not valid.") + configuration = MCPConfiguration.model_validate(args["configuration"]) + authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None with Session(db.engine) as session: service = MCPToolManageService(session=session) service.update_provider( @@ -968,13 +965,12 @@ class ToolProviderMCPApi(Resource): icon_type=args["icon_type"], icon_background=args["icon_background"], server_identifier=args["server_identifier"], - timeout=args.get("timeout"), - sse_read_timeout=args.get("sse_read_timeout"), - headers=args.get("headers"), - client_id=args.get("client_id"), - client_secret=args.get("client_secret"), - grant_type=args.get("grant_type"), - scope=args.get("scope"), + timeout=configuration.timeout, + sse_read_timeout=configuration.sse_read_timeout, + headers=args["headers"], + client_id=authentication.client_id if authentication else None, + client_secret=authentication.client_secret if authentication else None, + grant_type=authentication.grant_type if authentication else MCPSupportGrantType.AUTHORIZATION_CODE, ) session.commit() return {"result": "success"} diff --git a/api/core/entities/mcp_provider.py b/api/core/entities/mcp_provider.py index 7ab1117769..ef95ed9b4c 100644 --- a/api/core/entities/mcp_provider.py +++ b/api/core/entities/mcp_provider.py @@ -1,9 +1,10 @@ import json from datetime import datetime +from enum import StrEnum from typing import TYPE_CHECKING, Any from urllib.parse import urlparse -from pydantic import BaseModel +from pydantic import BaseModel, Field from configs import dify_config from core.entities.provider_entities import BasicProviderConfig @@ -28,6 +29,24 @@ MASK_CHAR = "*" MIN_UNMASK_LENGTH = 6 +class MCPSupportGrantType(StrEnum): + """The supported grant types for MCP""" + + AUTHORIZATION_CODE = "authorization_code" + CLIENT_CREDENTIALS = "client_credentials" + + +class MCPAuthentication(BaseModel): + client_id: str + client_secret: str | None = None + grant_type: MCPSupportGrantType = Field(default=MCPSupportGrantType.AUTHORIZATION_CODE) + + +class MCPConfiguration(BaseModel): + timeout: float = 30 + sse_read_timeout: float = 300 + + class MCPProviderEntity(BaseModel): """MCP Provider domain entity for business logic operations""" @@ -143,19 +162,25 @@ class MCPProviderEntity(BaseModel): "is_team_authorization": self.authed, "server_url": self.masked_server_url(), "server_identifier": self.provider_id, - "timeout": self.timeout, - "sse_read_timeout": self.sse_read_timeout, - "masked_headers": self.masked_headers(), "updated_at": int(self.updated_at.timestamp()), "label": I18nObject(en_US=self.name, zh_Hans=self.name).to_dict(), "description": I18nObject(en_US="", zh_Hans="").to_dict(), } - # Add masked credentials if they exist + # Add configuration + response["configuration"] = { + "timeout": str(self.timeout), + "sse_read_timeout": str(self.sse_read_timeout), + } + + # Add masked headers + response["masked_headers"] = self.masked_headers() + + # Add authentication info if available masked_creds = self.masked_credentials() if masked_creds: - response.update(masked_creds) - + response["authentication"] = masked_creds + response["is_dynamic_registration"] = self.credentials.get("is_dynamic_registration", True) return response def retrieve_client_information(self) -> OAuthClientInformation | None: @@ -255,8 +280,6 @@ class MCPProviderEntity(BaseModel): # Include non-sensitive fields (check both flat and nested structures) if "grant_type" in credentials: masked["grant_type"] = credentials["grant_type"] - if "scope" in credentials: - masked["scope"] = credentials["scope"] return masked diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index 5bae6b71e3..c3cb19a312 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -44,15 +44,14 @@ class ToolProviderApiEntity(BaseModel): server_url: str | None = Field(default="", description="The server url of the tool") updated_at: int = Field(default_factory=lambda: int(datetime.now().timestamp())) server_identifier: str | None = Field(default="", description="The server identifier of the MCP tool") - timeout: float | None = Field(default=30.0, description="The timeout of the MCP tool") - sse_read_timeout: float | None = Field(default=300.0, description="The SSE read timeout of the MCP tool") + masked_headers: dict[str, str] | None = Field(default=None, description="The masked headers of the MCP tool") original_headers: dict[str, str] | None = Field(default=None, description="The original headers of the MCP tool") - # MCP OAuth credentials - client_id: str | None = Field(default=None, description="The masked client ID for OAuth") - client_secret: str | None = Field(default=None, description="The masked client secret for OAuth") - grant_type: str | None = Field(default=None, description="The OAuth grant type") - scope: str | None = Field(default=None, description="The OAuth scope") + authentication: dict[str, str] | None = Field(default=None, description="The OAuth config of the MCP tool") + is_dynamic_registration: bool = Field(default=True, description="Whether the MCP tool is dynamically registered") + configuration: dict[str, str] | None = Field( + default=None, description="The timeout and sse_read_timeout of the MCP tool" + ) @field_validator("tools", mode="before") @classmethod @@ -75,13 +74,11 @@ class ToolProviderApiEntity(BaseModel): if self.type == ToolProviderType.MCP: optional_fields.update(self.optional_field("updated_at", self.updated_at)) optional_fields.update(self.optional_field("server_identifier", self.server_identifier)) - optional_fields.update(self.optional_field("timeout", self.timeout)) - optional_fields.update(self.optional_field("sse_read_timeout", self.sse_read_timeout)) + optional_fields.update(self.optional_field("configuration", self.configuration)) + optional_fields.update(self.optional_field("authentication", self.authentication)) + optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration)) optional_fields.update(self.optional_field("masked_headers", self.masked_headers)) - optional_fields.update(self.optional_field("client_id", self.client_id)) - optional_fields.update(self.optional_field("client_secret", self.client_secret)) - optional_fields.update(self.optional_field("grant_type", self.grant_type)) - optional_fields.update(self.optional_field("scope", self.scope)) + optional_fields.update(self.optional_field("original_headers", self.original_headers)) return { "id": self.id, "author": self.author, diff --git a/api/services/tools/mcp_tools_manage_service.py b/api/services/tools/mcp_tools_manage_service.py index f4ea5424f8..0529b84fdb 100644 --- a/api/services/tools/mcp_tools_manage_service.py +++ b/api/services/tools/mcp_tools_manage_service.py @@ -94,7 +94,6 @@ class MCPToolManageService: client_id: str | None = None, client_secret: str | None = None, grant_type: str = DEFAULT_GRANT_TYPE, - scope: str | None = None, ) -> ToolProviderApiEntity: """Create a new MCP provider.""" server_url_hash = hashlib.sha256(server_url.encode()).hexdigest() @@ -107,9 +106,7 @@ class MCPToolManageService: encrypted_headers = self._prepare_encrypted_dict(headers, tenant_id) if headers else None if client_id and client_secret: # Build the full credentials structure with encrypted client_id and client_secret - encrypted_credentials = self._build_and_encrypt_credentials( - client_id, client_secret, grant_type, scope, tenant_id - ) + encrypted_credentials = self._build_and_encrypt_credentials(client_id, client_secret, grant_type, tenant_id) else: encrypted_credentials = None # Create provider @@ -151,7 +148,6 @@ class MCPToolManageService: client_id: str | None = None, client_secret: str | None = None, grant_type: str | None = None, - scope: str | None = None, ) -> None: """Update an MCP provider.""" mcp_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) @@ -210,15 +206,14 @@ class MCPToolManageService: final_client_id, final_client_secret, final_grant_type, - final_scope, - ) = self._merge_credentials_with_masked(client_id, client_secret, grant_type, scope, mcp_provider) + ) = self._merge_credentials_with_masked(client_id, client_secret, grant_type, mcp_provider) # Use default grant_type if none found final_grant_type = final_grant_type or DEFAULT_GRANT_TYPE # Build and encrypt new credentials encrypted_credentials = self._build_and_encrypt_credentials( - final_client_id, final_client_secret, final_grant_type, final_scope, tenant_id + final_client_id, final_client_secret, final_grant_type, tenant_id ) mcp_provider.encrypted_credentials = encrypted_credentials @@ -502,20 +497,18 @@ class MCPToolManageService: client_id: str, client_secret: str, grant_type: str | None, - scope: str | None, mcp_provider: MCPToolProvider, - ) -> tuple[str, str, str | None, str | None]: + ) -> tuple[str, str, str | None]: """Merge incoming credentials with existing ones, preserving unchanged masked values. Args: client_id: Client ID from frontend (may be masked) client_secret: Client secret from frontend (may be masked) grant_type: Grant type from frontend - scope: OAuth scope from frontend mcp_provider: The MCP provider instance Returns: - Tuple of (final_client_id, final_client_secret, grant_type, scope) + Tuple of (final_client_id, final_client_secret, grant_type) """ mcp_provider_entity = mcp_provider.to_entity() existing_decrypted = mcp_provider_entity.decrypt_credentials() @@ -533,14 +526,12 @@ class MCPToolManageService: # Use existing decrypted value final_client_secret = existing_decrypted.get("client_secret", client_secret) - # Grant type and scope are not masked, use as is final_grant_type = grant_type if grant_type is not None else existing_decrypted.get("grant_type") - final_scope = scope if scope is not None else existing_decrypted.get("scope") - return final_client_id, final_client_secret, final_grant_type, final_scope + return final_client_id, final_client_secret, final_grant_type def _build_and_encrypt_credentials( - self, client_id: str, client_secret: str, grant_type: str, scope: str | None, tenant_id: str + self, client_id: str, client_secret: str, grant_type: str, tenant_id: str ) -> str: """Build credentials and encrypt sensitive fields.""" # Create a flat structure with all credential data @@ -549,11 +540,9 @@ class MCPToolManageService: "client_secret": client_secret, "grant_type": grant_type, "client_name": CLIENT_NAME, + "is_dynamic_registration": False, } - if scope: - credentials_data["scope"] = scope - # Add grant types and response types based on grant_type if grant_type == "client_credentials": credentials_data["grant_types"] = json.dumps(["client_credentials"])