Merge remote-tracking branch 'upstream/release/e-1.12.1' into deploy/enterprise

This commit is contained in:
yunlu.wen 2026-03-12 13:02:44 +08:00
commit 42226fdd3b
9 changed files with 327 additions and 112 deletions

View File

@ -11,9 +11,35 @@ from controllers.console.error import UnauthorizedAndForceLogout
from core.logging.context import init_request_context
from dify_app import DifyApp
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import LicenseStatus
logger = logging.getLogger(__name__)
# Console bootstrap APIs exempt from license check.
# Defined at module level to avoid per-request tuple construction.
# - system-features: license status for expiry UI (GlobalPublicStoreProvider)
# - setup: install/setup status check (AppInitializer)
# - init: init password validation for fresh install (InitPasswordPopup)
# - login: auto-login after setup completion (InstallForm)
# - features: billing/plan features (ProviderContextProvider)
# - account/profile: login check + user profile (AppContextProvider, useIsLogin)
# - workspaces/current: workspace + model providers (AppContextProvider)
# - version: version check (AppContextProvider)
# - activate/check: invitation link validation (signin page)
# Without these exemptions, the signin page triggers location.reload()
# on unauthorized_and_force_logout, causing an infinite loop.
_CONSOLE_EXEMPT_PREFIXES = (
"/console/api/system-features",
"/console/api/setup",
"/console/api/init",
"/console/api/login",
"/console/api/features",
"/console/api/account/profile",
"/console/api/workspaces/current",
"/console/api/version",
"/console/api/activate/check",
)
# ----------------------------
# Application Factory Function
@ -38,18 +64,12 @@ def create_flask_app_with_configs() -> DifyApp:
# When license expires, block all API access except bootstrap endpoints needed
# for the frontend to load the license expiration page without infinite reloads.
if dify_config.ENTERPRISE_ENABLED:
is_console_api = request.path.startswith("/console/api")
is_webapp_api = request.path.startswith("/api") and not is_console_api
is_console_api = request.path.startswith("/console/api/")
is_webapp_api = request.path.startswith("/api/") and not is_console_api
if is_console_api or is_webapp_api:
if is_console_api:
console_exempt_prefixes = (
"/console/api/system-features",
"/console/api/setup",
"/console/api/version",
"/console/api/activate/check",
)
is_exempt = any(request.path.startswith(p) for p in console_exempt_prefixes)
is_exempt = any(request.path.startswith(p) for p in _CONSOLE_EXEMPT_PREFIXES)
else: # webapp API
is_exempt = request.path.startswith("/api/system-features")
@ -57,10 +77,13 @@ def create_flask_app_with_configs() -> DifyApp:
try:
# Check license status with caching (10 min TTL)
license_status = EnterpriseService.get_cached_license_status()
if license_status in ["inactive", "expired", "lost"]:
if license_status in (LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST):
# Cookie clearing is handled by register_external_error_handlers
# in libs/external_api.py which detects the error code and calls
# build_force_logout_cookie_headers(). Frontend then checks
# code === 'unauthorized_and_force_logout' and calls location.reload().
raise UnauthorizedAndForceLogout(
f"Enterprise license is {license_status}. "
"Please contact your administrator."
f"Enterprise license is {license_status}. Please contact your administrator."
)
except UnauthorizedAndForceLogout:
raise

View File

@ -19,6 +19,10 @@ class EnterpriseFeatureConfig(BaseSettings):
default=False,
)
ENTERPRISE_REQUEST_TIMEOUT: int = Field(
ge=1, description="Maximum timeout in seconds for enterprise requests", default=5
)
class EnterpriseTelemetryConfig(BaseSettings):
"""

View File

@ -101,10 +101,7 @@ class BaseRequest:
# {"message": "..."}
# {"detail": "..."}
error_message = (
error_data.get("message")
or error_data.get("error")
or error_data.get("detail")
or error_message
error_data.get("message") or error_data.get("error") or error_data.get("detail") or error_message
)
except Exception:
# If JSON parsing fails, use the default message

View File

@ -1,6 +1,9 @@
from __future__ import annotations
import logging
import uuid
from datetime import datetime
from typing import TYPE_CHECKING
from pydantic import BaseModel, ConfigDict, Field, model_validator
@ -8,6 +11,9 @@ from configs import dify_config
from extensions.ext_redis import redis_client
from services.enterprise.base import EnterpriseRequest
if TYPE_CHECKING:
from services.feature_service import LicenseStatus
logger = logging.getLogger(__name__)
DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS = 1.0
@ -57,7 +63,7 @@ class DefaultWorkspaceJoinResult(BaseModel):
model_config = ConfigDict(extra="forbid", populate_by_name=True)
@model_validator(mode="after")
def _check_workspace_id_when_joined(self) -> "DefaultWorkspaceJoinResult":
def _check_workspace_id_when_joined(self) -> DefaultWorkspaceJoinResult:
if self.joined and not self.workspace_id:
raise ValueError("workspace_id must be non-empty when joined is True")
return self
@ -230,43 +236,60 @@ class EnterpriseService:
EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params)
@classmethod
def get_cached_license_status(cls):
"""
Get enterprise license status with Redis caching to reduce HTTP calls.
def get_cached_license_status(cls) -> LicenseStatus | None:
"""Get enterprise license status with Redis caching to reduce HTTP calls.
Only caches valid statuses (active/expiring) since invalid statuses
should be re-checked every request the admin may update the license
at any time.
Caches valid statuses (active/expiring) for 10 minutes. Invalid statuses
are not cached so license updates are picked up on the next request.
Returns license status string or None if unavailable.
Returns:
LicenseStatus enum value, or None if enterprise is disabled / unreachable.
"""
if not dify_config.ENTERPRISE_ENABLED:
return None
# Try cache first — only valid statuses are cached
try:
cached_status = redis_client.get(LICENSE_STATUS_CACHE_KEY)
if cached_status:
if isinstance(cached_status, bytes):
cached_status = cached_status.decode("utf-8")
return cached_status
except Exception:
logger.debug("Failed to get license status from cache, calling enterprise API")
cached = cls._read_cached_license_status()
if cached is not None:
return cached
return cls._fetch_and_cache_license_status()
@classmethod
def _read_cached_license_status(cls) -> LicenseStatus | None:
"""Read license status from Redis cache, returning None on miss or failure."""
from services.feature_service import LicenseStatus
try:
raw = redis_client.get(LICENSE_STATUS_CACHE_KEY)
if raw:
value = raw.decode("utf-8") if isinstance(raw, bytes) else raw
return LicenseStatus(value)
except Exception:
logger.debug("Failed to read license status from cache", exc_info=True)
return None
@classmethod
def _fetch_and_cache_license_status(cls) -> LicenseStatus | None:
"""Fetch license status from enterprise API and cache the result.
Only caches valid statuses (active/expiring) so license updates
for invalid statuses are picked up on the next request.
"""
from services.feature_service import LicenseStatus
# Cache miss or failure — call enterprise API
try:
info = cls.get_info()
license_info = info.get("License")
if license_info:
status = license_info.get("status", "inactive")
# Only cache valid statuses so license updates are picked up immediately
if status in ("active", "expiring"):
try:
redis_client.setex(LICENSE_STATUS_CACHE_KEY, LICENSE_STATUS_CACHE_TTL, status)
except Exception:
logger.debug("Failed to cache license status")
return status
except Exception:
logger.exception("Failed to get enterprise license status")
if not license_info:
return None
return None
status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
if status in (LicenseStatus.ACTIVE, LicenseStatus.EXPIRING):
try:
redis_client.setex(LICENSE_STATUS_CACHE_KEY, LICENSE_STATUS_CACHE_TTL, status)
except Exception:
logger.debug("Failed to cache license status", exc_info=True)
return status
except Exception:
logger.debug("Failed to fetch enterprise license status", exc_info=True)
return None

View File

@ -70,12 +70,11 @@ class PluginManagerService:
"POST",
"/pre-uninstall-plugin",
json=body.model_dump(),
raise_for_status=True,
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
)
except Exception:
logger.exception(
"failed to perform pre uninstall plugin hook. tenant_id: %s, plugin_unique_identifier: %s, ",
"failed to perform pre uninstall plugin hook. tenant_id: %s, plugin_unique_identifier: %s",
body.tenant_id,
body.plugin_unique_identifier,
)

View File

@ -30,10 +30,7 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.provider import ProviderCredential
from models.provider_ids import GenericProviderID
from services.enterprise.plugin_manager_service import (
PluginManagerService,
PreUninstallPluginRequest,
)
from services.enterprise.plugin_manager_service import PluginManagerService, PreUninstallPluginRequest
from services.errors.plugin import PluginInstallationForbiddenError
from services.feature_service import FeatureService, PluginInstallationScope

View File

@ -33,6 +33,8 @@ logger = logging.getLogger(__name__)
class ToolTransformService:
_MCP_SCHEMA_TYPE_RESOLUTION_MAX_DEPTH = 10
@classmethod
def get_tool_provider_icon_url(
cls, provider_type: str, provider_name: str, icon: str | Mapping[str, str]
@ -435,6 +437,46 @@ class ToolTransformService:
:return: list of ToolParameter instances
"""
def resolve_property_type(prop: dict[str, Any], depth: int = 0) -> str:
"""
Resolve a JSON schema property type while guarding against cyclic or deeply nested unions.
"""
if depth >= ToolTransformService._MCP_SCHEMA_TYPE_RESOLUTION_MAX_DEPTH:
return "string"
prop_type = prop.get("type")
if isinstance(prop_type, list):
non_null_types = [type_name for type_name in prop_type if type_name != "null"]
if non_null_types:
return non_null_types[0]
if prop_type:
return "string"
elif isinstance(prop_type, str):
if prop_type == "null":
return "string"
return prop_type
for union_key in ("anyOf", "oneOf"):
union_schemas = prop.get(union_key)
if not isinstance(union_schemas, list):
continue
for union_schema in union_schemas:
if not isinstance(union_schema, dict):
continue
union_type = resolve_property_type(union_schema, depth + 1)
if union_type != "null":
return union_type
all_of_schemas = prop.get("allOf")
if isinstance(all_of_schemas, list):
for all_of_schema in all_of_schemas:
if not isinstance(all_of_schema, dict):
continue
all_of_type = resolve_property_type(all_of_schema, depth + 1)
if all_of_type != "null":
return all_of_type
return "string"
def create_parameter(
name: str, description: str, param_type: str, required: bool, input_schema: dict[str, Any] | None = None
) -> ToolParameter:
@ -461,10 +503,7 @@ class ToolTransformService:
parameters = []
for name, prop in props.items():
current_description = prop.get("description", "")
prop_type = prop.get("type", "string")
if isinstance(prop_type, list):
prop_type = prop_type[0]
prop_type = resolve_property_type(prop)
if prop_type in TYPE_MAPPING:
prop_type = TYPE_MAPPING[prop_type]
input_schema = prop if prop_type in COMPLEX_TYPES else None

View File

@ -9,85 +9,87 @@ from unittest.mock import patch
from httpx import HTTPStatusError
from configs import dify_config
from services.enterprise.plugin_manager_service import (
PluginManagerService,
PreUninstallPluginRequest,
)
_FAKE_TIMEOUT = 30
_SEND_REQUEST_PATH = (
"services.enterprise.plugin_manager_service.EnterprisePluginManagerRequest.send_request"
)
_DIFY_CONFIG_PATH = "services.enterprise.plugin_manager_service.dify_config"
_LOGGER_PATH = "services.enterprise.plugin_manager_service.logger"
class TestTryPreUninstallPlugin:
def test_try_pre_uninstall_plugin_success(self):
@patch(_DIFY_CONFIG_PATH)
@patch(_SEND_REQUEST_PATH)
def test_try_pre_uninstall_plugin_success(self, mock_send_request, mock_config):
body = PreUninstallPluginRequest(
tenant_id="tenant-123",
plugin_unique_identifier="com.example.my_plugin",
)
mock_config.ENTERPRISE_REQUEST_TIMEOUT = _FAKE_TIMEOUT
mock_send_request.return_value = {}
with patch(
"services.enterprise.plugin_manager_service.EnterprisePluginManagerRequest.send_request"
) as mock_send_request:
mock_send_request.return_value = {}
PluginManagerService.try_pre_uninstall_plugin(body)
PluginManagerService.try_pre_uninstall_plugin(body)
mock_send_request.assert_called_once_with(
"POST",
"/pre-uninstall-plugin",
json={"tenant_id": "tenant-123", "plugin_unique_identifier": "com.example.my_plugin"},
timeout=_FAKE_TIMEOUT,
)
mock_send_request.assert_called_once_with(
"POST",
"/pre-uninstall-plugin",
json={"tenant_id": "tenant-123", "plugin_unique_identifier": "com.example.my_plugin"},
raise_for_status=True,
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
)
def test_try_pre_uninstall_plugin_http_error_soft_fails(self):
@patch(_DIFY_CONFIG_PATH)
@patch(_LOGGER_PATH)
@patch(_SEND_REQUEST_PATH)
def test_try_pre_uninstall_plugin_http_error_soft_fails(
self, mock_send_request, mock_logger, mock_config
):
body = PreUninstallPluginRequest(
tenant_id="tenant-456",
plugin_unique_identifier="com.example.other_plugin",
)
mock_config.ENTERPRISE_REQUEST_TIMEOUT = _FAKE_TIMEOUT
mock_send_request.side_effect = HTTPStatusError(
"502 Bad Gateway",
request=None,
response=None,
)
with (
patch(
"services.enterprise.plugin_manager_service.EnterprisePluginManagerRequest.send_request"
) as mock_send_request,
patch("services.enterprise.plugin_manager_service.logger") as mock_logger,
):
mock_send_request.side_effect = HTTPStatusError(
"502 Bad Gateway",
request=None,
response=None,
)
PluginManagerService.try_pre_uninstall_plugin(body)
PluginManagerService.try_pre_uninstall_plugin(body)
mock_send_request.assert_called_once_with(
"POST",
"/pre-uninstall-plugin",
json={"tenant_id": "tenant-456", "plugin_unique_identifier": "com.example.other_plugin"},
timeout=_FAKE_TIMEOUT,
)
mock_logger.exception.assert_called_once()
mock_send_request.assert_called_once_with(
"POST",
"/pre-uninstall-plugin",
json={"tenant_id": "tenant-456", "plugin_unique_identifier": "com.example.other_plugin"},
raise_for_status=True,
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
)
mock_logger.exception.assert_called_once()
def test_try_pre_uninstall_plugin_generic_exception_soft_fails(self):
@patch(_DIFY_CONFIG_PATH)
@patch(_LOGGER_PATH)
@patch(_SEND_REQUEST_PATH)
def test_try_pre_uninstall_plugin_generic_exception_soft_fails(
self, mock_send_request, mock_logger, mock_config
):
body = PreUninstallPluginRequest(
tenant_id="tenant-789",
plugin_unique_identifier="com.example.failing_plugin",
)
mock_config.ENTERPRISE_REQUEST_TIMEOUT = _FAKE_TIMEOUT
mock_send_request.side_effect = ConnectionError("network unreachable")
with (
patch(
"services.enterprise.plugin_manager_service.EnterprisePluginManagerRequest.send_request"
) as mock_send_request,
patch("services.enterprise.plugin_manager_service.logger") as mock_logger,
):
mock_send_request.side_effect = ConnectionError("network unreachable")
PluginManagerService.try_pre_uninstall_plugin(body)
PluginManagerService.try_pre_uninstall_plugin(body)
mock_send_request.assert_called_once_with(
"POST",
"/pre-uninstall-plugin",
json={"tenant_id": "tenant-789", "plugin_unique_identifier": "com.example.failing_plugin"},
raise_for_status=True,
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
)
mock_logger.exception.assert_called_once()
mock_send_request.assert_called_once_with(
"POST",
"/pre-uninstall-plugin",
json={"tenant_id": "tenant-789", "plugin_unique_identifier": "com.example.failing_plugin"},
timeout=_FAKE_TIMEOUT,
)
mock_logger.exception.assert_called_once()

View File

@ -7,7 +7,7 @@ import pytest
from core.mcp.types import Tool as MCPTool
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.entities.tool_entities import ToolParameter, ToolProviderType
from models.tools import MCPToolProvider
from services.tools.tools_transform_service import ToolTransformService
@ -175,6 +175,137 @@ class TestMCPToolTransform:
# The actual parameter conversion is handled by convert_mcp_schema_to_parameter
# which should be tested separately
def test_convert_mcp_schema_to_parameter_preserves_anyof_object_type(self):
"""Nullable object schemas should keep the object parameter type."""
schema = {
"type": "object",
"properties": {
"retrieval_model": {
"anyOf": [{"type": "object"}, {"type": "null"}],
"description": "检索模型配置",
}
},
}
result = ToolTransformService.convert_mcp_schema_to_parameter(schema)
assert len(result) == 1
assert result[0].name == "retrieval_model"
assert result[0].type == ToolParameter.ToolParameterType.OBJECT
assert result[0].input_schema == schema["properties"]["retrieval_model"]
def test_convert_mcp_schema_to_parameter_preserves_oneof_object_type(self):
"""Nullable oneOf object schemas should keep the object parameter type."""
schema = {
"type": "object",
"properties": {
"retrieval_model": {
"oneOf": [{"type": "object"}, {"type": "null"}],
"description": "检索模型配置",
}
},
}
result = ToolTransformService.convert_mcp_schema_to_parameter(schema)
assert len(result) == 1
assert result[0].name == "retrieval_model"
assert result[0].type == ToolParameter.ToolParameterType.OBJECT
assert result[0].input_schema == schema["properties"]["retrieval_model"]
def test_convert_mcp_schema_to_parameter_handles_null_type(self):
"""Schemas with only a null type should fall back to string."""
schema = {
"type": "object",
"properties": {
"null_prop_str": {"type": "null"},
"null_prop_list": {"type": ["null"]},
},
}
result = ToolTransformService.convert_mcp_schema_to_parameter(schema)
assert len(result) == 2
param_map = {parameter.name: parameter for parameter in result}
assert "null_prop_str" in param_map
assert param_map["null_prop_str"].type == ToolParameter.ToolParameterType.STRING
assert "null_prop_list" in param_map
assert param_map["null_prop_list"].type == ToolParameter.ToolParameterType.STRING
def test_convert_mcp_schema_to_parameter_preserves_allof_object_type_with_multiple_object_items(self):
"""Property-level allOf with multiple object items should still resolve to object."""
schema = {
"type": "object",
"properties": {
"config": {
"allOf": [
{
"type": "object",
"properties": {
"enabled": {"type": "boolean"},
},
"required": ["enabled"],
},
{
"type": "object",
"properties": {
"priority": {"type": "integer", "minimum": 1, "maximum": 10},
},
"required": ["priority"],
},
],
"description": "Config must match all schemas (allOf)",
}
},
}
result = ToolTransformService.convert_mcp_schema_to_parameter(schema)
assert len(result) == 1
assert result[0].name == "config"
assert result[0].type == ToolParameter.ToolParameterType.OBJECT
assert result[0].input_schema == schema["properties"]["config"]
def test_convert_mcp_schema_to_parameter_preserves_allof_object_type(self):
"""Composed property schemas should keep the object parameter type."""
schema = {
"type": "object",
"properties": {
"retrieval_model": {
"allOf": [
{"type": "object"},
{"properties": {"top_k": {"type": "integer"}}},
],
"description": "检索模型配置",
}
},
}
result = ToolTransformService.convert_mcp_schema_to_parameter(schema)
assert len(result) == 1
assert result[0].name == "retrieval_model"
assert result[0].type == ToolParameter.ToolParameterType.OBJECT
assert result[0].input_schema == schema["properties"]["retrieval_model"]
def test_convert_mcp_schema_to_parameter_limits_recursive_schema_depth(self):
"""Self-referential composed schemas should stop resolving after the configured max depth."""
recursive_property: dict[str, object] = {"description": "Recursive schema"}
recursive_property["anyOf"] = [recursive_property]
schema = {
"type": "object",
"properties": {
"recursive_config": recursive_property,
},
}
result = ToolTransformService.convert_mcp_schema_to_parameter(schema)
assert len(result) == 1
assert result[0].name == "recursive_config"
assert result[0].type == ToolParameter.ToolParameterType.STRING
assert result[0].input_schema is None
def test_mcp_provider_to_user_provider_for_list(self, mock_provider_full):
"""Test mcp_provider_to_user_provider with for_list=True."""
# Set tools data with null description