mirror of
https://github.com/langgenius/dify.git
synced 2026-05-12 15:58:19 +08:00
651 lines
22 KiB
Python
651 lines
22 KiB
Python
"""OAuth bearer primitives.
|
|
|
|
To add a token kind: write a Resolver, add a SubjectType + Accepts member,
|
|
append a TokenKind to build_registry, and update _SUBJECT_TO_ACCEPT.
|
|
Authenticator + validate_bearer stay untouched.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import uuid
|
|
from collections.abc import Callable, Iterable
|
|
from dataclasses import dataclass, field
|
|
from datetime import UTC, datetime
|
|
from enum import StrEnum
|
|
from functools import wraps
|
|
from typing import Literal, ParamSpec, Protocol, TypeVar
|
|
|
|
from flask import g, request
|
|
from sqlalchemy import select, update
|
|
from sqlalchemy.orm import Session
|
|
from werkzeug.exceptions import Forbidden, ServiceUnavailable, Unauthorized
|
|
|
|
from configs import dify_config
|
|
from extensions.ext_database import db
|
|
from extensions.ext_redis import redis_client
|
|
from libs.rate_limit import enforce_bearer_rate_limit
|
|
from models import Account, OAuthAccessToken, TenantAccountJoin
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# ============================================================================
|
|
# Contract — types, enums, protocols
|
|
# ============================================================================
|
|
|
|
|
|
class SubjectType(StrEnum):
|
|
ACCOUNT = "account"
|
|
EXTERNAL_SSO = "external_sso"
|
|
|
|
|
|
class Scope(StrEnum):
|
|
"""Catalog of bearer scopes recognised by the openapi surface.
|
|
|
|
`FULL` is the catch-all carried by `dfoa_` account tokens — it satisfies
|
|
any per-route `require_scope`. `dfoe_` tokens carry the per-feature scopes
|
|
(`APPS_RUN`, `APPS_READ_PERMITTED_EXTERNAL`).
|
|
"""
|
|
|
|
FULL = "full"
|
|
APPS_READ = "apps:read"
|
|
APPS_READ_PERMITTED_EXTERNAL = "apps:read:permitted-external"
|
|
APPS_RUN = "apps:run"
|
|
|
|
|
|
class Accepts(StrEnum):
|
|
"""Subject types a route is willing to accept as caller."""
|
|
|
|
USER_ACCOUNT = "user_account"
|
|
USER_EXT_SSO = "user_ext_sso"
|
|
|
|
|
|
ACCEPT_USER_ANY: frozenset[Accepts] = frozenset({Accepts.USER_ACCOUNT, Accepts.USER_EXT_SSO})
|
|
ACCEPT_USER_EXT_SSO: frozenset[Accepts] = frozenset({Accepts.USER_EXT_SSO})
|
|
|
|
_SUBJECT_TO_ACCEPT: dict[SubjectType, Accepts] = {
|
|
SubjectType.ACCOUNT: Accepts.USER_ACCOUNT,
|
|
SubjectType.EXTERNAL_SSO: Accepts.USER_EXT_SSO,
|
|
}
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class AuthContext:
|
|
"""Attached to ``g.auth_ctx``. ``scopes`` / ``subject_type`` / ``source``
|
|
come from the TokenKind, not the DB — corrupt rows can't elevate scope.
|
|
|
|
`verified_tenants` is a snapshot of the Layer-0 verdict cache at
|
|
authenticate time. Per-request mutations write through to Redis via
|
|
`record_layer0_verdict`; this snapshot is not updated in place (frozen).
|
|
"""
|
|
|
|
subject_type: SubjectType
|
|
subject_email: str | None
|
|
subject_issuer: str | None
|
|
account_id: uuid.UUID | None
|
|
client_id: str | None
|
|
scopes: frozenset[Scope]
|
|
token_id: uuid.UUID
|
|
source: str
|
|
expires_at: datetime | None
|
|
token_hash: str
|
|
verified_tenants: dict[str, bool] = field(default_factory=dict)
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class ResolvedRow:
|
|
subject_email: str | None
|
|
subject_issuer: str | None
|
|
account_id: uuid.UUID | None
|
|
client_id: str | None
|
|
token_id: uuid.UUID
|
|
expires_at: datetime | None
|
|
verified_tenants: dict[str, bool] = field(default_factory=dict)
|
|
|
|
def to_cache(self) -> dict:
|
|
return {
|
|
"subject_email": self.subject_email,
|
|
"subject_issuer": self.subject_issuer,
|
|
"account_id": str(self.account_id) if self.account_id else None,
|
|
"client_id": self.client_id,
|
|
"token_id": str(self.token_id),
|
|
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
|
|
"verified_tenants": dict(self.verified_tenants),
|
|
}
|
|
|
|
@classmethod
|
|
def from_cache(cls, data: dict) -> ResolvedRow:
|
|
return cls(
|
|
subject_email=data["subject_email"],
|
|
subject_issuer=data["subject_issuer"],
|
|
account_id=uuid.UUID(data["account_id"]) if data["account_id"] else None,
|
|
client_id=data.get("client_id"),
|
|
token_id=uuid.UUID(data["token_id"]),
|
|
expires_at=datetime.fromisoformat(data["expires_at"]) if data["expires_at"] else None,
|
|
verified_tenants=_coerce_verified_tenants(data.get("verified_tenants")),
|
|
)
|
|
|
|
|
|
def _coerce_verified_tenants(raw: object) -> dict[str, bool]:
|
|
"""Tolerate legacy entries that stored 'ok'/'denied' string verdicts.
|
|
|
|
TODO(post-v1.0): remove once the AuthContext cache TTL has fully cycled
|
|
on all live deployments (60s TTL → safe to drop one release after rollout).
|
|
"""
|
|
if not isinstance(raw, dict):
|
|
return {}
|
|
out: dict[str, bool] = {}
|
|
for k, v in raw.items():
|
|
if isinstance(v, bool):
|
|
out[k] = v
|
|
elif v == "ok":
|
|
out[k] = True
|
|
elif v == "denied":
|
|
out[k] = False
|
|
return out
|
|
|
|
|
|
class Resolver(Protocol):
|
|
def resolve(self, token_hash: str) -> ResolvedRow | None: # pragma: no cover - contract
|
|
...
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class TokenKind:
|
|
prefix: str
|
|
subject_type: SubjectType
|
|
scopes: frozenset[Scope]
|
|
source: str
|
|
resolver: Resolver
|
|
|
|
def matches(self, token: str) -> bool:
|
|
return token.startswith(self.prefix)
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class MintProfile:
|
|
"""Single source of truth for (subject_type, prefix, scopes) at mint time.
|
|
|
|
Consumers:
|
|
- ``build_registry`` reads scopes here so the resolve-time TokenKind
|
|
cannot drift from the mint-time intent.
|
|
- Device-flow ``approve`` / ``approve-external`` read prefix + scopes
|
|
here when calling ``mint_oauth_token`` and ``validate_mint_policy``.
|
|
- ``services.openapi.mint_policy.validate_mint_policy`` cross-checks
|
|
the (subject_type, prefix, scopes) triple a caller intends to mint
|
|
against this table — a caller that assembles its own scope set
|
|
from a non-canonical source will fail closed at approve time.
|
|
"""
|
|
|
|
subject_type: SubjectType
|
|
prefix: str
|
|
scopes: frozenset[Scope]
|
|
|
|
|
|
MINTABLE_PROFILES: dict[SubjectType, MintProfile] = {
|
|
SubjectType.ACCOUNT: MintProfile(
|
|
subject_type=SubjectType.ACCOUNT,
|
|
prefix="dfoa_",
|
|
scopes=frozenset({Scope.FULL}),
|
|
),
|
|
SubjectType.EXTERNAL_SSO: MintProfile(
|
|
subject_type=SubjectType.EXTERNAL_SSO,
|
|
prefix="dfoe_",
|
|
scopes=frozenset({Scope.APPS_RUN, Scope.APPS_READ_PERMITTED_EXTERNAL}),
|
|
),
|
|
}
|
|
|
|
|
|
class InvalidBearerError(Exception):
|
|
"""Token missing, unknown prefix, or no live row."""
|
|
|
|
|
|
class TokenExpiredError(Exception):
|
|
"""Hard-expire bookkeeping is the resolver's job before raising."""
|
|
|
|
|
|
# ============================================================================
|
|
# Registry
|
|
# ============================================================================
|
|
|
|
|
|
class TokenKindRegistry:
|
|
def __init__(self, kinds: Iterable[TokenKind]) -> None:
|
|
self._kinds: tuple[TokenKind, ...] = tuple(kinds)
|
|
prefixes = [k.prefix for k in self._kinds]
|
|
if len(set(prefixes)) != len(prefixes):
|
|
raise ValueError(f"duplicate prefix in registry: {prefixes}")
|
|
|
|
def find(self, token: str) -> TokenKind | None:
|
|
for k in self._kinds:
|
|
if k.matches(token):
|
|
return k
|
|
return None
|
|
|
|
def kinds(self) -> tuple[TokenKind, ...]:
|
|
return self._kinds
|
|
|
|
|
|
# ============================================================================
|
|
# Authenticator
|
|
# ============================================================================
|
|
|
|
|
|
def sha256_hex(token: str) -> str:
|
|
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
|
|
|
|
|
class BearerAuthenticator:
|
|
def __init__(self, registry: TokenKindRegistry) -> None:
|
|
self._registry = registry
|
|
|
|
@property
|
|
def registry(self) -> TokenKindRegistry:
|
|
return self._registry
|
|
|
|
def authenticate(self, token: str) -> AuthContext:
|
|
"""Identity + per-token rate limit (single source).
|
|
|
|
Both the openapi pipeline (`BearerCheck`) and the decorator
|
|
(`validate_bearer`) call this — rate-limit fires exactly once per
|
|
request regardless of which path hosts the route.
|
|
"""
|
|
kind = self._registry.find(token)
|
|
if kind is None:
|
|
raise InvalidBearerError("unknown token prefix")
|
|
token_hash = sha256_hex(token)
|
|
row = kind.resolver.resolve(token_hash)
|
|
if row is None:
|
|
raise InvalidBearerError("token unknown or revoked")
|
|
enforce_bearer_rate_limit(token_hash)
|
|
return AuthContext(
|
|
subject_type=kind.subject_type,
|
|
subject_email=row.subject_email,
|
|
subject_issuer=row.subject_issuer,
|
|
account_id=row.account_id,
|
|
client_id=row.client_id,
|
|
scopes=kind.scopes,
|
|
token_id=row.token_id,
|
|
source=kind.source,
|
|
expires_at=row.expires_at,
|
|
token_hash=token_hash,
|
|
verified_tenants=dict(row.verified_tenants),
|
|
)
|
|
|
|
|
|
# ============================================================================
|
|
# OAuth access token resolver (PAT resolver would be a sibling class)
|
|
# ============================================================================
|
|
|
|
TOKEN_CACHE_KEY_FMT = "auth:token:{hash}"
|
|
POSITIVE_TTL_SECONDS = 60
|
|
NEGATIVE_TTL_SECONDS = 10
|
|
AUDIT_OAUTH_EXPIRED = "oauth.token_expired"
|
|
|
|
ScopeVariant = Literal["account", "external_sso"]
|
|
|
|
|
|
class OAuthAccessTokenResolver:
|
|
"""``.for_account()`` / ``.for_external_sso()`` are variant-scoped views
|
|
sharing DB + cache plumbing.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
session_factory,
|
|
redis_client,
|
|
positive_ttl: int = POSITIVE_TTL_SECONDS,
|
|
negative_ttl: int = NEGATIVE_TTL_SECONDS,
|
|
) -> None:
|
|
self.session_factory = session_factory
|
|
self._redis = redis_client
|
|
self._positive_ttl = positive_ttl
|
|
self._negative_ttl = negative_ttl
|
|
|
|
def for_account(self) -> Resolver:
|
|
return _VariantResolver(self, variant="account")
|
|
|
|
def for_external_sso(self) -> Resolver:
|
|
return _VariantResolver(self, variant="external_sso")
|
|
|
|
def _cache_key(self, token_hash: str) -> str:
|
|
return TOKEN_CACHE_KEY_FMT.format(hash=token_hash)
|
|
|
|
def cache_get(self, token_hash: str) -> ResolvedRow | None | Literal["invalid"]:
|
|
raw = self._redis.get(self._cache_key(token_hash))
|
|
if raw is None:
|
|
return None
|
|
text = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
|
|
if text == "invalid":
|
|
return "invalid"
|
|
try:
|
|
return ResolvedRow.from_cache(json.loads(text))
|
|
except (ValueError, KeyError):
|
|
logger.warning("auth:token cache entry malformed; treating as miss")
|
|
return None
|
|
|
|
def cache_set_positive(self, token_hash: str, row: ResolvedRow) -> None:
|
|
self._redis.setex(
|
|
self._cache_key(token_hash),
|
|
self._positive_ttl,
|
|
json.dumps(row.to_cache()),
|
|
)
|
|
|
|
def cache_set_negative(self, token_hash: str) -> None:
|
|
self._redis.setex(self._cache_key(token_hash), self._negative_ttl, "invalid")
|
|
|
|
def hard_expire(self, session: Session, row_id: uuid.UUID | str, token_hash: str) -> None:
|
|
"""Atomic CAS — only the worker that flips revoked_at emits audit;
|
|
replays are idempotent.
|
|
"""
|
|
stmt = (
|
|
update(OAuthAccessToken)
|
|
.where(OAuthAccessToken.id == row_id, OAuthAccessToken.revoked_at.is_(None))
|
|
.values(revoked_at=datetime.now(UTC), token_hash=None)
|
|
)
|
|
result = session.execute(stmt)
|
|
session.commit()
|
|
if result.rowcount == 1:
|
|
logger.warning(
|
|
"audit: %s token_id=%s",
|
|
AUDIT_OAUTH_EXPIRED,
|
|
row_id,
|
|
extra={"audit": True, "token_id": str(row_id)},
|
|
)
|
|
self._redis.delete(self._cache_key(token_hash))
|
|
self.cache_set_negative(token_hash)
|
|
|
|
|
|
class _VariantResolver:
|
|
def __init__(self, parent: OAuthAccessTokenResolver, variant: ScopeVariant) -> None:
|
|
self._parent = parent
|
|
self._variant = variant
|
|
|
|
def resolve(self, token_hash: str) -> ResolvedRow | None:
|
|
cached = self._parent.cache_get(token_hash)
|
|
if cached == "invalid":
|
|
return None
|
|
if cached is not None and not isinstance(cached, str):
|
|
if not self._matches_variant(cached):
|
|
return None
|
|
return cached
|
|
|
|
# Flask-SQLAlchemy's scoped_session is request-bound and not a
|
|
# context manager; use it directly.
|
|
session = self._parent.session_factory()
|
|
row = self._load_from_db(session, token_hash)
|
|
if row is None:
|
|
self._parent.cache_set_negative(token_hash)
|
|
return None
|
|
|
|
now = datetime.now(UTC)
|
|
if row.expires_at is not None and row.expires_at <= now:
|
|
self._parent.hard_expire(session, row.id, token_hash)
|
|
return None
|
|
|
|
if not self._matches_variant_model(row):
|
|
logger.error(
|
|
"internal_state_invariant: account_id/prefix mismatch token_id=%s prefix=%s",
|
|
row.id,
|
|
row.prefix,
|
|
)
|
|
return None
|
|
|
|
resolved = ResolvedRow(
|
|
subject_email=row.subject_email,
|
|
subject_issuer=row.subject_issuer,
|
|
account_id=uuid.UUID(str(row.account_id)) if row.account_id else None,
|
|
client_id=row.client_id,
|
|
token_id=uuid.UUID(str(row.id)),
|
|
expires_at=row.expires_at,
|
|
)
|
|
self._parent.cache_set_positive(token_hash, resolved)
|
|
return resolved
|
|
|
|
def _matches_variant(self, row: ResolvedRow) -> bool:
|
|
has_account = row.account_id is not None
|
|
if self._variant == "account":
|
|
return has_account
|
|
return not has_account
|
|
|
|
def _matches_variant_model(self, row: OAuthAccessToken) -> bool:
|
|
has_account = row.account_id is not None
|
|
if self._variant == "account":
|
|
return has_account and row.prefix == "dfoa_"
|
|
return (not has_account) and row.prefix == "dfoe_"
|
|
|
|
def _load_from_db(self, session: Session, token_hash: str) -> OAuthAccessToken | None:
|
|
return (
|
|
session.query(OAuthAccessToken)
|
|
.filter(
|
|
OAuthAccessToken.token_hash == token_hash,
|
|
OAuthAccessToken.revoked_at.is_(None),
|
|
)
|
|
.one_or_none()
|
|
)
|
|
|
|
|
|
# ============================================================================
|
|
# Layer 0 — workspace membership cache + helper
|
|
# ============================================================================
|
|
|
|
|
|
def record_layer0_verdict(token_hash: str, tenant_id: str, verdict: bool) -> None:
|
|
"""Merge a Layer-0 membership verdict into the AuthContext cache entry at
|
|
`auth:token:{hash}`. No-op if entry missing/expired/invalid — next request
|
|
rebuilds via authenticate() and re-runs Layer 0.
|
|
"""
|
|
cache_key = TOKEN_CACHE_KEY_FMT.format(hash=token_hash)
|
|
raw = redis_client.get(cache_key)
|
|
if raw is None:
|
|
return
|
|
text = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
|
|
if text == "invalid":
|
|
return
|
|
try:
|
|
data = json.loads(text)
|
|
except (ValueError, KeyError):
|
|
return
|
|
ttl = redis_client.ttl(cache_key)
|
|
if ttl <= 0:
|
|
return
|
|
data.setdefault("verified_tenants", {})[tenant_id] = verdict
|
|
redis_client.setex(cache_key, ttl, json.dumps(data))
|
|
|
|
|
|
def check_workspace_membership(
|
|
*,
|
|
account_id: uuid.UUID | str,
|
|
tenant_id: str,
|
|
token_hash: str,
|
|
cached_verdicts: dict[str, bool],
|
|
) -> None:
|
|
"""Layer-0 enforcement core. Raises `Forbidden` on deny, returns on allow.
|
|
|
|
Shared by the pipeline step (`WorkspaceMembershipCheck`) and the
|
|
inline helper (`require_workspace_member`). Caller is responsible for
|
|
short-circuiting on EE / SSO subjects before invoking — this function
|
|
runs the membership + active-status checks unconditionally.
|
|
"""
|
|
cached = cached_verdicts.get(tenant_id)
|
|
if cached is True:
|
|
return
|
|
if cached is False:
|
|
raise Forbidden("workspace_membership_revoked")
|
|
|
|
join = db.session.execute(
|
|
select(TenantAccountJoin.id).where(
|
|
TenantAccountJoin.account_id == account_id,
|
|
TenantAccountJoin.tenant_id == tenant_id,
|
|
)
|
|
).scalar_one_or_none()
|
|
if join is None:
|
|
record_layer0_verdict(token_hash, tenant_id, False)
|
|
raise Forbidden("workspace_membership_revoked")
|
|
|
|
status = db.session.execute(select(Account.status).where(Account.id == account_id)).scalar_one_or_none()
|
|
if status != "active":
|
|
record_layer0_verdict(token_hash, tenant_id, False)
|
|
raise Forbidden("workspace_membership_revoked")
|
|
|
|
record_layer0_verdict(token_hash, tenant_id, True)
|
|
|
|
|
|
def require_workspace_member(ctx: AuthContext, tenant_id: str) -> None:
|
|
"""AuthContext-flavoured wrapper around `check_workspace_membership`.
|
|
|
|
No-op on EE (gateway RBAC owns tenant isolation) and for SSO subjects
|
|
(no `tenant_account_joins` row by definition).
|
|
"""
|
|
if dify_config.ENTERPRISE_ENABLED:
|
|
return
|
|
if ctx.subject_type != SubjectType.ACCOUNT or ctx.account_id is None:
|
|
return
|
|
check_workspace_membership(
|
|
account_id=ctx.account_id,
|
|
tenant_id=tenant_id,
|
|
token_hash=ctx.token_hash,
|
|
cached_verdicts=ctx.verified_tenants,
|
|
)
|
|
|
|
|
|
# ============================================================================
|
|
# Decorator — route-level bearer gate
|
|
# ============================================================================
|
|
|
|
|
|
_authenticator: BearerAuthenticator | None = None
|
|
|
|
|
|
def bind_authenticator(authenticator: BearerAuthenticator) -> None:
|
|
global _authenticator
|
|
_authenticator = authenticator
|
|
|
|
|
|
def get_authenticator() -> BearerAuthenticator:
|
|
if _authenticator is None:
|
|
raise RuntimeError("BearerAuthenticator not bound; call bind_authenticator at startup")
|
|
return _authenticator
|
|
|
|
|
|
def _extract_bearer(req) -> str | None:
|
|
header = req.headers.get("Authorization", "")
|
|
scheme, _, value = header.partition(" ")
|
|
if scheme.lower() != "bearer" or not value:
|
|
return None
|
|
return value.strip()
|
|
|
|
|
|
_DP = ParamSpec("_DP")
|
|
_DR = TypeVar("_DR")
|
|
|
|
|
|
def validate_bearer(*, accept: frozenset[Accepts]) -> Callable[[Callable[_DP, _DR]], Callable[_DP, _DR]]:
|
|
"""Opt-in: omitting it leaves the route unauthenticated.
|
|
|
|
Resolves user-level OAuth bearers (``dfoa_`` / ``dfoe_``). Legacy
|
|
``app-`` keys belong to ``service_api/wraps.py:validate_app_token``
|
|
and are rejected here as the wrong auth scheme for this surface.
|
|
"""
|
|
|
|
def wrap(fn: Callable[_DP, _DR]) -> Callable[_DP, _DR]:
|
|
@wraps(fn)
|
|
def inner(*args: _DP.args, **kwargs: _DP.kwargs) -> _DR:
|
|
token = _extract_bearer(request)
|
|
if token is None:
|
|
raise Unauthorized("missing bearer token")
|
|
|
|
if _authenticator is None:
|
|
raise ServiceUnavailable("bearer_auth_disabled: set ENABLE_OAUTH_BEARER=true to enable")
|
|
|
|
try:
|
|
ctx = get_authenticator().authenticate(token)
|
|
except InvalidBearerError as e:
|
|
raise Unauthorized(str(e))
|
|
|
|
if _SUBJECT_TO_ACCEPT[ctx.subject_type] not in accept:
|
|
raise Forbidden("token subject type not accepted here")
|
|
|
|
g.auth_ctx = ctx
|
|
return fn(*args, **kwargs)
|
|
|
|
return inner
|
|
|
|
return wrap
|
|
|
|
|
|
def bearer_feature_required[**P, R](fn: Callable[P, R]) -> Callable[P, R]:
|
|
"""503 if ENABLE_OAUTH_BEARER is off — minted tokens would be unusable
|
|
without the authenticator, so fail fast instead of approving silently.
|
|
"""
|
|
|
|
@wraps(fn)
|
|
def inner(*args: P.args, **kwargs: P.kwargs) -> R:
|
|
if not dify_config.ENABLE_OAUTH_BEARER:
|
|
raise ServiceUnavailable("bearer_auth_disabled: set ENABLE_OAUTH_BEARER=true to enable")
|
|
return fn(*args, **kwargs)
|
|
|
|
return inner
|
|
|
|
|
|
def require_scope(scope: Scope) -> Callable:
|
|
"""Route-level scope gate — must run AFTER validate_bearer so that
|
|
g.auth_ctx is set. Raises Forbidden('insufficient_scope: <scope>')
|
|
when the bearer lacks both the requested scope and `Scope.FULL`.
|
|
"""
|
|
|
|
def wrap(fn: Callable) -> Callable:
|
|
@wraps(fn)
|
|
def inner(*args, **kwargs):
|
|
ctx = getattr(g, "auth_ctx", None)
|
|
if ctx is None:
|
|
raise RuntimeError(
|
|
"require_scope used without validate_bearer; stack @validate_bearer above @require_scope"
|
|
)
|
|
if Scope.FULL not in ctx.scopes and scope not in ctx.scopes:
|
|
raise Forbidden(f"insufficient_scope: {scope}")
|
|
return fn(*args, **kwargs)
|
|
|
|
return inner
|
|
|
|
return wrap
|
|
|
|
|
|
# ============================================================================
|
|
# Wiring — called once from the app factory
|
|
# ============================================================================
|
|
|
|
|
|
def build_registry(session_factory, redis_client) -> TokenKindRegistry:
|
|
oauth = OAuthAccessTokenResolver(session_factory, redis_client)
|
|
account = MINTABLE_PROFILES[SubjectType.ACCOUNT]
|
|
external = MINTABLE_PROFILES[SubjectType.EXTERNAL_SSO]
|
|
return TokenKindRegistry(
|
|
[
|
|
TokenKind(
|
|
prefix=account.prefix,
|
|
subject_type=account.subject_type,
|
|
scopes=account.scopes,
|
|
source="oauth_account",
|
|
resolver=oauth.for_account(),
|
|
),
|
|
TokenKind(
|
|
prefix=external.prefix,
|
|
subject_type=external.subject_type,
|
|
scopes=external.scopes,
|
|
source="oauth_external_sso",
|
|
resolver=oauth.for_external_sso(),
|
|
),
|
|
]
|
|
)
|
|
|
|
|
|
def build_and_bind(session_factory, redis_client) -> BearerAuthenticator:
|
|
registry = build_registry(session_factory, redis_client)
|
|
auth = BearerAuthenticator(registry)
|
|
bind_authenticator(auth)
|
|
return auth
|