mirror of
https://github.com/langgenius/dify.git
synced 2026-05-09 04:36:31 +08:00
Type and lint pass over the openapi controllers, auth pipeline, and
oauth bearer/device-flow plumbing. Down from 36 pyright errors and 16
ruff errors to 0/0; 93 openapi unit tests pass.
Logic fixes:
- libs/oauth_bearer.py: drop private-naming on the friend-API methods
consumed by _VariantResolver (cache_get / cache_set_positive /
cache_set_negative / hard_expire / session_factory). They were always
cross-class accessors — leading underscore was misleading. Add public
registry property on BearerAuthenticator. _hard_expire row_id widened
to UUID | str (matches the StringUUID column type).
- libs/oauth_bearer.py: type validate_bearer / bearer_feature_required
with ParamSpec / PEP-695 so wrapped routes preserve their signature.
- libs/rate_limit.py: same — typed rate_limit decorator.
- services/oauth_device_flow.py: mint_oauth_token / _upsert accept
Session | scoped_session (Flask-SQLAlchemy proxy). Guard row-is-None
after upsert.
- controllers/openapi/{chat,completion,workflow}_messages.py: tuple-vs-
Mapping shape narrowing on AppGenerateService.generate return —
production returns Mapping, tests mock as (body, status). Validate
through Pydantic Response model in both shapes.
- controllers/openapi/oauth_device.py: replace flask_restx.reqparse (banned)
with Pydantic Request/Query models — DeviceCodeRequest, DevicePollRequest,
DeviceLookupQuery, DeviceMutateRequest. Two PEP-695 generic helpers
(_validate_json / _validate_query) translate ValidationError to BadRequest.
- controllers/openapi/auth/strategies.py: Protocol param-name match
(subject_type), Optional narrowing on app/tenant/account_id/subject_email.
- controllers/openapi/auth/steps.py: subject_type-is-None guard before
mounter dispatch.
- core/app/apps/workflow/generate_task_pipeline.py + models/workflow.py:
add WorkflowAppLogCreatedFrom.OPENAPI + matching match-case branch.
Fixes match-exhaustiveness and possibly-unbound created_from.
- libs/device_flow_security.py: pyright ignore on flask after_request
hook (registered by the framework, pyright sees as unused).
- services/oauth_device_flow.py: rename Exceptions to *Error suffix
(StateNotFoundError / InvalidTransitionError / UserCodeExhaustedError);
same for libs/oauth_bearer.py (InvalidBearerError / TokenExpiredError).
Update all callers across openapi controllers.
- controllers/openapi/{oauth_device,oauth_device_sso}.py +
services/oauth_device_flow.py: switch logger.error in except blocks
to logger.exception (TRY400) — keeps the traceback for ops.
- configs/feature/__init__.py: OPENAPI_KNOWN_CLIENT_IDS computed_field
needs an @property alongside for pyright to see it as a value, not a
method. Matches the existing line-451 pattern.
Plus ruff format + import-sort across the openapi tree (pure formatting).
197 lines
5.7 KiB
Python
197 lines
5.7 KiB
Python
"""Device-flow security primitives: enterprise_only gate, approval-grant
|
|
cookie mint/verify/consume, and anti-framing headers.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import secrets
|
|
from collections.abc import Callable
|
|
from dataclasses import dataclass
|
|
from datetime import UTC, datetime, timedelta
|
|
from functools import wraps
|
|
|
|
from flask import Blueprint
|
|
from werkzeug.exceptions import NotFound
|
|
|
|
from libs import jws
|
|
from libs.token import is_secure
|
|
from services.feature_service import FeatureService, LicenseStatus
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# ============================================================================
|
|
# enterprise_only decorator
|
|
# ============================================================================
|
|
|
|
|
|
# Fail-closed: any non-EE-active status (default NONE on CE, plus INACTIVE / EXPIRED / LOST)
|
|
# is denied. Future LicenseStatus values default to denial unless explicitly admitted.
|
|
_EE_ENABLED_STATUSES = {LicenseStatus.ACTIVE, LicenseStatus.EXPIRING}
|
|
|
|
|
|
def enterprise_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
|
"""404 on CE, passthrough on EE. Apply before rate-limit so CE
|
|
responses don't consume the bucket.
|
|
"""
|
|
|
|
@wraps(view)
|
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
|
settings = FeatureService.get_system_features()
|
|
if settings.license.status not in _EE_ENABLED_STATUSES:
|
|
raise NotFound()
|
|
return view(*args, **kwargs)
|
|
|
|
return decorated
|
|
|
|
|
|
# ============================================================================
|
|
# approval_grant cookie
|
|
# ============================================================================
|
|
|
|
|
|
APPROVAL_GRANT_COOKIE_NAME = "device_approval_grant"
|
|
APPROVAL_GRANT_COOKIE_PATH = "/openapi/v1/oauth/device"
|
|
APPROVAL_GRANT_COOKIE_TTL_SECONDS = 300 # 5 min
|
|
NONCE_TTL_SECONDS = 600 # 2x cookie TTL — defeats clock-skew late replay
|
|
NONCE_KEY_FMT = "device_approval_grant_nonce:{nonce}"
|
|
SSO_ASSERTION_NONCE_KEY_FMT = "sso_assertion_nonce:{nonce}"
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class ApprovalGrantClaims:
|
|
subject_email: str
|
|
subject_issuer: str
|
|
user_code: str
|
|
nonce: str
|
|
csrf_token: str
|
|
expires_at: datetime
|
|
|
|
|
|
def mint_approval_grant(
|
|
*,
|
|
keyset: jws.KeySet,
|
|
iss: str,
|
|
subject_email: str,
|
|
subject_issuer: str,
|
|
user_code: str,
|
|
) -> tuple[str, ApprovalGrantClaims]:
|
|
"""Use ``approval_grant_cookie_kwargs`` to set the cookie — single
|
|
source of truth for Path/HttpOnly/Secure/SameSite.
|
|
"""
|
|
now = datetime.now(UTC)
|
|
exp = now + timedelta(seconds=APPROVAL_GRANT_COOKIE_TTL_SECONDS)
|
|
nonce = _random_opaque()
|
|
csrf_token = _random_opaque()
|
|
|
|
payload = {
|
|
"iss": iss,
|
|
"subject_email": subject_email,
|
|
"subject_issuer": subject_issuer,
|
|
"user_code": user_code,
|
|
"nonce": nonce,
|
|
"csrf_token": csrf_token,
|
|
}
|
|
token = jws.sign(keyset, payload, aud=jws.AUD_APPROVAL_GRANT, ttl_seconds=APPROVAL_GRANT_COOKIE_TTL_SECONDS)
|
|
|
|
return token, ApprovalGrantClaims(
|
|
subject_email=subject_email,
|
|
subject_issuer=subject_issuer,
|
|
user_code=user_code,
|
|
nonce=nonce,
|
|
csrf_token=csrf_token,
|
|
expires_at=exp,
|
|
)
|
|
|
|
|
|
def verify_approval_grant(keyset: jws.KeySet, token: str) -> ApprovalGrantClaims:
|
|
"""Sig + aud + exp only — nonce consumption is the caller's job."""
|
|
data = jws.verify(keyset, token, expected_aud=jws.AUD_APPROVAL_GRANT)
|
|
return ApprovalGrantClaims(
|
|
subject_email=data["subject_email"],
|
|
subject_issuer=data["subject_issuer"],
|
|
user_code=data["user_code"],
|
|
nonce=data["nonce"],
|
|
csrf_token=data["csrf_token"],
|
|
expires_at=datetime.fromtimestamp(data["exp"], tz=UTC),
|
|
)
|
|
|
|
|
|
def consume_approval_grant_nonce(redis_client, nonce: str) -> bool:
|
|
if not nonce:
|
|
return False
|
|
return bool(
|
|
redis_client.set(
|
|
NONCE_KEY_FMT.format(nonce=nonce),
|
|
"1",
|
|
nx=True,
|
|
ex=NONCE_TTL_SECONDS,
|
|
)
|
|
)
|
|
|
|
|
|
def consume_sso_assertion_nonce(redis_client, nonce: str) -> bool:
|
|
if not nonce:
|
|
return False
|
|
return bool(
|
|
redis_client.set(
|
|
SSO_ASSERTION_NONCE_KEY_FMT.format(nonce=nonce),
|
|
"1",
|
|
nx=True,
|
|
ex=NONCE_TTL_SECONDS,
|
|
)
|
|
)
|
|
|
|
|
|
def approval_grant_cookie_kwargs(value: str) -> dict:
|
|
"""``secure`` follows is_secure() so HTTP-only deployments don't
|
|
silently drop the cookie.
|
|
"""
|
|
return {
|
|
"key": APPROVAL_GRANT_COOKIE_NAME,
|
|
"value": value,
|
|
"max_age": APPROVAL_GRANT_COOKIE_TTL_SECONDS,
|
|
"path": APPROVAL_GRANT_COOKIE_PATH,
|
|
"secure": is_secure(),
|
|
"httponly": True,
|
|
"samesite": "Lax",
|
|
}
|
|
|
|
|
|
def approval_grant_cleared_cookie_kwargs() -> dict:
|
|
return {
|
|
"key": APPROVAL_GRANT_COOKIE_NAME,
|
|
"value": "",
|
|
"max_age": 0,
|
|
"path": APPROVAL_GRANT_COOKIE_PATH,
|
|
"secure": is_secure(),
|
|
"httponly": True,
|
|
"samesite": "Lax",
|
|
}
|
|
|
|
|
|
def _random_opaque() -> str:
|
|
return secrets.token_urlsafe(16)
|
|
|
|
|
|
# ============================================================================
|
|
# Anti-framing headers
|
|
# ============================================================================
|
|
|
|
|
|
_ANTI_FRAMING_HEADERS = {
|
|
"X-Frame-Options": "DENY",
|
|
"Content-Security-Policy": "frame-ancestors 'none'",
|
|
}
|
|
|
|
|
|
def attach_anti_framing(bp: Blueprint) -> None:
|
|
"""X-Frame-Options + CSP on every response from ``bp`` (CI invariant #4)."""
|
|
|
|
@bp.after_request
|
|
def _apply_headers(response): # pyright: ignore[reportUnusedFunction]
|
|
for name, value in _ANTI_FRAMING_HEADERS.items():
|
|
response.headers.setdefault(name, value)
|
|
return response
|