feat: add mcp tool display directly (#30019)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
wangxiaolei 2025-12-26 10:41:10 +08:00 committed by GitHub
parent b892906d71
commit bdd8a35b9d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 162 additions and 6 deletions

View File

@ -1,4 +1,5 @@
import io import io
import logging
from urllib.parse import urlparse from urllib.parse import urlparse
from flask import make_response, redirect, request, send_file from flask import make_response, redirect, request, send_file
@ -17,6 +18,7 @@ from controllers.console.wraps import (
is_admin_or_owner_required, is_admin_or_owner_required,
setup_required, setup_required,
) )
from core.db.session_factory import session_factory
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
from core.helper.tool_provider_cache import ToolProviderListCache from core.helper.tool_provider_cache import ToolProviderListCache
from core.mcp.auth.auth_flow import auth, handle_callback from core.mcp.auth.auth_flow import auth, handle_callback
@ -40,6 +42,8 @@ from services.tools.tools_manage_service import ToolCommonService
from services.tools.tools_transform_service import ToolTransformService from services.tools.tools_transform_service import ToolTransformService
from services.tools.workflow_tools_manage_service import WorkflowToolManageService from services.tools.workflow_tools_manage_service import WorkflowToolManageService
logger = logging.getLogger(__name__)
def is_valid_url(url: str) -> bool: def is_valid_url(url: str) -> bool:
if not url: if not url:
@ -945,8 +949,8 @@ class ToolProviderMCPApi(Resource):
configuration = MCPConfiguration.model_validate(args["configuration"]) configuration = MCPConfiguration.model_validate(args["configuration"])
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
# Create provider in transaction # 1) Create provider in a short transaction (no network I/O inside)
with Session(db.engine) as session, session.begin(): with session_factory.create_session() as session, session.begin():
service = MCPToolManageService(session=session) service = MCPToolManageService(session=session)
result = service.create_provider( result = service.create_provider(
tenant_id=tenant_id, tenant_id=tenant_id,
@ -962,7 +966,28 @@ class ToolProviderMCPApi(Resource):
authentication=authentication, authentication=authentication,
) )
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations # 2) Try to fetch tools immediately after creation so they appear without a second save.
# Perform network I/O outside any DB session to avoid holding locks.
try:
reconnect = MCPToolManageService.reconnect_with_url(
server_url=args["server_url"],
headers=args.get("headers") or {},
timeout=configuration.timeout,
sse_read_timeout=configuration.sse_read_timeout,
)
# Update just-created provider with authed/tools in a new short transaction
with session_factory.create_session() as session, session.begin():
service = MCPToolManageService(session=session)
db_provider = service.get_provider(provider_id=result.id, tenant_id=tenant_id)
db_provider.authed = reconnect.authed
db_provider.tools = reconnect.tools
result = ToolTransformService.mcp_provider_to_user_provider(db_provider, for_list=True)
except Exception:
# Best-effort: if initial fetch fails (e.g., auth required), return created provider as-is
logger.warning("Failed to fetch MCP tools after creation", exc_info=True)
# Final cache invalidation to ensure list views are up to date
ToolProviderListCache.invalidate_cache(tenant_id) ToolProviderListCache.invalidate_cache(tenant_id)
return jsonable_encoder(result) return jsonable_encoder(result)

View File

@ -319,8 +319,14 @@ class MCPToolManageService:
except MCPError as e: except MCPError as e:
raise ValueError(f"Failed to connect to MCP server: {e}") raise ValueError(f"Failed to connect to MCP server: {e}")
# Update database with retrieved tools # Update database with retrieved tools (ensure description is a non-null string)
db_provider.tools = json.dumps([tool.model_dump() for tool in tools]) tools_payload = []
for tool in tools:
data = tool.model_dump()
if data.get("description") is None:
data["description"] = ""
tools_payload.append(data)
db_provider.tools = json.dumps(tools_payload)
db_provider.authed = True db_provider.authed = True
db_provider.updated_at = datetime.now() db_provider.updated_at = datetime.now()
self._session.flush() self._session.flush()
@ -620,6 +626,21 @@ class MCPToolManageService:
server_url_hash=new_server_url_hash, server_url_hash=new_server_url_hash,
) )
@staticmethod
def reconnect_with_url(
*,
server_url: str,
headers: dict[str, str],
timeout: float | None,
sse_read_timeout: float | None,
) -> ReconnectResult:
return MCPToolManageService._reconnect_with_url(
server_url=server_url,
headers=headers,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
)
@staticmethod @staticmethod
def _reconnect_with_url( def _reconnect_with_url(
*, *,
@ -642,9 +663,16 @@ class MCPToolManageService:
sse_read_timeout=sse_read_timeout, sse_read_timeout=sse_read_timeout,
) as mcp_client: ) as mcp_client:
tools = mcp_client.list_tools() tools = mcp_client.list_tools()
# Ensure tool descriptions are non-null in payload
tools_payload = []
for t in tools:
d = t.model_dump()
if d.get("description") is None:
d["description"] = ""
tools_payload.append(d)
return ReconnectResult( return ReconnectResult(
authed=True, authed=True,
tools=json.dumps([tool.model_dump() for tool in tools]), tools=json.dumps(tools_payload),
encrypted_credentials=EMPTY_CREDENTIALS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON,
) )
except MCPAuthError: except MCPAuthError:

View File

@ -0,0 +1,103 @@
import json
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from flask_restx import Api
from controllers.console.workspace.tool_providers import ToolProviderMCPApi
from core.db.session_factory import configure_session_factory
from extensions.ext_database import db
from services.tools.mcp_tools_manage_service import ReconnectResult
# Backward-compat fixtures referenced by @pytest.mark.usefixtures in this file.
# They are intentionally no-ops because the test already patches the required
# behaviors explicitly via @patch and context managers below.
@pytest.fixture
def _mock_cache():
return
@pytest.fixture
def _mock_user_tenant():
return
@pytest.fixture
def client():
app = Flask(__name__)
app.config["TESTING"] = True
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:"
api = Api(app)
api.add_resource(ToolProviderMCPApi, "/console/api/workspaces/current/tool-provider/mcp")
db.init_app(app)
# Configure session factory used by controller code
with app.app_context():
configure_session_factory(db.engine)
return app.test_client()
@patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant", return_value=(MagicMock(id="u1"), "t1")
)
@patch("controllers.console.workspace.tool_providers.ToolProviderListCache.invalidate_cache", return_value=None)
@patch("controllers.console.workspace.tool_providers.Session")
@patch("controllers.console.workspace.tool_providers.MCPToolManageService._reconnect_with_url")
@pytest.mark.usefixtures("_mock_cache", "_mock_user_tenant")
def test_create_mcp_provider_populates_tools(
mock_reconnect, mock_session, mock_invalidate_cache, mock_current_account_with_tenant, client
):
# Arrange: reconnect returns tools immediately
mock_reconnect.return_value = ReconnectResult(
authed=True,
tools=json.dumps(
[{"name": "ping", "description": "ok", "inputSchema": {"type": "object"}, "outputSchema": {}}]
),
encrypted_credentials="{}",
)
# Fake service.create_provider -> returns object with id for reload
svc = MagicMock()
create_result = MagicMock()
create_result.id = "provider-1"
svc.create_provider.return_value = create_result
svc.get_provider.return_value = MagicMock(id="provider-1", tenant_id="t1") # used by reload path
mock_session.return_value.__enter__.return_value = MagicMock()
# Patch MCPToolManageService constructed inside controller
with patch("controllers.console.workspace.tool_providers.MCPToolManageService", return_value=svc):
payload = {
"server_url": "http://example.com/mcp",
"name": "demo",
"icon": "😀",
"icon_type": "emoji",
"icon_background": "#000",
"server_identifier": "demo-sid",
"configuration": {"timeout": 5, "sse_read_timeout": 30},
"headers": {},
"authentication": {},
}
# Act
with (
patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"), # bypass setup_required DB check
patch("controllers.console.wraps.current_account_with_tenant", return_value=(MagicMock(id="u1"), "t1")),
patch("libs.login.check_csrf_token", return_value=None), # bypass CSRF in login_required
patch("libs.login._get_user", return_value=MagicMock(id="u1", is_authenticated=True)), # login
patch(
"services.tools.tools_transform_service.ToolTransformService.mcp_provider_to_user_provider",
return_value={"id": "provider-1", "tools": [{"name": "ping"}]},
),
):
resp = client.post(
"/console/api/workspaces/current/tool-provider/mcp",
data=json.dumps(payload),
content_type="application/json",
)
# Assert
assert resp.status_code == 200
body = resp.get_json()
assert body.get("id") == "provider-1"
# 若 transform 后包含 tools 字段,确保非空
assert isinstance(body.get("tools"), list)
assert body["tools"]