dify/api/app_factory.py
GareArc 0e9dc86f3b
fix: use UnauthorizedAndForceLogout to trigger frontend logout on license expiry
Change license check to raise UnauthorizedAndForceLogout exception instead
of returning generic JSON response. This ensures proper frontend handling:

Frontend behavior (service/base.ts line 588):
- Checks if code === 'unauthorized_and_force_logout'
- Executes globalThis.location.reload()
- Forces user logout and redirect to login page
- Login page displays license expiration UI (already exists)

Response format:
- HTTP 401 (not 403)
- code: "unauthorized_and_force_logout"
- Triggers frontend reload which clears auth state

This completes the license enforcement flow:
1. Backend blocks all business APIs when license expires
2. Backend returns proper error code to trigger logout
3. Frontend reloads and redirects to login
4. Login page shows license expiration message
2026-03-04 20:30:53 -08:00

208 lines
7.0 KiB
Python

import logging
import time
from flask import request
from opentelemetry.trace import get_current_span
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
from configs import dify_config
from contexts.wrapper import RecyclableContextVar
from controllers.console.error import UnauthorizedAndForceLogout
from core.logging.context import init_request_context
from dify_app import DifyApp
from services.feature_service import FeatureService, LicenseStatus
logger = logging.getLogger(__name__)
# ----------------------------
# Application Factory Function
# ----------------------------
def create_flask_app_with_configs() -> DifyApp:
"""
create a raw flask app
with configs loaded from .env file
"""
dify_app = DifyApp(__name__)
dify_app.config.from_mapping(dify_config.model_dump())
dify_app.config["RESTX_INCLUDE_ALL_MODELS"] = True
# add before request hook
@dify_app.before_request
def before_request():
# Initialize logging context for this request
init_request_context()
RecyclableContextVar.increment_thread_recycles()
# Enterprise license validation for console API endpoints
if dify_config.ENTERPRISE_ENABLED and request.path.startswith("/console/api"):
# Skip license check for auth-related endpoints and system endpoints
exempt_paths = [
"/console/api/login",
"/console/api/logout",
"/console/api/oauth",
"/console/api/setup",
"/console/api/init",
"/console/api/forgot-password",
"/console/api/email-code-login",
"/console/api/activation",
"/console/api/data-source-oauth",
"/console/api/features", # Allow fetching features to show license status
]
# Check if current path is exempt
is_exempt = any(request.path.startswith(path) for path in exempt_paths)
if not is_exempt:
try:
# Check license status
system_features = FeatureService.get_system_features(is_authenticated=True)
if system_features.license.status in [
LicenseStatus.INACTIVE,
LicenseStatus.EXPIRED,
LicenseStatus.LOST,
]:
# Raise UnauthorizedAndForceLogout to trigger frontend reload and logout
# Frontend checks code === 'unauthorized_and_force_logout' and calls location.reload()
raise UnauthorizedAndForceLogout(
f"Enterprise license is {system_features.license.status.value}. "
"Please contact your administrator."
)
except UnauthorizedAndForceLogout:
# Re-raise to let Flask error handler convert to proper JSON response
raise
except Exception:
# If license check fails, log but don't block the request
# This prevents service disruption if enterprise API is temporarily unavailable
logger.exception("Failed to check enterprise license status")
# add after request hook for injecting trace headers from OpenTelemetry span context
# Only adds headers when OTEL is enabled and has valid context
@dify_app.after_request
def add_trace_headers(response):
try:
span = get_current_span()
ctx = span.get_span_context() if span else None
if not ctx or not ctx.is_valid:
return response
# Inject trace headers from OTEL context
if ctx.trace_id != INVALID_TRACE_ID and "X-Trace-Id" not in response.headers:
response.headers["X-Trace-Id"] = format(ctx.trace_id, "032x")
if ctx.span_id != INVALID_SPAN_ID and "X-Span-Id" not in response.headers:
response.headers["X-Span-Id"] = format(ctx.span_id, "016x")
except Exception:
# Never break the response due to tracing header injection
logger.warning("Failed to add trace headers to response", exc_info=True)
return response
# Capture the decorator's return value to avoid pyright reportUnusedFunction
_ = before_request
_ = add_trace_headers
return dify_app
def create_app() -> DifyApp:
start_time = time.perf_counter()
app = create_flask_app_with_configs()
initialize_extensions(app)
end_time = time.perf_counter()
if dify_config.DEBUG:
logger.info("Finished create_app (%s ms)", round((end_time - start_time) * 1000, 2))
return app
def initialize_extensions(app: DifyApp):
# Initialize Flask context capture for workflow execution
from context.flask_app_context import init_flask_context
from extensions import (
ext_app_metrics,
ext_blueprints,
ext_celery,
ext_code_based_extension,
ext_commands,
ext_compress,
ext_database,
ext_enterprise_telemetry,
ext_fastopenapi,
ext_forward_refs,
ext_hosting_provider,
ext_import_modules,
ext_logging,
ext_login,
ext_logstore,
ext_mail,
ext_migrate,
ext_orjson,
ext_otel,
ext_proxy_fix,
ext_redis,
ext_request_logging,
ext_sentry,
ext_session_factory,
ext_set_secretkey,
ext_storage,
ext_timezone,
ext_warnings,
)
init_flask_context()
extensions = [
ext_timezone,
ext_logging,
ext_warnings,
ext_import_modules,
ext_orjson,
ext_forward_refs,
ext_set_secretkey,
ext_compress,
ext_code_based_extension,
ext_database,
ext_app_metrics,
ext_migrate,
ext_redis,
ext_storage,
ext_logstore, # Initialize logstore after storage, before celery
ext_celery,
ext_login,
ext_mail,
ext_hosting_provider,
ext_sentry,
ext_proxy_fix,
ext_blueprints,
ext_commands,
ext_fastopenapi,
ext_otel,
ext_enterprise_telemetry,
ext_request_logging,
ext_session_factory,
]
for ext in extensions:
short_name = ext.__name__.split(".")[-1]
is_enabled = ext.is_enabled() if hasattr(ext, "is_enabled") else True
if not is_enabled:
if dify_config.DEBUG:
logger.info("Skipped %s", short_name)
continue
start_time = time.perf_counter()
ext.init_app(app)
end_time = time.perf_counter()
if dify_config.DEBUG:
logger.info("Loaded %s (%s ms)", short_name, round((end_time - start_time) * 1000, 2))
def create_migrations_app() -> DifyApp:
app = create_flask_app_with_configs()
from extensions import ext_database, ext_migrate
# Initialize only required extensions
ext_database.init_app(app)
ext_migrate.init_app(app)
return app