mirror of https://github.com/langgenius/dify.git
fix: add RFC 9728 compliant well-known URL discovery with path insertion fallback (#29960)
This commit is contained in:
parent
7b60ff3d2d
commit
7501360663
|
|
@ -18,6 +18,7 @@ from controllers.console.wraps import (
|
||||||
setup_required,
|
setup_required,
|
||||||
)
|
)
|
||||||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
|
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
|
||||||
|
from core.helper.tool_provider_cache import ToolProviderListCache
|
||||||
from core.mcp.auth.auth_flow import auth, handle_callback
|
from core.mcp.auth.auth_flow import auth, handle_callback
|
||||||
from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError
|
from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError
|
||||||
from core.mcp.mcp_client import MCPClient
|
from core.mcp.mcp_client import MCPClient
|
||||||
|
|
@ -944,7 +945,7 @@ class ToolProviderMCPApi(Resource):
|
||||||
configuration = MCPConfiguration.model_validate(args["configuration"])
|
configuration = MCPConfiguration.model_validate(args["configuration"])
|
||||||
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
||||||
|
|
||||||
# Create provider
|
# Create provider in transaction
|
||||||
with Session(db.engine) as session, session.begin():
|
with Session(db.engine) as session, session.begin():
|
||||||
service = MCPToolManageService(session=session)
|
service = MCPToolManageService(session=session)
|
||||||
result = service.create_provider(
|
result = service.create_provider(
|
||||||
|
|
@ -960,6 +961,10 @@ class ToolProviderMCPApi(Resource):
|
||||||
configuration=configuration,
|
configuration=configuration,
|
||||||
authentication=authentication,
|
authentication=authentication,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
|
||||||
|
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||||
|
|
||||||
return jsonable_encoder(result)
|
return jsonable_encoder(result)
|
||||||
|
|
||||||
@console_ns.expect(parser_mcp_put)
|
@console_ns.expect(parser_mcp_put)
|
||||||
|
|
@ -972,17 +977,23 @@ class ToolProviderMCPApi(Resource):
|
||||||
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
||||||
_, current_tenant_id = current_account_with_tenant()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
# Step 1: Validate server URL change if needed (includes URL format validation and network operation)
|
# Step 1: Get provider data for URL validation (short-lived session, no network I/O)
|
||||||
validation_result = None
|
validation_data = None
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
service = MCPToolManageService(session=session)
|
service = MCPToolManageService(session=session)
|
||||||
validation_result = service.validate_server_url_change(
|
validation_data = service.get_provider_for_url_validation(
|
||||||
tenant_id=current_tenant_id, provider_id=args["provider_id"], new_server_url=args["server_url"]
|
tenant_id=current_tenant_id, provider_id=args["provider_id"]
|
||||||
)
|
)
|
||||||
|
|
||||||
# No need to check for errors here, exceptions will be raised directly
|
# Step 2: Perform URL validation with network I/O OUTSIDE of any database session
|
||||||
|
# This prevents holding database locks during potentially slow network operations
|
||||||
|
validation_result = MCPToolManageService.validate_server_url_standalone(
|
||||||
|
tenant_id=current_tenant_id,
|
||||||
|
new_server_url=args["server_url"],
|
||||||
|
validation_data=validation_data,
|
||||||
|
)
|
||||||
|
|
||||||
# Step 2: Perform database update in a transaction
|
# Step 3: Perform database update in a transaction
|
||||||
with Session(db.engine) as session, session.begin():
|
with Session(db.engine) as session, session.begin():
|
||||||
service = MCPToolManageService(session=session)
|
service = MCPToolManageService(session=session)
|
||||||
service.update_provider(
|
service.update_provider(
|
||||||
|
|
@ -999,6 +1010,10 @@ class ToolProviderMCPApi(Resource):
|
||||||
authentication=authentication,
|
authentication=authentication,
|
||||||
validation_result=validation_result,
|
validation_result=validation_result,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
|
||||||
|
ToolProviderListCache.invalidate_cache(current_tenant_id)
|
||||||
|
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
@console_ns.expect(parser_mcp_delete)
|
@console_ns.expect(parser_mcp_delete)
|
||||||
|
|
@ -1012,6 +1027,10 @@ class ToolProviderMCPApi(Resource):
|
||||||
with Session(db.engine) as session, session.begin():
|
with Session(db.engine) as session, session.begin():
|
||||||
service = MCPToolManageService(session=session)
|
service = MCPToolManageService(session=session)
|
||||||
service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"])
|
service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"])
|
||||||
|
|
||||||
|
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
|
||||||
|
ToolProviderListCache.invalidate_cache(current_tenant_id)
|
||||||
|
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,11 @@ def build_protected_resource_metadata_discovery_urls(
|
||||||
"""
|
"""
|
||||||
Build a list of URLs to try for Protected Resource Metadata discovery.
|
Build a list of URLs to try for Protected Resource Metadata discovery.
|
||||||
|
|
||||||
Per SEP-985, supports fallback when discovery fails at one URL.
|
Per RFC 9728 Section 5.1, supports fallback when discovery fails at one URL.
|
||||||
|
Priority order:
|
||||||
|
1. URL from WWW-Authenticate header (if provided)
|
||||||
|
2. Well-known URI with path: https://example.com/.well-known/oauth-protected-resource/public/mcp
|
||||||
|
3. Well-known URI at root: https://example.com/.well-known/oauth-protected-resource
|
||||||
"""
|
"""
|
||||||
urls = []
|
urls = []
|
||||||
|
|
||||||
|
|
@ -58,9 +62,18 @@ def build_protected_resource_metadata_discovery_urls(
|
||||||
# Fallback: construct from server URL
|
# Fallback: construct from server URL
|
||||||
parsed = urlparse(server_url)
|
parsed = urlparse(server_url)
|
||||||
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
||||||
fallback_url = urljoin(base_url, "/.well-known/oauth-protected-resource")
|
path = parsed.path.rstrip("/")
|
||||||
if fallback_url not in urls:
|
|
||||||
urls.append(fallback_url)
|
# Priority 2: With path insertion (e.g., /.well-known/oauth-protected-resource/public/mcp)
|
||||||
|
if path:
|
||||||
|
path_url = f"{base_url}/.well-known/oauth-protected-resource{path}"
|
||||||
|
if path_url not in urls:
|
||||||
|
urls.append(path_url)
|
||||||
|
|
||||||
|
# Priority 3: At root (e.g., /.well-known/oauth-protected-resource)
|
||||||
|
root_url = f"{base_url}/.well-known/oauth-protected-resource"
|
||||||
|
if root_url not in urls:
|
||||||
|
urls.append(root_url)
|
||||||
|
|
||||||
return urls
|
return urls
|
||||||
|
|
||||||
|
|
@ -71,30 +84,34 @@ def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: st
|
||||||
|
|
||||||
Supports both OAuth 2.0 (RFC 8414) and OpenID Connect discovery.
|
Supports both OAuth 2.0 (RFC 8414) and OpenID Connect discovery.
|
||||||
|
|
||||||
Per RFC 8414 section 3:
|
Per RFC 8414 section 3.1 and section 5, try all possible endpoints:
|
||||||
- If issuer has no path: https://example.com/.well-known/oauth-authorization-server
|
- OAuth 2.0 with path insertion: https://example.com/.well-known/oauth-authorization-server/tenant1
|
||||||
- If issuer has path: https://example.com/.well-known/oauth-authorization-server{path}
|
- OpenID Connect with path insertion: https://example.com/.well-known/openid-configuration/tenant1
|
||||||
|
- OpenID Connect path appending: https://example.com/tenant1/.well-known/openid-configuration
|
||||||
Example:
|
- OAuth 2.0 at root: https://example.com/.well-known/oauth-authorization-server
|
||||||
- issuer: https://example.com/oauth
|
- OpenID Connect at root: https://example.com/.well-known/openid-configuration
|
||||||
- metadata: https://example.com/.well-known/oauth-authorization-server/oauth
|
|
||||||
"""
|
"""
|
||||||
urls = []
|
urls = []
|
||||||
base_url = auth_server_url or server_url
|
base_url = auth_server_url or server_url
|
||||||
|
|
||||||
parsed = urlparse(base_url)
|
parsed = urlparse(base_url)
|
||||||
base = f"{parsed.scheme}://{parsed.netloc}"
|
base = f"{parsed.scheme}://{parsed.netloc}"
|
||||||
path = parsed.path.rstrip("/") # Remove trailing slash
|
path = parsed.path.rstrip("/")
|
||||||
|
# OAuth 2.0 Authorization Server Metadata at root (MCP-03-26)
|
||||||
|
urls.append(f"{base}/.well-known/oauth-authorization-server")
|
||||||
|
|
||||||
# Try OpenID Connect discovery first (more common)
|
# OpenID Connect Discovery at root
|
||||||
urls.append(urljoin(base + "/", ".well-known/openid-configuration"))
|
urls.append(f"{base}/.well-known/openid-configuration")
|
||||||
|
|
||||||
# OAuth 2.0 Authorization Server Metadata (RFC 8414)
|
|
||||||
# Include the path component if present in the issuer URL
|
|
||||||
if path:
|
if path:
|
||||||
urls.append(urljoin(base, f".well-known/oauth-authorization-server{path}"))
|
# OpenID Connect Discovery with path insertion
|
||||||
else:
|
urls.append(f"{base}/.well-known/openid-configuration{path}")
|
||||||
urls.append(urljoin(base, ".well-known/oauth-authorization-server"))
|
|
||||||
|
# OpenID Connect Discovery path appending
|
||||||
|
urls.append(f"{base}{path}/.well-known/openid-configuration")
|
||||||
|
|
||||||
|
# OAuth 2.0 Authorization Server Metadata with path insertion
|
||||||
|
urls.append(f"{base}/.well-known/oauth-authorization-server{path}")
|
||||||
|
|
||||||
return urls
|
return urls
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -59,7 +59,7 @@ class MCPClient:
|
||||||
try:
|
try:
|
||||||
logger.debug("Not supported method %s found in URL path, trying default 'mcp' method.", method_name)
|
logger.debug("Not supported method %s found in URL path, trying default 'mcp' method.", method_name)
|
||||||
self.connect_server(sse_client, "sse")
|
self.connect_server(sse_client, "sse")
|
||||||
except MCPConnectionError:
|
except (MCPConnectionError, ValueError):
|
||||||
logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.")
|
logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.")
|
||||||
self.connect_server(streamablehttp_client, "mcp")
|
self.connect_server(streamablehttp_client, "mcp")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,6 @@ from sqlalchemy.orm import Session
|
||||||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity
|
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity
|
||||||
from core.helper import encrypter
|
from core.helper import encrypter
|
||||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||||
from core.helper.tool_provider_cache import ToolProviderListCache
|
|
||||||
from core.mcp.auth.auth_flow import auth
|
from core.mcp.auth.auth_flow import auth
|
||||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||||
from core.mcp.error import MCPAuthError, MCPError
|
from core.mcp.error import MCPAuthError, MCPError
|
||||||
|
|
@ -65,6 +64,15 @@ class ServerUrlValidationResult(BaseModel):
|
||||||
return self.needs_validation and self.validation_passed and self.reconnect_result is not None
|
return self.needs_validation and self.validation_passed and self.reconnect_result is not None
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderUrlValidationData(BaseModel):
|
||||||
|
"""Data required for URL validation, extracted from database to perform network operations outside of session"""
|
||||||
|
|
||||||
|
current_server_url_hash: str
|
||||||
|
headers: dict[str, str]
|
||||||
|
timeout: float | None
|
||||||
|
sse_read_timeout: float | None
|
||||||
|
|
||||||
|
|
||||||
class MCPToolManageService:
|
class MCPToolManageService:
|
||||||
"""Service class for managing MCP tools and providers."""
|
"""Service class for managing MCP tools and providers."""
|
||||||
|
|
||||||
|
|
@ -166,9 +174,6 @@ class MCPToolManageService:
|
||||||
self._session.add(mcp_tool)
|
self._session.add(mcp_tool)
|
||||||
self._session.flush()
|
self._session.flush()
|
||||||
|
|
||||||
# Invalidate tool providers cache
|
|
||||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
|
||||||
|
|
||||||
mcp_providers = ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
|
mcp_providers = ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
|
||||||
return mcp_providers
|
return mcp_providers
|
||||||
|
|
||||||
|
|
@ -192,7 +197,7 @@ class MCPToolManageService:
|
||||||
Update an MCP provider.
|
Update an MCP provider.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
validation_result: Pre-validation result from validate_server_url_change.
|
validation_result: Pre-validation result from validate_server_url_standalone.
|
||||||
If provided and contains reconnect_result, it will be used
|
If provided and contains reconnect_result, it will be used
|
||||||
instead of performing network operations.
|
instead of performing network operations.
|
||||||
"""
|
"""
|
||||||
|
|
@ -251,8 +256,6 @@ class MCPToolManageService:
|
||||||
# Flush changes to database
|
# Flush changes to database
|
||||||
self._session.flush()
|
self._session.flush()
|
||||||
|
|
||||||
# Invalidate tool providers cache
|
|
||||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
|
||||||
except IntegrityError as e:
|
except IntegrityError as e:
|
||||||
self._handle_integrity_error(e, name, server_url, server_identifier)
|
self._handle_integrity_error(e, name, server_url, server_identifier)
|
||||||
|
|
||||||
|
|
@ -261,9 +264,6 @@ class MCPToolManageService:
|
||||||
mcp_tool = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
mcp_tool = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||||
self._session.delete(mcp_tool)
|
self._session.delete(mcp_tool)
|
||||||
|
|
||||||
# Invalidate tool providers cache
|
|
||||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
|
||||||
|
|
||||||
def list_providers(
|
def list_providers(
|
||||||
self, *, tenant_id: str, for_list: bool = False, include_sensitive: bool = True
|
self, *, tenant_id: str, for_list: bool = False, include_sensitive: bool = True
|
||||||
) -> list[ToolProviderApiEntity]:
|
) -> list[ToolProviderApiEntity]:
|
||||||
|
|
@ -546,30 +546,39 @@ class MCPToolManageService:
|
||||||
)
|
)
|
||||||
return self.execute_auth_actions(auth_result)
|
return self.execute_auth_actions(auth_result)
|
||||||
|
|
||||||
def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> ReconnectResult:
|
def get_provider_for_url_validation(self, *, tenant_id: str, provider_id: str) -> ProviderUrlValidationData:
|
||||||
"""Attempt to reconnect to MCP provider with new server URL."""
|
"""
|
||||||
|
Get provider data required for URL validation.
|
||||||
|
This method performs database read and should be called within a session.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ProviderUrlValidationData: Data needed for standalone URL validation
|
||||||
|
"""
|
||||||
|
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||||
provider_entity = provider.to_entity()
|
provider_entity = provider.to_entity()
|
||||||
headers = provider_entity.headers
|
return ProviderUrlValidationData(
|
||||||
|
current_server_url_hash=provider.server_url_hash,
|
||||||
try:
|
headers=provider_entity.headers,
|
||||||
tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity)
|
timeout=provider_entity.timeout,
|
||||||
return ReconnectResult(
|
sse_read_timeout=provider_entity.sse_read_timeout,
|
||||||
authed=True,
|
|
||||||
tools=json.dumps([tool.model_dump() for tool in tools]),
|
|
||||||
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
|
|
||||||
)
|
)
|
||||||
except MCPAuthError:
|
|
||||||
return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON)
|
|
||||||
except MCPError as e:
|
|
||||||
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
|
|
||||||
|
|
||||||
def validate_server_url_change(
|
@staticmethod
|
||||||
self, *, tenant_id: str, provider_id: str, new_server_url: str
|
def validate_server_url_standalone(
|
||||||
|
*,
|
||||||
|
tenant_id: str,
|
||||||
|
new_server_url: str,
|
||||||
|
validation_data: ProviderUrlValidationData,
|
||||||
) -> ServerUrlValidationResult:
|
) -> ServerUrlValidationResult:
|
||||||
"""
|
"""
|
||||||
Validate server URL change by attempting to connect to the new server.
|
Validate server URL change by attempting to connect to the new server.
|
||||||
This method should be called BEFORE update_provider to perform network operations
|
This method performs network operations and MUST be called OUTSIDE of any database session
|
||||||
outside of the database transaction.
|
to avoid holding locks during network I/O.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id: Tenant ID for encryption
|
||||||
|
new_server_url: The new server URL to validate
|
||||||
|
validation_data: Provider data obtained from get_provider_for_url_validation
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ServerUrlValidationResult: Validation result with connection status and tools if successful
|
ServerUrlValidationResult: Validation result with connection status and tools if successful
|
||||||
|
|
@ -579,25 +588,30 @@ class MCPToolManageService:
|
||||||
return ServerUrlValidationResult(needs_validation=False)
|
return ServerUrlValidationResult(needs_validation=False)
|
||||||
|
|
||||||
# Validate URL format
|
# Validate URL format
|
||||||
if not self._is_valid_url(new_server_url):
|
parsed = urlparse(new_server_url)
|
||||||
|
if not all([parsed.scheme, parsed.netloc]) or parsed.scheme not in ["http", "https"]:
|
||||||
raise ValueError("Server URL is not valid.")
|
raise ValueError("Server URL is not valid.")
|
||||||
|
|
||||||
# Always encrypt and hash the URL
|
# Always encrypt and hash the URL
|
||||||
encrypted_server_url = encrypter.encrypt_token(tenant_id, new_server_url)
|
encrypted_server_url = encrypter.encrypt_token(tenant_id, new_server_url)
|
||||||
new_server_url_hash = hashlib.sha256(new_server_url.encode()).hexdigest()
|
new_server_url_hash = hashlib.sha256(new_server_url.encode()).hexdigest()
|
||||||
|
|
||||||
# Get current provider
|
|
||||||
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
|
||||||
|
|
||||||
# Check if URL is actually different
|
# Check if URL is actually different
|
||||||
if new_server_url_hash == provider.server_url_hash:
|
if new_server_url_hash == validation_data.current_server_url_hash:
|
||||||
# URL hasn't changed, but still return the encrypted data
|
# URL hasn't changed, but still return the encrypted data
|
||||||
return ServerUrlValidationResult(
|
return ServerUrlValidationResult(
|
||||||
needs_validation=False, encrypted_server_url=encrypted_server_url, server_url_hash=new_server_url_hash
|
needs_validation=False,
|
||||||
|
encrypted_server_url=encrypted_server_url,
|
||||||
|
server_url_hash=new_server_url_hash,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Perform validation by attempting to connect
|
# Perform network validation - this is the expensive operation that should be outside session
|
||||||
reconnect_result = self._reconnect_provider(server_url=new_server_url, provider=provider)
|
reconnect_result = MCPToolManageService._reconnect_with_url(
|
||||||
|
server_url=new_server_url,
|
||||||
|
headers=validation_data.headers,
|
||||||
|
timeout=validation_data.timeout,
|
||||||
|
sse_read_timeout=validation_data.sse_read_timeout,
|
||||||
|
)
|
||||||
return ServerUrlValidationResult(
|
return ServerUrlValidationResult(
|
||||||
needs_validation=True,
|
needs_validation=True,
|
||||||
validation_passed=True,
|
validation_passed=True,
|
||||||
|
|
@ -606,6 +620,38 @@ class MCPToolManageService:
|
||||||
server_url_hash=new_server_url_hash,
|
server_url_hash=new_server_url_hash,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _reconnect_with_url(
|
||||||
|
*,
|
||||||
|
server_url: str,
|
||||||
|
headers: dict[str, str],
|
||||||
|
timeout: float | None,
|
||||||
|
sse_read_timeout: float | None,
|
||||||
|
) -> ReconnectResult:
|
||||||
|
"""
|
||||||
|
Attempt to connect to MCP server with given URL.
|
||||||
|
This is a static method that performs network I/O without database access.
|
||||||
|
"""
|
||||||
|
from core.mcp.mcp_client import MCPClient
|
||||||
|
|
||||||
|
try:
|
||||||
|
with MCPClient(
|
||||||
|
server_url=server_url,
|
||||||
|
headers=headers,
|
||||||
|
timeout=timeout,
|
||||||
|
sse_read_timeout=sse_read_timeout,
|
||||||
|
) as mcp_client:
|
||||||
|
tools = mcp_client.list_tools()
|
||||||
|
return ReconnectResult(
|
||||||
|
authed=True,
|
||||||
|
tools=json.dumps([tool.model_dump() for tool in tools]),
|
||||||
|
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
|
||||||
|
)
|
||||||
|
except MCPAuthError:
|
||||||
|
return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON)
|
||||||
|
except MCPError as e:
|
||||||
|
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
|
||||||
|
|
||||||
def _build_tool_provider_response(
|
def _build_tool_provider_response(
|
||||||
self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list
|
self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list
|
||||||
) -> ToolProviderApiEntity:
|
) -> ToolProviderApiEntity:
|
||||||
|
|
|
||||||
|
|
@ -1308,18 +1308,17 @@ class TestMCPToolManageService:
|
||||||
type("MockTool", (), {"model_dump": lambda self: {"name": "test_tool_2", "description": "Test tool 2"}})(),
|
type("MockTool", (), {"model_dump": lambda self: {"name": "test_tool_2", "description": "Test tool 2"}})(),
|
||||||
]
|
]
|
||||||
|
|
||||||
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
|
with patch("core.mcp.mcp_client.MCPClient") as mock_mcp_client:
|
||||||
# Setup mock client
|
# Setup mock client
|
||||||
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
|
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
|
||||||
mock_client_instance.list_tools.return_value = mock_tools
|
mock_client_instance.list_tools.return_value = mock_tools
|
||||||
|
|
||||||
# Act: Execute the method under test
|
# Act: Execute the method under test
|
||||||
from extensions.ext_database import db
|
result = MCPToolManageService._reconnect_with_url(
|
||||||
|
|
||||||
service = MCPToolManageService(db.session())
|
|
||||||
result = service._reconnect_provider(
|
|
||||||
server_url="https://example.com/mcp",
|
server_url="https://example.com/mcp",
|
||||||
provider=mcp_provider,
|
headers={"X-Test": "1"},
|
||||||
|
timeout=mcp_provider.timeout,
|
||||||
|
sse_read_timeout=mcp_provider.sse_read_timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert: Verify the expected outcomes
|
# Assert: Verify the expected outcomes
|
||||||
|
|
@ -1337,8 +1336,12 @@ class TestMCPToolManageService:
|
||||||
assert tools_data[1]["name"] == "test_tool_2"
|
assert tools_data[1]["name"] == "test_tool_2"
|
||||||
|
|
||||||
# Verify mock interactions
|
# Verify mock interactions
|
||||||
provider_entity = mcp_provider.to_entity()
|
mock_mcp_client.assert_called_once_with(
|
||||||
mock_mcp_client.assert_called_once()
|
server_url="https://example.com/mcp",
|
||||||
|
headers={"X-Test": "1"},
|
||||||
|
timeout=mcp_provider.timeout,
|
||||||
|
sse_read_timeout=mcp_provider.sse_read_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
def test_re_connect_mcp_provider_auth_error(self, db_session_with_containers, mock_external_service_dependencies):
|
def test_re_connect_mcp_provider_auth_error(self, db_session_with_containers, mock_external_service_dependencies):
|
||||||
"""
|
"""
|
||||||
|
|
@ -1361,19 +1364,18 @@ class TestMCPToolManageService:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock MCPClient to raise authentication error
|
# Mock MCPClient to raise authentication error
|
||||||
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
|
with patch("core.mcp.mcp_client.MCPClient") as mock_mcp_client:
|
||||||
from core.mcp.error import MCPAuthError
|
from core.mcp.error import MCPAuthError
|
||||||
|
|
||||||
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
|
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
|
||||||
mock_client_instance.list_tools.side_effect = MCPAuthError("Authentication required")
|
mock_client_instance.list_tools.side_effect = MCPAuthError("Authentication required")
|
||||||
|
|
||||||
# Act: Execute the method under test
|
# Act: Execute the method under test
|
||||||
from extensions.ext_database import db
|
result = MCPToolManageService._reconnect_with_url(
|
||||||
|
|
||||||
service = MCPToolManageService(db.session())
|
|
||||||
result = service._reconnect_provider(
|
|
||||||
server_url="https://example.com/mcp",
|
server_url="https://example.com/mcp",
|
||||||
provider=mcp_provider,
|
headers={},
|
||||||
|
timeout=mcp_provider.timeout,
|
||||||
|
sse_read_timeout=mcp_provider.sse_read_timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert: Verify the expected outcomes
|
# Assert: Verify the expected outcomes
|
||||||
|
|
@ -1404,18 +1406,17 @@ class TestMCPToolManageService:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock MCPClient to raise connection error
|
# Mock MCPClient to raise connection error
|
||||||
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
|
with patch("core.mcp.mcp_client.MCPClient") as mock_mcp_client:
|
||||||
from core.mcp.error import MCPError
|
from core.mcp.error import MCPError
|
||||||
|
|
||||||
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
|
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
|
||||||
mock_client_instance.list_tools.side_effect = MCPError("Connection failed")
|
mock_client_instance.list_tools.side_effect = MCPError("Connection failed")
|
||||||
|
|
||||||
# Act & Assert: Verify proper error handling
|
# Act & Assert: Verify proper error handling
|
||||||
from extensions.ext_database import db
|
|
||||||
|
|
||||||
service = MCPToolManageService(db.session())
|
|
||||||
with pytest.raises(ValueError, match="Failed to re-connect MCP server: Connection failed"):
|
with pytest.raises(ValueError, match="Failed to re-connect MCP server: Connection failed"):
|
||||||
service._reconnect_provider(
|
MCPToolManageService._reconnect_with_url(
|
||||||
server_url="https://example.com/mcp",
|
server_url="https://example.com/mcp",
|
||||||
provider=mcp_provider,
|
headers={"X-Test": "1"},
|
||||||
|
timeout=mcp_provider.timeout,
|
||||||
|
sse_read_timeout=mcp_provider.sse_read_timeout,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue