mirror of
https://github.com/langgenius/dify.git
synced 2026-04-27 02:36:29 +08:00
refactor(mcp): clean the client service code
This commit is contained in:
parent
f16151ea29
commit
aed9955105
@ -7,6 +7,7 @@ from flask_restx import (
|
|||||||
Resource,
|
Resource,
|
||||||
reqparse,
|
reqparse,
|
||||||
)
|
)
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
@ -17,13 +18,13 @@ from controllers.console.wraps import (
|
|||||||
setup_required,
|
setup_required,
|
||||||
)
|
)
|
||||||
from core.mcp.auth.auth_flow import auth, handle_callback
|
from core.mcp.auth.auth_flow import auth, handle_callback
|
||||||
from core.mcp.auth.auth_provider import OAuthClientProvider
|
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||||
from core.mcp.error import MCPAuthError, MCPError
|
from core.mcp.error import MCPError
|
||||||
from core.mcp.mcp_client import MCPClient
|
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.plugin.entities.plugin import ToolProviderID
|
from core.plugin.entities.plugin import ToolProviderID
|
||||||
from core.plugin.impl.oauth import OAuthHandler
|
from core.plugin.impl.oauth import OAuthHandler
|
||||||
from core.tools.entities.tool_entities import CredentialType
|
from core.tools.entities.tool_entities import CredentialType
|
||||||
|
from extensions.ext_database import db
|
||||||
from libs.helper import StrLen, alphanumeric, uuid_value
|
from libs.helper import StrLen, alphanumeric, uuid_value
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from services.plugin.oauth_service import OAuthProxyService
|
from services.plugin.oauth_service import OAuthProxyService
|
||||||
@ -870,8 +871,9 @@ class ToolProviderMCPApi(Resource):
|
|||||||
user = current_user
|
user = current_user
|
||||||
if not is_valid_url(args["server_url"]):
|
if not is_valid_url(args["server_url"]):
|
||||||
raise ValueError("Server URL is not valid.")
|
raise ValueError("Server URL is not valid.")
|
||||||
return jsonable_encoder(
|
with Session(db.engine) as session:
|
||||||
MCPToolManageService.create_mcp_provider(
|
service = MCPToolManageService(session=session)
|
||||||
|
result = service.create_provider(
|
||||||
tenant_id=user.current_tenant_id,
|
tenant_id=user.current_tenant_id,
|
||||||
server_url=args["server_url"],
|
server_url=args["server_url"],
|
||||||
name=args["name"],
|
name=args["name"],
|
||||||
@ -884,7 +886,8 @@ class ToolProviderMCPApi(Resource):
|
|||||||
sse_read_timeout=args["sse_read_timeout"],
|
sse_read_timeout=args["sse_read_timeout"],
|
||||||
headers=args["headers"],
|
headers=args["headers"],
|
||||||
)
|
)
|
||||||
)
|
session.commit()
|
||||||
|
return jsonable_encoder(result)
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -907,20 +910,23 @@ class ToolProviderMCPApi(Resource):
|
|||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
raise ValueError("Server URL is not valid.")
|
raise ValueError("Server URL is not valid.")
|
||||||
MCPToolManageService.update_mcp_provider(
|
with Session(db.engine) as session:
|
||||||
tenant_id=current_user.current_tenant_id,
|
service = MCPToolManageService(session=session)
|
||||||
provider_id=args["provider_id"],
|
service.update_provider(
|
||||||
server_url=args["server_url"],
|
tenant_id=current_user.current_tenant_id,
|
||||||
name=args["name"],
|
provider_id=args["provider_id"],
|
||||||
icon=args["icon"],
|
server_url=args["server_url"],
|
||||||
icon_type=args["icon_type"],
|
name=args["name"],
|
||||||
icon_background=args["icon_background"],
|
icon=args["icon"],
|
||||||
server_identifier=args["server_identifier"],
|
icon_type=args["icon_type"],
|
||||||
timeout=args.get("timeout"),
|
icon_background=args["icon_background"],
|
||||||
sse_read_timeout=args.get("sse_read_timeout"),
|
server_identifier=args["server_identifier"],
|
||||||
headers=args.get("headers"),
|
timeout=args.get("timeout"),
|
||||||
)
|
sse_read_timeout=args.get("sse_read_timeout"),
|
||||||
return {"result": "success"}
|
headers=args.get("headers"),
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
return {"result": "success"}
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -929,8 +935,11 @@ class ToolProviderMCPApi(Resource):
|
|||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
MCPToolManageService.delete_mcp_tool(tenant_id=current_user.current_tenant_id, provider_id=args["provider_id"])
|
with Session(db.engine) as session:
|
||||||
return {"result": "success"}
|
service = MCPToolManageService(session=session)
|
||||||
|
service.delete_provider(tenant_id=current_user.current_tenant_id, provider_id=args["provider_id"])
|
||||||
|
session.commit()
|
||||||
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
class ToolMCPAuthApi(Resource):
|
class ToolMCPAuthApi(Resource):
|
||||||
@ -944,45 +953,50 @@ class ToolMCPAuthApi(Resource):
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
provider_id = args["provider_id"]
|
provider_id = args["provider_id"]
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
|
||||||
if not provider:
|
|
||||||
raise ValueError("provider not found")
|
|
||||||
# headers1: if headers is provided, use it and don't need to get token
|
|
||||||
headers = provider.decrypted_headers or {}
|
|
||||||
|
|
||||||
# headers2: Add OAuth token if authed and no headers provided
|
with Session(db.engine) as session:
|
||||||
if not provider.decrypted_headers and provider.authed:
|
service = MCPToolManageService(session=session)
|
||||||
token = OAuthClientProvider(provider_id, tenant_id, for_list=True).tokens()
|
db_provider = service.get_provider_by_id(provider_id, tenant_id)
|
||||||
if token:
|
if not db_provider:
|
||||||
headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
|
raise ValueError("provider not found")
|
||||||
try:
|
|
||||||
# try to connect to MCP server with headers
|
|
||||||
with MCPClient(
|
|
||||||
provider.decrypted_server_url,
|
|
||||||
headers=headers,
|
|
||||||
timeout=provider.timeout,
|
|
||||||
sse_read_timeout=provider.sse_read_timeout,
|
|
||||||
):
|
|
||||||
MCPToolManageService.update_mcp_provider_credentials(
|
|
||||||
mcp_provider=provider,
|
|
||||||
credentials=provider.decrypted_credentials,
|
|
||||||
authed=True,
|
|
||||||
)
|
|
||||||
return {"result": "success"}
|
|
||||||
|
|
||||||
except MCPAuthError as e:
|
# Convert to entity
|
||||||
|
provider_entity = db_provider.to_entity()
|
||||||
|
server_url = provider_entity.decrypt_server_url()
|
||||||
|
|
||||||
|
# Option 1: if headers is provided, use it and don't need to get token
|
||||||
|
headers = provider_entity.decrypt_headers()
|
||||||
|
|
||||||
|
# Option 2: Add OAuth token if authed and no headers provided
|
||||||
|
if not provider_entity.headers and provider_entity.authed:
|
||||||
|
token = provider_entity.retrieve_tokens()
|
||||||
|
if token:
|
||||||
|
headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
|
||||||
try:
|
try:
|
||||||
if provider.decrypted_headers:
|
# Use MCPClientWithAuthRetry to handle authentication automatically
|
||||||
raise ValueError(f"Failed to authenticate, please check your headers: {e}") from e
|
with MCPClientWithAuthRetry(
|
||||||
# if auth failed, try to auth with OAuth or exchange token
|
server_url=server_url,
|
||||||
auth_provider = OAuthClientProvider(provider_id, tenant_id, for_list=True)
|
headers=headers,
|
||||||
return auth(auth_provider, provider.decrypted_server_url, args["authorization_code"])
|
timeout=provider_entity.timeout,
|
||||||
except Exception as e:
|
sse_read_timeout=provider_entity.sse_read_timeout,
|
||||||
MCPToolManageService.clear_mcp_provider_credentials(mcp_provider=provider)
|
provider_entity=provider_entity
|
||||||
raise ValueError(f"Failed to authenticate, please try again: {e}") from e
|
if not provider_entity.headers
|
||||||
except MCPError as e:
|
else None, # Only use auth retry if no custom headers
|
||||||
MCPToolManageService.clear_mcp_provider_credentials(mcp_provider=provider)
|
auth_callback=auth if not provider_entity.headers else None,
|
||||||
raise ValueError(f"Failed to connect to MCP server: {e}") from e
|
authorization_code=args.get("authorization_code"),
|
||||||
|
):
|
||||||
|
service.update_provider_credentials(
|
||||||
|
provider=db_provider,
|
||||||
|
credentials=provider_entity.credentials,
|
||||||
|
authed=True,
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
return {"result": "success"}
|
||||||
|
|
||||||
|
except MCPError as e:
|
||||||
|
service.clear_provider_credentials(provider=db_provider)
|
||||||
|
session.commit()
|
||||||
|
raise ValueError(f"Failed to connect to MCP server: {e}") from e
|
||||||
|
|
||||||
|
|
||||||
class ToolMCPDetailApi(Resource):
|
class ToolMCPDetailApi(Resource):
|
||||||
@ -991,8 +1005,10 @@ class ToolMCPDetailApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider_id):
|
def get(self, provider_id):
|
||||||
user = current_user
|
user = current_user
|
||||||
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, user.current_tenant_id)
|
with Session(db.engine) as session:
|
||||||
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
|
service = MCPToolManageService(session=session)
|
||||||
|
provider = service.get_provider_by_id(provider_id, user.current_tenant_id)
|
||||||
|
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
|
||||||
|
|
||||||
|
|
||||||
class ToolMCPListAllApi(Resource):
|
class ToolMCPListAllApi(Resource):
|
||||||
@ -1003,9 +1019,11 @@ class ToolMCPListAllApi(Resource):
|
|||||||
user = current_user
|
user = current_user
|
||||||
tenant_id = user.current_tenant_id
|
tenant_id = user.current_tenant_id
|
||||||
|
|
||||||
tools = MCPToolManageService.retrieve_mcp_tools(tenant_id=tenant_id)
|
with Session(db.engine) as session:
|
||||||
|
service = MCPToolManageService(session=session)
|
||||||
|
tools = service.list_providers(tenant_id=tenant_id)
|
||||||
|
|
||||||
return [tool.to_dict() for tool in tools]
|
return [tool.to_dict() for tool in tools]
|
||||||
|
|
||||||
|
|
||||||
class ToolMCPUpdateApi(Resource):
|
class ToolMCPUpdateApi(Resource):
|
||||||
@ -1014,11 +1032,13 @@ class ToolMCPUpdateApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider_id):
|
def get(self, provider_id):
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
tools = MCPToolManageService.list_mcp_tool_from_remote_server(
|
with Session(db.engine) as session:
|
||||||
tenant_id=tenant_id,
|
service = MCPToolManageService(session=session)
|
||||||
provider_id=provider_id,
|
tools = service.list_provider_tools(
|
||||||
)
|
tenant_id=tenant_id,
|
||||||
return jsonable_encoder(tools)
|
provider_id=provider_id,
|
||||||
|
)
|
||||||
|
return jsonable_encoder(tools)
|
||||||
|
|
||||||
|
|
||||||
class ToolMCPCallbackApi(Resource):
|
class ToolMCPCallbackApi(Resource):
|
||||||
|
|||||||
202
api/core/entities/mcp_provider.py
Normal file
202
api/core/entities/mcp_provider.py
Normal file
@ -0,0 +1,202 @@
|
|||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from core.entities.provider_entities import BasicProviderConfig
|
||||||
|
from core.file import helpers as file_helpers
|
||||||
|
from core.helper import encrypter
|
||||||
|
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||||
|
from core.mcp.types import OAuthClientInformation, OAuthClientMetadata, OAuthTokens
|
||||||
|
from core.tools.entities.common_entities import I18nObject
|
||||||
|
from core.tools.entities.tool_entities import ToolProviderType
|
||||||
|
from core.tools.utils.encryption import create_provider_encrypter
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from models.tools import MCPToolProvider
|
||||||
|
|
||||||
|
|
||||||
|
class MCPProviderEntity(BaseModel):
|
||||||
|
"""MCP Provider domain entity for business logic operations"""
|
||||||
|
|
||||||
|
# Basic identification
|
||||||
|
id: str
|
||||||
|
provider_id: str # server_identifier
|
||||||
|
name: str
|
||||||
|
tenant_id: str
|
||||||
|
user_id: str
|
||||||
|
|
||||||
|
# Server connection info
|
||||||
|
server_url: str # encrypted URL
|
||||||
|
headers: dict[str, str] # encrypted headers
|
||||||
|
timeout: float
|
||||||
|
sse_read_timeout: float
|
||||||
|
|
||||||
|
# Authentication related
|
||||||
|
authed: bool
|
||||||
|
credentials: dict[str, Any] # encrypted credentials
|
||||||
|
code_verifier: Optional[str] = None # for OAuth
|
||||||
|
|
||||||
|
# Tools and display info
|
||||||
|
tools: list[dict[str, Any]] # parsed tools list
|
||||||
|
icon: str | dict[str, str] # parsed icon
|
||||||
|
|
||||||
|
# Timestamps
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_db_model(cls, db_provider: "MCPToolProvider") -> "MCPProviderEntity":
|
||||||
|
"""Create entity from database model with decryption"""
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
id=db_provider.id,
|
||||||
|
provider_id=db_provider.server_identifier,
|
||||||
|
name=db_provider.name,
|
||||||
|
tenant_id=db_provider.tenant_id,
|
||||||
|
user_id=db_provider.user_id,
|
||||||
|
server_url=db_provider.server_url,
|
||||||
|
headers=db_provider.headers,
|
||||||
|
timeout=db_provider.timeout,
|
||||||
|
sse_read_timeout=db_provider.sse_read_timeout,
|
||||||
|
authed=db_provider.authed,
|
||||||
|
credentials=db_provider.credentials,
|
||||||
|
tools=db_provider.tool_dict,
|
||||||
|
icon=db_provider.icon or "",
|
||||||
|
created_at=db_provider.created_at,
|
||||||
|
updated_at=db_provider.updated_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def redirect_url(self) -> str:
|
||||||
|
"""OAuth redirect URL"""
|
||||||
|
return dify_config.CONSOLE_API_URL + "/console/api/mcp/oauth/callback"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def client_metadata(self) -> OAuthClientMetadata:
|
||||||
|
"""Metadata about this OAuth client."""
|
||||||
|
return OAuthClientMetadata(
|
||||||
|
redirect_uris=[self.redirect_url],
|
||||||
|
token_endpoint_auth_method="none",
|
||||||
|
grant_types=["authorization_code", "refresh_token"],
|
||||||
|
response_types=["code"],
|
||||||
|
client_name="Dify",
|
||||||
|
client_uri="https://github.com/langgenius/dify",
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider_icon(self) -> dict[str, str] | str:
|
||||||
|
"""Get provider icon, handling both dict and string formats"""
|
||||||
|
if isinstance(self.icon, dict):
|
||||||
|
return self.icon
|
||||||
|
try:
|
||||||
|
return json.loads(self.icon)
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
# If not JSON, assume it's a file path
|
||||||
|
return file_helpers.get_signed_file_url(self.icon)
|
||||||
|
|
||||||
|
def to_api_response(self, user_name: Optional[str] = None) -> dict[str, Any]:
|
||||||
|
"""Convert to API response format"""
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"author": user_name or "Anonymous",
|
||||||
|
"name": self.name,
|
||||||
|
"icon": self.provider_icon,
|
||||||
|
"type": ToolProviderType.MCP.value,
|
||||||
|
"is_team_authorization": self.authed,
|
||||||
|
"server_url": self.masked_server_url(),
|
||||||
|
"server_identifier": self.provider_id,
|
||||||
|
"timeout": self.timeout,
|
||||||
|
"sse_read_timeout": self.sse_read_timeout,
|
||||||
|
"masked_headers": self.masked_headers(),
|
||||||
|
"updated_at": int(self.updated_at.timestamp()),
|
||||||
|
"label": I18nObject(en_US=self.name, zh_Hans=self.name).to_dict(),
|
||||||
|
"description": I18nObject(en_US="", zh_Hans="").to_dict(),
|
||||||
|
}
|
||||||
|
|
||||||
|
def retrieve_client_information(self) -> Optional[OAuthClientInformation]:
|
||||||
|
"""OAuth client information if available"""
|
||||||
|
client_info = self.decrypt_credentials().get("client_information", {})
|
||||||
|
if not client_info:
|
||||||
|
return None
|
||||||
|
return OAuthClientInformation.model_validate(client_info)
|
||||||
|
|
||||||
|
def retrieve_tokens(self) -> Optional[OAuthTokens]:
|
||||||
|
"""OAuth tokens if available"""
|
||||||
|
if not self.credentials:
|
||||||
|
return None
|
||||||
|
credentials = self.decrypt_credentials()
|
||||||
|
return OAuthTokens(
|
||||||
|
access_token=credentials.get("access_token", ""),
|
||||||
|
token_type=credentials.get("token_type", "Bearer"),
|
||||||
|
expires_in=int(credentials.get("expires_in", "3600") or 3600),
|
||||||
|
refresh_token=credentials.get("refresh_token", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
def masked_server_url(self) -> str:
|
||||||
|
"""Masked server URL for display"""
|
||||||
|
parsed = urlparse(self.decrypt_server_url())
|
||||||
|
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
||||||
|
if parsed.path and parsed.path != "/":
|
||||||
|
return f"{base_url}/******"
|
||||||
|
return base_url
|
||||||
|
|
||||||
|
def masked_headers(self) -> dict[str, str]:
|
||||||
|
"""Masked headers for display"""
|
||||||
|
masked: dict[str, str] = {}
|
||||||
|
for key, value in self.decrypt_headers().items():
|
||||||
|
if len(value) > 6:
|
||||||
|
masked[key] = value[:2] + "*" * (len(value) - 4) + value[-2:]
|
||||||
|
else:
|
||||||
|
masked[key] = "*" * len(value)
|
||||||
|
return masked
|
||||||
|
|
||||||
|
def decrypt_server_url(self) -> str:
|
||||||
|
"""Decrypt server URL"""
|
||||||
|
|
||||||
|
return encrypter.decrypt_token(self.tenant_id, self.server_url)
|
||||||
|
|
||||||
|
def decrypt_headers(self) -> dict[str, Any]:
|
||||||
|
"""Decrypt headers"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not self.headers:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Create dynamic config for all headers as SECRET_INPUT
|
||||||
|
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in self.headers]
|
||||||
|
|
||||||
|
encrypter_instance, _ = create_provider_encrypter(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
config=config,
|
||||||
|
cache=NoOpProviderCredentialCache(),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = encrypter_instance.decrypt(self.headers)
|
||||||
|
return result
|
||||||
|
except Exception:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def decrypt_credentials(
|
||||||
|
self,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Decrypt credentials"""
|
||||||
|
try:
|
||||||
|
if not self.credentials:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
encrypter, _ = create_provider_encrypter(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
config=[
|
||||||
|
BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key)
|
||||||
|
for key in self.credentials
|
||||||
|
],
|
||||||
|
cache=NoOpProviderCredentialCache(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return encrypter.decrypt(self.credentials)
|
||||||
|
except Exception:
|
||||||
|
return {}
|
||||||
@ -9,8 +9,9 @@ from urllib.parse import urljoin, urlparse
|
|||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from core.mcp.auth.auth_provider import OAuthClientProvider
|
from core.entities.mcp_provider import MCPProviderEntity
|
||||||
from core.mcp.types import (
|
from core.mcp.types import (
|
||||||
LATEST_PROTOCOL_VERSION,
|
LATEST_PROTOCOL_VERSION,
|
||||||
OAuthClientInformation,
|
OAuthClientInformation,
|
||||||
@ -19,7 +20,9 @@ from core.mcp.types import (
|
|||||||
OAuthMetadata,
|
OAuthMetadata,
|
||||||
OAuthTokens,
|
OAuthTokens,
|
||||||
)
|
)
|
||||||
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
|
from services.tools.mcp_oauth_service import MCPOAuthService
|
||||||
|
|
||||||
OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry
|
OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry
|
||||||
OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
|
OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
|
||||||
@ -94,8 +97,13 @@ def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackSta
|
|||||||
full_state_data.code_verifier,
|
full_state_data.code_verifier,
|
||||||
full_state_data.redirect_uri,
|
full_state_data.redirect_uri,
|
||||||
)
|
)
|
||||||
provider = OAuthClientProvider(full_state_data.provider_id, full_state_data.tenant_id, for_list=True)
|
|
||||||
provider.save_tokens(tokens)
|
# Save tokens using the service layer
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
oauth_service = MCPOAuthService(session=session)
|
||||||
|
oauth_service.save_tokens(full_state_data.provider_id, full_state_data.tenant_id, tokens)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
return full_state_data
|
return full_state_data
|
||||||
|
|
||||||
|
|
||||||
@ -295,24 +303,33 @@ def register_client(
|
|||||||
|
|
||||||
|
|
||||||
def auth(
|
def auth(
|
||||||
provider: OAuthClientProvider,
|
provider: MCPProviderEntity,
|
||||||
server_url: str,
|
|
||||||
authorization_code: Optional[str] = None,
|
authorization_code: Optional[str] = None,
|
||||||
state_param: Optional[str] = None,
|
state_param: Optional[str] = None,
|
||||||
) -> dict[str, str]:
|
) -> dict[str, str]:
|
||||||
"""Orchestrates the full auth flow with a server using secure Redis state storage."""
|
"""Orchestrates the full auth flow with a server using secure Redis state storage."""
|
||||||
metadata = discover_oauth_metadata(server_url)
|
server_url = provider.decrypt_server_url()
|
||||||
|
server_metadata = discover_oauth_metadata(server_url)
|
||||||
|
client_metadata = provider.client_metadata
|
||||||
|
provider_id = provider.id
|
||||||
|
tenant_id = provider.tenant_id
|
||||||
|
client_information = provider.retrieve_client_information()
|
||||||
|
redirect_url = provider.redirect_url
|
||||||
|
|
||||||
# Handle client registration if needed
|
|
||||||
client_information = provider.client_information()
|
|
||||||
if not client_information:
|
if not client_information:
|
||||||
if authorization_code is not None:
|
if authorization_code is not None:
|
||||||
raise ValueError("Existing OAuth client information is required when exchanging an authorization code")
|
raise ValueError("Existing OAuth client information is required when exchanging an authorization code")
|
||||||
try:
|
try:
|
||||||
full_information = register_client(server_url, metadata, provider.client_metadata)
|
full_information = register_client(server_url, server_metadata, client_metadata)
|
||||||
except httpx.RequestError as e:
|
except httpx.RequestError as e:
|
||||||
raise ValueError(f"Could not register OAuth client: {e}")
|
raise ValueError(f"Could not register OAuth client: {e}")
|
||||||
provider.save_client_information(full_information)
|
|
||||||
|
# Save client information using service layer
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
oauth_service = MCPOAuthService(session=session)
|
||||||
|
oauth_service.save_client_information(provider_id, tenant_id, full_information)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
client_information = full_information
|
client_information = full_information
|
||||||
|
|
||||||
# Exchange authorization code for tokens
|
# Exchange authorization code for tokens
|
||||||
@ -335,22 +352,36 @@ def auth(
|
|||||||
|
|
||||||
tokens = exchange_authorization(
|
tokens = exchange_authorization(
|
||||||
server_url,
|
server_url,
|
||||||
metadata,
|
server_metadata,
|
||||||
client_information,
|
client_information,
|
||||||
authorization_code,
|
authorization_code,
|
||||||
code_verifier,
|
code_verifier,
|
||||||
redirect_uri,
|
redirect_uri,
|
||||||
)
|
)
|
||||||
provider.save_tokens(tokens)
|
|
||||||
|
# Save tokens using service layer
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
oauth_service = MCPOAuthService(session=session)
|
||||||
|
oauth_service.save_tokens(provider_id, tenant_id, tokens)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
provider_tokens = provider.tokens()
|
provider_tokens = provider.retrieve_tokens()
|
||||||
|
|
||||||
# Handle token refresh or new authorization
|
# Handle token refresh or new authorization
|
||||||
if provider_tokens and provider_tokens.refresh_token:
|
if provider_tokens and provider_tokens.refresh_token:
|
||||||
try:
|
try:
|
||||||
new_tokens = refresh_authorization(server_url, metadata, client_information, provider_tokens.refresh_token)
|
new_tokens = refresh_authorization(
|
||||||
provider.save_tokens(new_tokens)
|
server_url, server_metadata, client_information, provider_tokens.refresh_token
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save new tokens using service layer
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
oauth_service = MCPOAuthService(session=session)
|
||||||
|
oauth_service.save_tokens(provider_id, tenant_id, new_tokens)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Could not refresh OAuth tokens: {e}")
|
raise ValueError(f"Could not refresh OAuth tokens: {e}")
|
||||||
@ -358,12 +389,17 @@ def auth(
|
|||||||
# Start new authorization flow
|
# Start new authorization flow
|
||||||
authorization_url, code_verifier = start_authorization(
|
authorization_url, code_verifier = start_authorization(
|
||||||
server_url,
|
server_url,
|
||||||
metadata,
|
server_metadata,
|
||||||
client_information,
|
client_information,
|
||||||
provider.redirect_url,
|
redirect_url,
|
||||||
provider.mcp_provider.id,
|
provider_id,
|
||||||
provider.mcp_provider.tenant_id,
|
tenant_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
provider.save_code_verifier(code_verifier)
|
# Save code verifier using service layer
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
oauth_service = MCPOAuthService(session=session)
|
||||||
|
oauth_service.save_code_verifier(provider_id, tenant_id, code_verifier)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
return {"authorization_url": authorization_url}
|
return {"authorization_url": authorization_url}
|
||||||
|
|||||||
@ -1,79 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from configs import dify_config
|
|
||||||
from core.mcp.types import (
|
|
||||||
OAuthClientInformation,
|
|
||||||
OAuthClientInformationFull,
|
|
||||||
OAuthClientMetadata,
|
|
||||||
OAuthTokens,
|
|
||||||
)
|
|
||||||
from models.tools import MCPToolProvider
|
|
||||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
|
||||||
|
|
||||||
|
|
||||||
class OAuthClientProvider:
|
|
||||||
mcp_provider: MCPToolProvider
|
|
||||||
|
|
||||||
def __init__(self, provider_id: str, tenant_id: str, for_list: bool = False):
|
|
||||||
if for_list:
|
|
||||||
self.mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
|
||||||
else:
|
|
||||||
self.mcp_provider = MCPToolManageService.get_mcp_provider_by_server_identifier(provider_id, tenant_id)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def redirect_url(self) -> str:
|
|
||||||
"""The URL to redirect the user agent to after authorization."""
|
|
||||||
return dify_config.CONSOLE_API_URL + "/console/api/mcp/oauth/callback"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def client_metadata(self) -> OAuthClientMetadata:
|
|
||||||
"""Metadata about this OAuth client."""
|
|
||||||
return OAuthClientMetadata(
|
|
||||||
redirect_uris=[self.redirect_url],
|
|
||||||
token_endpoint_auth_method="none",
|
|
||||||
grant_types=["authorization_code", "refresh_token"],
|
|
||||||
response_types=["code"],
|
|
||||||
client_name="Dify",
|
|
||||||
client_uri="https://github.com/langgenius/dify",
|
|
||||||
)
|
|
||||||
|
|
||||||
def client_information(self) -> Optional[OAuthClientInformation]:
|
|
||||||
"""Loads information about this OAuth client."""
|
|
||||||
client_information = self.mcp_provider.decrypted_credentials.get("client_information", {})
|
|
||||||
if not client_information:
|
|
||||||
return None
|
|
||||||
return OAuthClientInformation.model_validate(client_information)
|
|
||||||
|
|
||||||
def save_client_information(self, client_information: OAuthClientInformationFull):
|
|
||||||
"""Saves client information after dynamic registration."""
|
|
||||||
MCPToolManageService.update_mcp_provider_credentials(
|
|
||||||
self.mcp_provider,
|
|
||||||
{"client_information": client_information.model_dump()},
|
|
||||||
)
|
|
||||||
|
|
||||||
def tokens(self) -> Optional[OAuthTokens]:
|
|
||||||
"""Loads any existing OAuth tokens for the current session."""
|
|
||||||
credentials = self.mcp_provider.decrypted_credentials
|
|
||||||
if not credentials:
|
|
||||||
return None
|
|
||||||
return OAuthTokens(
|
|
||||||
access_token=credentials.get("access_token", ""),
|
|
||||||
token_type=credentials.get("token_type", "Bearer"),
|
|
||||||
expires_in=int(credentials.get("expires_in", "3600") or 3600),
|
|
||||||
refresh_token=credentials.get("refresh_token", ""),
|
|
||||||
)
|
|
||||||
|
|
||||||
def save_tokens(self, tokens: OAuthTokens):
|
|
||||||
"""Stores new OAuth tokens for the current session."""
|
|
||||||
# update mcp provider credentials
|
|
||||||
token_dict = tokens.model_dump()
|
|
||||||
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, token_dict, authed=True)
|
|
||||||
|
|
||||||
def save_code_verifier(self, code_verifier: str):
|
|
||||||
"""Saves a PKCE code verifier for the current session."""
|
|
||||||
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, {"code_verifier": code_verifier})
|
|
||||||
|
|
||||||
def code_verifier(self) -> str:
|
|
||||||
"""Loads the PKCE code verifier for the current session."""
|
|
||||||
# get code verifier from mcp provider credentials
|
|
||||||
return str(self.mcp_provider.decrypted_credentials.get("code_verifier", ""))
|
|
||||||
204
api/core/mcp/auth_client.py
Normal file
204
api/core/mcp/auth_client.py
Normal file
@ -0,0 +1,204 @@
|
|||||||
|
"""
|
||||||
|
MCP Client with Authentication Retry Support
|
||||||
|
|
||||||
|
This module provides a wrapper around MCPClient that automatically handles
|
||||||
|
authentication failures and retries operations after refreshing tokens.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import Callable
|
||||||
|
from types import TracebackType
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from core.entities.mcp_provider import MCPProviderEntity
|
||||||
|
from core.mcp.error import MCPAuthError
|
||||||
|
from core.mcp.mcp_client import MCPClient
|
||||||
|
from core.mcp.types import CallToolResult, Tool
|
||||||
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MCPClientWithAuthRetry:
|
||||||
|
"""
|
||||||
|
A wrapper around MCPClient that provides automatic authentication retry.
|
||||||
|
|
||||||
|
This class intercepts MCPAuthError exceptions and attempts to refresh
|
||||||
|
authentication before retrying the failed operation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
server_url: str,
|
||||||
|
headers: dict[str, str] | None = None,
|
||||||
|
timeout: float | None = None,
|
||||||
|
sse_read_timeout: float | None = None,
|
||||||
|
provider_entity: MCPProviderEntity | None = None,
|
||||||
|
auth_callback: Callable[[MCPProviderEntity, Optional[str]], dict[str, str]] | None = None,
|
||||||
|
authorization_code: Optional[str] = None,
|
||||||
|
by_server_id: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the MCP client with auth retry capability.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_url: The MCP server URL
|
||||||
|
headers: Optional headers for requests
|
||||||
|
timeout: Request timeout
|
||||||
|
sse_read_timeout: SSE read timeout
|
||||||
|
provider_entity: Provider entity for authentication
|
||||||
|
auth_callback: Authentication callback function
|
||||||
|
authorization_code: Optional authorization code for initial auth
|
||||||
|
"""
|
||||||
|
self.server_url = server_url
|
||||||
|
self.headers = headers or {}
|
||||||
|
self.timeout = timeout
|
||||||
|
self.sse_read_timeout = sse_read_timeout
|
||||||
|
self.provider_entity = provider_entity
|
||||||
|
self.auth_callback = auth_callback
|
||||||
|
self.authorization_code = authorization_code
|
||||||
|
self._has_retried = False
|
||||||
|
self._client: MCPClient | None = None
|
||||||
|
self.by_server_id = by_server_id
|
||||||
|
|
||||||
|
def _create_client(self) -> MCPClient:
|
||||||
|
"""Create a new MCPClient instance with current headers."""
|
||||||
|
return MCPClient(
|
||||||
|
server_url=self.server_url,
|
||||||
|
headers=self.headers,
|
||||||
|
timeout=self.timeout,
|
||||||
|
sse_read_timeout=self.sse_read_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _handle_auth_error(self, error: MCPAuthError) -> None:
|
||||||
|
"""
|
||||||
|
Handle authentication error by refreshing tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error: The authentication error
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MCPAuthError: If authentication fails or max retries reached
|
||||||
|
"""
|
||||||
|
from services.tools.mcp_oauth_service import MCPOAuthService
|
||||||
|
|
||||||
|
if not self.provider_entity or not self.auth_callback:
|
||||||
|
raise error
|
||||||
|
|
||||||
|
if self._has_retried:
|
||||||
|
raise error
|
||||||
|
|
||||||
|
self._has_retried = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Perform authentication
|
||||||
|
self.auth_callback(self.provider_entity, self.authorization_code)
|
||||||
|
|
||||||
|
# Retrieve new tokens
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
oauth_service = MCPOAuthService(session=session)
|
||||||
|
self.provider_entity = oauth_service.get_provider_entity(
|
||||||
|
self.provider_entity.id, self.provider_entity.tenant_id, by_server_id=self.by_server_id
|
||||||
|
)
|
||||||
|
token = self.provider_entity.retrieve_tokens()
|
||||||
|
if not token:
|
||||||
|
raise MCPAuthError("Authentication failed - no token received")
|
||||||
|
|
||||||
|
# Update headers with new token
|
||||||
|
self.headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
|
||||||
|
|
||||||
|
# Clear authorization code after first use
|
||||||
|
self.authorization_code = None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Authentication retry failed")
|
||||||
|
raise MCPAuthError(f"Authentication retry failed: {e}") from e
|
||||||
|
|
||||||
|
def _execute_with_retry(self, func: Callable[..., Any], *args, **kwargs) -> Any:
|
||||||
|
"""
|
||||||
|
Execute a function with authentication retry logic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: The function to execute
|
||||||
|
*args: Positional arguments for the function
|
||||||
|
**kwargs: Keyword arguments for the function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The result of the function call
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MCPAuthError: If authentication fails after retries
|
||||||
|
Any other exceptions from the function
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
except MCPAuthError as e:
|
||||||
|
self._handle_auth_error(e)
|
||||||
|
# Recreate client with new headers
|
||||||
|
if self._client:
|
||||||
|
self._client.cleanup()
|
||||||
|
self._client = self._create_client()
|
||||||
|
self._client.__enter__()
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
finally:
|
||||||
|
# Reset retry flag after operation completes
|
||||||
|
self._has_retried = False
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
"""Enter the context manager."""
|
||||||
|
self._client = self._create_client()
|
||||||
|
|
||||||
|
# Try to initialize with retry
|
||||||
|
def initialize():
|
||||||
|
if self._client is None:
|
||||||
|
raise ValueError("Client not created")
|
||||||
|
self._client.__enter__()
|
||||||
|
return self
|
||||||
|
|
||||||
|
return self._execute_with_retry(initialize)
|
||||||
|
|
||||||
|
def __exit__(self, exc_type: type | None, exc_value: BaseException | None, traceback: TracebackType | None):
|
||||||
|
"""Exit the context manager."""
|
||||||
|
if self._client:
|
||||||
|
self._client.__exit__(exc_type, exc_value, traceback)
|
||||||
|
self._client = None
|
||||||
|
|
||||||
|
def list_tools(self) -> list[Tool]:
|
||||||
|
"""
|
||||||
|
List available tools from the MCP server.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of available tools
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MCPAuthError: If authentication fails after retries
|
||||||
|
"""
|
||||||
|
if not self._client:
|
||||||
|
raise ValueError("Client not initialized. Use within a context manager.")
|
||||||
|
return self._execute_with_retry(self._client.list_tools)
|
||||||
|
|
||||||
|
def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult:
|
||||||
|
"""
|
||||||
|
Invoke a tool on the MCP server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: Name of the tool to invoke
|
||||||
|
tool_args: Arguments for the tool
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Result of the tool invocation
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MCPAuthError: If authentication fails after retries
|
||||||
|
"""
|
||||||
|
if not self._client:
|
||||||
|
raise ValueError("Client not initialized. Use within a context manager.")
|
||||||
|
return self._execute_with_retry(self._client.invoke_tool, tool_name, tool_args)
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Clean up resources."""
|
||||||
|
if self._client:
|
||||||
|
self._client.cleanup()
|
||||||
|
self._client = None
|
||||||
@ -46,7 +46,6 @@ class ToolProviderApiEntity(BaseModel):
|
|||||||
timeout: Optional[float] = Field(default=30.0, description="The timeout of the MCP tool")
|
timeout: Optional[float] = Field(default=30.0, description="The timeout of the MCP tool")
|
||||||
sse_read_timeout: Optional[float] = Field(default=300.0, description="The SSE read timeout of the MCP tool")
|
sse_read_timeout: Optional[float] = Field(default=300.0, description="The SSE read timeout of the MCP tool")
|
||||||
masked_headers: Optional[dict[str, str]] = Field(default=None, description="The masked headers of the MCP tool")
|
masked_headers: Optional[dict[str, str]] = Field(default=None, description="The masked headers of the MCP tool")
|
||||||
original_headers: Optional[dict[str, str]] = Field(default=None, description="The original headers of the MCP tool")
|
|
||||||
|
|
||||||
@field_validator("tools", mode="before")
|
@field_validator("tools", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -72,7 +71,6 @@ class ToolProviderApiEntity(BaseModel):
|
|||||||
optional_fields.update(self.optional_field("timeout", self.timeout))
|
optional_fields.update(self.optional_field("timeout", self.timeout))
|
||||||
optional_fields.update(self.optional_field("sse_read_timeout", self.sse_read_timeout))
|
optional_fields.update(self.optional_field("sse_read_timeout", self.sse_read_timeout))
|
||||||
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
|
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
|
||||||
optional_fields.update(self.optional_field("original_headers", self.original_headers))
|
|
||||||
return {
|
return {
|
||||||
"id": self.id,
|
"id": self.id,
|
||||||
"author": self.author,
|
"author": self.author,
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import json
|
|
||||||
from typing import Any, Optional, Self
|
from typing import Any, Optional, Self
|
||||||
|
|
||||||
|
from core.entities.mcp_provider import MCPProviderEntity
|
||||||
from core.mcp.types import Tool as RemoteMCPTool
|
from core.mcp.types import Tool as RemoteMCPTool
|
||||||
from core.tools.__base.tool_provider import ToolProviderController
|
from core.tools.__base.tool_provider import ToolProviderController
|
||||||
from core.tools.__base.tool_runtime import ToolRuntime
|
from core.tools.__base.tool_runtime import ToolRuntime
|
||||||
@ -52,18 +52,28 @@ class MCPToolProviderController(ToolProviderController):
|
|||||||
"""
|
"""
|
||||||
from db provider
|
from db provider
|
||||||
"""
|
"""
|
||||||
tools = []
|
# Convert to entity first
|
||||||
tools_data = json.loads(db_provider.tools)
|
provider_entity = db_provider.to_entity()
|
||||||
remote_mcp_tools = [RemoteMCPTool(**tool) for tool in tools_data]
|
return cls.from_entity(provider_entity)
|
||||||
user = db_provider.load_user()
|
|
||||||
|
@classmethod
|
||||||
|
def from_entity(cls, entity: MCPProviderEntity) -> Self:
|
||||||
|
"""
|
||||||
|
create a MCPToolProviderController from a MCPProviderEntity
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
remote_mcp_tools = [RemoteMCPTool(**tool) for tool in entity.tools]
|
||||||
|
except Exception:
|
||||||
|
remote_mcp_tools = []
|
||||||
|
|
||||||
tools = [
|
tools = [
|
||||||
ToolEntity(
|
ToolEntity(
|
||||||
identity=ToolIdentity(
|
identity=ToolIdentity(
|
||||||
author=user.name if user else "Anonymous",
|
author="Anonymous", # Tool level author is not stored
|
||||||
name=remote_mcp_tool.name,
|
name=remote_mcp_tool.name,
|
||||||
label=I18nObject(en_US=remote_mcp_tool.name, zh_Hans=remote_mcp_tool.name),
|
label=I18nObject(en_US=remote_mcp_tool.name, zh_Hans=remote_mcp_tool.name),
|
||||||
provider=db_provider.server_identifier,
|
provider=entity.provider_id,
|
||||||
icon=db_provider.icon,
|
icon=entity.icon if isinstance(entity.icon, str) else "",
|
||||||
),
|
),
|
||||||
parameters=ToolTransformService.convert_mcp_schema_to_parameter(remote_mcp_tool.inputSchema),
|
parameters=ToolTransformService.convert_mcp_schema_to_parameter(remote_mcp_tool.inputSchema),
|
||||||
description=ToolDescription(
|
description=ToolDescription(
|
||||||
@ -81,22 +91,22 @@ class MCPToolProviderController(ToolProviderController):
|
|||||||
return cls(
|
return cls(
|
||||||
entity=ToolProviderEntityWithPlugin(
|
entity=ToolProviderEntityWithPlugin(
|
||||||
identity=ToolProviderIdentity(
|
identity=ToolProviderIdentity(
|
||||||
author=user.name if user else "Anonymous",
|
author="Anonymous", # Provider level author is not stored in entity
|
||||||
name=db_provider.name,
|
name=entity.name,
|
||||||
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
|
label=I18nObject(en_US=entity.name, zh_Hans=entity.name),
|
||||||
description=I18nObject(en_US="", zh_Hans=""),
|
description=I18nObject(en_US="", zh_Hans=""),
|
||||||
icon=db_provider.icon,
|
icon=entity.icon if isinstance(entity.icon, str) else "",
|
||||||
),
|
),
|
||||||
plugin_id=None,
|
plugin_id=None,
|
||||||
credentials_schema=[],
|
credentials_schema=[],
|
||||||
tools=tools,
|
tools=tools,
|
||||||
),
|
),
|
||||||
provider_id=db_provider.server_identifier or "",
|
provider_id=entity.provider_id,
|
||||||
tenant_id=db_provider.tenant_id or "",
|
tenant_id=entity.tenant_id,
|
||||||
server_url=db_provider.decrypted_server_url,
|
server_url=entity.server_url,
|
||||||
headers=db_provider.decrypted_headers or {},
|
headers=entity.headers,
|
||||||
timeout=db_provider.timeout,
|
timeout=entity.timeout,
|
||||||
sse_read_timeout=db_provider.sse_read_timeout,
|
sse_read_timeout=entity.sse_read_timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||||
|
|||||||
@ -4,8 +4,8 @@ from collections.abc import Generator
|
|||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from core.mcp.auth.auth_flow import auth
|
from core.mcp.auth.auth_flow import auth
|
||||||
from core.mcp.error import MCPAuthError, MCPConnectionError
|
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||||
from core.mcp.mcp_client import MCPClient
|
from core.mcp.error import MCPConnectionError
|
||||||
from core.mcp.types import CallToolResult, ImageContent, TextContent
|
from core.mcp.types import CallToolResult, ImageContent, TextContent
|
||||||
from core.tools.__base.tool import Tool
|
from core.tools.__base.tool import Tool
|
||||||
from core.tools.__base.tool_runtime import ToolRuntime
|
from core.tools.__base.tool_runtime import ToolRuntime
|
||||||
@ -118,59 +118,37 @@ class MCPTool(Tool):
|
|||||||
headers = self.headers.copy() if self.headers else {}
|
headers = self.headers.copy() if self.headers else {}
|
||||||
tool_parameters = self._handle_none_parameter(tool_parameters)
|
tool_parameters = self._handle_none_parameter(tool_parameters)
|
||||||
|
|
||||||
# Initialize auth provider
|
# Get provider entity to access tokens
|
||||||
from core.mcp.auth.auth_provider import OAuthClientProvider
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
provider = None
|
from extensions.ext_database import db
|
||||||
|
from services.tools.mcp_oauth_service import MCPOAuthService
|
||||||
|
|
||||||
try:
|
try:
|
||||||
provider = OAuthClientProvider(self.provider_id, self.tenant_id, for_list=False)
|
with Session(db.engine) as session:
|
||||||
except Exception as e:
|
oauth_service = MCPOAuthService(session=session)
|
||||||
# If provider initialization fails, continue without auth
|
provider_entity = oauth_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True)
|
||||||
|
|
||||||
|
# Try to get existing token and add to headers
|
||||||
|
tokens = provider_entity.retrieve_tokens()
|
||||||
|
if tokens and tokens.access_token:
|
||||||
|
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
|
||||||
|
except Exception:
|
||||||
|
# If provider retrieval or token fails, continue without auth
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Try to get existing token and add to headers
|
# Use MCPClientWithAuthRetry to handle authentication automatically
|
||||||
if provider:
|
try:
|
||||||
try:
|
with MCPClientWithAuthRetry(
|
||||||
token = provider.tokens()
|
server_url=provider_entity.decrypt_server_url() if provider_entity else self.server_url,
|
||||||
if token:
|
headers=headers,
|
||||||
headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
|
|
||||||
except Exception:
|
|
||||||
# If token retrieval fails, continue without auth header
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Define a helper function to invoke the tool
|
|
||||||
def _invoke_with_client(client_headers: dict[str, str]) -> CallToolResult:
|
|
||||||
with MCPClient(
|
|
||||||
self.server_url,
|
|
||||||
headers=client_headers,
|
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
sse_read_timeout=self.sse_read_timeout,
|
sse_read_timeout=self.sse_read_timeout,
|
||||||
|
provider_entity=provider_entity,
|
||||||
|
auth_callback=auth,
|
||||||
|
by_server_id=True,
|
||||||
) as mcp_client:
|
) as mcp_client:
|
||||||
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
|
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
|
||||||
|
|
||||||
try:
|
|
||||||
# First attempt with current headers
|
|
||||||
return _invoke_with_client(headers)
|
|
||||||
except MCPAuthError as e:
|
|
||||||
# Authentication required - try to authenticate
|
|
||||||
if not provider:
|
|
||||||
raise ToolInvokeError("Authentication required but no auth provider available") from e
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Perform authentication flow
|
|
||||||
auth(provider, self.server_url, None, None, False)
|
|
||||||
token = provider.tokens()
|
|
||||||
if not token:
|
|
||||||
raise ToolInvokeError("Authentication failed - no token received")
|
|
||||||
|
|
||||||
# Update headers with new token while preserving other headers
|
|
||||||
headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
|
|
||||||
|
|
||||||
# Retry with authenticated headers
|
|
||||||
return _invoke_with_client(headers)
|
|
||||||
except MCPAuthError as auth_error:
|
|
||||||
raise ToolInvokeError("Authentication failed") from auth_error
|
|
||||||
except MCPConnectionError as e:
|
except MCPConnectionError as e:
|
||||||
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
|
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@ -27,6 +27,7 @@ from core.tools.plugin_tool.tool import PluginTool
|
|||||||
from core.tools.utils.uuid_utils import is_valid_uuid
|
from core.tools.utils.uuid_utils import is_valid_uuid
|
||||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
from extensions.ext_database import db
|
||||||
from services.enterprise.plugin_manager_service import PluginCredentialType
|
from services.enterprise.plugin_manager_service import PluginCredentialType
|
||||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||||
|
|
||||||
@ -59,8 +60,7 @@ from core.tools.tool_label_manager import ToolLabelManager
|
|||||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||||
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
|
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
|
||||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||||
from extensions.ext_database import db
|
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
||||||
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
|
|
||||||
from services.tools.tools_transform_service import ToolTransformService
|
from services.tools.tools_transform_service import ToolTransformService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -715,7 +715,9 @@ class ToolManager:
|
|||||||
)
|
)
|
||||||
result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
|
result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
|
||||||
if "mcp" in filters:
|
if "mcp" in filters:
|
||||||
mcp_providers = MCPToolManageService.retrieve_mcp_tools(tenant_id, for_list=True)
|
with Session(db.engine) as session:
|
||||||
|
mcp_service = MCPToolManageService(session=session)
|
||||||
|
mcp_providers = mcp_service.list_providers(tenant_id=tenant_id, for_list=True)
|
||||||
for mcp_provider in mcp_providers:
|
for mcp_provider in mcp_providers:
|
||||||
result_providers[f"mcp_provider.{mcp_provider.name}"] = mcp_provider
|
result_providers[f"mcp_provider.{mcp_provider.name}"] = mcp_provider
|
||||||
|
|
||||||
@ -770,17 +772,12 @@ class ToolManager:
|
|||||||
|
|
||||||
:return: the provider controller, the credentials
|
:return: the provider controller, the credentials
|
||||||
"""
|
"""
|
||||||
provider: MCPToolProvider | None = (
|
with Session(db.engine) as session:
|
||||||
db.session.query(MCPToolProvider)
|
mcp_service = MCPToolManageService(session=session)
|
||||||
.where(
|
try:
|
||||||
MCPToolProvider.server_identifier == provider_id,
|
provider = mcp_service.get_provider_by_server_identifier(provider_id, tenant_id)
|
||||||
MCPToolProvider.tenant_id == tenant_id,
|
except ValueError:
|
||||||
)
|
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if provider is None:
|
|
||||||
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
|
|
||||||
|
|
||||||
controller = MCPToolProviderController.from_db(provider)
|
controller = MCPToolProviderController.from_db(provider)
|
||||||
|
|
||||||
@ -918,16 +915,13 @@ class ToolManager:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict[str, str] | str:
|
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict[str, str] | str:
|
||||||
try:
|
try:
|
||||||
mcp_provider: MCPToolProvider | None = (
|
with Session(db.engine) as session:
|
||||||
db.session.query(MCPToolProvider)
|
mcp_service = MCPToolManageService(session=session)
|
||||||
.where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == provider_id)
|
try:
|
||||||
.first()
|
mcp_provider = mcp_service.get_provider_by_server_identifier(provider_id, tenant_id)
|
||||||
)
|
return mcp_provider.provider_icon
|
||||||
|
except ValueError:
|
||||||
if mcp_provider is None:
|
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
|
||||||
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
|
|
||||||
|
|
||||||
return mcp_provider.provider_icon
|
|
||||||
except Exception:
|
except Exception:
|
||||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||||
|
|
||||||
|
|||||||
@ -1,16 +1,12 @@
|
|||||||
import json
|
import json
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Optional, cast
|
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from deprecated import deprecated
|
from deprecated import deprecated
|
||||||
from sqlalchemy import ForeignKey, String, func
|
from sqlalchemy import ForeignKey, String, func
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
from core.file import helpers as file_helpers
|
|
||||||
from core.helper import encrypter
|
|
||||||
from core.mcp.types import Tool
|
|
||||||
from core.tools.entities.common_entities import I18nObject
|
from core.tools.entities.common_entities import I18nObject
|
||||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||||
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
|
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
|
||||||
@ -20,6 +16,9 @@ from .engine import db
|
|||||||
from .model import Account, App, Tenant
|
from .model import Account, App, Tenant
|
||||||
from .types import StringUUID
|
from .types import StringUUID
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from core.entities.mcp_provider import MCPProviderEntity
|
||||||
|
|
||||||
|
|
||||||
# system level tool oauth client params (client_id, client_secret, etc.)
|
# system level tool oauth client params (client_id, client_secret, etc.)
|
||||||
class ToolOAuthSystemClient(TypeBase):
|
class ToolOAuthSystemClient(TypeBase):
|
||||||
@ -286,119 +285,34 @@ class MCPToolProvider(Base):
|
|||||||
def load_user(self) -> Account | None:
|
def load_user(self) -> Account | None:
|
||||||
return db.session.query(Account).where(Account.id == self.user_id).first()
|
return db.session.query(Account).where(Account.id == self.user_id).first()
|
||||||
|
|
||||||
@property
|
|
||||||
def tenant(self) -> Tenant | None:
|
|
||||||
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def credentials(self) -> dict[str, Any]:
|
def credentials(self) -> dict[str, Any]:
|
||||||
try:
|
try:
|
||||||
return cast(dict[str, Any], json.loads(self.encrypted_credentials)) or {}
|
return json.loads(self.encrypted_credentials)
|
||||||
except Exception:
|
except Exception:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def mcp_tools(self) -> list[Tool]:
|
def headers(self) -> dict[str, Any]:
|
||||||
return [Tool(**tool) for tool in json.loads(self.tools)]
|
if self.encrypted_headers is None:
|
||||||
|
return {}
|
||||||
@property
|
|
||||||
def provider_icon(self) -> dict[str, str] | str:
|
|
||||||
try:
|
try:
|
||||||
return cast(dict[str, str], json.loads(self.icon))
|
return json.loads(self.encrypted_headers)
|
||||||
except json.JSONDecodeError:
|
|
||||||
return file_helpers.get_signed_file_url(self.icon)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def decrypted_server_url(self) -> str:
|
|
||||||
return encrypter.decrypt_token(self.tenant_id, self.server_url)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def decrypted_headers(self) -> dict[str, Any]:
|
|
||||||
"""Get decrypted headers for MCP server requests."""
|
|
||||||
from core.entities.provider_entities import BasicProviderConfig
|
|
||||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
|
||||||
from core.tools.utils.encryption import create_provider_encrypter
|
|
||||||
|
|
||||||
try:
|
|
||||||
if not self.encrypted_headers:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
headers_data = json.loads(self.encrypted_headers)
|
|
||||||
|
|
||||||
# Create dynamic config for all headers as SECRET_INPUT
|
|
||||||
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers_data]
|
|
||||||
|
|
||||||
encrypter_instance, _ = create_provider_encrypter(
|
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
config=config,
|
|
||||||
cache=NoOpProviderCredentialCache(),
|
|
||||||
)
|
|
||||||
|
|
||||||
result = encrypter_instance.decrypt(headers_data)
|
|
||||||
return result
|
|
||||||
except Exception:
|
except Exception:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def masked_headers(self) -> dict[str, Any]:
|
def tool_dict(self) -> list[dict[str, Any]]:
|
||||||
"""Get masked headers for frontend display."""
|
|
||||||
from core.entities.provider_entities import BasicProviderConfig
|
|
||||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
|
||||||
from core.tools.utils.encryption import create_provider_encrypter
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not self.encrypted_headers:
|
return json.loads(self.tools) if self.tools else []
|
||||||
return {}
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
return []
|
||||||
|
|
||||||
headers_data = json.loads(self.encrypted_headers)
|
def to_entity(self) -> "MCPProviderEntity":
|
||||||
|
"""Convert to domain entity"""
|
||||||
|
from core.entities.mcp_provider import MCPProviderEntity
|
||||||
|
|
||||||
# Create dynamic config for all headers as SECRET_INPUT
|
return MCPProviderEntity.from_db_model(self)
|
||||||
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers_data]
|
|
||||||
|
|
||||||
encrypter_instance, _ = create_provider_encrypter(
|
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
config=config,
|
|
||||||
cache=NoOpProviderCredentialCache(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# First decrypt, then mask
|
|
||||||
decrypted_headers = encrypter_instance.decrypt(headers_data)
|
|
||||||
result = encrypter_instance.mask_tool_credentials(decrypted_headers)
|
|
||||||
return result
|
|
||||||
except Exception:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def masked_server_url(self) -> str:
|
|
||||||
def mask_url(url: str, mask_char: str = "*") -> str:
|
|
||||||
"""
|
|
||||||
mask the url to a simple string
|
|
||||||
"""
|
|
||||||
parsed = urlparse(url)
|
|
||||||
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
|
||||||
|
|
||||||
if parsed.path and parsed.path != "/":
|
|
||||||
return f"{base_url}/{mask_char * 6}"
|
|
||||||
else:
|
|
||||||
return base_url
|
|
||||||
|
|
||||||
return mask_url(self.decrypted_server_url)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def decrypted_credentials(self) -> dict[str, Any]:
|
|
||||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
|
||||||
from core.tools.mcp_tool.provider import MCPToolProviderController
|
|
||||||
from core.tools.utils.encryption import create_provider_encrypter
|
|
||||||
|
|
||||||
provider_controller = MCPToolProviderController.from_db(self)
|
|
||||||
|
|
||||||
encrypter, _ = create_provider_encrypter(
|
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
|
||||||
cache=NoOpProviderCredentialCache(),
|
|
||||||
)
|
|
||||||
|
|
||||||
return encrypter.decrypt(self.credentials)
|
|
||||||
|
|
||||||
|
|
||||||
class ToolModelInvoke(Base):
|
class ToolModelInvoke(Base):
|
||||||
|
|||||||
53
api/services/tools/mcp_oauth_service.py
Normal file
53
api/services/tools/mcp_oauth_service.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
"""
|
||||||
|
MCP OAuth Service - handles OAuth-related database operations
|
||||||
|
"""
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from core.entities.mcp_provider import MCPProviderEntity
|
||||||
|
from core.mcp.types import OAuthClientInformationFull, OAuthTokens
|
||||||
|
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||||
|
|
||||||
|
|
||||||
|
class MCPOAuthService:
|
||||||
|
"""Service for handling MCP OAuth operations"""
|
||||||
|
|
||||||
|
def __init__(self, session: Session):
|
||||||
|
self._session = session
|
||||||
|
self._mcp_service = MCPToolManageService(session=session)
|
||||||
|
|
||||||
|
def get_provider_entity(self, provider_id: str, tenant_id: str, by_server_id: bool = False) -> MCPProviderEntity:
|
||||||
|
"""Get provider entity by ID"""
|
||||||
|
if by_server_id:
|
||||||
|
db_provider = self._mcp_service.get_provider_by_server_identifier(provider_id, tenant_id)
|
||||||
|
else:
|
||||||
|
db_provider = self._mcp_service.get_provider_by_id(provider_id, tenant_id)
|
||||||
|
return db_provider.to_entity()
|
||||||
|
|
||||||
|
def save_client_information(
|
||||||
|
self, provider_id: str, tenant_id: str, client_information: OAuthClientInformationFull
|
||||||
|
) -> None:
|
||||||
|
"""Save OAuth client information"""
|
||||||
|
db_provider = self._mcp_service.get_provider_by_id(provider_id, tenant_id)
|
||||||
|
self._mcp_service.update_provider_credentials(
|
||||||
|
provider=db_provider,
|
||||||
|
credentials={"client_information": client_information.model_dump()},
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_tokens(self, provider_id: str, tenant_id: str, tokens: OAuthTokens, authed: bool = True) -> None:
|
||||||
|
"""Save OAuth tokens"""
|
||||||
|
db_provider = self._mcp_service.get_provider_by_id(provider_id, tenant_id)
|
||||||
|
token_dict = tokens.model_dump()
|
||||||
|
self._mcp_service.update_provider_credentials(provider=db_provider, credentials=token_dict, authed=authed)
|
||||||
|
|
||||||
|
def save_code_verifier(self, provider_id: str, tenant_id: str, code_verifier: str) -> None:
|
||||||
|
"""Save PKCE code verifier"""
|
||||||
|
db_provider = self._mcp_service.get_provider_by_id(provider_id, tenant_id)
|
||||||
|
self._mcp_service.update_provider_credentials(
|
||||||
|
provider=db_provider, credentials={"code_verifier": code_verifier}
|
||||||
|
)
|
||||||
|
|
||||||
|
def clear_credentials(self, provider_id: str, tenant_id: str) -> None:
|
||||||
|
"""Clear provider credentials"""
|
||||||
|
db_provider = self._mcp_service.get_provider_by_id(provider_id, tenant_id)
|
||||||
|
self._mcp_service.clear_provider_credentials(provider=db_provider)
|
||||||
@ -1,24 +1,27 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
|
from collections.abc import Callable
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
|
|
||||||
from sqlalchemy import or_
|
from sqlalchemy import or_, select
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from core.entities.mcp_provider import MCPProviderEntity
|
||||||
from core.helper import encrypter
|
from core.helper import encrypter
|
||||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||||
|
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||||
from core.mcp.error import MCPAuthError, MCPError
|
from core.mcp.error import MCPAuthError, MCPError
|
||||||
from core.mcp.mcp_client import MCPClient
|
|
||||||
from core.mcp.types import OAuthTokens
|
from core.mcp.types import OAuthTokens
|
||||||
from core.tools.entities.api_entities import ToolProviderApiEntity
|
from core.tools.entities.api_entities import ToolProviderApiEntity
|
||||||
from core.tools.entities.common_entities import I18nObject
|
|
||||||
from core.tools.entities.tool_entities import ToolProviderType
|
|
||||||
from core.tools.utils.encryption import ProviderConfigEncrypter
|
from core.tools.utils.encryption import ProviderConfigEncrypter
|
||||||
from extensions.ext_database import db
|
|
||||||
from models.tools import MCPToolProvider
|
from models.tools import MCPToolProvider
|
||||||
from services.tools.tools_transform_service import ToolTransformService
|
from services.tools.tools_transform_service import ToolTransformService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
UNCHANGED_SERVER_URL_PLACEHOLDER = "[__HIDDEN__]"
|
UNCHANGED_SERVER_URL_PLACEHOLDER = "[__HIDDEN__]"
|
||||||
|
|
||||||
|
|
||||||
@ -27,8 +30,10 @@ class MCPToolManageService:
|
|||||||
Service class for managing mcp tools.
|
Service class for managing mcp tools.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
def __init__(self, session: Session):
|
||||||
def _encrypt_headers(headers: dict[str, str], tenant_id: str) -> dict[str, str]:
|
self._session = session
|
||||||
|
|
||||||
|
def _encrypt_headers(self, headers: dict[str, str], tenant_id: str) -> dict[str, str]:
|
||||||
"""
|
"""
|
||||||
Encrypt headers using ProviderConfigEncrypter with all headers as SECRET_INPUT.
|
Encrypt headers using ProviderConfigEncrypter with all headers as SECRET_INPUT.
|
||||||
|
|
||||||
@ -57,48 +62,53 @@ class MCPToolManageService:
|
|||||||
|
|
||||||
return encrypter_instance.encrypt(headers)
|
return encrypter_instance.encrypt(headers)
|
||||||
|
|
||||||
@staticmethod
|
def _retrieve_remote_mcp_tools(
|
||||||
def _retrieve_remote_mcp_tools(server_url: str, headers: dict[str, str], timeout: float, sse_read_timeout: float):
|
self,
|
||||||
with MCPClient(
|
server_url: str,
|
||||||
|
headers: dict[str, str],
|
||||||
|
provider_entity: MCPProviderEntity,
|
||||||
|
auth_callback: Callable[[MCPProviderEntity, Optional[str]], dict[str, str]],
|
||||||
|
):
|
||||||
|
"""Retrieve tools from remote MCP server"""
|
||||||
|
with MCPClientWithAuthRetry(
|
||||||
server_url,
|
server_url,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
timeout=timeout,
|
timeout=provider_entity.timeout,
|
||||||
sse_read_timeout=sse_read_timeout,
|
sse_read_timeout=provider_entity.sse_read_timeout,
|
||||||
|
provider_entity=provider_entity,
|
||||||
|
auth_callback=auth_callback,
|
||||||
) as mcp_client:
|
) as mcp_client:
|
||||||
tools = mcp_client.list_tools()
|
tools = mcp_client.list_tools()
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
@staticmethod
|
def _process_headers(self, headers: dict[str, str], tokens: OAuthTokens | None = None) -> dict[str, str]:
|
||||||
def _process_headers(headers: dict[str, str], tokens: OAuthTokens | None = None):
|
"""Process headers and add OAuth token if available"""
|
||||||
headers = headers or {}
|
headers = headers.copy() if headers else {}
|
||||||
if tokens:
|
if tokens:
|
||||||
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
|
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
@staticmethod
|
def get_provider_by_id(self, provider_id: str, tenant_id: str) -> MCPToolProvider:
|
||||||
def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider:
|
"""Get MCP provider by ID"""
|
||||||
res = (
|
stmt = select(MCPToolProvider).where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id)
|
||||||
db.session.query(MCPToolProvider)
|
provider = self._session.scalar(stmt)
|
||||||
.where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id)
|
if not provider:
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if not res:
|
|
||||||
raise ValueError("MCP tool not found")
|
raise ValueError("MCP tool not found")
|
||||||
return res
|
return provider
|
||||||
|
|
||||||
@staticmethod
|
def get_provider_by_server_identifier(self, server_identifier: str, tenant_id: str) -> MCPToolProvider:
|
||||||
def get_mcp_provider_by_server_identifier(server_identifier: str, tenant_id: str) -> MCPToolProvider:
|
"""Get MCP provider by server identifier"""
|
||||||
res = (
|
stmt = select(MCPToolProvider).where(
|
||||||
db.session.query(MCPToolProvider)
|
MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier
|
||||||
.where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier)
|
|
||||||
.first()
|
|
||||||
)
|
)
|
||||||
if not res:
|
provider = self._session.scalar(stmt)
|
||||||
|
if not provider:
|
||||||
raise ValueError("MCP tool not found")
|
raise ValueError("MCP tool not found")
|
||||||
return res
|
return provider
|
||||||
|
|
||||||
@staticmethod
|
def create_provider(
|
||||||
def create_mcp_provider(
|
self,
|
||||||
|
*,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
name: str,
|
name: str,
|
||||||
server_url: str,
|
server_url: str,
|
||||||
@ -111,19 +121,20 @@ class MCPToolManageService:
|
|||||||
sse_read_timeout: float,
|
sse_read_timeout: float,
|
||||||
headers: dict[str, str] | None = None,
|
headers: dict[str, str] | None = None,
|
||||||
) -> ToolProviderApiEntity:
|
) -> ToolProviderApiEntity:
|
||||||
|
"""Create a new MCP provider"""
|
||||||
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
|
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
|
||||||
existing_provider = (
|
|
||||||
db.session.query(MCPToolProvider)
|
# Check for existing provider
|
||||||
.where(
|
stmt = select(MCPToolProvider).where(
|
||||||
MCPToolProvider.tenant_id == tenant_id,
|
MCPToolProvider.tenant_id == tenant_id,
|
||||||
or_(
|
or_(
|
||||||
MCPToolProvider.name == name,
|
MCPToolProvider.name == name,
|
||||||
MCPToolProvider.server_url_hash == server_url_hash,
|
MCPToolProvider.server_url_hash == server_url_hash,
|
||||||
MCPToolProvider.server_identifier == server_identifier,
|
MCPToolProvider.server_identifier == server_identifier,
|
||||||
),
|
),
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
)
|
||||||
|
existing_provider = self._session.scalar(stmt)
|
||||||
|
|
||||||
if existing_provider:
|
if existing_provider:
|
||||||
if existing_provider.name == name:
|
if existing_provider.name == name:
|
||||||
raise ValueError(f"MCP tool {name} already exists")
|
raise ValueError(f"MCP tool {name} already exists")
|
||||||
@ -131,13 +142,17 @@ class MCPToolManageService:
|
|||||||
raise ValueError(f"MCP tool {server_url} already exists")
|
raise ValueError(f"MCP tool {server_url} already exists")
|
||||||
if existing_provider.server_identifier == server_identifier:
|
if existing_provider.server_identifier == server_identifier:
|
||||||
raise ValueError(f"MCP tool {server_identifier} already exists")
|
raise ValueError(f"MCP tool {server_identifier} already exists")
|
||||||
|
|
||||||
|
# Encrypt server URL
|
||||||
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
|
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
|
||||||
|
|
||||||
# Encrypt headers
|
# Encrypt headers
|
||||||
encrypted_headers = None
|
encrypted_headers = None
|
||||||
if headers:
|
if headers:
|
||||||
encrypted_headers_dict = MCPToolManageService._encrypt_headers(headers, tenant_id)
|
encrypted_headers_dict = self._encrypt_headers(headers, tenant_id)
|
||||||
encrypted_headers = json.dumps(encrypted_headers_dict)
|
encrypted_headers = json.dumps(encrypted_headers_dict)
|
||||||
|
|
||||||
|
# Create provider
|
||||||
mcp_tool = MCPToolProvider(
|
mcp_tool = MCPToolProvider(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
name=name,
|
name=name,
|
||||||
@ -152,91 +167,68 @@ class MCPToolManageService:
|
|||||||
sse_read_timeout=sse_read_timeout,
|
sse_read_timeout=sse_read_timeout,
|
||||||
encrypted_headers=encrypted_headers,
|
encrypted_headers=encrypted_headers,
|
||||||
)
|
)
|
||||||
db.session.add(mcp_tool)
|
|
||||||
db.session.commit()
|
self._session.add(mcp_tool)
|
||||||
|
self._session.commit()
|
||||||
|
|
||||||
return ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
|
return ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
|
||||||
|
|
||||||
@staticmethod
|
def list_providers(self, *, tenant_id: str, for_list: bool = False) -> list[ToolProviderApiEntity]:
|
||||||
def retrieve_mcp_tools(tenant_id: str, for_list: bool = False) -> list[ToolProviderApiEntity]:
|
"""List all MCP providers for a tenant"""
|
||||||
mcp_providers = (
|
stmt = select(MCPToolProvider).where(MCPToolProvider.tenant_id == tenant_id).order_by(MCPToolProvider.name)
|
||||||
db.session.query(MCPToolProvider)
|
|
||||||
.where(MCPToolProvider.tenant_id == tenant_id)
|
mcp_providers = self._session.scalars(stmt).all()
|
||||||
.order_by(MCPToolProvider.name)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
return [
|
return [
|
||||||
ToolTransformService.mcp_provider_to_user_provider(mcp_provider, for_list=for_list)
|
ToolTransformService.mcp_provider_to_user_provider(provider, for_list=for_list)
|
||||||
for mcp_provider in mcp_providers
|
for provider in mcp_providers
|
||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
def list_provider_tools(self, *, tenant_id: str, provider_id: str) -> ToolProviderApiEntity:
|
||||||
def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str) -> ToolProviderApiEntity:
|
"""List tools from remote MCP server"""
|
||||||
from core.mcp.auth.auth_flow import auth
|
from core.mcp.auth.auth_flow import auth
|
||||||
from core.mcp.auth.auth_provider import OAuthClientProvider
|
|
||||||
|
|
||||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
# Load provider and convert to entity
|
||||||
server_url = mcp_provider.decrypted_server_url
|
db_provider = self.get_provider_by_id(provider_id, tenant_id)
|
||||||
authed = mcp_provider.authed
|
provider_entity = db_provider.to_entity()
|
||||||
headers = mcp_provider.decrypted_headers
|
|
||||||
timeout = mcp_provider.timeout
|
|
||||||
sse_read_timeout = mcp_provider.sse_read_timeout
|
|
||||||
|
|
||||||
# Handle authentication headers if authed
|
# Handle authentication headers if authed
|
||||||
if not authed:
|
if not provider_entity.authed:
|
||||||
raise ValueError("Please auth the tool first")
|
raise ValueError("Please auth the tool first")
|
||||||
|
|
||||||
provider = OAuthClientProvider(provider_id, tenant_id, for_list=True)
|
tokens = provider_entity.retrieve_tokens()
|
||||||
tokens = provider.tokens()
|
headers = self._process_headers(provider_entity.headers, tokens)
|
||||||
headers = cls._process_headers(headers, tokens)
|
server_url = provider_entity.decrypt_server_url()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tools = cls._retrieve_remote_mcp_tools(server_url, headers, timeout, sse_read_timeout)
|
tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity, auth)
|
||||||
except MCPAuthError:
|
|
||||||
try:
|
|
||||||
auth(provider, server_url, None, None, False)
|
|
||||||
tokens = provider.tokens()
|
|
||||||
re_authed_headers = cls._process_headers(headers, tokens)
|
|
||||||
tools = cls._retrieve_remote_mcp_tools(server_url, re_authed_headers, timeout, sse_read_timeout)
|
|
||||||
except Exception:
|
|
||||||
raise ValueError("Please auth the tool first")
|
|
||||||
except MCPError as e:
|
except MCPError as e:
|
||||||
raise ValueError(f"Failed to connect to MCP server: {e}")
|
raise ValueError(f"Failed to connect to MCP server: {e}")
|
||||||
|
|
||||||
try:
|
# Update database record with new tools
|
||||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
db_provider.tools = json.dumps([tool.model_dump() for tool in tools])
|
||||||
mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools])
|
db_provider.authed = True
|
||||||
mcp_provider.authed = True
|
db_provider.updated_at = datetime.now()
|
||||||
mcp_provider.updated_at = datetime.now()
|
self._session.commit()
|
||||||
db.session.commit()
|
|
||||||
except Exception:
|
|
||||||
db.session.rollback()
|
|
||||||
raise
|
|
||||||
|
|
||||||
user = mcp_provider.load_user()
|
# Create API response using entity
|
||||||
return ToolProviderApiEntity(
|
user = db_provider.load_user()
|
||||||
id=mcp_provider.id,
|
response = provider_entity.to_api_response(
|
||||||
name=mcp_provider.name,
|
user_name=user.name if user else None,
|
||||||
tools=ToolTransformService.mcp_tool_to_user_tool(mcp_provider, tools),
|
|
||||||
type=ToolProviderType.MCP,
|
|
||||||
icon=mcp_provider.icon,
|
|
||||||
author=user.name if user else "Anonymous",
|
|
||||||
server_url=mcp_provider.masked_server_url,
|
|
||||||
updated_at=int(mcp_provider.updated_at.timestamp()),
|
|
||||||
description=I18nObject(en_US="", zh_Hans=""),
|
|
||||||
label=I18nObject(en_US=mcp_provider.name, zh_Hans=mcp_provider.name),
|
|
||||||
plugin_unique_identifier=mcp_provider.server_identifier,
|
|
||||||
)
|
)
|
||||||
|
response["tools"] = ToolTransformService.mcp_tool_to_user_tool(db_provider, tools)
|
||||||
|
response["plugin_unique_identifier"] = provider_entity.provider_id
|
||||||
|
|
||||||
@classmethod
|
return ToolProviderApiEntity(**response)
|
||||||
def delete_mcp_tool(cls, tenant_id: str, provider_id: str):
|
|
||||||
mcp_tool = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
|
||||||
|
|
||||||
db.session.delete(mcp_tool)
|
def delete_provider(self, *, tenant_id: str, provider_id: str) -> None:
|
||||||
db.session.commit()
|
"""Delete an MCP provider"""
|
||||||
|
mcp_tool = self.get_provider_by_id(provider_id, tenant_id)
|
||||||
|
self._session.delete(mcp_tool)
|
||||||
|
self._session.commit()
|
||||||
|
|
||||||
@classmethod
|
def update_provider(
|
||||||
def update_mcp_provider(
|
self,
|
||||||
cls,
|
*,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
provider_id: str,
|
provider_id: str,
|
||||||
name: str,
|
name: str,
|
||||||
@ -248,21 +240,27 @@ class MCPToolManageService:
|
|||||||
timeout: float | None = None,
|
timeout: float | None = None,
|
||||||
sse_read_timeout: float | None = None,
|
sse_read_timeout: float | None = None,
|
||||||
headers: dict[str, str] | None = None,
|
headers: dict[str, str] | None = None,
|
||||||
):
|
) -> None:
|
||||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
"""Update an MCP provider"""
|
||||||
|
mcp_provider = self.get_provider_by_id(provider_id, tenant_id)
|
||||||
|
|
||||||
reconnect_result = None
|
reconnect_result = None
|
||||||
encrypted_server_url = None
|
encrypted_server_url = None
|
||||||
server_url_hash = None
|
server_url_hash = None
|
||||||
|
|
||||||
|
# Handle server URL update
|
||||||
if UNCHANGED_SERVER_URL_PLACEHOLDER not in server_url:
|
if UNCHANGED_SERVER_URL_PLACEHOLDER not in server_url:
|
||||||
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
|
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
|
||||||
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
|
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
|
||||||
|
|
||||||
if server_url_hash != mcp_provider.server_url_hash:
|
if server_url_hash != mcp_provider.server_url_hash:
|
||||||
reconnect_result = cls._re_connect_mcp_provider(server_url, provider_id, tenant_id)
|
reconnect_result = self._reconnect_provider(
|
||||||
|
server_url=server_url,
|
||||||
|
provider=mcp_provider,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Update basic fields
|
||||||
mcp_provider.updated_at = datetime.now()
|
mcp_provider.updated_at = datetime.now()
|
||||||
mcp_provider.name = name
|
mcp_provider.name = name
|
||||||
mcp_provider.icon = (
|
mcp_provider.icon = (
|
||||||
@ -270,6 +268,7 @@ class MCPToolManageService:
|
|||||||
)
|
)
|
||||||
mcp_provider.server_identifier = server_identifier
|
mcp_provider.server_identifier = server_identifier
|
||||||
|
|
||||||
|
# Update server URL if changed
|
||||||
if encrypted_server_url is not None and server_url_hash is not None:
|
if encrypted_server_url is not None and server_url_hash is not None:
|
||||||
mcp_provider.server_url = encrypted_server_url
|
mcp_provider.server_url = encrypted_server_url
|
||||||
mcp_provider.server_url_hash = server_url_hash
|
mcp_provider.server_url_hash = server_url_hash
|
||||||
@ -279,6 +278,7 @@ class MCPToolManageService:
|
|||||||
mcp_provider.tools = reconnect_result["tools"]
|
mcp_provider.tools = reconnect_result["tools"]
|
||||||
mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"]
|
mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"]
|
||||||
|
|
||||||
|
# Update optional fields
|
||||||
if timeout is not None:
|
if timeout is not None:
|
||||||
mcp_provider.timeout = timeout
|
mcp_provider.timeout = timeout
|
||||||
if sse_read_timeout is not None:
|
if sse_read_timeout is not None:
|
||||||
@ -286,13 +286,15 @@ class MCPToolManageService:
|
|||||||
if headers is not None:
|
if headers is not None:
|
||||||
# Encrypt headers
|
# Encrypt headers
|
||||||
if headers:
|
if headers:
|
||||||
encrypted_headers_dict = MCPToolManageService._encrypt_headers(headers, tenant_id)
|
encrypted_headers_dict = self._encrypt_headers(headers, tenant_id)
|
||||||
mcp_provider.encrypted_headers = json.dumps(encrypted_headers_dict)
|
mcp_provider.encrypted_headers = json.dumps(encrypted_headers_dict)
|
||||||
else:
|
else:
|
||||||
mcp_provider.encrypted_headers = None
|
mcp_provider.encrypted_headers = None
|
||||||
db.session.commit()
|
|
||||||
|
self._session.commit()
|
||||||
|
|
||||||
except IntegrityError as e:
|
except IntegrityError as e:
|
||||||
db.session.rollback()
|
self._session.rollback()
|
||||||
error_msg = str(e.orig)
|
error_msg = str(e.orig)
|
||||||
if "unique_mcp_provider_name" in error_msg:
|
if "unique_mcp_provider_name" in error_msg:
|
||||||
raise ValueError(f"MCP tool {name} already exists")
|
raise ValueError(f"MCP tool {name} already exists")
|
||||||
@ -302,54 +304,55 @@ class MCPToolManageService:
|
|||||||
raise ValueError(f"MCP tool {server_identifier} already exists")
|
raise ValueError(f"MCP tool {server_identifier} already exists")
|
||||||
raise
|
raise
|
||||||
except Exception:
|
except Exception:
|
||||||
db.session.rollback()
|
self._session.rollback()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@classmethod
|
def update_provider_credentials(
|
||||||
def update_mcp_provider_credentials(
|
self, *, provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False
|
||||||
cls, mcp_provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False
|
) -> None:
|
||||||
):
|
"""Update provider credentials"""
|
||||||
from core.tools.mcp_tool.provider import MCPToolProviderController
|
from core.tools.mcp_tool.provider import MCPToolProviderController
|
||||||
|
|
||||||
provider_controller = MCPToolProviderController.from_db(mcp_provider)
|
provider_controller = MCPToolProviderController.from_db(provider)
|
||||||
tool_configuration = ProviderConfigEncrypter(
|
tool_configuration = ProviderConfigEncrypter(
|
||||||
tenant_id=mcp_provider.tenant_id,
|
tenant_id=provider.tenant_id,
|
||||||
config=list(provider_controller.get_credentials_schema()),
|
config=list(provider_controller.get_credentials_schema()),
|
||||||
provider_config_cache=NoOpProviderCredentialCache(),
|
provider_config_cache=NoOpProviderCredentialCache(),
|
||||||
)
|
)
|
||||||
credentials = tool_configuration.encrypt(credentials)
|
encrypted_credentials = tool_configuration.encrypt(credentials)
|
||||||
mcp_provider.updated_at = datetime.now()
|
provider.updated_at = datetime.now()
|
||||||
mcp_provider.encrypted_credentials = json.dumps({**mcp_provider.credentials, **credentials})
|
provider.encrypted_credentials = json.dumps({**provider.credentials, **encrypted_credentials})
|
||||||
mcp_provider.authed = authed
|
provider.authed = authed
|
||||||
if not authed:
|
if not authed:
|
||||||
mcp_provider.tools = "[]"
|
provider.tools = "[]"
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
@classmethod
|
self._session.commit()
|
||||||
def clear_mcp_provider_credentials(
|
|
||||||
cls,
|
|
||||||
mcp_provider: MCPToolProvider,
|
|
||||||
):
|
|
||||||
mcp_provider.tools = "[]"
|
|
||||||
mcp_provider.encrypted_credentials = "{}"
|
|
||||||
mcp_provider.updated_at = datetime.now()
|
|
||||||
mcp_provider.authed = False
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
@classmethod
|
def clear_provider_credentials(self, *, provider: MCPToolProvider) -> None:
|
||||||
def _re_connect_mcp_provider(cls, server_url: str, provider_id: str, tenant_id: str) -> dict[str, Any]:
|
"""Clear provider credentials"""
|
||||||
# Get the existing provider to access headers and timeout settings
|
provider.tools = "[]"
|
||||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
provider.encrypted_credentials = "{}"
|
||||||
headers = mcp_provider.decrypted_headers
|
provider.updated_at = datetime.now()
|
||||||
timeout = mcp_provider.timeout
|
provider.authed = False
|
||||||
sse_read_timeout = mcp_provider.sse_read_timeout
|
self._session.commit()
|
||||||
|
|
||||||
|
def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> dict[str, Any]:
|
||||||
|
"""Attempt to reconnect to MCP provider with new server URL"""
|
||||||
|
from core.mcp.auth.auth_flow import auth
|
||||||
|
|
||||||
|
provider_entity = provider.to_entity()
|
||||||
|
headers = provider_entity.headers
|
||||||
|
timeout = provider_entity.timeout
|
||||||
|
sse_read_timeout = provider_entity.sse_read_timeout
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with MCPClient(
|
with MCPClientWithAuthRetry(
|
||||||
server_url,
|
server_url,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
sse_read_timeout=sse_read_timeout,
|
sse_read_timeout=sse_read_timeout,
|
||||||
|
provider_entity=provider_entity,
|
||||||
|
auth_callback=auth,
|
||||||
) as mcp_client:
|
) as mcp_client:
|
||||||
tools = mcp_client.list_tools()
|
tools = mcp_client.list_tools()
|
||||||
return {
|
return {
|
||||||
|
|||||||
@ -221,27 +221,20 @@ class ToolTransformService:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def mcp_provider_to_user_provider(db_provider: MCPToolProvider, for_list: bool = False) -> ToolProviderApiEntity:
|
def mcp_provider_to_user_provider(db_provider: MCPToolProvider, for_list: bool = False) -> ToolProviderApiEntity:
|
||||||
|
# Convert to entity and use its API response method
|
||||||
|
provider_entity = db_provider.to_entity()
|
||||||
user = db_provider.load_user()
|
user = db_provider.load_user()
|
||||||
return ToolProviderApiEntity(
|
|
||||||
id=db_provider.server_identifier if not for_list else db_provider.id,
|
response = provider_entity.to_api_response(user_name=user.name if user else None)
|
||||||
author=user.name if user else "Anonymous",
|
|
||||||
name=db_provider.name,
|
# Add additional fields specific to the transform
|
||||||
icon=db_provider.provider_icon,
|
response["id"] = db_provider.server_identifier if not for_list else db_provider.id
|
||||||
type=ToolProviderType.MCP,
|
response["tools"] = ToolTransformService.mcp_tool_to_user_tool(
|
||||||
is_team_authorization=db_provider.authed,
|
db_provider, [MCPTool(**tool) for tool in json.loads(db_provider.tools)]
|
||||||
server_url=db_provider.masked_server_url,
|
|
||||||
tools=ToolTransformService.mcp_tool_to_user_tool(
|
|
||||||
db_provider, [MCPTool(**tool) for tool in json.loads(db_provider.tools)]
|
|
||||||
),
|
|
||||||
updated_at=int(db_provider.updated_at.timestamp()),
|
|
||||||
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
|
|
||||||
description=I18nObject(en_US="", zh_Hans=""),
|
|
||||||
server_identifier=db_provider.server_identifier,
|
|
||||||
timeout=db_provider.timeout,
|
|
||||||
sse_read_timeout=db_provider.sse_read_timeout,
|
|
||||||
masked_headers=db_provider.masked_headers,
|
|
||||||
original_headers=db_provider.decrypted_headers,
|
|
||||||
)
|
)
|
||||||
|
response["server_identifier"] = db_provider.server_identifier
|
||||||
|
|
||||||
|
return ToolProviderApiEntity(**response)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def mcp_tool_to_user_tool(mcp_provider: MCPToolProvider, tools: list[MCPTool]) -> list[ToolApiEntity]:
|
def mcp_tool_to_user_tool(mcp_provider: MCPToolProvider, tools: list[MCPTool]) -> list[ToolApiEntity]:
|
||||||
@ -403,7 +396,7 @@ class ToolTransformService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def convert_mcp_schema_to_parameter(schema: dict) -> list["ToolParameter"]:
|
def convert_mcp_schema_to_parameter(schema: dict[str, Any]) -> list["ToolParameter"]:
|
||||||
"""
|
"""
|
||||||
Convert MCP JSON schema to tool parameters
|
Convert MCP JSON schema to tool parameters
|
||||||
|
|
||||||
@ -412,7 +405,7 @@ class ToolTransformService:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def create_parameter(
|
def create_parameter(
|
||||||
name: str, description: str, param_type: str, required: bool, input_schema: dict | None = None
|
name: str, description: str, param_type: str, required: bool, input_schema: dict[str, Any] | None = None
|
||||||
) -> ToolParameter:
|
) -> ToolParameter:
|
||||||
"""Create a ToolParameter instance with given attributes"""
|
"""Create a ToolParameter instance with given attributes"""
|
||||||
input_schema_dict: dict[str, Any] = {"input_schema": input_schema} if input_schema else {}
|
input_schema_dict: dict[str, Any] = {"input_schema": input_schema} if input_schema else {}
|
||||||
@ -427,7 +420,9 @@ class ToolTransformService:
|
|||||||
**input_schema_dict,
|
**input_schema_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
def process_properties(props: dict, required: list, prefix: str = "") -> list[ToolParameter]:
|
def process_properties(
|
||||||
|
props: dict[str, dict[str, Any]], required: list[str], prefix: str = ""
|
||||||
|
) -> list[ToolParameter]:
|
||||||
"""Process properties recursively"""
|
"""Process properties recursively"""
|
||||||
TYPE_MAPPING = {"integer": "number", "float": "number"}
|
TYPE_MAPPING = {"integer": "number", "float": "number"}
|
||||||
COMPLEX_TYPES = ["array", "object"]
|
COMPLEX_TYPES = ["array", "object"]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user