chore: change the field

This commit is contained in:
Novice 2025-10-13 18:06:23 +08:00
parent a538f80e95
commit 0a6da0bf2f
No known key found for this signature in database
GPG Key ID: EE3F68E3105DAAAB
4 changed files with 79 additions and 74 deletions

View File

@ -17,6 +17,7 @@ from controllers.console.wraps import (
enterprise_license_required,
setup_required,
)
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPSupportGrantType
from core.mcp.auth.auth_flow import auth, handle_callback
from core.mcp.error import MCPAuthError, MCPError
from core.mcp.mcp_client import MCPClient
@ -895,39 +896,37 @@ class ToolProviderMCPApi(Resource):
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
parser.add_argument("timeout", type=float, required=False, nullable=False, location="json", default=30)
parser.add_argument(
"sse_read_timeout", type=float, required=False, nullable=False, location="json", default=300
)
parser.add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={})
parser.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
parser.add_argument("client_id", type=str, required=False, nullable=True, location="json", default="")
parser.add_argument("client_secret", type=str, required=False, nullable=True, location="json", default="")
parser.add_argument(
"grant_type", type=str, required=False, nullable=True, location="json", default="authorization_code"
)
parser.add_argument("scope", type=str, required=False, nullable=True, location="json", default="")
parser.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
args = parser.parse_args()
user = current_user
# Validate server URL
if not is_valid_url(args["server_url"]):
raise ValueError("Server URL is not valid.")
# Parse and validate models
configuration = MCPConfiguration.model_validate(args["configuration"])
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
# Create provider
with Session(db.engine) as session:
service = MCPToolManageService(session=session)
result = service.create_provider(
tenant_id=user.current_tenant_id,
tenant_id=current_user.current_tenant_id,
user_id=current_user.id,
server_url=args["server_url"],
name=args["name"],
icon=args["icon"],
icon_type=args["icon_type"],
icon_background=args["icon_background"],
user_id=user.id,
server_identifier=args["server_identifier"],
timeout=args["timeout"],
sse_read_timeout=args["sse_read_timeout"],
timeout=configuration.timeout,
sse_read_timeout=configuration.sse_read_timeout,
headers=args["headers"],
client_id=args["client_id"],
client_secret=args["client_secret"],
grant_type=args["grant_type"],
scope=args["scope"],
client_id=authentication.client_id if authentication else None,
client_secret=authentication.client_secret if authentication else None,
grant_type=authentication.grant_type if authentication else MCPSupportGrantType.AUTHORIZATION_CODE,
)
session.commit()
return jsonable_encoder(result)
@ -944,19 +943,17 @@ class ToolProviderMCPApi(Resource):
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
parser.add_argument("timeout", type=float, required=False, nullable=True, location="json")
parser.add_argument("sse_read_timeout", type=float, required=False, nullable=True, location="json")
parser.add_argument("headers", type=dict, required=False, nullable=True, location="json")
parser.add_argument("client_id", type=str, required=False, nullable=True, location="json")
parser.add_argument("client_secret", type=str, required=False, nullable=True, location="json")
parser.add_argument("grant_type", type=str, required=False, nullable=True, location="json")
parser.add_argument("scope", type=str, required=False, nullable=True, location="json")
parser.add_argument("configuration", type=dict, required=False, nullable=True, location="json")
parser.add_argument("authentication", type=dict, required=False, nullable=True, location="json")
args = parser.parse_args()
if not is_valid_url(args["server_url"]):
if "[__HIDDEN__]" in args["server_url"]:
pass
else:
raise ValueError("Server URL is not valid.")
configuration = MCPConfiguration.model_validate(args["configuration"])
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
with Session(db.engine) as session:
service = MCPToolManageService(session=session)
service.update_provider(
@ -968,13 +965,12 @@ class ToolProviderMCPApi(Resource):
icon_type=args["icon_type"],
icon_background=args["icon_background"],
server_identifier=args["server_identifier"],
timeout=args.get("timeout"),
sse_read_timeout=args.get("sse_read_timeout"),
headers=args.get("headers"),
client_id=args.get("client_id"),
client_secret=args.get("client_secret"),
grant_type=args.get("grant_type"),
scope=args.get("scope"),
timeout=configuration.timeout,
sse_read_timeout=configuration.sse_read_timeout,
headers=args["headers"],
client_id=authentication.client_id if authentication else None,
client_secret=authentication.client_secret if authentication else None,
grant_type=authentication.grant_type if authentication else MCPSupportGrantType.AUTHORIZATION_CODE,
)
session.commit()
return {"result": "success"}

View File

@ -1,9 +1,10 @@
import json
from datetime import datetime
from enum import StrEnum
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse
from pydantic import BaseModel
from pydantic import BaseModel, Field
from configs import dify_config
from core.entities.provider_entities import BasicProviderConfig
@ -28,6 +29,24 @@ MASK_CHAR = "*"
MIN_UNMASK_LENGTH = 6
class MCPSupportGrantType(StrEnum):
"""The supported grant types for MCP"""
AUTHORIZATION_CODE = "authorization_code"
CLIENT_CREDENTIALS = "client_credentials"
class MCPAuthentication(BaseModel):
client_id: str
client_secret: str | None = None
grant_type: MCPSupportGrantType = Field(default=MCPSupportGrantType.AUTHORIZATION_CODE)
class MCPConfiguration(BaseModel):
timeout: float = 30
sse_read_timeout: float = 300
class MCPProviderEntity(BaseModel):
"""MCP Provider domain entity for business logic operations"""
@ -143,19 +162,25 @@ class MCPProviderEntity(BaseModel):
"is_team_authorization": self.authed,
"server_url": self.masked_server_url(),
"server_identifier": self.provider_id,
"timeout": self.timeout,
"sse_read_timeout": self.sse_read_timeout,
"masked_headers": self.masked_headers(),
"updated_at": int(self.updated_at.timestamp()),
"label": I18nObject(en_US=self.name, zh_Hans=self.name).to_dict(),
"description": I18nObject(en_US="", zh_Hans="").to_dict(),
}
# Add masked credentials if they exist
# Add configuration
response["configuration"] = {
"timeout": str(self.timeout),
"sse_read_timeout": str(self.sse_read_timeout),
}
# Add masked headers
response["masked_headers"] = self.masked_headers()
# Add authentication info if available
masked_creds = self.masked_credentials()
if masked_creds:
response.update(masked_creds)
response["authentication"] = masked_creds
response["is_dynamic_registration"] = self.credentials.get("is_dynamic_registration", True)
return response
def retrieve_client_information(self) -> OAuthClientInformation | None:
@ -255,8 +280,6 @@ class MCPProviderEntity(BaseModel):
# Include non-sensitive fields (check both flat and nested structures)
if "grant_type" in credentials:
masked["grant_type"] = credentials["grant_type"]
if "scope" in credentials:
masked["scope"] = credentials["scope"]
return masked

View File

@ -44,15 +44,14 @@ class ToolProviderApiEntity(BaseModel):
server_url: str | None = Field(default="", description="The server url of the tool")
updated_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
server_identifier: str | None = Field(default="", description="The server identifier of the MCP tool")
timeout: float | None = Field(default=30.0, description="The timeout of the MCP tool")
sse_read_timeout: float | None = Field(default=300.0, description="The SSE read timeout of the MCP tool")
masked_headers: dict[str, str] | None = Field(default=None, description="The masked headers of the MCP tool")
original_headers: dict[str, str] | None = Field(default=None, description="The original headers of the MCP tool")
# MCP OAuth credentials
client_id: str | None = Field(default=None, description="The masked client ID for OAuth")
client_secret: str | None = Field(default=None, description="The masked client secret for OAuth")
grant_type: str | None = Field(default=None, description="The OAuth grant type")
scope: str | None = Field(default=None, description="The OAuth scope")
authentication: dict[str, str] | None = Field(default=None, description="The OAuth config of the MCP tool")
is_dynamic_registration: bool = Field(default=True, description="Whether the MCP tool is dynamically registered")
configuration: dict[str, str] | None = Field(
default=None, description="The timeout and sse_read_timeout of the MCP tool"
)
@field_validator("tools", mode="before")
@classmethod
@ -75,13 +74,11 @@ class ToolProviderApiEntity(BaseModel):
if self.type == ToolProviderType.MCP:
optional_fields.update(self.optional_field("updated_at", self.updated_at))
optional_fields.update(self.optional_field("server_identifier", self.server_identifier))
optional_fields.update(self.optional_field("timeout", self.timeout))
optional_fields.update(self.optional_field("sse_read_timeout", self.sse_read_timeout))
optional_fields.update(self.optional_field("configuration", self.configuration))
optional_fields.update(self.optional_field("authentication", self.authentication))
optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration))
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
optional_fields.update(self.optional_field("client_id", self.client_id))
optional_fields.update(self.optional_field("client_secret", self.client_secret))
optional_fields.update(self.optional_field("grant_type", self.grant_type))
optional_fields.update(self.optional_field("scope", self.scope))
optional_fields.update(self.optional_field("original_headers", self.original_headers))
return {
"id": self.id,
"author": self.author,

View File

@ -94,7 +94,6 @@ class MCPToolManageService:
client_id: str | None = None,
client_secret: str | None = None,
grant_type: str = DEFAULT_GRANT_TYPE,
scope: str | None = None,
) -> ToolProviderApiEntity:
"""Create a new MCP provider."""
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
@ -107,9 +106,7 @@ class MCPToolManageService:
encrypted_headers = self._prepare_encrypted_dict(headers, tenant_id) if headers else None
if client_id and client_secret:
# Build the full credentials structure with encrypted client_id and client_secret
encrypted_credentials = self._build_and_encrypt_credentials(
client_id, client_secret, grant_type, scope, tenant_id
)
encrypted_credentials = self._build_and_encrypt_credentials(client_id, client_secret, grant_type, tenant_id)
else:
encrypted_credentials = None
# Create provider
@ -151,7 +148,6 @@ class MCPToolManageService:
client_id: str | None = None,
client_secret: str | None = None,
grant_type: str | None = None,
scope: str | None = None,
) -> None:
"""Update an MCP provider."""
mcp_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
@ -210,15 +206,14 @@ class MCPToolManageService:
final_client_id,
final_client_secret,
final_grant_type,
final_scope,
) = self._merge_credentials_with_masked(client_id, client_secret, grant_type, scope, mcp_provider)
) = self._merge_credentials_with_masked(client_id, client_secret, grant_type, mcp_provider)
# Use default grant_type if none found
final_grant_type = final_grant_type or DEFAULT_GRANT_TYPE
# Build and encrypt new credentials
encrypted_credentials = self._build_and_encrypt_credentials(
final_client_id, final_client_secret, final_grant_type, final_scope, tenant_id
final_client_id, final_client_secret, final_grant_type, tenant_id
)
mcp_provider.encrypted_credentials = encrypted_credentials
@ -502,20 +497,18 @@ class MCPToolManageService:
client_id: str,
client_secret: str,
grant_type: str | None,
scope: str | None,
mcp_provider: MCPToolProvider,
) -> tuple[str, str, str | None, str | None]:
) -> tuple[str, str, str | None]:
"""Merge incoming credentials with existing ones, preserving unchanged masked values.
Args:
client_id: Client ID from frontend (may be masked)
client_secret: Client secret from frontend (may be masked)
grant_type: Grant type from frontend
scope: OAuth scope from frontend
mcp_provider: The MCP provider instance
Returns:
Tuple of (final_client_id, final_client_secret, grant_type, scope)
Tuple of (final_client_id, final_client_secret, grant_type)
"""
mcp_provider_entity = mcp_provider.to_entity()
existing_decrypted = mcp_provider_entity.decrypt_credentials()
@ -533,14 +526,12 @@ class MCPToolManageService:
# Use existing decrypted value
final_client_secret = existing_decrypted.get("client_secret", client_secret)
# Grant type and scope are not masked, use as is
final_grant_type = grant_type if grant_type is not None else existing_decrypted.get("grant_type")
final_scope = scope if scope is not None else existing_decrypted.get("scope")
return final_client_id, final_client_secret, final_grant_type, final_scope
return final_client_id, final_client_secret, final_grant_type
def _build_and_encrypt_credentials(
self, client_id: str, client_secret: str, grant_type: str, scope: str | None, tenant_id: str
self, client_id: str, client_secret: str, grant_type: str, tenant_id: str
) -> str:
"""Build credentials and encrypt sensitive fields."""
# Create a flat structure with all credential data
@ -549,11 +540,9 @@ class MCPToolManageService:
"client_secret": client_secret,
"grant_type": grant_type,
"client_name": CLIENT_NAME,
"is_dynamic_registration": False,
}
if scope:
credentials_data["scope"] = scope
# Add grant types and response types based on grant_type
if grant_type == "client_credentials":
credentials_data["grant_types"] = json.dumps(["client_credentials"])