mirror of https://github.com/langgenius/dify.git
feat: add mcp tool display directly (#30019)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
b892906d71
commit
bdd8a35b9d
|
|
@ -1,4 +1,5 @@
|
||||||
import io
|
import io
|
||||||
|
import logging
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from flask import make_response, redirect, request, send_file
|
from flask import make_response, redirect, request, send_file
|
||||||
|
|
@ -17,6 +18,7 @@ from controllers.console.wraps import (
|
||||||
is_admin_or_owner_required,
|
is_admin_or_owner_required,
|
||||||
setup_required,
|
setup_required,
|
||||||
)
|
)
|
||||||
|
from core.db.session_factory import session_factory
|
||||||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
|
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
|
||||||
from core.helper.tool_provider_cache import ToolProviderListCache
|
from core.helper.tool_provider_cache import ToolProviderListCache
|
||||||
from core.mcp.auth.auth_flow import auth, handle_callback
|
from core.mcp.auth.auth_flow import auth, handle_callback
|
||||||
|
|
@ -40,6 +42,8 @@ from services.tools.tools_manage_service import ToolCommonService
|
||||||
from services.tools.tools_transform_service import ToolTransformService
|
from services.tools.tools_transform_service import ToolTransformService
|
||||||
from services.tools.workflow_tools_manage_service import WorkflowToolManageService
|
from services.tools.workflow_tools_manage_service import WorkflowToolManageService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def is_valid_url(url: str) -> bool:
|
def is_valid_url(url: str) -> bool:
|
||||||
if not url:
|
if not url:
|
||||||
|
|
@ -945,8 +949,8 @@ class ToolProviderMCPApi(Resource):
|
||||||
configuration = MCPConfiguration.model_validate(args["configuration"])
|
configuration = MCPConfiguration.model_validate(args["configuration"])
|
||||||
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
||||||
|
|
||||||
# Create provider in transaction
|
# 1) Create provider in a short transaction (no network I/O inside)
|
||||||
with Session(db.engine) as session, session.begin():
|
with session_factory.create_session() as session, session.begin():
|
||||||
service = MCPToolManageService(session=session)
|
service = MCPToolManageService(session=session)
|
||||||
result = service.create_provider(
|
result = service.create_provider(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
|
|
@ -962,7 +966,28 @@ class ToolProviderMCPApi(Resource):
|
||||||
authentication=authentication,
|
authentication=authentication,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
|
# 2) Try to fetch tools immediately after creation so they appear without a second save.
|
||||||
|
# Perform network I/O outside any DB session to avoid holding locks.
|
||||||
|
try:
|
||||||
|
reconnect = MCPToolManageService.reconnect_with_url(
|
||||||
|
server_url=args["server_url"],
|
||||||
|
headers=args.get("headers") or {},
|
||||||
|
timeout=configuration.timeout,
|
||||||
|
sse_read_timeout=configuration.sse_read_timeout,
|
||||||
|
)
|
||||||
|
# Update just-created provider with authed/tools in a new short transaction
|
||||||
|
with session_factory.create_session() as session, session.begin():
|
||||||
|
service = MCPToolManageService(session=session)
|
||||||
|
db_provider = service.get_provider(provider_id=result.id, tenant_id=tenant_id)
|
||||||
|
db_provider.authed = reconnect.authed
|
||||||
|
db_provider.tools = reconnect.tools
|
||||||
|
|
||||||
|
result = ToolTransformService.mcp_provider_to_user_provider(db_provider, for_list=True)
|
||||||
|
except Exception:
|
||||||
|
# Best-effort: if initial fetch fails (e.g., auth required), return created provider as-is
|
||||||
|
logger.warning("Failed to fetch MCP tools after creation", exc_info=True)
|
||||||
|
|
||||||
|
# Final cache invalidation to ensure list views are up to date
|
||||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||||
|
|
||||||
return jsonable_encoder(result)
|
return jsonable_encoder(result)
|
||||||
|
|
|
||||||
|
|
@ -319,8 +319,14 @@ class MCPToolManageService:
|
||||||
except MCPError as e:
|
except MCPError as e:
|
||||||
raise ValueError(f"Failed to connect to MCP server: {e}")
|
raise ValueError(f"Failed to connect to MCP server: {e}")
|
||||||
|
|
||||||
# Update database with retrieved tools
|
# Update database with retrieved tools (ensure description is a non-null string)
|
||||||
db_provider.tools = json.dumps([tool.model_dump() for tool in tools])
|
tools_payload = []
|
||||||
|
for tool in tools:
|
||||||
|
data = tool.model_dump()
|
||||||
|
if data.get("description") is None:
|
||||||
|
data["description"] = ""
|
||||||
|
tools_payload.append(data)
|
||||||
|
db_provider.tools = json.dumps(tools_payload)
|
||||||
db_provider.authed = True
|
db_provider.authed = True
|
||||||
db_provider.updated_at = datetime.now()
|
db_provider.updated_at = datetime.now()
|
||||||
self._session.flush()
|
self._session.flush()
|
||||||
|
|
@ -620,6 +626,21 @@ class MCPToolManageService:
|
||||||
server_url_hash=new_server_url_hash,
|
server_url_hash=new_server_url_hash,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def reconnect_with_url(
|
||||||
|
*,
|
||||||
|
server_url: str,
|
||||||
|
headers: dict[str, str],
|
||||||
|
timeout: float | None,
|
||||||
|
sse_read_timeout: float | None,
|
||||||
|
) -> ReconnectResult:
|
||||||
|
return MCPToolManageService._reconnect_with_url(
|
||||||
|
server_url=server_url,
|
||||||
|
headers=headers,
|
||||||
|
timeout=timeout,
|
||||||
|
sse_read_timeout=sse_read_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _reconnect_with_url(
|
def _reconnect_with_url(
|
||||||
*,
|
*,
|
||||||
|
|
@ -642,9 +663,16 @@ class MCPToolManageService:
|
||||||
sse_read_timeout=sse_read_timeout,
|
sse_read_timeout=sse_read_timeout,
|
||||||
) as mcp_client:
|
) as mcp_client:
|
||||||
tools = mcp_client.list_tools()
|
tools = mcp_client.list_tools()
|
||||||
|
# Ensure tool descriptions are non-null in payload
|
||||||
|
tools_payload = []
|
||||||
|
for t in tools:
|
||||||
|
d = t.model_dump()
|
||||||
|
if d.get("description") is None:
|
||||||
|
d["description"] = ""
|
||||||
|
tools_payload.append(d)
|
||||||
return ReconnectResult(
|
return ReconnectResult(
|
||||||
authed=True,
|
authed=True,
|
||||||
tools=json.dumps([tool.model_dump() for tool in tools]),
|
tools=json.dumps(tools_payload),
|
||||||
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
|
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
|
||||||
)
|
)
|
||||||
except MCPAuthError:
|
except MCPAuthError:
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,103 @@
|
||||||
|
import json
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from flask import Flask
|
||||||
|
from flask_restx import Api
|
||||||
|
|
||||||
|
from controllers.console.workspace.tool_providers import ToolProviderMCPApi
|
||||||
|
from core.db.session_factory import configure_session_factory
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from services.tools.mcp_tools_manage_service import ReconnectResult
|
||||||
|
|
||||||
|
|
||||||
|
# Backward-compat fixtures referenced by @pytest.mark.usefixtures in this file.
|
||||||
|
# They are intentionally no-ops because the test already patches the required
|
||||||
|
# behaviors explicitly via @patch and context managers below.
|
||||||
|
@pytest.fixture
|
||||||
|
def _mock_cache():
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def _mock_user_tenant():
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client():
|
||||||
|
app = Flask(__name__)
|
||||||
|
app.config["TESTING"] = True
|
||||||
|
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:"
|
||||||
|
api = Api(app)
|
||||||
|
api.add_resource(ToolProviderMCPApi, "/console/api/workspaces/current/tool-provider/mcp")
|
||||||
|
db.init_app(app)
|
||||||
|
# Configure session factory used by controller code
|
||||||
|
with app.app_context():
|
||||||
|
configure_session_factory(db.engine)
|
||||||
|
return app.test_client()
|
||||||
|
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"controllers.console.workspace.tool_providers.current_account_with_tenant", return_value=(MagicMock(id="u1"), "t1")
|
||||||
|
)
|
||||||
|
@patch("controllers.console.workspace.tool_providers.ToolProviderListCache.invalidate_cache", return_value=None)
|
||||||
|
@patch("controllers.console.workspace.tool_providers.Session")
|
||||||
|
@patch("controllers.console.workspace.tool_providers.MCPToolManageService._reconnect_with_url")
|
||||||
|
@pytest.mark.usefixtures("_mock_cache", "_mock_user_tenant")
|
||||||
|
def test_create_mcp_provider_populates_tools(
|
||||||
|
mock_reconnect, mock_session, mock_invalidate_cache, mock_current_account_with_tenant, client
|
||||||
|
):
|
||||||
|
# Arrange: reconnect returns tools immediately
|
||||||
|
mock_reconnect.return_value = ReconnectResult(
|
||||||
|
authed=True,
|
||||||
|
tools=json.dumps(
|
||||||
|
[{"name": "ping", "description": "ok", "inputSchema": {"type": "object"}, "outputSchema": {}}]
|
||||||
|
),
|
||||||
|
encrypted_credentials="{}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fake service.create_provider -> returns object with id for reload
|
||||||
|
svc = MagicMock()
|
||||||
|
create_result = MagicMock()
|
||||||
|
create_result.id = "provider-1"
|
||||||
|
svc.create_provider.return_value = create_result
|
||||||
|
svc.get_provider.return_value = MagicMock(id="provider-1", tenant_id="t1") # used by reload path
|
||||||
|
mock_session.return_value.__enter__.return_value = MagicMock()
|
||||||
|
# Patch MCPToolManageService constructed inside controller
|
||||||
|
with patch("controllers.console.workspace.tool_providers.MCPToolManageService", return_value=svc):
|
||||||
|
payload = {
|
||||||
|
"server_url": "http://example.com/mcp",
|
||||||
|
"name": "demo",
|
||||||
|
"icon": "😀",
|
||||||
|
"icon_type": "emoji",
|
||||||
|
"icon_background": "#000",
|
||||||
|
"server_identifier": "demo-sid",
|
||||||
|
"configuration": {"timeout": 5, "sse_read_timeout": 30},
|
||||||
|
"headers": {},
|
||||||
|
"authentication": {},
|
||||||
|
}
|
||||||
|
# Act
|
||||||
|
with (
|
||||||
|
patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"), # bypass setup_required DB check
|
||||||
|
patch("controllers.console.wraps.current_account_with_tenant", return_value=(MagicMock(id="u1"), "t1")),
|
||||||
|
patch("libs.login.check_csrf_token", return_value=None), # bypass CSRF in login_required
|
||||||
|
patch("libs.login._get_user", return_value=MagicMock(id="u1", is_authenticated=True)), # login
|
||||||
|
patch(
|
||||||
|
"services.tools.tools_transform_service.ToolTransformService.mcp_provider_to_user_provider",
|
||||||
|
return_value={"id": "provider-1", "tools": [{"name": "ping"}]},
|
||||||
|
),
|
||||||
|
):
|
||||||
|
resp = client.post(
|
||||||
|
"/console/api/workspaces/current/tool-provider/mcp",
|
||||||
|
data=json.dumps(payload),
|
||||||
|
content_type="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.get_json()
|
||||||
|
assert body.get("id") == "provider-1"
|
||||||
|
# 若 transform 后包含 tools 字段,确保非空
|
||||||
|
assert isinstance(body.get("tools"), list)
|
||||||
|
assert body["tools"]
|
||||||
Loading…
Reference in New Issue