mirror of
https://github.com/langgenius/dify.git
synced 2026-03-14 13:51:33 +08:00
Merge remote-tracking branch 'upstream/release/e-1.12.1' into deploy/enterprise
This commit is contained in:
commit
42226fdd3b
@ -11,9 +11,35 @@ from controllers.console.error import UnauthorizedAndForceLogout
|
||||
from core.logging.context import init_request_context
|
||||
from dify_app import DifyApp
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Console bootstrap APIs exempt from license check.
|
||||
# Defined at module level to avoid per-request tuple construction.
|
||||
# - system-features: license status for expiry UI (GlobalPublicStoreProvider)
|
||||
# - setup: install/setup status check (AppInitializer)
|
||||
# - init: init password validation for fresh install (InitPasswordPopup)
|
||||
# - login: auto-login after setup completion (InstallForm)
|
||||
# - features: billing/plan features (ProviderContextProvider)
|
||||
# - account/profile: login check + user profile (AppContextProvider, useIsLogin)
|
||||
# - workspaces/current: workspace + model providers (AppContextProvider)
|
||||
# - version: version check (AppContextProvider)
|
||||
# - activate/check: invitation link validation (signin page)
|
||||
# Without these exemptions, the signin page triggers location.reload()
|
||||
# on unauthorized_and_force_logout, causing an infinite loop.
|
||||
_CONSOLE_EXEMPT_PREFIXES = (
|
||||
"/console/api/system-features",
|
||||
"/console/api/setup",
|
||||
"/console/api/init",
|
||||
"/console/api/login",
|
||||
"/console/api/features",
|
||||
"/console/api/account/profile",
|
||||
"/console/api/workspaces/current",
|
||||
"/console/api/version",
|
||||
"/console/api/activate/check",
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# Application Factory Function
|
||||
@ -38,18 +64,12 @@ def create_flask_app_with_configs() -> DifyApp:
|
||||
# When license expires, block all API access except bootstrap endpoints needed
|
||||
# for the frontend to load the license expiration page without infinite reloads.
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
is_console_api = request.path.startswith("/console/api")
|
||||
is_webapp_api = request.path.startswith("/api") and not is_console_api
|
||||
is_console_api = request.path.startswith("/console/api/")
|
||||
is_webapp_api = request.path.startswith("/api/") and not is_console_api
|
||||
|
||||
if is_console_api or is_webapp_api:
|
||||
if is_console_api:
|
||||
console_exempt_prefixes = (
|
||||
"/console/api/system-features",
|
||||
"/console/api/setup",
|
||||
"/console/api/version",
|
||||
"/console/api/activate/check",
|
||||
)
|
||||
is_exempt = any(request.path.startswith(p) for p in console_exempt_prefixes)
|
||||
is_exempt = any(request.path.startswith(p) for p in _CONSOLE_EXEMPT_PREFIXES)
|
||||
else: # webapp API
|
||||
is_exempt = request.path.startswith("/api/system-features")
|
||||
|
||||
@ -57,10 +77,13 @@ def create_flask_app_with_configs() -> DifyApp:
|
||||
try:
|
||||
# Check license status with caching (10 min TTL)
|
||||
license_status = EnterpriseService.get_cached_license_status()
|
||||
if license_status in ["inactive", "expired", "lost"]:
|
||||
if license_status in (LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST):
|
||||
# Cookie clearing is handled by register_external_error_handlers
|
||||
# in libs/external_api.py which detects the error code and calls
|
||||
# build_force_logout_cookie_headers(). Frontend then checks
|
||||
# code === 'unauthorized_and_force_logout' and calls location.reload().
|
||||
raise UnauthorizedAndForceLogout(
|
||||
f"Enterprise license is {license_status}. "
|
||||
"Please contact your administrator."
|
||||
f"Enterprise license is {license_status}. Please contact your administrator."
|
||||
)
|
||||
except UnauthorizedAndForceLogout:
|
||||
raise
|
||||
|
||||
@ -19,6 +19,10 @@ class EnterpriseFeatureConfig(BaseSettings):
|
||||
default=False,
|
||||
)
|
||||
|
||||
ENTERPRISE_REQUEST_TIMEOUT: int = Field(
|
||||
ge=1, description="Maximum timeout in seconds for enterprise requests", default=5
|
||||
)
|
||||
|
||||
|
||||
class EnterpriseTelemetryConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
@ -101,10 +101,7 @@ class BaseRequest:
|
||||
# {"message": "..."}
|
||||
# {"detail": "..."}
|
||||
error_message = (
|
||||
error_data.get("message")
|
||||
or error_data.get("error")
|
||||
or error_data.get("detail")
|
||||
or error_message
|
||||
error_data.get("message") or error_data.get("error") or error_data.get("detail") or error_message
|
||||
)
|
||||
except Exception:
|
||||
# If JSON parsing fails, use the default message
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
@ -8,6 +11,9 @@ from configs import dify_config
|
||||
from extensions.ext_redis import redis_client
|
||||
from services.enterprise.base import EnterpriseRequest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS = 1.0
|
||||
@ -57,7 +63,7 @@ class DefaultWorkspaceJoinResult(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid", populate_by_name=True)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _check_workspace_id_when_joined(self) -> "DefaultWorkspaceJoinResult":
|
||||
def _check_workspace_id_when_joined(self) -> DefaultWorkspaceJoinResult:
|
||||
if self.joined and not self.workspace_id:
|
||||
raise ValueError("workspace_id must be non-empty when joined is True")
|
||||
return self
|
||||
@ -230,43 +236,60 @@ class EnterpriseService:
|
||||
EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params)
|
||||
|
||||
@classmethod
|
||||
def get_cached_license_status(cls):
|
||||
"""
|
||||
Get enterprise license status with Redis caching to reduce HTTP calls.
|
||||
def get_cached_license_status(cls) -> LicenseStatus | None:
|
||||
"""Get enterprise license status with Redis caching to reduce HTTP calls.
|
||||
|
||||
Only caches valid statuses (active/expiring) since invalid statuses
|
||||
should be re-checked every request — the admin may update the license
|
||||
at any time.
|
||||
Caches valid statuses (active/expiring) for 10 minutes. Invalid statuses
|
||||
are not cached so license updates are picked up on the next request.
|
||||
|
||||
Returns license status string or None if unavailable.
|
||||
Returns:
|
||||
LicenseStatus enum value, or None if enterprise is disabled / unreachable.
|
||||
"""
|
||||
if not dify_config.ENTERPRISE_ENABLED:
|
||||
return None
|
||||
|
||||
# Try cache first — only valid statuses are cached
|
||||
try:
|
||||
cached_status = redis_client.get(LICENSE_STATUS_CACHE_KEY)
|
||||
if cached_status:
|
||||
if isinstance(cached_status, bytes):
|
||||
cached_status = cached_status.decode("utf-8")
|
||||
return cached_status
|
||||
except Exception:
|
||||
logger.debug("Failed to get license status from cache, calling enterprise API")
|
||||
cached = cls._read_cached_license_status()
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
return cls._fetch_and_cache_license_status()
|
||||
|
||||
@classmethod
|
||||
def _read_cached_license_status(cls) -> LicenseStatus | None:
|
||||
"""Read license status from Redis cache, returning None on miss or failure."""
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
try:
|
||||
raw = redis_client.get(LICENSE_STATUS_CACHE_KEY)
|
||||
if raw:
|
||||
value = raw.decode("utf-8") if isinstance(raw, bytes) else raw
|
||||
return LicenseStatus(value)
|
||||
except Exception:
|
||||
logger.debug("Failed to read license status from cache", exc_info=True)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _fetch_and_cache_license_status(cls) -> LicenseStatus | None:
|
||||
"""Fetch license status from enterprise API and cache the result.
|
||||
|
||||
Only caches valid statuses (active/expiring) so license updates
|
||||
for invalid statuses are picked up on the next request.
|
||||
"""
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
# Cache miss or failure — call enterprise API
|
||||
try:
|
||||
info = cls.get_info()
|
||||
license_info = info.get("License")
|
||||
if license_info:
|
||||
status = license_info.get("status", "inactive")
|
||||
# Only cache valid statuses so license updates are picked up immediately
|
||||
if status in ("active", "expiring"):
|
||||
try:
|
||||
redis_client.setex(LICENSE_STATUS_CACHE_KEY, LICENSE_STATUS_CACHE_TTL, status)
|
||||
except Exception:
|
||||
logger.debug("Failed to cache license status")
|
||||
return status
|
||||
except Exception:
|
||||
logger.exception("Failed to get enterprise license status")
|
||||
if not license_info:
|
||||
return None
|
||||
|
||||
return None
|
||||
status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
|
||||
if status in (LicenseStatus.ACTIVE, LicenseStatus.EXPIRING):
|
||||
try:
|
||||
redis_client.setex(LICENSE_STATUS_CACHE_KEY, LICENSE_STATUS_CACHE_TTL, status)
|
||||
except Exception:
|
||||
logger.debug("Failed to cache license status", exc_info=True)
|
||||
return status
|
||||
except Exception:
|
||||
logger.debug("Failed to fetch enterprise license status", exc_info=True)
|
||||
return None
|
||||
|
||||
@ -70,12 +70,11 @@ class PluginManagerService:
|
||||
"POST",
|
||||
"/pre-uninstall-plugin",
|
||||
json=body.model_dump(),
|
||||
raise_for_status=True,
|
||||
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"failed to perform pre uninstall plugin hook. tenant_id: %s, plugin_unique_identifier: %s, ",
|
||||
"failed to perform pre uninstall plugin hook. tenant_id: %s, plugin_unique_identifier: %s",
|
||||
body.tenant_id,
|
||||
body.plugin_unique_identifier,
|
||||
)
|
||||
|
||||
@ -30,10 +30,7 @@ from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.provider import ProviderCredential
|
||||
from models.provider_ids import GenericProviderID
|
||||
from services.enterprise.plugin_manager_service import (
|
||||
PluginManagerService,
|
||||
PreUninstallPluginRequest,
|
||||
)
|
||||
from services.enterprise.plugin_manager_service import PluginManagerService, PreUninstallPluginRequest
|
||||
from services.errors.plugin import PluginInstallationForbiddenError
|
||||
from services.feature_service import FeatureService, PluginInstallationScope
|
||||
|
||||
|
||||
@ -33,6 +33,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolTransformService:
|
||||
_MCP_SCHEMA_TYPE_RESOLUTION_MAX_DEPTH = 10
|
||||
|
||||
@classmethod
|
||||
def get_tool_provider_icon_url(
|
||||
cls, provider_type: str, provider_name: str, icon: str | Mapping[str, str]
|
||||
@ -435,6 +437,46 @@ class ToolTransformService:
|
||||
:return: list of ToolParameter instances
|
||||
"""
|
||||
|
||||
def resolve_property_type(prop: dict[str, Any], depth: int = 0) -> str:
|
||||
"""
|
||||
Resolve a JSON schema property type while guarding against cyclic or deeply nested unions.
|
||||
"""
|
||||
if depth >= ToolTransformService._MCP_SCHEMA_TYPE_RESOLUTION_MAX_DEPTH:
|
||||
return "string"
|
||||
prop_type = prop.get("type")
|
||||
if isinstance(prop_type, list):
|
||||
non_null_types = [type_name for type_name in prop_type if type_name != "null"]
|
||||
if non_null_types:
|
||||
return non_null_types[0]
|
||||
if prop_type:
|
||||
return "string"
|
||||
elif isinstance(prop_type, str):
|
||||
if prop_type == "null":
|
||||
return "string"
|
||||
return prop_type
|
||||
|
||||
for union_key in ("anyOf", "oneOf"):
|
||||
union_schemas = prop.get(union_key)
|
||||
if not isinstance(union_schemas, list):
|
||||
continue
|
||||
|
||||
for union_schema in union_schemas:
|
||||
if not isinstance(union_schema, dict):
|
||||
continue
|
||||
union_type = resolve_property_type(union_schema, depth + 1)
|
||||
if union_type != "null":
|
||||
return union_type
|
||||
|
||||
all_of_schemas = prop.get("allOf")
|
||||
if isinstance(all_of_schemas, list):
|
||||
for all_of_schema in all_of_schemas:
|
||||
if not isinstance(all_of_schema, dict):
|
||||
continue
|
||||
all_of_type = resolve_property_type(all_of_schema, depth + 1)
|
||||
if all_of_type != "null":
|
||||
return all_of_type
|
||||
return "string"
|
||||
|
||||
def create_parameter(
|
||||
name: str, description: str, param_type: str, required: bool, input_schema: dict[str, Any] | None = None
|
||||
) -> ToolParameter:
|
||||
@ -461,10 +503,7 @@ class ToolTransformService:
|
||||
parameters = []
|
||||
for name, prop in props.items():
|
||||
current_description = prop.get("description", "")
|
||||
prop_type = prop.get("type", "string")
|
||||
|
||||
if isinstance(prop_type, list):
|
||||
prop_type = prop_type[0]
|
||||
prop_type = resolve_property_type(prop)
|
||||
if prop_type in TYPE_MAPPING:
|
||||
prop_type = TYPE_MAPPING[prop_type]
|
||||
input_schema = prop if prop_type in COMPLEX_TYPES else None
|
||||
|
||||
@ -9,85 +9,87 @@ from unittest.mock import patch
|
||||
|
||||
from httpx import HTTPStatusError
|
||||
|
||||
from configs import dify_config
|
||||
from services.enterprise.plugin_manager_service import (
|
||||
PluginManagerService,
|
||||
PreUninstallPluginRequest,
|
||||
)
|
||||
|
||||
_FAKE_TIMEOUT = 30
|
||||
|
||||
|
||||
_SEND_REQUEST_PATH = (
|
||||
"services.enterprise.plugin_manager_service.EnterprisePluginManagerRequest.send_request"
|
||||
)
|
||||
_DIFY_CONFIG_PATH = "services.enterprise.plugin_manager_service.dify_config"
|
||||
_LOGGER_PATH = "services.enterprise.plugin_manager_service.logger"
|
||||
|
||||
|
||||
class TestTryPreUninstallPlugin:
|
||||
def test_try_pre_uninstall_plugin_success(self):
|
||||
@patch(_DIFY_CONFIG_PATH)
|
||||
@patch(_SEND_REQUEST_PATH)
|
||||
def test_try_pre_uninstall_plugin_success(self, mock_send_request, mock_config):
|
||||
body = PreUninstallPluginRequest(
|
||||
tenant_id="tenant-123",
|
||||
plugin_unique_identifier="com.example.my_plugin",
|
||||
)
|
||||
mock_config.ENTERPRISE_REQUEST_TIMEOUT = _FAKE_TIMEOUT
|
||||
mock_send_request.return_value = {}
|
||||
|
||||
with patch(
|
||||
"services.enterprise.plugin_manager_service.EnterprisePluginManagerRequest.send_request"
|
||||
) as mock_send_request:
|
||||
mock_send_request.return_value = {}
|
||||
PluginManagerService.try_pre_uninstall_plugin(body)
|
||||
|
||||
PluginManagerService.try_pre_uninstall_plugin(body)
|
||||
mock_send_request.assert_called_once_with(
|
||||
"POST",
|
||||
"/pre-uninstall-plugin",
|
||||
json={"tenant_id": "tenant-123", "plugin_unique_identifier": "com.example.my_plugin"},
|
||||
timeout=_FAKE_TIMEOUT,
|
||||
)
|
||||
|
||||
mock_send_request.assert_called_once_with(
|
||||
"POST",
|
||||
"/pre-uninstall-plugin",
|
||||
json={"tenant_id": "tenant-123", "plugin_unique_identifier": "com.example.my_plugin"},
|
||||
raise_for_status=True,
|
||||
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
|
||||
)
|
||||
|
||||
def test_try_pre_uninstall_plugin_http_error_soft_fails(self):
|
||||
@patch(_DIFY_CONFIG_PATH)
|
||||
@patch(_LOGGER_PATH)
|
||||
@patch(_SEND_REQUEST_PATH)
|
||||
def test_try_pre_uninstall_plugin_http_error_soft_fails(
|
||||
self, mock_send_request, mock_logger, mock_config
|
||||
):
|
||||
body = PreUninstallPluginRequest(
|
||||
tenant_id="tenant-456",
|
||||
plugin_unique_identifier="com.example.other_plugin",
|
||||
)
|
||||
mock_config.ENTERPRISE_REQUEST_TIMEOUT = _FAKE_TIMEOUT
|
||||
mock_send_request.side_effect = HTTPStatusError(
|
||||
"502 Bad Gateway",
|
||||
request=None,
|
||||
response=None,
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"services.enterprise.plugin_manager_service.EnterprisePluginManagerRequest.send_request"
|
||||
) as mock_send_request,
|
||||
patch("services.enterprise.plugin_manager_service.logger") as mock_logger,
|
||||
):
|
||||
mock_send_request.side_effect = HTTPStatusError(
|
||||
"502 Bad Gateway",
|
||||
request=None,
|
||||
response=None,
|
||||
)
|
||||
PluginManagerService.try_pre_uninstall_plugin(body)
|
||||
|
||||
PluginManagerService.try_pre_uninstall_plugin(body)
|
||||
mock_send_request.assert_called_once_with(
|
||||
"POST",
|
||||
"/pre-uninstall-plugin",
|
||||
json={"tenant_id": "tenant-456", "plugin_unique_identifier": "com.example.other_plugin"},
|
||||
timeout=_FAKE_TIMEOUT,
|
||||
)
|
||||
mock_logger.exception.assert_called_once()
|
||||
|
||||
mock_send_request.assert_called_once_with(
|
||||
"POST",
|
||||
"/pre-uninstall-plugin",
|
||||
json={"tenant_id": "tenant-456", "plugin_unique_identifier": "com.example.other_plugin"},
|
||||
raise_for_status=True,
|
||||
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
|
||||
)
|
||||
mock_logger.exception.assert_called_once()
|
||||
|
||||
def test_try_pre_uninstall_plugin_generic_exception_soft_fails(self):
|
||||
@patch(_DIFY_CONFIG_PATH)
|
||||
@patch(_LOGGER_PATH)
|
||||
@patch(_SEND_REQUEST_PATH)
|
||||
def test_try_pre_uninstall_plugin_generic_exception_soft_fails(
|
||||
self, mock_send_request, mock_logger, mock_config
|
||||
):
|
||||
body = PreUninstallPluginRequest(
|
||||
tenant_id="tenant-789",
|
||||
plugin_unique_identifier="com.example.failing_plugin",
|
||||
)
|
||||
mock_config.ENTERPRISE_REQUEST_TIMEOUT = _FAKE_TIMEOUT
|
||||
mock_send_request.side_effect = ConnectionError("network unreachable")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"services.enterprise.plugin_manager_service.EnterprisePluginManagerRequest.send_request"
|
||||
) as mock_send_request,
|
||||
patch("services.enterprise.plugin_manager_service.logger") as mock_logger,
|
||||
):
|
||||
mock_send_request.side_effect = ConnectionError("network unreachable")
|
||||
PluginManagerService.try_pre_uninstall_plugin(body)
|
||||
|
||||
PluginManagerService.try_pre_uninstall_plugin(body)
|
||||
|
||||
mock_send_request.assert_called_once_with(
|
||||
"POST",
|
||||
"/pre-uninstall-plugin",
|
||||
json={"tenant_id": "tenant-789", "plugin_unique_identifier": "com.example.failing_plugin"},
|
||||
raise_for_status=True,
|
||||
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
|
||||
)
|
||||
mock_logger.exception.assert_called_once()
|
||||
mock_send_request.assert_called_once_with(
|
||||
"POST",
|
||||
"/pre-uninstall-plugin",
|
||||
json={"tenant_id": "tenant-789", "plugin_unique_identifier": "com.example.failing_plugin"},
|
||||
timeout=_FAKE_TIMEOUT,
|
||||
)
|
||||
mock_logger.exception.assert_called_once()
|
||||
|
||||
@ -7,7 +7,7 @@ import pytest
|
||||
from core.mcp.types import Tool as MCPTool
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.entities.tool_entities import ToolParameter, ToolProviderType
|
||||
from models.tools import MCPToolProvider
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
@ -175,6 +175,137 @@ class TestMCPToolTransform:
|
||||
# The actual parameter conversion is handled by convert_mcp_schema_to_parameter
|
||||
# which should be tested separately
|
||||
|
||||
def test_convert_mcp_schema_to_parameter_preserves_anyof_object_type(self):
|
||||
"""Nullable object schemas should keep the object parameter type."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"retrieval_model": {
|
||||
"anyOf": [{"type": "object"}, {"type": "null"}],
|
||||
"description": "检索模型配置",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
result = ToolTransformService.convert_mcp_schema_to_parameter(schema)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "retrieval_model"
|
||||
assert result[0].type == ToolParameter.ToolParameterType.OBJECT
|
||||
assert result[0].input_schema == schema["properties"]["retrieval_model"]
|
||||
|
||||
def test_convert_mcp_schema_to_parameter_preserves_oneof_object_type(self):
|
||||
"""Nullable oneOf object schemas should keep the object parameter type."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"retrieval_model": {
|
||||
"oneOf": [{"type": "object"}, {"type": "null"}],
|
||||
"description": "检索模型配置",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
result = ToolTransformService.convert_mcp_schema_to_parameter(schema)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "retrieval_model"
|
||||
assert result[0].type == ToolParameter.ToolParameterType.OBJECT
|
||||
assert result[0].input_schema == schema["properties"]["retrieval_model"]
|
||||
|
||||
def test_convert_mcp_schema_to_parameter_handles_null_type(self):
|
||||
"""Schemas with only a null type should fall back to string."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"null_prop_str": {"type": "null"},
|
||||
"null_prop_list": {"type": ["null"]},
|
||||
},
|
||||
}
|
||||
|
||||
result = ToolTransformService.convert_mcp_schema_to_parameter(schema)
|
||||
|
||||
assert len(result) == 2
|
||||
param_map = {parameter.name: parameter for parameter in result}
|
||||
assert "null_prop_str" in param_map
|
||||
assert param_map["null_prop_str"].type == ToolParameter.ToolParameterType.STRING
|
||||
assert "null_prop_list" in param_map
|
||||
assert param_map["null_prop_list"].type == ToolParameter.ToolParameterType.STRING
|
||||
|
||||
def test_convert_mcp_schema_to_parameter_preserves_allof_object_type_with_multiple_object_items(self):
|
||||
"""Property-level allOf with multiple object items should still resolve to object."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {
|
||||
"allOf": [
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"enabled": {"type": "boolean"},
|
||||
},
|
||||
"required": ["enabled"],
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"priority": {"type": "integer", "minimum": 1, "maximum": 10},
|
||||
},
|
||||
"required": ["priority"],
|
||||
},
|
||||
],
|
||||
"description": "Config must match all schemas (allOf)",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
result = ToolTransformService.convert_mcp_schema_to_parameter(schema)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "config"
|
||||
assert result[0].type == ToolParameter.ToolParameterType.OBJECT
|
||||
assert result[0].input_schema == schema["properties"]["config"]
|
||||
|
||||
def test_convert_mcp_schema_to_parameter_preserves_allof_object_type(self):
|
||||
"""Composed property schemas should keep the object parameter type."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"retrieval_model": {
|
||||
"allOf": [
|
||||
{"type": "object"},
|
||||
{"properties": {"top_k": {"type": "integer"}}},
|
||||
],
|
||||
"description": "检索模型配置",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
result = ToolTransformService.convert_mcp_schema_to_parameter(schema)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "retrieval_model"
|
||||
assert result[0].type == ToolParameter.ToolParameterType.OBJECT
|
||||
assert result[0].input_schema == schema["properties"]["retrieval_model"]
|
||||
|
||||
def test_convert_mcp_schema_to_parameter_limits_recursive_schema_depth(self):
|
||||
"""Self-referential composed schemas should stop resolving after the configured max depth."""
|
||||
recursive_property: dict[str, object] = {"description": "Recursive schema"}
|
||||
recursive_property["anyOf"] = [recursive_property]
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"recursive_config": recursive_property,
|
||||
},
|
||||
}
|
||||
|
||||
result = ToolTransformService.convert_mcp_schema_to_parameter(schema)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "recursive_config"
|
||||
assert result[0].type == ToolParameter.ToolParameterType.STRING
|
||||
assert result[0].input_schema is None
|
||||
|
||||
def test_mcp_provider_to_user_provider_for_list(self, mock_provider_full):
|
||||
"""Test mcp_provider_to_user_provider with for_list=True."""
|
||||
# Set tools data with null description
|
||||
|
||||
Loading…
Reference in New Issue
Block a user