dify/api/libs/oauth_bearer.py

651 lines
22 KiB
Python

"""OAuth bearer primitives.
To add a token kind: write a Resolver, add a SubjectType + Accepts member,
append a TokenKind to build_registry, and update _SUBJECT_TO_ACCEPT.
Authenticator + validate_bearer stay untouched.
"""
from __future__ import annotations
import hashlib
import json
import logging
import uuid
from collections.abc import Callable, Iterable
from dataclasses import dataclass, field
from datetime import UTC, datetime
from enum import StrEnum
from functools import wraps
from typing import Literal, ParamSpec, Protocol, TypeVar
from flask import g, request
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, ServiceUnavailable, Unauthorized
from configs import dify_config
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.rate_limit import enforce_bearer_rate_limit
from models import Account, OAuthAccessToken, TenantAccountJoin
logger = logging.getLogger(__name__)
# ============================================================================
# Contract — types, enums, protocols
# ============================================================================
class SubjectType(StrEnum):
ACCOUNT = "account"
EXTERNAL_SSO = "external_sso"
class Scope(StrEnum):
"""Catalog of bearer scopes recognised by the openapi surface.
`FULL` is the catch-all carried by `dfoa_` account tokens — it satisfies
any per-route `require_scope`. `dfoe_` tokens carry the per-feature scopes
(`APPS_RUN`, `APPS_READ_PERMITTED_EXTERNAL`).
"""
FULL = "full"
APPS_READ = "apps:read"
APPS_READ_PERMITTED_EXTERNAL = "apps:read:permitted-external"
APPS_RUN = "apps:run"
class Accepts(StrEnum):
"""Subject types a route is willing to accept as caller."""
USER_ACCOUNT = "user_account"
USER_EXT_SSO = "user_ext_sso"
ACCEPT_USER_ANY: frozenset[Accepts] = frozenset({Accepts.USER_ACCOUNT, Accepts.USER_EXT_SSO})
ACCEPT_USER_EXT_SSO: frozenset[Accepts] = frozenset({Accepts.USER_EXT_SSO})
_SUBJECT_TO_ACCEPT: dict[SubjectType, Accepts] = {
SubjectType.ACCOUNT: Accepts.USER_ACCOUNT,
SubjectType.EXTERNAL_SSO: Accepts.USER_EXT_SSO,
}
@dataclass(frozen=True, slots=True)
class AuthContext:
"""Attached to ``g.auth_ctx``. ``scopes`` / ``subject_type`` / ``source``
come from the TokenKind, not the DB — corrupt rows can't elevate scope.
`verified_tenants` is a snapshot of the Layer-0 verdict cache at
authenticate time. Per-request mutations write through to Redis via
`record_layer0_verdict`; this snapshot is not updated in place (frozen).
"""
subject_type: SubjectType
subject_email: str | None
subject_issuer: str | None
account_id: uuid.UUID | None
client_id: str | None
scopes: frozenset[Scope]
token_id: uuid.UUID
source: str
expires_at: datetime | None
token_hash: str
verified_tenants: dict[str, bool] = field(default_factory=dict)
@dataclass(frozen=True, slots=True)
class ResolvedRow:
subject_email: str | None
subject_issuer: str | None
account_id: uuid.UUID | None
client_id: str | None
token_id: uuid.UUID
expires_at: datetime | None
verified_tenants: dict[str, bool] = field(default_factory=dict)
def to_cache(self) -> dict:
return {
"subject_email": self.subject_email,
"subject_issuer": self.subject_issuer,
"account_id": str(self.account_id) if self.account_id else None,
"client_id": self.client_id,
"token_id": str(self.token_id),
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
"verified_tenants": dict(self.verified_tenants),
}
@classmethod
def from_cache(cls, data: dict) -> ResolvedRow:
return cls(
subject_email=data["subject_email"],
subject_issuer=data["subject_issuer"],
account_id=uuid.UUID(data["account_id"]) if data["account_id"] else None,
client_id=data.get("client_id"),
token_id=uuid.UUID(data["token_id"]),
expires_at=datetime.fromisoformat(data["expires_at"]) if data["expires_at"] else None,
verified_tenants=_coerce_verified_tenants(data.get("verified_tenants")),
)
def _coerce_verified_tenants(raw: object) -> dict[str, bool]:
"""Tolerate legacy entries that stored 'ok'/'denied' string verdicts.
TODO(post-v1.0): remove once the AuthContext cache TTL has fully cycled
on all live deployments (60s TTL → safe to drop one release after rollout).
"""
if not isinstance(raw, dict):
return {}
out: dict[str, bool] = {}
for k, v in raw.items():
if isinstance(v, bool):
out[k] = v
elif v == "ok":
out[k] = True
elif v == "denied":
out[k] = False
return out
class Resolver(Protocol):
def resolve(self, token_hash: str) -> ResolvedRow | None: # pragma: no cover - contract
...
@dataclass(frozen=True, slots=True)
class TokenKind:
prefix: str
subject_type: SubjectType
scopes: frozenset[Scope]
source: str
resolver: Resolver
def matches(self, token: str) -> bool:
return token.startswith(self.prefix)
@dataclass(frozen=True, slots=True)
class MintProfile:
"""Single source of truth for (subject_type, prefix, scopes) at mint time.
Consumers:
- ``build_registry`` reads scopes here so the resolve-time TokenKind
cannot drift from the mint-time intent.
- Device-flow ``approve`` / ``approve-external`` read prefix + scopes
here when calling ``mint_oauth_token`` and ``validate_mint_policy``.
- ``services.openapi.mint_policy.validate_mint_policy`` cross-checks
the (subject_type, prefix, scopes) triple a caller intends to mint
against this table — a caller that assembles its own scope set
from a non-canonical source will fail closed at approve time.
"""
subject_type: SubjectType
prefix: str
scopes: frozenset[Scope]
MINTABLE_PROFILES: dict[SubjectType, MintProfile] = {
SubjectType.ACCOUNT: MintProfile(
subject_type=SubjectType.ACCOUNT,
prefix="dfoa_",
scopes=frozenset({Scope.FULL}),
),
SubjectType.EXTERNAL_SSO: MintProfile(
subject_type=SubjectType.EXTERNAL_SSO,
prefix="dfoe_",
scopes=frozenset({Scope.APPS_RUN, Scope.APPS_READ_PERMITTED_EXTERNAL}),
),
}
class InvalidBearerError(Exception):
"""Token missing, unknown prefix, or no live row."""
class TokenExpiredError(Exception):
"""Hard-expire bookkeeping is the resolver's job before raising."""
# ============================================================================
# Registry
# ============================================================================
class TokenKindRegistry:
def __init__(self, kinds: Iterable[TokenKind]) -> None:
self._kinds: tuple[TokenKind, ...] = tuple(kinds)
prefixes = [k.prefix for k in self._kinds]
if len(set(prefixes)) != len(prefixes):
raise ValueError(f"duplicate prefix in registry: {prefixes}")
def find(self, token: str) -> TokenKind | None:
for k in self._kinds:
if k.matches(token):
return k
return None
def kinds(self) -> tuple[TokenKind, ...]:
return self._kinds
# ============================================================================
# Authenticator
# ============================================================================
def sha256_hex(token: str) -> str:
return hashlib.sha256(token.encode("utf-8")).hexdigest()
class BearerAuthenticator:
def __init__(self, registry: TokenKindRegistry) -> None:
self._registry = registry
@property
def registry(self) -> TokenKindRegistry:
return self._registry
def authenticate(self, token: str) -> AuthContext:
"""Identity + per-token rate limit (single source).
Both the openapi pipeline (`BearerCheck`) and the decorator
(`validate_bearer`) call this — rate-limit fires exactly once per
request regardless of which path hosts the route.
"""
kind = self._registry.find(token)
if kind is None:
raise InvalidBearerError("unknown token prefix")
token_hash = sha256_hex(token)
row = kind.resolver.resolve(token_hash)
if row is None:
raise InvalidBearerError("token unknown or revoked")
enforce_bearer_rate_limit(token_hash)
return AuthContext(
subject_type=kind.subject_type,
subject_email=row.subject_email,
subject_issuer=row.subject_issuer,
account_id=row.account_id,
client_id=row.client_id,
scopes=kind.scopes,
token_id=row.token_id,
source=kind.source,
expires_at=row.expires_at,
token_hash=token_hash,
verified_tenants=dict(row.verified_tenants),
)
# ============================================================================
# OAuth access token resolver (PAT resolver would be a sibling class)
# ============================================================================
TOKEN_CACHE_KEY_FMT = "auth:token:{hash}"
POSITIVE_TTL_SECONDS = 60
NEGATIVE_TTL_SECONDS = 10
AUDIT_OAUTH_EXPIRED = "oauth.token_expired"
ScopeVariant = Literal["account", "external_sso"]
class OAuthAccessTokenResolver:
"""``.for_account()`` / ``.for_external_sso()`` are variant-scoped views
sharing DB + cache plumbing.
"""
def __init__(
self,
session_factory,
redis_client,
positive_ttl: int = POSITIVE_TTL_SECONDS,
negative_ttl: int = NEGATIVE_TTL_SECONDS,
) -> None:
self.session_factory = session_factory
self._redis = redis_client
self._positive_ttl = positive_ttl
self._negative_ttl = negative_ttl
def for_account(self) -> Resolver:
return _VariantResolver(self, variant="account")
def for_external_sso(self) -> Resolver:
return _VariantResolver(self, variant="external_sso")
def _cache_key(self, token_hash: str) -> str:
return TOKEN_CACHE_KEY_FMT.format(hash=token_hash)
def cache_get(self, token_hash: str) -> ResolvedRow | None | Literal["invalid"]:
raw = self._redis.get(self._cache_key(token_hash))
if raw is None:
return None
text = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
if text == "invalid":
return "invalid"
try:
return ResolvedRow.from_cache(json.loads(text))
except (ValueError, KeyError):
logger.warning("auth:token cache entry malformed; treating as miss")
return None
def cache_set_positive(self, token_hash: str, row: ResolvedRow) -> None:
self._redis.setex(
self._cache_key(token_hash),
self._positive_ttl,
json.dumps(row.to_cache()),
)
def cache_set_negative(self, token_hash: str) -> None:
self._redis.setex(self._cache_key(token_hash), self._negative_ttl, "invalid")
def hard_expire(self, session: Session, row_id: uuid.UUID | str, token_hash: str) -> None:
"""Atomic CAS — only the worker that flips revoked_at emits audit;
replays are idempotent.
"""
stmt = (
update(OAuthAccessToken)
.where(OAuthAccessToken.id == row_id, OAuthAccessToken.revoked_at.is_(None))
.values(revoked_at=datetime.now(UTC), token_hash=None)
)
result = session.execute(stmt)
session.commit()
if result.rowcount == 1:
logger.warning(
"audit: %s token_id=%s",
AUDIT_OAUTH_EXPIRED,
row_id,
extra={"audit": True, "token_id": str(row_id)},
)
self._redis.delete(self._cache_key(token_hash))
self.cache_set_negative(token_hash)
class _VariantResolver:
def __init__(self, parent: OAuthAccessTokenResolver, variant: ScopeVariant) -> None:
self._parent = parent
self._variant = variant
def resolve(self, token_hash: str) -> ResolvedRow | None:
cached = self._parent.cache_get(token_hash)
if cached == "invalid":
return None
if cached is not None and not isinstance(cached, str):
if not self._matches_variant(cached):
return None
return cached
# Flask-SQLAlchemy's scoped_session is request-bound and not a
# context manager; use it directly.
session = self._parent.session_factory()
row = self._load_from_db(session, token_hash)
if row is None:
self._parent.cache_set_negative(token_hash)
return None
now = datetime.now(UTC)
if row.expires_at is not None and row.expires_at <= now:
self._parent.hard_expire(session, row.id, token_hash)
return None
if not self._matches_variant_model(row):
logger.error(
"internal_state_invariant: account_id/prefix mismatch token_id=%s prefix=%s",
row.id,
row.prefix,
)
return None
resolved = ResolvedRow(
subject_email=row.subject_email,
subject_issuer=row.subject_issuer,
account_id=uuid.UUID(str(row.account_id)) if row.account_id else None,
client_id=row.client_id,
token_id=uuid.UUID(str(row.id)),
expires_at=row.expires_at,
)
self._parent.cache_set_positive(token_hash, resolved)
return resolved
def _matches_variant(self, row: ResolvedRow) -> bool:
has_account = row.account_id is not None
if self._variant == "account":
return has_account
return not has_account
def _matches_variant_model(self, row: OAuthAccessToken) -> bool:
has_account = row.account_id is not None
if self._variant == "account":
return has_account and row.prefix == "dfoa_"
return (not has_account) and row.prefix == "dfoe_"
def _load_from_db(self, session: Session, token_hash: str) -> OAuthAccessToken | None:
return (
session.query(OAuthAccessToken)
.filter(
OAuthAccessToken.token_hash == token_hash,
OAuthAccessToken.revoked_at.is_(None),
)
.one_or_none()
)
# ============================================================================
# Layer 0 — workspace membership cache + helper
# ============================================================================
def record_layer0_verdict(token_hash: str, tenant_id: str, verdict: bool) -> None:
"""Merge a Layer-0 membership verdict into the AuthContext cache entry at
`auth:token:{hash}`. No-op if entry missing/expired/invalid — next request
rebuilds via authenticate() and re-runs Layer 0.
"""
cache_key = TOKEN_CACHE_KEY_FMT.format(hash=token_hash)
raw = redis_client.get(cache_key)
if raw is None:
return
text = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
if text == "invalid":
return
try:
data = json.loads(text)
except (ValueError, KeyError):
return
ttl = redis_client.ttl(cache_key)
if ttl <= 0:
return
data.setdefault("verified_tenants", {})[tenant_id] = verdict
redis_client.setex(cache_key, ttl, json.dumps(data))
def check_workspace_membership(
*,
account_id: uuid.UUID | str,
tenant_id: str,
token_hash: str,
cached_verdicts: dict[str, bool],
) -> None:
"""Layer-0 enforcement core. Raises `Forbidden` on deny, returns on allow.
Shared by the pipeline step (`WorkspaceMembershipCheck`) and the
inline helper (`require_workspace_member`). Caller is responsible for
short-circuiting on EE / SSO subjects before invoking — this function
runs the membership + active-status checks unconditionally.
"""
cached = cached_verdicts.get(tenant_id)
if cached is True:
return
if cached is False:
raise Forbidden("workspace_membership_revoked")
join = db.session.execute(
select(TenantAccountJoin.id).where(
TenantAccountJoin.account_id == account_id,
TenantAccountJoin.tenant_id == tenant_id,
)
).scalar_one_or_none()
if join is None:
record_layer0_verdict(token_hash, tenant_id, False)
raise Forbidden("workspace_membership_revoked")
status = db.session.execute(select(Account.status).where(Account.id == account_id)).scalar_one_or_none()
if status != "active":
record_layer0_verdict(token_hash, tenant_id, False)
raise Forbidden("workspace_membership_revoked")
record_layer0_verdict(token_hash, tenant_id, True)
def require_workspace_member(ctx: AuthContext, tenant_id: str) -> None:
"""AuthContext-flavoured wrapper around `check_workspace_membership`.
No-op on EE (gateway RBAC owns tenant isolation) and for SSO subjects
(no `tenant_account_joins` row by definition).
"""
if dify_config.ENTERPRISE_ENABLED:
return
if ctx.subject_type != SubjectType.ACCOUNT or ctx.account_id is None:
return
check_workspace_membership(
account_id=ctx.account_id,
tenant_id=tenant_id,
token_hash=ctx.token_hash,
cached_verdicts=ctx.verified_tenants,
)
# ============================================================================
# Decorator — route-level bearer gate
# ============================================================================
_authenticator: BearerAuthenticator | None = None
def bind_authenticator(authenticator: BearerAuthenticator) -> None:
global _authenticator
_authenticator = authenticator
def get_authenticator() -> BearerAuthenticator:
if _authenticator is None:
raise RuntimeError("BearerAuthenticator not bound; call bind_authenticator at startup")
return _authenticator
def _extract_bearer(req) -> str | None:
header = req.headers.get("Authorization", "")
scheme, _, value = header.partition(" ")
if scheme.lower() != "bearer" or not value:
return None
return value.strip()
_DP = ParamSpec("_DP")
_DR = TypeVar("_DR")
def validate_bearer(*, accept: frozenset[Accepts]) -> Callable[[Callable[_DP, _DR]], Callable[_DP, _DR]]:
"""Opt-in: omitting it leaves the route unauthenticated.
Resolves user-level OAuth bearers (``dfoa_`` / ``dfoe_``). Legacy
``app-`` keys belong to ``service_api/wraps.py:validate_app_token``
and are rejected here as the wrong auth scheme for this surface.
"""
def wrap(fn: Callable[_DP, _DR]) -> Callable[_DP, _DR]:
@wraps(fn)
def inner(*args: _DP.args, **kwargs: _DP.kwargs) -> _DR:
token = _extract_bearer(request)
if token is None:
raise Unauthorized("missing bearer token")
if _authenticator is None:
raise ServiceUnavailable("bearer_auth_disabled: set ENABLE_OAUTH_BEARER=true to enable")
try:
ctx = get_authenticator().authenticate(token)
except InvalidBearerError as e:
raise Unauthorized(str(e))
if _SUBJECT_TO_ACCEPT[ctx.subject_type] not in accept:
raise Forbidden("token subject type not accepted here")
g.auth_ctx = ctx
return fn(*args, **kwargs)
return inner
return wrap
def bearer_feature_required[**P, R](fn: Callable[P, R]) -> Callable[P, R]:
"""503 if ENABLE_OAUTH_BEARER is off — minted tokens would be unusable
without the authenticator, so fail fast instead of approving silently.
"""
@wraps(fn)
def inner(*args: P.args, **kwargs: P.kwargs) -> R:
if not dify_config.ENABLE_OAUTH_BEARER:
raise ServiceUnavailable("bearer_auth_disabled: set ENABLE_OAUTH_BEARER=true to enable")
return fn(*args, **kwargs)
return inner
def require_scope(scope: Scope) -> Callable:
"""Route-level scope gate — must run AFTER validate_bearer so that
g.auth_ctx is set. Raises Forbidden('insufficient_scope: <scope>')
when the bearer lacks both the requested scope and `Scope.FULL`.
"""
def wrap(fn: Callable) -> Callable:
@wraps(fn)
def inner(*args, **kwargs):
ctx = getattr(g, "auth_ctx", None)
if ctx is None:
raise RuntimeError(
"require_scope used without validate_bearer; stack @validate_bearer above @require_scope"
)
if Scope.FULL not in ctx.scopes and scope not in ctx.scopes:
raise Forbidden(f"insufficient_scope: {scope}")
return fn(*args, **kwargs)
return inner
return wrap
# ============================================================================
# Wiring — called once from the app factory
# ============================================================================
def build_registry(session_factory, redis_client) -> TokenKindRegistry:
oauth = OAuthAccessTokenResolver(session_factory, redis_client)
account = MINTABLE_PROFILES[SubjectType.ACCOUNT]
external = MINTABLE_PROFILES[SubjectType.EXTERNAL_SSO]
return TokenKindRegistry(
[
TokenKind(
prefix=account.prefix,
subject_type=account.subject_type,
scopes=account.scopes,
source="oauth_account",
resolver=oauth.for_account(),
),
TokenKind(
prefix=external.prefix,
subject_type=external.subject_type,
scopes=external.scopes,
source="oauth_external_sso",
resolver=oauth.for_external_sso(),
),
]
)
def build_and_bind(session_factory, redis_client) -> BearerAuthenticator:
registry = build_registry(session_factory, redis_client)
auth = BearerAuthenticator(registry)
bind_authenticator(auth)
return auth