dify/api/app_factory.py
GareArc 0ed39d81e9
feat: add global license check middleware to block API access on expiry
Add before_request middleware that validates enterprise license status
for all /console/api endpoints when ENTERPRISE_ENABLED is true.

Behavior:
- Checks license status before each console API request
- Returns 403 with clear error message when license is expired/inactive/lost
- Exempts auth endpoints (login, oauth, forgot-password, etc.)
- Exempts /console/api/features so frontend can fetch license status
- Gracefully handles errors to avoid service disruption

This ensures all business APIs are blocked when license expires,
addressing the issue where APIs remained callable after expiry.
2026-03-04 20:10:42 -08:00

206 lines
6.8 KiB
Python

import logging
import time
from flask import jsonify, request
from opentelemetry.trace import get_current_span
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
from configs import dify_config
from contexts.wrapper import RecyclableContextVar
from core.logging.context import init_request_context
from dify_app import DifyApp
from services.feature_service import FeatureService, LicenseStatus
logger = logging.getLogger(__name__)
# ----------------------------
# Application Factory Function
# ----------------------------
def create_flask_app_with_configs() -> DifyApp:
"""
create a raw flask app
with configs loaded from .env file
"""
dify_app = DifyApp(__name__)
dify_app.config.from_mapping(dify_config.model_dump())
dify_app.config["RESTX_INCLUDE_ALL_MODELS"] = True
# add before request hook
@dify_app.before_request
def before_request():
# Initialize logging context for this request
init_request_context()
RecyclableContextVar.increment_thread_recycles()
# Enterprise license validation for console API endpoints
if dify_config.ENTERPRISE_ENABLED and request.path.startswith("/console/api"):
# Skip license check for auth-related endpoints and system endpoints
exempt_paths = [
"/console/api/login",
"/console/api/logout",
"/console/api/oauth",
"/console/api/setup",
"/console/api/init",
"/console/api/forgot-password",
"/console/api/email-code-login",
"/console/api/activation",
"/console/api/data-source-oauth",
"/console/api/features", # Allow fetching features to show license status
]
# Check if current path is exempt
is_exempt = any(request.path.startswith(path) for path in exempt_paths)
if not is_exempt:
try:
# Check license status
system_features = FeatureService.get_system_features(is_authenticated=True)
if system_features.license.status in [
LicenseStatus.INACTIVE,
LicenseStatus.EXPIRED,
LicenseStatus.LOST,
]:
return jsonify({
"code": "license_expired",
"message": (
f"Enterprise license is {system_features.license.status.value}. "
"Please contact your administrator."
),
"status": system_features.license.status.value,
}), 403
except Exception:
# If license check fails, log but don't block the request
# This prevents service disruption if enterprise API is temporarily unavailable
logger.exception("Failed to check enterprise license status")
# add after request hook for injecting trace headers from OpenTelemetry span context
# Only adds headers when OTEL is enabled and has valid context
@dify_app.after_request
def add_trace_headers(response):
try:
span = get_current_span()
ctx = span.get_span_context() if span else None
if not ctx or not ctx.is_valid:
return response
# Inject trace headers from OTEL context
if ctx.trace_id != INVALID_TRACE_ID and "X-Trace-Id" not in response.headers:
response.headers["X-Trace-Id"] = format(ctx.trace_id, "032x")
if ctx.span_id != INVALID_SPAN_ID and "X-Span-Id" not in response.headers:
response.headers["X-Span-Id"] = format(ctx.span_id, "016x")
except Exception:
# Never break the response due to tracing header injection
logger.warning("Failed to add trace headers to response", exc_info=True)
return response
# Capture the decorator's return value to avoid pyright reportUnusedFunction
_ = before_request
_ = add_trace_headers
return dify_app
def create_app() -> DifyApp:
start_time = time.perf_counter()
app = create_flask_app_with_configs()
initialize_extensions(app)
end_time = time.perf_counter()
if dify_config.DEBUG:
logger.info("Finished create_app (%s ms)", round((end_time - start_time) * 1000, 2))
return app
def initialize_extensions(app: DifyApp):
# Initialize Flask context capture for workflow execution
from context.flask_app_context import init_flask_context
from extensions import (
ext_app_metrics,
ext_blueprints,
ext_celery,
ext_code_based_extension,
ext_commands,
ext_compress,
ext_database,
ext_enterprise_telemetry,
ext_fastopenapi,
ext_forward_refs,
ext_hosting_provider,
ext_import_modules,
ext_logging,
ext_login,
ext_logstore,
ext_mail,
ext_migrate,
ext_orjson,
ext_otel,
ext_proxy_fix,
ext_redis,
ext_request_logging,
ext_sentry,
ext_session_factory,
ext_set_secretkey,
ext_storage,
ext_timezone,
ext_warnings,
)
init_flask_context()
extensions = [
ext_timezone,
ext_logging,
ext_warnings,
ext_import_modules,
ext_orjson,
ext_forward_refs,
ext_set_secretkey,
ext_compress,
ext_code_based_extension,
ext_database,
ext_app_metrics,
ext_migrate,
ext_redis,
ext_storage,
ext_logstore, # Initialize logstore after storage, before celery
ext_celery,
ext_login,
ext_mail,
ext_hosting_provider,
ext_sentry,
ext_proxy_fix,
ext_blueprints,
ext_commands,
ext_fastopenapi,
ext_otel,
ext_enterprise_telemetry,
ext_request_logging,
ext_session_factory,
]
for ext in extensions:
short_name = ext.__name__.split(".")[-1]
is_enabled = ext.is_enabled() if hasattr(ext, "is_enabled") else True
if not is_enabled:
if dify_config.DEBUG:
logger.info("Skipped %s", short_name)
continue
start_time = time.perf_counter()
ext.init_app(app)
end_time = time.perf_counter()
if dify_config.DEBUG:
logger.info("Loaded %s (%s ms)", short_name, round((end_time - start_time) * 1000, 2))
def create_migrations_app() -> DifyApp:
app = create_flask_app_with_configs()
from extensions import ext_database, ext_migrate
# Initialize only required extensions
ext_database.init_app(app)
ext_migrate.init_app(app)
return app