diff --git a/api/app_factory.py b/api/app_factory.py index 15fae5525d..78cbbed765 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -11,9 +11,35 @@ from controllers.console.error import UnauthorizedAndForceLogout from core.logging.context import init_request_context from dify_app import DifyApp from services.enterprise.enterprise_service import EnterpriseService +from services.feature_service import LicenseStatus logger = logging.getLogger(__name__) +# Console bootstrap APIs exempt from license check. +# Defined at module level to avoid per-request tuple construction. +# - system-features: license status for expiry UI (GlobalPublicStoreProvider) +# - setup: install/setup status check (AppInitializer) +# - init: init password validation for fresh install (InitPasswordPopup) +# - login: auto-login after setup completion (InstallForm) +# - features: billing/plan features (ProviderContextProvider) +# - account/profile: login check + user profile (AppContextProvider, useIsLogin) +# - workspaces/current: workspace + model providers (AppContextProvider) +# - version: version check (AppContextProvider) +# - activate/check: invitation link validation (signin page) +# Without these exemptions, the signin page triggers location.reload() +# on unauthorized_and_force_logout, causing an infinite loop. +_CONSOLE_EXEMPT_PREFIXES = ( + "/console/api/system-features", + "/console/api/setup", + "/console/api/init", + "/console/api/login", + "/console/api/features", + "/console/api/account/profile", + "/console/api/workspaces/current", + "/console/api/version", + "/console/api/activate/check", +) + # ---------------------------- # Application Factory Function @@ -38,18 +64,12 @@ def create_flask_app_with_configs() -> DifyApp: # When license expires, block all API access except bootstrap endpoints needed # for the frontend to load the license expiration page without infinite reloads. if dify_config.ENTERPRISE_ENABLED: - is_console_api = request.path.startswith("/console/api") - is_webapp_api = request.path.startswith("/api") and not is_console_api + is_console_api = request.path.startswith("/console/api/") + is_webapp_api = request.path.startswith("/api/") and not is_console_api if is_console_api or is_webapp_api: if is_console_api: - console_exempt_prefixes = ( - "/console/api/system-features", - "/console/api/setup", - "/console/api/version", - "/console/api/activate/check", - ) - is_exempt = any(request.path.startswith(p) for p in console_exempt_prefixes) + is_exempt = any(request.path.startswith(p) for p in _CONSOLE_EXEMPT_PREFIXES) else: # webapp API is_exempt = request.path.startswith("/api/system-features") @@ -57,10 +77,13 @@ def create_flask_app_with_configs() -> DifyApp: try: # Check license status with caching (10 min TTL) license_status = EnterpriseService.get_cached_license_status() - if license_status in ["inactive", "expired", "lost"]: + if license_status in (LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST): + # Cookie clearing is handled by register_external_error_handlers + # in libs/external_api.py which detects the error code and calls + # build_force_logout_cookie_headers(). Frontend then checks + # code === 'unauthorized_and_force_logout' and calls location.reload(). raise UnauthorizedAndForceLogout( - f"Enterprise license is {license_status}. " - "Please contact your administrator." + f"Enterprise license is {license_status}. Please contact your administrator." ) except UnauthorizedAndForceLogout: raise diff --git a/api/configs/enterprise/__init__.py b/api/configs/enterprise/__init__.py index 01450f081b..0d9d12c345 100644 --- a/api/configs/enterprise/__init__.py +++ b/api/configs/enterprise/__init__.py @@ -19,6 +19,10 @@ class EnterpriseFeatureConfig(BaseSettings): default=False, ) + ENTERPRISE_REQUEST_TIMEOUT: int = Field( + ge=1, description="Maximum timeout in seconds for enterprise requests", default=5 + ) + class EnterpriseTelemetryConfig(BaseSettings): """ diff --git a/api/services/enterprise/base.py b/api/services/enterprise/base.py index 86cca34cf2..cc29ecfdc4 100644 --- a/api/services/enterprise/base.py +++ b/api/services/enterprise/base.py @@ -101,10 +101,7 @@ class BaseRequest: # {"message": "..."} # {"detail": "..."} error_message = ( - error_data.get("message") - or error_data.get("error") - or error_data.get("detail") - or error_message + error_data.get("message") or error_data.get("error") or error_data.get("detail") or error_message ) except Exception: # If JSON parsing fails, use the default message diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index c2d89283a6..4e6638ebd9 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import logging import uuid from datetime import datetime +from typing import TYPE_CHECKING from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -8,6 +11,9 @@ from configs import dify_config from extensions.ext_redis import redis_client from services.enterprise.base import EnterpriseRequest +if TYPE_CHECKING: + from services.feature_service import LicenseStatus + logger = logging.getLogger(__name__) DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS = 1.0 @@ -57,7 +63,7 @@ class DefaultWorkspaceJoinResult(BaseModel): model_config = ConfigDict(extra="forbid", populate_by_name=True) @model_validator(mode="after") - def _check_workspace_id_when_joined(self) -> "DefaultWorkspaceJoinResult": + def _check_workspace_id_when_joined(self) -> DefaultWorkspaceJoinResult: if self.joined and not self.workspace_id: raise ValueError("workspace_id must be non-empty when joined is True") return self @@ -230,43 +236,60 @@ class EnterpriseService: EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params) @classmethod - def get_cached_license_status(cls): - """ - Get enterprise license status with Redis caching to reduce HTTP calls. + def get_cached_license_status(cls) -> LicenseStatus | None: + """Get enterprise license status with Redis caching to reduce HTTP calls. - Only caches valid statuses (active/expiring) since invalid statuses - should be re-checked every request — the admin may update the license - at any time. + Caches valid statuses (active/expiring) for 10 minutes. Invalid statuses + are not cached so license updates are picked up on the next request. - Returns license status string or None if unavailable. + Returns: + LicenseStatus enum value, or None if enterprise is disabled / unreachable. """ if not dify_config.ENTERPRISE_ENABLED: return None - # Try cache first — only valid statuses are cached - try: - cached_status = redis_client.get(LICENSE_STATUS_CACHE_KEY) - if cached_status: - if isinstance(cached_status, bytes): - cached_status = cached_status.decode("utf-8") - return cached_status - except Exception: - logger.debug("Failed to get license status from cache, calling enterprise API") + cached = cls._read_cached_license_status() + if cached is not None: + return cached + + return cls._fetch_and_cache_license_status() + + @classmethod + def _read_cached_license_status(cls) -> LicenseStatus | None: + """Read license status from Redis cache, returning None on miss or failure.""" + from services.feature_service import LicenseStatus + + try: + raw = redis_client.get(LICENSE_STATUS_CACHE_KEY) + if raw: + value = raw.decode("utf-8") if isinstance(raw, bytes) else raw + return LicenseStatus(value) + except Exception: + logger.debug("Failed to read license status from cache", exc_info=True) + return None + + @classmethod + def _fetch_and_cache_license_status(cls) -> LicenseStatus | None: + """Fetch license status from enterprise API and cache the result. + + Only caches valid statuses (active/expiring) so license updates + for invalid statuses are picked up on the next request. + """ + from services.feature_service import LicenseStatus - # Cache miss or failure — call enterprise API try: info = cls.get_info() license_info = info.get("License") - if license_info: - status = license_info.get("status", "inactive") - # Only cache valid statuses so license updates are picked up immediately - if status in ("active", "expiring"): - try: - redis_client.setex(LICENSE_STATUS_CACHE_KEY, LICENSE_STATUS_CACHE_TTL, status) - except Exception: - logger.debug("Failed to cache license status") - return status - except Exception: - logger.exception("Failed to get enterprise license status") + if not license_info: + return None - return None \ No newline at end of file + status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE)) + if status in (LicenseStatus.ACTIVE, LicenseStatus.EXPIRING): + try: + redis_client.setex(LICENSE_STATUS_CACHE_KEY, LICENSE_STATUS_CACHE_TTL, status) + except Exception: + logger.debug("Failed to cache license status", exc_info=True) + return status + except Exception: + logger.debug("Failed to fetch enterprise license status", exc_info=True) + return None diff --git a/api/services/enterprise/plugin_manager_service.py b/api/services/enterprise/plugin_manager_service.py index 2769edd765..d4be36305e 100644 --- a/api/services/enterprise/plugin_manager_service.py +++ b/api/services/enterprise/plugin_manager_service.py @@ -70,12 +70,11 @@ class PluginManagerService: "POST", "/pre-uninstall-plugin", json=body.model_dump(), - raise_for_status=True, timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT, ) except Exception: logger.exception( - "failed to perform pre uninstall plugin hook. tenant_id: %s, plugin_unique_identifier: %s, ", + "failed to perform pre uninstall plugin hook. tenant_id: %s, plugin_unique_identifier: %s", body.tenant_id, body.plugin_unique_identifier, ) diff --git a/api/services/plugin/plugin_service.py b/api/services/plugin/plugin_service.py index 91dd54bf38..d063c57b32 100644 --- a/api/services/plugin/plugin_service.py +++ b/api/services/plugin/plugin_service.py @@ -30,10 +30,7 @@ from extensions.ext_database import db from extensions.ext_redis import redis_client from models.provider import ProviderCredential from models.provider_ids import GenericProviderID -from services.enterprise.plugin_manager_service import ( - PluginManagerService, - PreUninstallPluginRequest, -) +from services.enterprise.plugin_manager_service import PluginManagerService, PreUninstallPluginRequest from services.errors.plugin import PluginInstallationForbiddenError from services.feature_service import FeatureService, PluginInstallationScope diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index e323b3cda9..b6e5367023 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -33,6 +33,8 @@ logger = logging.getLogger(__name__) class ToolTransformService: + _MCP_SCHEMA_TYPE_RESOLUTION_MAX_DEPTH = 10 + @classmethod def get_tool_provider_icon_url( cls, provider_type: str, provider_name: str, icon: str | Mapping[str, str] @@ -435,6 +437,46 @@ class ToolTransformService: :return: list of ToolParameter instances """ + def resolve_property_type(prop: dict[str, Any], depth: int = 0) -> str: + """ + Resolve a JSON schema property type while guarding against cyclic or deeply nested unions. + """ + if depth >= ToolTransformService._MCP_SCHEMA_TYPE_RESOLUTION_MAX_DEPTH: + return "string" + prop_type = prop.get("type") + if isinstance(prop_type, list): + non_null_types = [type_name for type_name in prop_type if type_name != "null"] + if non_null_types: + return non_null_types[0] + if prop_type: + return "string" + elif isinstance(prop_type, str): + if prop_type == "null": + return "string" + return prop_type + + for union_key in ("anyOf", "oneOf"): + union_schemas = prop.get(union_key) + if not isinstance(union_schemas, list): + continue + + for union_schema in union_schemas: + if not isinstance(union_schema, dict): + continue + union_type = resolve_property_type(union_schema, depth + 1) + if union_type != "null": + return union_type + + all_of_schemas = prop.get("allOf") + if isinstance(all_of_schemas, list): + for all_of_schema in all_of_schemas: + if not isinstance(all_of_schema, dict): + continue + all_of_type = resolve_property_type(all_of_schema, depth + 1) + if all_of_type != "null": + return all_of_type + return "string" + def create_parameter( name: str, description: str, param_type: str, required: bool, input_schema: dict[str, Any] | None = None ) -> ToolParameter: @@ -461,10 +503,7 @@ class ToolTransformService: parameters = [] for name, prop in props.items(): current_description = prop.get("description", "") - prop_type = prop.get("type", "string") - - if isinstance(prop_type, list): - prop_type = prop_type[0] + prop_type = resolve_property_type(prop) if prop_type in TYPE_MAPPING: prop_type = TYPE_MAPPING[prop_type] input_schema = prop if prop_type in COMPLEX_TYPES else None diff --git a/api/tests/unit_tests/services/enterprise/test_plugin_manager_service.py b/api/tests/unit_tests/services/enterprise/test_plugin_manager_service.py index d5f34d00b9..bd81f0ff89 100644 --- a/api/tests/unit_tests/services/enterprise/test_plugin_manager_service.py +++ b/api/tests/unit_tests/services/enterprise/test_plugin_manager_service.py @@ -9,85 +9,87 @@ from unittest.mock import patch from httpx import HTTPStatusError -from configs import dify_config from services.enterprise.plugin_manager_service import ( PluginManagerService, PreUninstallPluginRequest, ) +_FAKE_TIMEOUT = 30 + + +_SEND_REQUEST_PATH = ( + "services.enterprise.plugin_manager_service.EnterprisePluginManagerRequest.send_request" +) +_DIFY_CONFIG_PATH = "services.enterprise.plugin_manager_service.dify_config" +_LOGGER_PATH = "services.enterprise.plugin_manager_service.logger" + class TestTryPreUninstallPlugin: - def test_try_pre_uninstall_plugin_success(self): + @patch(_DIFY_CONFIG_PATH) + @patch(_SEND_REQUEST_PATH) + def test_try_pre_uninstall_plugin_success(self, mock_send_request, mock_config): body = PreUninstallPluginRequest( tenant_id="tenant-123", plugin_unique_identifier="com.example.my_plugin", ) + mock_config.ENTERPRISE_REQUEST_TIMEOUT = _FAKE_TIMEOUT + mock_send_request.return_value = {} - with patch( - "services.enterprise.plugin_manager_service.EnterprisePluginManagerRequest.send_request" - ) as mock_send_request: - mock_send_request.return_value = {} + PluginManagerService.try_pre_uninstall_plugin(body) - PluginManagerService.try_pre_uninstall_plugin(body) + mock_send_request.assert_called_once_with( + "POST", + "/pre-uninstall-plugin", + json={"tenant_id": "tenant-123", "plugin_unique_identifier": "com.example.my_plugin"}, + timeout=_FAKE_TIMEOUT, + ) - mock_send_request.assert_called_once_with( - "POST", - "/pre-uninstall-plugin", - json={"tenant_id": "tenant-123", "plugin_unique_identifier": "com.example.my_plugin"}, - raise_for_status=True, - timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT, - ) - - def test_try_pre_uninstall_plugin_http_error_soft_fails(self): + @patch(_DIFY_CONFIG_PATH) + @patch(_LOGGER_PATH) + @patch(_SEND_REQUEST_PATH) + def test_try_pre_uninstall_plugin_http_error_soft_fails( + self, mock_send_request, mock_logger, mock_config + ): body = PreUninstallPluginRequest( tenant_id="tenant-456", plugin_unique_identifier="com.example.other_plugin", ) + mock_config.ENTERPRISE_REQUEST_TIMEOUT = _FAKE_TIMEOUT + mock_send_request.side_effect = HTTPStatusError( + "502 Bad Gateway", + request=None, + response=None, + ) - with ( - patch( - "services.enterprise.plugin_manager_service.EnterprisePluginManagerRequest.send_request" - ) as mock_send_request, - patch("services.enterprise.plugin_manager_service.logger") as mock_logger, - ): - mock_send_request.side_effect = HTTPStatusError( - "502 Bad Gateway", - request=None, - response=None, - ) + PluginManagerService.try_pre_uninstall_plugin(body) - PluginManagerService.try_pre_uninstall_plugin(body) + mock_send_request.assert_called_once_with( + "POST", + "/pre-uninstall-plugin", + json={"tenant_id": "tenant-456", "plugin_unique_identifier": "com.example.other_plugin"}, + timeout=_FAKE_TIMEOUT, + ) + mock_logger.exception.assert_called_once() - mock_send_request.assert_called_once_with( - "POST", - "/pre-uninstall-plugin", - json={"tenant_id": "tenant-456", "plugin_unique_identifier": "com.example.other_plugin"}, - raise_for_status=True, - timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT, - ) - mock_logger.exception.assert_called_once() - - def test_try_pre_uninstall_plugin_generic_exception_soft_fails(self): + @patch(_DIFY_CONFIG_PATH) + @patch(_LOGGER_PATH) + @patch(_SEND_REQUEST_PATH) + def test_try_pre_uninstall_plugin_generic_exception_soft_fails( + self, mock_send_request, mock_logger, mock_config + ): body = PreUninstallPluginRequest( tenant_id="tenant-789", plugin_unique_identifier="com.example.failing_plugin", ) + mock_config.ENTERPRISE_REQUEST_TIMEOUT = _FAKE_TIMEOUT + mock_send_request.side_effect = ConnectionError("network unreachable") - with ( - patch( - "services.enterprise.plugin_manager_service.EnterprisePluginManagerRequest.send_request" - ) as mock_send_request, - patch("services.enterprise.plugin_manager_service.logger") as mock_logger, - ): - mock_send_request.side_effect = ConnectionError("network unreachable") + PluginManagerService.try_pre_uninstall_plugin(body) - PluginManagerService.try_pre_uninstall_plugin(body) - - mock_send_request.assert_called_once_with( - "POST", - "/pre-uninstall-plugin", - json={"tenant_id": "tenant-789", "plugin_unique_identifier": "com.example.failing_plugin"}, - raise_for_status=True, - timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT, - ) - mock_logger.exception.assert_called_once() + mock_send_request.assert_called_once_with( + "POST", + "/pre-uninstall-plugin", + json={"tenant_id": "tenant-789", "plugin_unique_identifier": "com.example.failing_plugin"}, + timeout=_FAKE_TIMEOUT, + ) + mock_logger.exception.assert_called_once() diff --git a/api/tests/unit_tests/services/tools/test_mcp_tools_transform.py b/api/tests/unit_tests/services/tools/test_mcp_tools_transform.py index 7511fd6f0c..9537d207f0 100644 --- a/api/tests/unit_tests/services/tools/test_mcp_tools_transform.py +++ b/api/tests/unit_tests/services/tools/test_mcp_tools_transform.py @@ -7,7 +7,7 @@ import pytest from core.mcp.types import Tool as MCPTool from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolProviderType +from core.tools.entities.tool_entities import ToolParameter, ToolProviderType from models.tools import MCPToolProvider from services.tools.tools_transform_service import ToolTransformService @@ -175,6 +175,137 @@ class TestMCPToolTransform: # The actual parameter conversion is handled by convert_mcp_schema_to_parameter # which should be tested separately + def test_convert_mcp_schema_to_parameter_preserves_anyof_object_type(self): + """Nullable object schemas should keep the object parameter type.""" + schema = { + "type": "object", + "properties": { + "retrieval_model": { + "anyOf": [{"type": "object"}, {"type": "null"}], + "description": "检索模型配置", + } + }, + } + + result = ToolTransformService.convert_mcp_schema_to_parameter(schema) + + assert len(result) == 1 + assert result[0].name == "retrieval_model" + assert result[0].type == ToolParameter.ToolParameterType.OBJECT + assert result[0].input_schema == schema["properties"]["retrieval_model"] + + def test_convert_mcp_schema_to_parameter_preserves_oneof_object_type(self): + """Nullable oneOf object schemas should keep the object parameter type.""" + schema = { + "type": "object", + "properties": { + "retrieval_model": { + "oneOf": [{"type": "object"}, {"type": "null"}], + "description": "检索模型配置", + } + }, + } + + result = ToolTransformService.convert_mcp_schema_to_parameter(schema) + + assert len(result) == 1 + assert result[0].name == "retrieval_model" + assert result[0].type == ToolParameter.ToolParameterType.OBJECT + assert result[0].input_schema == schema["properties"]["retrieval_model"] + + def test_convert_mcp_schema_to_parameter_handles_null_type(self): + """Schemas with only a null type should fall back to string.""" + schema = { + "type": "object", + "properties": { + "null_prop_str": {"type": "null"}, + "null_prop_list": {"type": ["null"]}, + }, + } + + result = ToolTransformService.convert_mcp_schema_to_parameter(schema) + + assert len(result) == 2 + param_map = {parameter.name: parameter for parameter in result} + assert "null_prop_str" in param_map + assert param_map["null_prop_str"].type == ToolParameter.ToolParameterType.STRING + assert "null_prop_list" in param_map + assert param_map["null_prop_list"].type == ToolParameter.ToolParameterType.STRING + + def test_convert_mcp_schema_to_parameter_preserves_allof_object_type_with_multiple_object_items(self): + """Property-level allOf with multiple object items should still resolve to object.""" + schema = { + "type": "object", + "properties": { + "config": { + "allOf": [ + { + "type": "object", + "properties": { + "enabled": {"type": "boolean"}, + }, + "required": ["enabled"], + }, + { + "type": "object", + "properties": { + "priority": {"type": "integer", "minimum": 1, "maximum": 10}, + }, + "required": ["priority"], + }, + ], + "description": "Config must match all schemas (allOf)", + } + }, + } + + result = ToolTransformService.convert_mcp_schema_to_parameter(schema) + + assert len(result) == 1 + assert result[0].name == "config" + assert result[0].type == ToolParameter.ToolParameterType.OBJECT + assert result[0].input_schema == schema["properties"]["config"] + + def test_convert_mcp_schema_to_parameter_preserves_allof_object_type(self): + """Composed property schemas should keep the object parameter type.""" + schema = { + "type": "object", + "properties": { + "retrieval_model": { + "allOf": [ + {"type": "object"}, + {"properties": {"top_k": {"type": "integer"}}}, + ], + "description": "检索模型配置", + } + }, + } + + result = ToolTransformService.convert_mcp_schema_to_parameter(schema) + + assert len(result) == 1 + assert result[0].name == "retrieval_model" + assert result[0].type == ToolParameter.ToolParameterType.OBJECT + assert result[0].input_schema == schema["properties"]["retrieval_model"] + + def test_convert_mcp_schema_to_parameter_limits_recursive_schema_depth(self): + """Self-referential composed schemas should stop resolving after the configured max depth.""" + recursive_property: dict[str, object] = {"description": "Recursive schema"} + recursive_property["anyOf"] = [recursive_property] + schema = { + "type": "object", + "properties": { + "recursive_config": recursive_property, + }, + } + + result = ToolTransformService.convert_mcp_schema_to_parameter(schema) + + assert len(result) == 1 + assert result[0].name == "recursive_config" + assert result[0].type == ToolParameter.ToolParameterType.STRING + assert result[0].input_schema is None + def test_mcp_provider_to_user_provider_for_list(self, mock_provider_full): """Test mcp_provider_to_user_provider with for_list=True.""" # Set tools data with null description