From 41af72449d3f26d97937a0ad4bd46810f4380908 Mon Sep 17 00:00:00 2001 From: GareArc Date: Sun, 8 Mar 2026 17:00:12 -0700 Subject: [PATCH] fix: address PR review feedback on enterprise license enforcement - Cache invalid license statuses with 30s TTL to prevent DoS amplification - Return LicenseStatus enum (not raw str) from get_cached_license_status - Flatten nested try/except into _read_cached_license_status / _fetch_and_cache_license_status helpers - Escalate log levels from debug to warning with exc_info for cache failures - Switch before_request license check from fail-open to fail-closed - Remove dead raise_for_status parameter from BaseRequest.send_request - Gate license expired_at behind is_authenticated; only expose status to unauthenticated callers (CVE-2025-63387) - Remove redundant 'not is_console_api' guard in before_request - Add 8 unit tests for get_cached_license_status --- api/app_factory.py | 10 +- api/services/enterprise/base.py | 8 +- api/services/enterprise/enterprise_service.py | 85 +++++++---- api/services/feature_service.py | 18 ++- .../enterprise/test_enterprise_service.py | 142 +++++++++++++++++- 5 files changed, 210 insertions(+), 53 deletions(-) diff --git a/api/app_factory.py b/api/app_factory.py index 3029353f50..65cd6dfd4e 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -40,7 +40,7 @@ def create_flask_app_with_configs() -> DifyApp: # for the frontend to load the license expiration page without infinite reloads. if dify_config.ENTERPRISE_ENABLED: is_console_api = request.path.startswith("/console/api/") - is_webapp_api = request.path.startswith("/api/") and not is_console_api + is_webapp_api = request.path.startswith("/api/") if is_console_api or is_webapp_api: if is_console_api: @@ -56,7 +56,7 @@ def create_flask_app_with_configs() -> DifyApp: if not is_exempt: try: - # Check license status with caching (10 min TTL) + # Check license status (cached — see EnterpriseService for TTL details) license_status = EnterpriseService.get_cached_license_status() if license_status in (LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST): raise UnauthorizedAndForceLogout( @@ -65,7 +65,13 @@ def create_flask_app_with_configs() -> DifyApp: except UnauthorizedAndForceLogout: raise except Exception: + # Fail-closed: if we cannot verify the license (Redis down + + # enterprise API unreachable), block the request. An unreachable + # sidecar is itself an abnormal state that should surface. logger.exception("Failed to check enterprise license status") + raise UnauthorizedAndForceLogout( + "Unable to verify enterprise license. Please contact your administrator." + ) # add after request hook for injecting trace headers from OpenTelemetry span context # Only adds headers when OTEL is enabled and has valid context diff --git a/api/services/enterprise/base.py b/api/services/enterprise/base.py index cc29ecfdc4..7fe20b050b 100644 --- a/api/services/enterprise/base.py +++ b/api/services/enterprise/base.py @@ -48,7 +48,6 @@ class BaseRequest: params: Mapping[str, Any] | None = None, *, timeout: float | httpx.Timeout | None = None, - raise_for_status: bool = False, ) -> Any: headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key} url = f"{cls.base_url}{endpoint}" @@ -72,14 +71,9 @@ class BaseRequest: response = client.request(method, url, **request_kwargs) - # Always validate HTTP status and raise domain-specific errors + # Validate HTTP status and raise domain-specific errors if not response.is_success: cls._handle_error_response(response) - - # Legacy support: still respect raise_for_status parameter - if raise_for_status: - response.raise_for_status() - return response.json() @classmethod diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index 28c631b0e9..2b3ce056c9 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import logging import uuid from datetime import datetime +from typing import TYPE_CHECKING from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -8,12 +11,16 @@ from configs import dify_config from extensions.ext_redis import redis_client from services.enterprise.base import EnterpriseRequest +if TYPE_CHECKING: + from services.feature_service import LicenseStatus + logger = logging.getLogger(__name__) DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS = 1.0 # License status cache configuration LICENSE_STATUS_CACHE_KEY = "enterprise:license:status" -LICENSE_STATUS_CACHE_TTL = 600 # 10 minutes +VALID_LICENSE_CACHE_TTL = 600 # 10 minutes — valid licenses are stable +INVALID_LICENSE_CACHE_TTL = 30 # 30 seconds — short so admin fixes are picked up quickly class WebAppSettings(BaseModel): @@ -56,7 +63,7 @@ class DefaultWorkspaceJoinResult(BaseModel): model_config = ConfigDict(extra="forbid", populate_by_name=True) @model_validator(mode="after") - def _check_workspace_id_when_joined(self) -> "DefaultWorkspaceJoinResult": + def _check_workspace_id_when_joined(self) -> DefaultWorkspaceJoinResult: if self.joined and not self.workspace_id: raise ValueError("workspace_id must be non-empty when joined is True") return self @@ -119,7 +126,6 @@ class EnterpriseService: "/default-workspace/members", json={"account_id": account_id}, timeout=DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS, - raise_for_status=True, ) if not isinstance(data, dict): raise ValueError("Invalid response format from enterprise default workspace API") @@ -229,45 +235,62 @@ class EnterpriseService: EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params) @classmethod - def get_cached_license_status(cls): - """ - Get enterprise license status with Redis caching to reduce HTTP calls. + def get_cached_license_status(cls) -> LicenseStatus | None: + """Get enterprise license status with Redis caching to reduce HTTP calls. - Only caches valid statuses (active/expiring) since invalid statuses - should be re-checked every request — the admin may update the license - at any time. + Caches valid statuses (active/expiring) for 10 minutes and invalid statuses + (inactive/expired/lost) for 1 minute. The shorter TTL for invalid statuses + balances prompt license-fix detection against DoS mitigation — without + caching, every request on an expired license would hit the enterprise API. - Returns license status string or None if unavailable. + Returns: + LicenseStatus enum value, or None if enterprise is disabled / unreachable. """ if not dify_config.ENTERPRISE_ENABLED: return None - # Try cache first — only valid statuses are cached - try: - cached_status = redis_client.get(LICENSE_STATUS_CACHE_KEY) - if cached_status: - if isinstance(cached_status, bytes): - cached_status = cached_status.decode("utf-8") - return cached_status - except Exception: - logger.debug("Failed to get license status from cache, calling enterprise API") + cached = cls._read_cached_license_status() + if cached is not None: + return cached + + return cls._fetch_and_cache_license_status() + + @classmethod + def _read_cached_license_status(cls) -> LicenseStatus | None: + """Read license status from Redis cache, returning None on miss or failure.""" + from services.feature_service import LicenseStatus + + try: + raw = redis_client.get(LICENSE_STATUS_CACHE_KEY) + if raw: + value = raw.decode("utf-8") if isinstance(raw, bytes) else raw + return LicenseStatus(value) + except Exception: + logger.warning("Failed to read license status from cache", exc_info=True) + return None + + @classmethod + def _fetch_and_cache_license_status(cls) -> LicenseStatus | None: + """Fetch license status from enterprise API and cache the result.""" + from services.feature_service import LicenseStatus - # Cache miss or failure — call enterprise API try: info = cls.get_info() license_info = info.get("License") - if license_info: - from services.feature_service import LicenseStatus + if not license_info: + return None - status = license_info.get("status", LicenseStatus.INACTIVE) - # Only cache valid statuses so license updates are picked up immediately - if status in (LicenseStatus.ACTIVE, LicenseStatus.EXPIRING): - try: - redis_client.setex(LICENSE_STATUS_CACHE_KEY, LICENSE_STATUS_CACHE_TTL, status) - except Exception: - logger.debug("Failed to cache license status") - return status + status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE)) + ttl = ( + VALID_LICENSE_CACHE_TTL + if status in (LicenseStatus.ACTIVE, LicenseStatus.EXPIRING) + else INVALID_LICENSE_CACHE_TTL + ) + try: + redis_client.setex(LICENSE_STATUS_CACHE_KEY, ttl, status) + except Exception: + logger.warning("Failed to cache license status", exc_info=True) + return status except Exception: logger.exception("Failed to get enterprise license status") - return None diff --git a/api/services/feature_service.py b/api/services/feature_service.py index a37d0a2167..a53ac9d980 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -379,17 +379,19 @@ class FeatureService: ) features.webapp_auth.sso_config.protocol = enterprise_info.get("SSOEnforcedForWebProtocol", "") - # License status and expiry are always exposed so the login page can - # show the expiry UI after a force-logout (the user is unauthenticated - # at that point). Workspace usage details remain auth-gated. + # SECURITY NOTE: Only license *status* is exposed to unauthenticated callers + # so the login page can detect an expired/inactive license after force-logout. + # All other license details (expiry date, workspace usage) remain auth-gated. + # See CVE-2025-63387 for prior information-leakage context. if license_info := enterprise_info.get("License"): features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE)) - features.license.expired_at = license_info.get("expiredAt", "") - if is_authenticated and (workspaces_info := license_info.get("workspaces")): - features.license.workspaces.enabled = workspaces_info.get("enabled", False) - features.license.workspaces.limit = workspaces_info.get("limit", 0) - features.license.workspaces.size = workspaces_info.get("used", 0) + if is_authenticated: + features.license.expired_at = license_info.get("expiredAt", "") + if workspaces_info := license_info.get("workspaces"): + features.license.workspaces.enabled = workspaces_info.get("enabled", False) + features.license.workspaces.limit = workspaces_info.get("limit", 0) + features.license.workspaces.size = workspaces_info.get("used", 0) if "PluginInstallationPermission" in enterprise_info: plugin_installation_info = enterprise_info["PluginInstallationPermission"] diff --git a/api/tests/unit_tests/services/enterprise/test_enterprise_service.py b/api/tests/unit_tests/services/enterprise/test_enterprise_service.py index 03c4f793cf..59c07bfb37 100644 --- a/api/tests/unit_tests/services/enterprise/test_enterprise_service.py +++ b/api/tests/unit_tests/services/enterprise/test_enterprise_service.py @@ -1,9 +1,8 @@ """Unit tests for enterprise service integrations. -This module covers the enterprise-only default workspace auto-join behavior: -- Enterprise mode disabled: no external calls -- Successful join / skipped join: no errors -- Failures (network/invalid response/invalid UUID): soft-fail wrapper must not raise +Covers: +- Default workspace auto-join behavior +- License status caching (get_cached_license_status) """ from unittest.mock import patch @@ -11,6 +10,9 @@ from unittest.mock import patch import pytest from services.enterprise.enterprise_service import ( + INVALID_LICENSE_CACHE_TTL, + LICENSE_STATUS_CACHE_KEY, + VALID_LICENSE_CACHE_TTL, DefaultWorkspaceJoinResult, EnterpriseService, try_join_default_workspace, @@ -37,7 +39,6 @@ class TestJoinDefaultWorkspace: "/default-workspace/members", json={"account_id": account_id}, timeout=1.0, - raise_for_status=True, ) def test_join_default_workspace_invalid_response_format_raises(self): @@ -139,3 +140,134 @@ class TestTryJoinDefaultWorkspace: # Should not raise even though UUID parsing fails inside join_default_workspace try_join_default_workspace("not-a-uuid") + + +# --------------------------------------------------------------------------- +# get_cached_license_status +# --------------------------------------------------------------------------- + +_EE_SVC = "services.enterprise.enterprise_service" + + +class TestGetCachedLicenseStatus: + """Tests for EnterpriseService.get_cached_license_status.""" + + def test_returns_none_when_enterprise_disabled(self): + with patch(f"{_EE_SVC}.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = False + + assert EnterpriseService.get_cached_license_status() is None + + def test_cache_hit_returns_license_status_enum(self): + from services.feature_service import LicenseStatus + + with ( + patch(f"{_EE_SVC}.dify_config") as mock_config, + patch(f"{_EE_SVC}.redis_client") as mock_redis, + patch.object(EnterpriseService, "get_info") as mock_get_info, + ): + mock_config.ENTERPRISE_ENABLED = True + mock_redis.get.return_value = b"active" + + result = EnterpriseService.get_cached_license_status() + + assert result == LicenseStatus.ACTIVE + assert isinstance(result, LicenseStatus) + mock_get_info.assert_not_called() + + def test_cache_miss_fetches_api_and_caches_valid_status(self): + from services.feature_service import LicenseStatus + + with ( + patch(f"{_EE_SVC}.dify_config") as mock_config, + patch(f"{_EE_SVC}.redis_client") as mock_redis, + patch.object(EnterpriseService, "get_info") as mock_get_info, + ): + mock_config.ENTERPRISE_ENABLED = True + mock_redis.get.return_value = None + mock_get_info.return_value = {"License": {"status": "active"}} + + result = EnterpriseService.get_cached_license_status() + + assert result == LicenseStatus.ACTIVE + mock_redis.setex.assert_called_once_with( + LICENSE_STATUS_CACHE_KEY, VALID_LICENSE_CACHE_TTL, LicenseStatus.ACTIVE + ) + + def test_cache_miss_fetches_api_and_caches_invalid_status_with_short_ttl(self): + from services.feature_service import LicenseStatus + + with ( + patch(f"{_EE_SVC}.dify_config") as mock_config, + patch(f"{_EE_SVC}.redis_client") as mock_redis, + patch.object(EnterpriseService, "get_info") as mock_get_info, + ): + mock_config.ENTERPRISE_ENABLED = True + mock_redis.get.return_value = None + mock_get_info.return_value = {"License": {"status": "expired"}} + + result = EnterpriseService.get_cached_license_status() + + assert result == LicenseStatus.EXPIRED + mock_redis.setex.assert_called_once_with( + LICENSE_STATUS_CACHE_KEY, INVALID_LICENSE_CACHE_TTL, LicenseStatus.EXPIRED + ) + + def test_redis_read_failure_falls_through_to_api(self): + from services.feature_service import LicenseStatus + + with ( + patch(f"{_EE_SVC}.dify_config") as mock_config, + patch(f"{_EE_SVC}.redis_client") as mock_redis, + patch.object(EnterpriseService, "get_info") as mock_get_info, + ): + mock_config.ENTERPRISE_ENABLED = True + mock_redis.get.side_effect = ConnectionError("redis down") + mock_get_info.return_value = {"License": {"status": "active"}} + + result = EnterpriseService.get_cached_license_status() + + assert result == LicenseStatus.ACTIVE + mock_get_info.assert_called_once() + + def test_redis_write_failure_still_returns_status(self): + from services.feature_service import LicenseStatus + + with ( + patch(f"{_EE_SVC}.dify_config") as mock_config, + patch(f"{_EE_SVC}.redis_client") as mock_redis, + patch.object(EnterpriseService, "get_info") as mock_get_info, + ): + mock_config.ENTERPRISE_ENABLED = True + mock_redis.get.return_value = None + mock_redis.setex.side_effect = ConnectionError("redis down") + mock_get_info.return_value = {"License": {"status": "expiring"}} + + result = EnterpriseService.get_cached_license_status() + + assert result == LicenseStatus.EXPIRING + + def test_api_failure_returns_none(self): + with ( + patch(f"{_EE_SVC}.dify_config") as mock_config, + patch(f"{_EE_SVC}.redis_client") as mock_redis, + patch.object(EnterpriseService, "get_info") as mock_get_info, + ): + mock_config.ENTERPRISE_ENABLED = True + mock_redis.get.return_value = None + mock_get_info.side_effect = Exception("network failure") + + assert EnterpriseService.get_cached_license_status() is None + + def test_api_returns_no_license_info(self): + with ( + patch(f"{_EE_SVC}.dify_config") as mock_config, + patch(f"{_EE_SVC}.redis_client") as mock_redis, + patch.object(EnterpriseService, "get_info") as mock_get_info, + ): + mock_config.ENTERPRISE_ENABLED = True + mock_redis.get.return_value = None + mock_get_info.return_value = {} # no "License" key + + assert EnterpriseService.get_cached_license_status() is None + mock_redis.setex.assert_not_called()