mirror of https://github.com/langgenius/dify.git
chore: change the field
This commit is contained in:
parent
a538f80e95
commit
0a6da0bf2f
|
|
@ -17,6 +17,7 @@ from controllers.console.wraps import (
|
|||
enterprise_license_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPSupportGrantType
|
||||
from core.mcp.auth.auth_flow import auth, handle_callback
|
||||
from core.mcp.error import MCPAuthError, MCPError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
|
|
@ -895,39 +896,37 @@ class ToolProviderMCPApi(Resource):
|
|||
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
|
||||
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("timeout", type=float, required=False, nullable=False, location="json", default=30)
|
||||
parser.add_argument(
|
||||
"sse_read_timeout", type=float, required=False, nullable=False, location="json", default=300
|
||||
)
|
||||
parser.add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={})
|
||||
parser.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
|
||||
parser.add_argument("client_id", type=str, required=False, nullable=True, location="json", default="")
|
||||
parser.add_argument("client_secret", type=str, required=False, nullable=True, location="json", default="")
|
||||
parser.add_argument(
|
||||
"grant_type", type=str, required=False, nullable=True, location="json", default="authorization_code"
|
||||
)
|
||||
parser.add_argument("scope", type=str, required=False, nullable=True, location="json", default="")
|
||||
parser.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
|
||||
args = parser.parse_args()
|
||||
user = current_user
|
||||
|
||||
# Validate server URL
|
||||
if not is_valid_url(args["server_url"]):
|
||||
raise ValueError("Server URL is not valid.")
|
||||
|
||||
# Parse and validate models
|
||||
configuration = MCPConfiguration.model_validate(args["configuration"])
|
||||
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
||||
|
||||
# Create provider
|
||||
with Session(db.engine) as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
result = service.create_provider(
|
||||
tenant_id=user.current_tenant_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
user_id=current_user.id,
|
||||
server_url=args["server_url"],
|
||||
name=args["name"],
|
||||
icon=args["icon"],
|
||||
icon_type=args["icon_type"],
|
||||
icon_background=args["icon_background"],
|
||||
user_id=user.id,
|
||||
server_identifier=args["server_identifier"],
|
||||
timeout=args["timeout"],
|
||||
sse_read_timeout=args["sse_read_timeout"],
|
||||
timeout=configuration.timeout,
|
||||
sse_read_timeout=configuration.sse_read_timeout,
|
||||
headers=args["headers"],
|
||||
client_id=args["client_id"],
|
||||
client_secret=args["client_secret"],
|
||||
grant_type=args["grant_type"],
|
||||
scope=args["scope"],
|
||||
client_id=authentication.client_id if authentication else None,
|
||||
client_secret=authentication.client_secret if authentication else None,
|
||||
grant_type=authentication.grant_type if authentication else MCPSupportGrantType.AUTHORIZATION_CODE,
|
||||
)
|
||||
session.commit()
|
||||
return jsonable_encoder(result)
|
||||
|
|
@ -944,19 +943,17 @@ class ToolProviderMCPApi(Resource):
|
|||
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("timeout", type=float, required=False, nullable=True, location="json")
|
||||
parser.add_argument("sse_read_timeout", type=float, required=False, nullable=True, location="json")
|
||||
parser.add_argument("headers", type=dict, required=False, nullable=True, location="json")
|
||||
parser.add_argument("client_id", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("client_secret", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("grant_type", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("scope", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("configuration", type=dict, required=False, nullable=True, location="json")
|
||||
parser.add_argument("authentication", type=dict, required=False, nullable=True, location="json")
|
||||
|
||||
args = parser.parse_args()
|
||||
if not is_valid_url(args["server_url"]):
|
||||
if "[__HIDDEN__]" in args["server_url"]:
|
||||
pass
|
||||
else:
|
||||
raise ValueError("Server URL is not valid.")
|
||||
configuration = MCPConfiguration.model_validate(args["configuration"])
|
||||
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
||||
with Session(db.engine) as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
service.update_provider(
|
||||
|
|
@ -968,13 +965,12 @@ class ToolProviderMCPApi(Resource):
|
|||
icon_type=args["icon_type"],
|
||||
icon_background=args["icon_background"],
|
||||
server_identifier=args["server_identifier"],
|
||||
timeout=args.get("timeout"),
|
||||
sse_read_timeout=args.get("sse_read_timeout"),
|
||||
headers=args.get("headers"),
|
||||
client_id=args.get("client_id"),
|
||||
client_secret=args.get("client_secret"),
|
||||
grant_type=args.get("grant_type"),
|
||||
scope=args.get("scope"),
|
||||
timeout=configuration.timeout,
|
||||
sse_read_timeout=configuration.sse_read_timeout,
|
||||
headers=args["headers"],
|
||||
client_id=authentication.client_id if authentication else None,
|
||||
client_secret=authentication.client_secret if authentication else None,
|
||||
grant_type=authentication.grant_type if authentication else MCPSupportGrantType.AUTHORIZATION_CODE,
|
||||
)
|
||||
session.commit()
|
||||
return {"result": "success"}
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
import json
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
|
|
@ -28,6 +29,24 @@ MASK_CHAR = "*"
|
|||
MIN_UNMASK_LENGTH = 6
|
||||
|
||||
|
||||
class MCPSupportGrantType(StrEnum):
|
||||
"""The supported grant types for MCP"""
|
||||
|
||||
AUTHORIZATION_CODE = "authorization_code"
|
||||
CLIENT_CREDENTIALS = "client_credentials"
|
||||
|
||||
|
||||
class MCPAuthentication(BaseModel):
|
||||
client_id: str
|
||||
client_secret: str | None = None
|
||||
grant_type: MCPSupportGrantType = Field(default=MCPSupportGrantType.AUTHORIZATION_CODE)
|
||||
|
||||
|
||||
class MCPConfiguration(BaseModel):
|
||||
timeout: float = 30
|
||||
sse_read_timeout: float = 300
|
||||
|
||||
|
||||
class MCPProviderEntity(BaseModel):
|
||||
"""MCP Provider domain entity for business logic operations"""
|
||||
|
||||
|
|
@ -143,19 +162,25 @@ class MCPProviderEntity(BaseModel):
|
|||
"is_team_authorization": self.authed,
|
||||
"server_url": self.masked_server_url(),
|
||||
"server_identifier": self.provider_id,
|
||||
"timeout": self.timeout,
|
||||
"sse_read_timeout": self.sse_read_timeout,
|
||||
"masked_headers": self.masked_headers(),
|
||||
"updated_at": int(self.updated_at.timestamp()),
|
||||
"label": I18nObject(en_US=self.name, zh_Hans=self.name).to_dict(),
|
||||
"description": I18nObject(en_US="", zh_Hans="").to_dict(),
|
||||
}
|
||||
|
||||
# Add masked credentials if they exist
|
||||
# Add configuration
|
||||
response["configuration"] = {
|
||||
"timeout": str(self.timeout),
|
||||
"sse_read_timeout": str(self.sse_read_timeout),
|
||||
}
|
||||
|
||||
# Add masked headers
|
||||
response["masked_headers"] = self.masked_headers()
|
||||
|
||||
# Add authentication info if available
|
||||
masked_creds = self.masked_credentials()
|
||||
if masked_creds:
|
||||
response.update(masked_creds)
|
||||
|
||||
response["authentication"] = masked_creds
|
||||
response["is_dynamic_registration"] = self.credentials.get("is_dynamic_registration", True)
|
||||
return response
|
||||
|
||||
def retrieve_client_information(self) -> OAuthClientInformation | None:
|
||||
|
|
@ -255,8 +280,6 @@ class MCPProviderEntity(BaseModel):
|
|||
# Include non-sensitive fields (check both flat and nested structures)
|
||||
if "grant_type" in credentials:
|
||||
masked["grant_type"] = credentials["grant_type"]
|
||||
if "scope" in credentials:
|
||||
masked["scope"] = credentials["scope"]
|
||||
|
||||
return masked
|
||||
|
||||
|
|
|
|||
|
|
@ -44,15 +44,14 @@ class ToolProviderApiEntity(BaseModel):
|
|||
server_url: str | None = Field(default="", description="The server url of the tool")
|
||||
updated_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
|
||||
server_identifier: str | None = Field(default="", description="The server identifier of the MCP tool")
|
||||
timeout: float | None = Field(default=30.0, description="The timeout of the MCP tool")
|
||||
sse_read_timeout: float | None = Field(default=300.0, description="The SSE read timeout of the MCP tool")
|
||||
|
||||
masked_headers: dict[str, str] | None = Field(default=None, description="The masked headers of the MCP tool")
|
||||
original_headers: dict[str, str] | None = Field(default=None, description="The original headers of the MCP tool")
|
||||
# MCP OAuth credentials
|
||||
client_id: str | None = Field(default=None, description="The masked client ID for OAuth")
|
||||
client_secret: str | None = Field(default=None, description="The masked client secret for OAuth")
|
||||
grant_type: str | None = Field(default=None, description="The OAuth grant type")
|
||||
scope: str | None = Field(default=None, description="The OAuth scope")
|
||||
authentication: dict[str, str] | None = Field(default=None, description="The OAuth config of the MCP tool")
|
||||
is_dynamic_registration: bool = Field(default=True, description="Whether the MCP tool is dynamically registered")
|
||||
configuration: dict[str, str] | None = Field(
|
||||
default=None, description="The timeout and sse_read_timeout of the MCP tool"
|
||||
)
|
||||
|
||||
@field_validator("tools", mode="before")
|
||||
@classmethod
|
||||
|
|
@ -75,13 +74,11 @@ class ToolProviderApiEntity(BaseModel):
|
|||
if self.type == ToolProviderType.MCP:
|
||||
optional_fields.update(self.optional_field("updated_at", self.updated_at))
|
||||
optional_fields.update(self.optional_field("server_identifier", self.server_identifier))
|
||||
optional_fields.update(self.optional_field("timeout", self.timeout))
|
||||
optional_fields.update(self.optional_field("sse_read_timeout", self.sse_read_timeout))
|
||||
optional_fields.update(self.optional_field("configuration", self.configuration))
|
||||
optional_fields.update(self.optional_field("authentication", self.authentication))
|
||||
optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration))
|
||||
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
|
||||
optional_fields.update(self.optional_field("client_id", self.client_id))
|
||||
optional_fields.update(self.optional_field("client_secret", self.client_secret))
|
||||
optional_fields.update(self.optional_field("grant_type", self.grant_type))
|
||||
optional_fields.update(self.optional_field("scope", self.scope))
|
||||
optional_fields.update(self.optional_field("original_headers", self.original_headers))
|
||||
return {
|
||||
"id": self.id,
|
||||
"author": self.author,
|
||||
|
|
|
|||
|
|
@ -94,7 +94,6 @@ class MCPToolManageService:
|
|||
client_id: str | None = None,
|
||||
client_secret: str | None = None,
|
||||
grant_type: str = DEFAULT_GRANT_TYPE,
|
||||
scope: str | None = None,
|
||||
) -> ToolProviderApiEntity:
|
||||
"""Create a new MCP provider."""
|
||||
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
|
||||
|
|
@ -107,9 +106,7 @@ class MCPToolManageService:
|
|||
encrypted_headers = self._prepare_encrypted_dict(headers, tenant_id) if headers else None
|
||||
if client_id and client_secret:
|
||||
# Build the full credentials structure with encrypted client_id and client_secret
|
||||
encrypted_credentials = self._build_and_encrypt_credentials(
|
||||
client_id, client_secret, grant_type, scope, tenant_id
|
||||
)
|
||||
encrypted_credentials = self._build_and_encrypt_credentials(client_id, client_secret, grant_type, tenant_id)
|
||||
else:
|
||||
encrypted_credentials = None
|
||||
# Create provider
|
||||
|
|
@ -151,7 +148,6 @@ class MCPToolManageService:
|
|||
client_id: str | None = None,
|
||||
client_secret: str | None = None,
|
||||
grant_type: str | None = None,
|
||||
scope: str | None = None,
|
||||
) -> None:
|
||||
"""Update an MCP provider."""
|
||||
mcp_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
|
|
@ -210,15 +206,14 @@ class MCPToolManageService:
|
|||
final_client_id,
|
||||
final_client_secret,
|
||||
final_grant_type,
|
||||
final_scope,
|
||||
) = self._merge_credentials_with_masked(client_id, client_secret, grant_type, scope, mcp_provider)
|
||||
) = self._merge_credentials_with_masked(client_id, client_secret, grant_type, mcp_provider)
|
||||
|
||||
# Use default grant_type if none found
|
||||
final_grant_type = final_grant_type or DEFAULT_GRANT_TYPE
|
||||
|
||||
# Build and encrypt new credentials
|
||||
encrypted_credentials = self._build_and_encrypt_credentials(
|
||||
final_client_id, final_client_secret, final_grant_type, final_scope, tenant_id
|
||||
final_client_id, final_client_secret, final_grant_type, tenant_id
|
||||
)
|
||||
mcp_provider.encrypted_credentials = encrypted_credentials
|
||||
|
||||
|
|
@ -502,20 +497,18 @@ class MCPToolManageService:
|
|||
client_id: str,
|
||||
client_secret: str,
|
||||
grant_type: str | None,
|
||||
scope: str | None,
|
||||
mcp_provider: MCPToolProvider,
|
||||
) -> tuple[str, str, str | None, str | None]:
|
||||
) -> tuple[str, str, str | None]:
|
||||
"""Merge incoming credentials with existing ones, preserving unchanged masked values.
|
||||
|
||||
Args:
|
||||
client_id: Client ID from frontend (may be masked)
|
||||
client_secret: Client secret from frontend (may be masked)
|
||||
grant_type: Grant type from frontend
|
||||
scope: OAuth scope from frontend
|
||||
mcp_provider: The MCP provider instance
|
||||
|
||||
Returns:
|
||||
Tuple of (final_client_id, final_client_secret, grant_type, scope)
|
||||
Tuple of (final_client_id, final_client_secret, grant_type)
|
||||
"""
|
||||
mcp_provider_entity = mcp_provider.to_entity()
|
||||
existing_decrypted = mcp_provider_entity.decrypt_credentials()
|
||||
|
|
@ -533,14 +526,12 @@ class MCPToolManageService:
|
|||
# Use existing decrypted value
|
||||
final_client_secret = existing_decrypted.get("client_secret", client_secret)
|
||||
|
||||
# Grant type and scope are not masked, use as is
|
||||
final_grant_type = grant_type if grant_type is not None else existing_decrypted.get("grant_type")
|
||||
final_scope = scope if scope is not None else existing_decrypted.get("scope")
|
||||
|
||||
return final_client_id, final_client_secret, final_grant_type, final_scope
|
||||
return final_client_id, final_client_secret, final_grant_type
|
||||
|
||||
def _build_and_encrypt_credentials(
|
||||
self, client_id: str, client_secret: str, grant_type: str, scope: str | None, tenant_id: str
|
||||
self, client_id: str, client_secret: str, grant_type: str, tenant_id: str
|
||||
) -> str:
|
||||
"""Build credentials and encrypt sensitive fields."""
|
||||
# Create a flat structure with all credential data
|
||||
|
|
@ -549,11 +540,9 @@ class MCPToolManageService:
|
|||
"client_secret": client_secret,
|
||||
"grant_type": grant_type,
|
||||
"client_name": CLIENT_NAME,
|
||||
"is_dynamic_registration": False,
|
||||
}
|
||||
|
||||
if scope:
|
||||
credentials_data["scope"] = scope
|
||||
|
||||
# Add grant types and response types based on grant_type
|
||||
if grant_type == "client_credentials":
|
||||
credentials_data["grant_types"] = json.dumps(["client_credentials"])
|
||||
|
|
|
|||
Loading…
Reference in New Issue