mirror of https://github.com/langgenius/dify.git
feat: add mcp tool display directly (#30019)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
b892906d71
commit
bdd8a35b9d
|
|
@ -1,4 +1,5 @@
|
|||
import io
|
||||
import logging
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from flask import make_response, redirect, request, send_file
|
||||
|
|
@ -17,6 +18,7 @@ from controllers.console.wraps import (
|
|||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.db.session_factory import session_factory
|
||||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
|
||||
from core.helper.tool_provider_cache import ToolProviderListCache
|
||||
from core.mcp.auth.auth_flow import auth, handle_callback
|
||||
|
|
@ -40,6 +42,8 @@ from services.tools.tools_manage_service import ToolCommonService
|
|||
from services.tools.tools_transform_service import ToolTransformService
|
||||
from services.tools.workflow_tools_manage_service import WorkflowToolManageService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_valid_url(url: str) -> bool:
|
||||
if not url:
|
||||
|
|
@ -945,8 +949,8 @@ class ToolProviderMCPApi(Resource):
|
|||
configuration = MCPConfiguration.model_validate(args["configuration"])
|
||||
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
||||
|
||||
# Create provider in transaction
|
||||
with Session(db.engine) as session, session.begin():
|
||||
# 1) Create provider in a short transaction (no network I/O inside)
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
result = service.create_provider(
|
||||
tenant_id=tenant_id,
|
||||
|
|
@ -962,7 +966,28 @@ class ToolProviderMCPApi(Resource):
|
|||
authentication=authentication,
|
||||
)
|
||||
|
||||
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
|
||||
# 2) Try to fetch tools immediately after creation so they appear without a second save.
|
||||
# Perform network I/O outside any DB session to avoid holding locks.
|
||||
try:
|
||||
reconnect = MCPToolManageService.reconnect_with_url(
|
||||
server_url=args["server_url"],
|
||||
headers=args.get("headers") or {},
|
||||
timeout=configuration.timeout,
|
||||
sse_read_timeout=configuration.sse_read_timeout,
|
||||
)
|
||||
# Update just-created provider with authed/tools in a new short transaction
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
db_provider = service.get_provider(provider_id=result.id, tenant_id=tenant_id)
|
||||
db_provider.authed = reconnect.authed
|
||||
db_provider.tools = reconnect.tools
|
||||
|
||||
result = ToolTransformService.mcp_provider_to_user_provider(db_provider, for_list=True)
|
||||
except Exception:
|
||||
# Best-effort: if initial fetch fails (e.g., auth required), return created provider as-is
|
||||
logger.warning("Failed to fetch MCP tools after creation", exc_info=True)
|
||||
|
||||
# Final cache invalidation to ensure list views are up to date
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
return jsonable_encoder(result)
|
||||
|
|
|
|||
|
|
@ -319,8 +319,14 @@ class MCPToolManageService:
|
|||
except MCPError as e:
|
||||
raise ValueError(f"Failed to connect to MCP server: {e}")
|
||||
|
||||
# Update database with retrieved tools
|
||||
db_provider.tools = json.dumps([tool.model_dump() for tool in tools])
|
||||
# Update database with retrieved tools (ensure description is a non-null string)
|
||||
tools_payload = []
|
||||
for tool in tools:
|
||||
data = tool.model_dump()
|
||||
if data.get("description") is None:
|
||||
data["description"] = ""
|
||||
tools_payload.append(data)
|
||||
db_provider.tools = json.dumps(tools_payload)
|
||||
db_provider.authed = True
|
||||
db_provider.updated_at = datetime.now()
|
||||
self._session.flush()
|
||||
|
|
@ -620,6 +626,21 @@ class MCPToolManageService:
|
|||
server_url_hash=new_server_url_hash,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def reconnect_with_url(
|
||||
*,
|
||||
server_url: str,
|
||||
headers: dict[str, str],
|
||||
timeout: float | None,
|
||||
sse_read_timeout: float | None,
|
||||
) -> ReconnectResult:
|
||||
return MCPToolManageService._reconnect_with_url(
|
||||
server_url=server_url,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _reconnect_with_url(
|
||||
*,
|
||||
|
|
@ -642,9 +663,16 @@ class MCPToolManageService:
|
|||
sse_read_timeout=sse_read_timeout,
|
||||
) as mcp_client:
|
||||
tools = mcp_client.list_tools()
|
||||
# Ensure tool descriptions are non-null in payload
|
||||
tools_payload = []
|
||||
for t in tools:
|
||||
d = t.model_dump()
|
||||
if d.get("description") is None:
|
||||
d["description"] = ""
|
||||
tools_payload.append(d)
|
||||
return ReconnectResult(
|
||||
authed=True,
|
||||
tools=json.dumps([tool.model_dump() for tool in tools]),
|
||||
tools=json.dumps(tools_payload),
|
||||
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
|
||||
)
|
||||
except MCPAuthError:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,103 @@
|
|||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask_restx import Api
|
||||
|
||||
from controllers.console.workspace.tool_providers import ToolProviderMCPApi
|
||||
from core.db.session_factory import configure_session_factory
|
||||
from extensions.ext_database import db
|
||||
from services.tools.mcp_tools_manage_service import ReconnectResult
|
||||
|
||||
|
||||
# Backward-compat fixtures referenced by @pytest.mark.usefixtures in this file.
|
||||
# They are intentionally no-ops because the test already patches the required
|
||||
# behaviors explicitly via @patch and context managers below.
|
||||
@pytest.fixture
|
||||
def _mock_cache():
|
||||
return
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _mock_user_tenant():
|
||||
return
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:"
|
||||
api = Api(app)
|
||||
api.add_resource(ToolProviderMCPApi, "/console/api/workspaces/current/tool-provider/mcp")
|
||||
db.init_app(app)
|
||||
# Configure session factory used by controller code
|
||||
with app.app_context():
|
||||
configure_session_factory(db.engine)
|
||||
return app.test_client()
|
||||
|
||||
|
||||
@patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant", return_value=(MagicMock(id="u1"), "t1")
|
||||
)
|
||||
@patch("controllers.console.workspace.tool_providers.ToolProviderListCache.invalidate_cache", return_value=None)
|
||||
@patch("controllers.console.workspace.tool_providers.Session")
|
||||
@patch("controllers.console.workspace.tool_providers.MCPToolManageService._reconnect_with_url")
|
||||
@pytest.mark.usefixtures("_mock_cache", "_mock_user_tenant")
|
||||
def test_create_mcp_provider_populates_tools(
|
||||
mock_reconnect, mock_session, mock_invalidate_cache, mock_current_account_with_tenant, client
|
||||
):
|
||||
# Arrange: reconnect returns tools immediately
|
||||
mock_reconnect.return_value = ReconnectResult(
|
||||
authed=True,
|
||||
tools=json.dumps(
|
||||
[{"name": "ping", "description": "ok", "inputSchema": {"type": "object"}, "outputSchema": {}}]
|
||||
),
|
||||
encrypted_credentials="{}",
|
||||
)
|
||||
|
||||
# Fake service.create_provider -> returns object with id for reload
|
||||
svc = MagicMock()
|
||||
create_result = MagicMock()
|
||||
create_result.id = "provider-1"
|
||||
svc.create_provider.return_value = create_result
|
||||
svc.get_provider.return_value = MagicMock(id="provider-1", tenant_id="t1") # used by reload path
|
||||
mock_session.return_value.__enter__.return_value = MagicMock()
|
||||
# Patch MCPToolManageService constructed inside controller
|
||||
with patch("controllers.console.workspace.tool_providers.MCPToolManageService", return_value=svc):
|
||||
payload = {
|
||||
"server_url": "http://example.com/mcp",
|
||||
"name": "demo",
|
||||
"icon": "😀",
|
||||
"icon_type": "emoji",
|
||||
"icon_background": "#000",
|
||||
"server_identifier": "demo-sid",
|
||||
"configuration": {"timeout": 5, "sse_read_timeout": 30},
|
||||
"headers": {},
|
||||
"authentication": {},
|
||||
}
|
||||
# Act
|
||||
with (
|
||||
patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"), # bypass setup_required DB check
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(MagicMock(id="u1"), "t1")),
|
||||
patch("libs.login.check_csrf_token", return_value=None), # bypass CSRF in login_required
|
||||
patch("libs.login._get_user", return_value=MagicMock(id="u1", is_authenticated=True)), # login
|
||||
patch(
|
||||
"services.tools.tools_transform_service.ToolTransformService.mcp_provider_to_user_provider",
|
||||
return_value={"id": "provider-1", "tools": [{"name": "ping"}]},
|
||||
),
|
||||
):
|
||||
resp = client.post(
|
||||
"/console/api/workspaces/current/tool-provider/mcp",
|
||||
data=json.dumps(payload),
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert resp.status_code == 200
|
||||
body = resp.get_json()
|
||||
assert body.get("id") == "provider-1"
|
||||
# 若 transform 后包含 tools 字段,确保非空
|
||||
assert isinstance(body.get("tools"), list)
|
||||
assert body["tools"]
|
||||
Loading…
Reference in New Issue