feat(openapi): bearer auth pipeline + Layer 0 + per-token rate limit (CE)

Bearer auth surface for /openapi/v1/* run-routes:

- OAUTH_BEARER_PIPELINE (renamed from APP_PIPELINE for clarity outside this
  module) composes BearerCheck → ScopeCheck → AppResolver →
  WorkspaceMembershipCheck → AppAuthzCheck → CallerMount.
- BearerAuthenticator.authenticate() is the single source of identity +
  rate-limit. Both pipeline (BearerCheck) and decorator (validate_bearer)
  delegate to it, so per-token rate limit fires exactly once per request.
- Layer 0 (workspace membership) is CE-only; on EE the gateway owns
  tenant isolation. Verdicts are cached on the AuthContext entry as
  verified_tenants: dict[str, bool] (legacy "ok"/"denied" strings tolerated
  by from_cache for one TTL cycle, then removed).
- check_workspace_membership(...) is the shared core; the pipeline step
  and the inline require_workspace_member helper both delegate to it.
- Per-token rate limit: 60/min sliding window, RFC-7231-compliant 429
  with Retry-After header + JSON body { error, retry_after_ms }. Bucket
  key is sha256(token) so all replicas share state via Redis.

API hygiene:
- Scope StrEnum (FULL, APPS_READ, APPS_RUN) replaces bare string literals.
- /openapi/v1/apps/<id>/info: scope flipped from apps:run to apps:read.
- /info migrates off the pipeline to validate_bearer + require_scope +
  require_workspace_member (no AppAuthzCheck/CallerMount needed for reads).
- ResolvedRow gains to_cache() / from_cache() classmethods.
- AuthContext gains token_hash + verified_tenants, dropping the per-route
  re-hash and per-request Redis read on the cache hit path.

OPENAPI_RATE_LIMIT_PER_TOKEN config (default 60).
This commit is contained in:
GareArc 2026-05-05 18:07:47 -07:00
parent 8a62c1d915
commit 591048d7c2
No known key found for this signature in database
22 changed files with 808 additions and 180 deletions

View File

@ -908,6 +908,12 @@ class AuthConfig(BaseSettings):
default=True,
)
OPENAPI_RATE_LIMIT_PER_TOKEN: PositiveInt = Field(
description="Per-token rate limit on /openapi/v1/* (requests per minute). "
"Bucket keyed on sha256(token), shared across api replicas via Redis.",
default=60,
)
class ModerationConfig(BaseSettings):
"""

View File

@ -2,36 +2,34 @@
from __future__ import annotations
from flask import g
from flask_restx import Resource
from pydantic import BaseModel
from werkzeug.exceptions import NotFound
from controllers.openapi import openapi_ns
from controllers.openapi.auth.composition import APP_PIPELINE
class AppInfoResponse(BaseModel):
id: str
name: str
description: str | None = None
mode: str
author_name: str | None = None
tags: list[str] = []
def _unpack_app(app_model):
return app_model
from controllers.openapi.apps import account_or_404, app_info_payload
from extensions.ext_database import db
from libs.oauth_bearer import (
ACCEPT_USER_ANY,
Scope,
require_scope,
require_workspace_member,
validate_bearer,
)
from models import App
@openapi_ns.route("/apps/<string:app_id>/info")
class AppInfoApi(Resource):
@APP_PIPELINE.guard(scope="apps:run")
def get(self, app_id, app_model, caller, caller_kind):
app = _unpack_app(app_model)
return AppInfoResponse(
id=app.id,
name=app.name,
description=app.description,
mode=app.mode,
author_name=app.author_name,
tags=[t.name for t in app.tags],
).model_dump(mode="json")
@validate_bearer(accept=ACCEPT_USER_ANY)
@require_scope(Scope.APPS_READ) # type: ignore[reportUntypedFunctionDecorator]
def get(self, app_id: str):
ctx = g.auth_ctx
account_or_404(ctx)
app = db.session.get(App, app_id)
if not app or app.status != "normal":
raise NotFound("app not found")
require_workspace_member(ctx, str(app.tenant_id))
return app_info_payload(app), 200

View File

@ -1,3 +1,3 @@
from controllers.openapi.auth.composition import APP_PIPELINE
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
__all__ = ["APP_PIPELINE"]
__all__ = ["OAUTH_BEARER_PIPELINE"]

View File

@ -1,6 +1,9 @@
"""APP_PIPELINE — the only auth scheme for openapi app endpoints.
"""`OAUTH_BEARER_PIPELINE` — the auth scheme for openapi `/run` endpoints.
Endpoints attach via @APP_PIPELINE.guard(scope=). No alternative paths.
Endpoints attach via `@OAUTH_BEARER_PIPELINE.guard(scope=)`. No alternative
paths. Read endpoints (`/apps`, `/info`, `/parameters`, `/describe`) skip
the pipeline and use `validate_bearer + require_scope + require_workspace_member`
inline they don't need `AppAuthzCheck`/`CallerMount`.
"""
from __future__ import annotations
@ -12,6 +15,7 @@ from controllers.openapi.auth.steps import (
BearerCheck,
CallerMount,
ScopeCheck,
WorkspaceMembershipCheck,
)
from controllers.openapi.auth.strategies import (
AccountMounter,
@ -29,10 +33,11 @@ def _resolve_app_authz_strategy() -> AppAuthzStrategy:
return MembershipStrategy()
APP_PIPELINE = Pipeline(
OAUTH_BEARER_PIPELINE = Pipeline(
BearerCheck(),
ScopeCheck(),
AppResolver(),
WorkspaceMembershipCheck(),
AppAuthzCheck(_resolve_app_authz_strategy),
CallerMount(AccountMounter(), EndUserMounter()),
)

View File

@ -7,29 +7,35 @@ read populated values via the decorator's kwargs unpacking.
from __future__ import annotations
import uuid
from dataclasses import dataclass, field
from datetime import datetime
from typing import Literal, Protocol
from typing import TYPE_CHECKING, Literal, Protocol
from flask import Request
from libs.oauth_bearer import SubjectType
from libs.oauth_bearer import Scope, SubjectType
if TYPE_CHECKING:
from models import App, Tenant
@dataclass
class Context:
request: Request
required_scope: str
required_scope: Scope
subject_type: SubjectType | None = None
subject_email: str | None = None
subject_issuer: str | None = None
account_id: str | None = None
scopes: frozenset[str] = field(default_factory=frozenset)
token_id: str | None = None
account_id: uuid.UUID | None = None
scopes: frozenset[Scope] = field(default_factory=frozenset)
token_id: uuid.UUID | None = None
token_hash: str | None = None
cached_verified_tenants: dict[str, bool] | None = None
source: str | None = None
expires_at: datetime | None = None
app: object | None = None
tenant: object | None = None
app: App | None = None
tenant: Tenant | None = None
caller: object | None = None
caller_kind: Literal["account", "end_user"] | None = None

View File

@ -12,6 +12,7 @@ from functools import wraps
from flask import request
from controllers.openapi.auth.context import Context, Step
from libs.oauth_bearer import Scope
class Pipeline:
@ -22,7 +23,7 @@ class Pipeline:
for step in self._steps:
step(ctx)
def guard(self, *, scope: str):
def guard(self, *, scope: Scope):
def decorator(view):
@wraps(view)
def decorated(*args, **kwargs):

View File

@ -1,7 +1,7 @@
"""Pipeline steps. Each is one responsibility.
BearerCheck is the only step that touches the token registry; downstream
steps see only the populated Context.
`BearerCheck` is the only step that touches the token registry; downstream
steps see only the populated `Context`.
"""
from __future__ import annotations
@ -10,62 +10,52 @@ from collections.abc import Callable
from werkzeug.exceptions import BadRequest, Forbidden, NotFound, Unauthorized
from configs import dify_config
from controllers.openapi.auth.context import Context
from controllers.openapi.auth.strategies import AppAuthzStrategy, CallerMounter
from extensions.ext_database import db
from libs.oauth_bearer import TokenExpiredError, get_authenticator, sha256_hex
from libs.oauth_bearer import (
InvalidBearerError,
Scope,
SubjectType,
_extract_bearer, # type: ignore[attr-defined]
check_workspace_membership,
get_authenticator,
)
from models import App, Tenant, TenantStatus
def _registry():
return get_authenticator().registry
def _extract_bearer(req) -> str | None:
auth = req.headers.get("Authorization")
if not auth or not auth.lower().startswith("bearer "):
return None
return auth.split(None, 1)[1].strip() or None
def _hash_token(token: str) -> str:
return sha256_hex(token)
class BearerCheck:
"""Resolve bearer → populate identity fields."""
"""Resolve bearer → populate identity fields. Rate-limit is enforced
inside `BearerAuthenticator.authenticate`, so no separate step here."""
def __call__(self, ctx: Context) -> None:
token = _extract_bearer(ctx.request)
if not token:
raise Unauthorized("bearer required")
kind = _registry().find(token)
if kind is None:
raise Unauthorized("invalid bearer prefix")
try:
row = kind.resolver.resolve(_hash_token(token))
except TokenExpiredError:
raise Unauthorized("token expired")
if row is None:
raise Unauthorized("invalid bearer")
authn = get_authenticator().authenticate(token)
except InvalidBearerError as e:
raise Unauthorized(str(e))
ctx.subject_type = kind.subject_type
ctx.subject_email = row.subject_email
ctx.subject_issuer = row.subject_issuer
ctx.account_id = row.account_id
ctx.scopes = kind.scopes
ctx.source = kind.source
ctx.token_id = row.token_id
ctx.expires_at = row.expires_at
ctx.subject_type = authn.subject_type
ctx.subject_email = authn.subject_email
ctx.subject_issuer = authn.subject_issuer
ctx.account_id = authn.account_id
ctx.scopes = frozenset(authn.scopes)
ctx.source = authn.source
ctx.token_id = authn.token_id
ctx.expires_at = authn.expires_at
ctx.token_hash = authn.token_hash
ctx.cached_verified_tenants = dict(authn.verified_tenants)
class ScopeCheck:
"""Verify ctx.scopes (already populated by BearerCheck) covers required."""
def __call__(self, ctx: Context) -> None:
if "full" in ctx.scopes or ctx.required_scope in ctx.scopes:
if Scope.FULL in ctx.scopes or ctx.required_scope in ctx.scopes:
return
raise Forbidden("insufficient_scope")
@ -73,8 +63,9 @@ class ScopeCheck:
class AppResolver:
"""Read app_id from request.view_args, populate ctx.app + ctx.tenant.
Every endpoint using APP_PIPELINE must declare ``<string:app_id>`` in
its route that is the design lock-in (no body / header coupling).
Every endpoint using the OAuth bearer pipeline must declare
``<string:app_id>`` in its route that is the design lock-in (no body /
header coupling).
"""
def __call__(self, ctx: Context) -> None:
@ -92,6 +83,31 @@ class AppResolver:
ctx.app, ctx.tenant = app, tenant
class WorkspaceMembershipCheck:
"""Layer 0 — workspace membership gate.
CE-only (skipped when ENTERPRISE_ENABLED). Account-subject bearers
(dfoa_) only SSO subjects skip.
"""
def __call__(self, ctx: Context) -> None:
if dify_config.ENTERPRISE_ENABLED:
return
if ctx.subject_type != SubjectType.ACCOUNT:
return
if ctx.account_id is None or ctx.tenant is None:
raise Unauthorized("account_id or tenant unset — BearerCheck or AppResolver did not run")
if ctx.token_hash is None:
raise Unauthorized("token_hash unset — BearerCheck did not run")
check_workspace_membership(
account_id=ctx.account_id,
tenant_id=ctx.tenant.id,
token_hash=ctx.token_hash,
cached_verdicts=ctx.cached_verified_tenants or {},
)
class AppAuthzCheck:
def __init__(self, resolve_strategy: Callable[[], AppAuthzStrategy]) -> None:
self._resolve = resolve_strategy

View File

@ -7,6 +7,7 @@ composition stays a flat list.
from __future__ import annotations
import uuid
from typing import Protocol
from flask import current_app
@ -38,7 +39,7 @@ class AclStrategy:
return False
return EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
user_id=ctx.subject_email,
app_id=ctx.app.id, # type: ignore[attr-defined]
app_id=ctx.app.id,
)
@ -55,10 +56,10 @@ class MembershipStrategy:
return False
if ctx.tenant is None:
return False
return _has_tenant_membership(ctx.account_id, ctx.tenant.id) # type: ignore[attr-defined]
return _has_tenant_membership(ctx.account_id, ctx.tenant.id)
def _has_tenant_membership(account_id: str | None, tenant_id: str) -> bool:
def _has_tenant_membership(account_id: uuid.UUID | str | None, tenant_id: str) -> bool:
if not account_id:
return False
row = db.session.execute(
@ -92,7 +93,7 @@ class AccountMounter:
account = db.session.get(Account, ctx.account_id)
if account is None:
raise RuntimeError("AccountMounter: account row missing for resolved bearer")
account.current_tenant = ctx.tenant # type: ignore[assignment]
account.current_tenant = ctx.tenant
_login_as(account)
ctx.caller, ctx.caller_kind = account, "account"
@ -106,8 +107,8 @@ class EndUserMounter:
raise RuntimeError("EndUserMounter: tenant/app/subject_email unset — earlier steps did not run")
end_user = EndUserService.get_or_create_end_user_by_type(
InvokeFrom.OPENAPI,
tenant_id=ctx.tenant.id, # type: ignore[attr-defined]
app_id=ctx.app.id, # type: ignore[attr-defined]
tenant_id=ctx.tenant.id,
app_id=ctx.app.id,
user_id=ctx.subject_email,
)
_login_as(end_user)

View File

@ -3,7 +3,7 @@ service_api/app/completion.py:ChatApi.
Differences from service_api:
- App is in URL path, not header.
- One decorator: @APP_PIPELINE.guard(scope="apps:run").
- One decorator: @OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN).
- Request body has no `user` field (Model 2: identity is the bearer).
- Typed Request and Response models.
- invoke_from = InvokeFrom.OPENAPI.
@ -25,7 +25,7 @@ import services
from controllers.openapi import openapi_ns
from controllers.openapi._audit import emit_app_run
from controllers.openapi._models import MessageMetadata
from controllers.openapi.auth.composition import APP_PIPELINE
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
from controllers.service_api.app.error import (
AppUnavailableError,
CompletionRequestError,
@ -45,6 +45,7 @@ from core.errors.error import (
from graphon.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.helper import UUIDStrOrEmpty
from libs.oauth_bearer import Scope
from models.model import App, AppMode
from services.app_generate_service import AppGenerateService
from services.errors.app import (
@ -101,7 +102,7 @@ def _unpack_caller(caller):
@openapi_ns.route("/apps/<string:app_id>/chat-messages")
class ChatMessagesApi(Resource):
@APP_PIPELINE.guard(scope="apps:run")
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
def post(self, app_id: str, app_model: App, caller, caller_kind: str):
app = _unpack_app(app_model)
if AppMode.value_of(app.mode) not in {

View File

@ -16,7 +16,7 @@ import services
from controllers.openapi import openapi_ns
from controllers.openapi._audit import emit_app_run
from controllers.openapi._models import MessageMetadata
from controllers.openapi.auth.composition import APP_PIPELINE
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
from controllers.service_api.app.error import (
AppUnavailableError,
CompletionRequestError,
@ -33,6 +33,7 @@ from core.errors.error import (
)
from graphon.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.oauth_bearer import Scope
from models.model import App, AppMode
from services.app_generate_service import AppGenerateService
@ -67,7 +68,7 @@ def _unpack_caller(caller):
@openapi_ns.route("/apps/<string:app_id>/completion-messages")
class CompletionMessagesApi(Resource):
@APP_PIPELINE.guard(scope="apps:run")
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
def post(self, app_id: str, app_model: App, caller, caller_kind: str):
app = _unpack_app(app_model)
if AppMode.value_of(app.mode) != AppMode.COMPLETION:

View File

@ -15,7 +15,7 @@ from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
from controllers.common.controller_schemas import WorkflowRunPayload as WorkflowRunPayloadBase
from controllers.openapi import openapi_ns
from controllers.openapi._audit import emit_app_run
from controllers.openapi.auth.composition import APP_PIPELINE
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
from controllers.service_api.app.error import (
CompletionRequestError,
NotWorkflowAppError,
@ -32,6 +32,7 @@ from core.errors.error import (
)
from graphon.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.oauth_bearer import Scope
from models.model import App, AppMode
from services.app_generate_service import AppGenerateService
from services.errors.app import (
@ -77,7 +78,7 @@ def _unpack_caller(caller):
@openapi_ns.route("/apps/<string:app_id>/workflows/run")
class WorkflowRunApi(Resource):
@APP_PIPELINE.guard(scope="apps:run")
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
def post(self, app_id: str, app_model: App, caller, caller_kind: str):
app = _unpack_app(app_model)
if AppMode.value_of(app.mode) != AppMode.WORKFLOW:

View File

@ -542,3 +542,18 @@ class RateLimiter:
self._redis_client.zadd(key, {member: current_time})
self._redis_client.expire(key, self.time_window * 2)
def seconds_until_available(self, email: str) -> int:
"""Seconds until the oldest in-window entry expires, freeing a slot.
Defensive floor of 1 second. Caller should only invoke this after
is_rate_limited() returned True.
"""
key = self._get_key(email)
oldest = cast(Any, self._redis_client).zrange(key, 0, 0, withscores=True)
if not oldest:
return 1
_member, score = oldest[0]
free_at = int(score) + self.time_window
remaining = free_at - int(time.time())
return max(remaining, 1)

View File

@ -12,19 +12,22 @@ import json
import logging
import uuid
from collections.abc import Callable, Iterable
from dataclasses import dataclass
from dataclasses import dataclass, field
from datetime import UTC, datetime
from enum import StrEnum
from functools import wraps
from typing import Literal, ParamSpec, Protocol, TypeVar
from flask import g, request
from sqlalchemy import update
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, ServiceUnavailable, Unauthorized
from configs import dify_config
from models import OAuthAccessToken
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.rate_limit import enforce_bearer_rate_limit
from models import Account, OAuthAccessToken, TenantAccountJoin
logger = logging.getLogger(__name__)
@ -39,20 +42,54 @@ class SubjectType(StrEnum):
EXTERNAL_SSO = "external_sso"
class Scope(StrEnum):
"""Catalog of bearer scopes recognised by the openapi surface.
`FULL` is the catch-all carried by `dfoa_` account tokens it satisfies
any per-route `require_scope`. `dfoe_` tokens carry the per-feature scopes
(`APPS_RUN` today).
"""
FULL = "full"
APPS_READ = "apps:read"
APPS_RUN = "apps:run"
class Accepts(StrEnum):
"""Subject types a route is willing to accept as caller."""
USER_ACCOUNT = "user_account"
USER_EXT_SSO = "user_ext_sso"
ACCEPT_USER_ANY: frozenset[Accepts] = frozenset({Accepts.USER_ACCOUNT, Accepts.USER_EXT_SSO})
_SUBJECT_TO_ACCEPT: dict[SubjectType, Accepts] = {
SubjectType.ACCOUNT: Accepts.USER_ACCOUNT,
SubjectType.EXTERNAL_SSO: Accepts.USER_EXT_SSO,
}
@dataclass(frozen=True, slots=True)
class AuthContext:
"""Attached to ``g.auth_ctx``. ``scopes`` / ``subject_type`` / ``source``
come from the TokenKind, not the DB corrupt rows can't elevate scope.
`verified_tenants` is a snapshot of the Layer-0 verdict cache at
authenticate time. Per-request mutations write through to Redis via
`record_layer0_verdict`; this snapshot is not updated in place (frozen).
"""
subject_type: SubjectType
subject_email: str | None
subject_issuer: str | None
account_id: uuid.UUID | None
scopes: frozenset[str]
scopes: frozenset[Scope]
token_id: uuid.UUID
source: str
expires_at: datetime | None
token_hash: str
verified_tenants: dict[str, bool] = field(default_factory=dict)
@dataclass(frozen=True, slots=True)
@ -62,6 +99,47 @@ class ResolvedRow:
account_id: uuid.UUID | None
token_id: uuid.UUID
expires_at: datetime | None
verified_tenants: dict[str, bool] = field(default_factory=dict)
def to_cache(self) -> dict:
return {
"subject_email": self.subject_email,
"subject_issuer": self.subject_issuer,
"account_id": str(self.account_id) if self.account_id else None,
"token_id": str(self.token_id),
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
"verified_tenants": dict(self.verified_tenants),
}
@classmethod
def from_cache(cls, data: dict) -> ResolvedRow:
return cls(
subject_email=data["subject_email"],
subject_issuer=data["subject_issuer"],
account_id=uuid.UUID(data["account_id"]) if data["account_id"] else None,
token_id=uuid.UUID(data["token_id"]),
expires_at=datetime.fromisoformat(data["expires_at"]) if data["expires_at"] else None,
verified_tenants=_coerce_verified_tenants(data.get("verified_tenants")),
)
def _coerce_verified_tenants(raw: object) -> dict[str, bool]:
"""Tolerate legacy entries that stored 'ok'/'denied' string verdicts.
TODO(post-v1.0): remove once the AuthContext cache TTL has fully cycled
on all live deployments (60s TTL safe to drop one release after rollout).
"""
if not isinstance(raw, dict):
return {}
out: dict[str, bool] = {}
for k, v in raw.items():
if isinstance(v, bool):
out[k] = v
elif v == "ok":
out[k] = True
elif v == "denied":
out[k] = False
return out
class Resolver(Protocol):
@ -73,7 +151,7 @@ class Resolver(Protocol):
class TokenKind:
prefix: str
subject_type: SubjectType
scopes: frozenset[str]
scopes: frozenset[Scope]
source: str
resolver: Resolver
@ -129,12 +207,20 @@ class BearerAuthenticator:
return self._registry
def authenticate(self, token: str) -> AuthContext:
"""Identity + per-token rate limit (single source).
Both the openapi pipeline (`BearerCheck`) and the decorator
(`validate_bearer`) call this rate-limit fires exactly once per
request regardless of which path hosts the route.
"""
kind = self._registry.find(token)
if kind is None:
raise InvalidBearerError("unknown token prefix")
row = kind.resolver.resolve(sha256_hex(token))
token_hash = sha256_hex(token)
row = kind.resolver.resolve(token_hash)
if row is None:
raise InvalidBearerError("token unknown or revoked")
enforce_bearer_rate_limit(token_hash)
return AuthContext(
subject_type=kind.subject_type,
subject_email=row.subject_email,
@ -144,6 +230,8 @@ class BearerAuthenticator:
token_id=row.token_id,
source=kind.source,
expires_at=row.expires_at,
token_hash=token_hash,
verified_tenants=dict(row.verified_tenants),
)
@ -171,8 +259,6 @@ class OAuthAccessTokenResolver:
positive_ttl: int = POSITIVE_TTL_SECONDS,
negative_ttl: int = NEGATIVE_TTL_SECONDS,
) -> None:
# session_factory and the cache helpers below are friend-API for
# _VariantResolver in this module — kept public-named on purpose.
self.session_factory = session_factory
self._redis = redis_client
self._positive_ttl = positive_ttl
@ -195,8 +281,7 @@ class OAuthAccessTokenResolver:
if text == "invalid":
return "invalid"
try:
data = json.loads(text)
return _row_from_cache(data)
return ResolvedRow.from_cache(json.loads(text))
except (ValueError, KeyError):
logger.warning("auth:token cache entry malformed; treating as miss")
return None
@ -205,7 +290,7 @@ class OAuthAccessTokenResolver:
self._redis.setex(
self._cache_key(token_hash),
self._positive_ttl,
json.dumps(_row_to_cache(row)),
json.dumps(row.to_cache()),
)
def cache_set_negative(self, token_hash: str) -> None:
@ -213,7 +298,7 @@ class OAuthAccessTokenResolver:
def hard_expire(self, session: Session, row_id: uuid.UUID | str, token_hash: str) -> None:
"""Atomic CAS — only the worker that flips revoked_at emits audit;
replays are idempotent. Spec: tokens.md §Detection + hard-expire.
replays are idempotent.
"""
stmt = (
update(OAuthAccessToken)
@ -247,8 +332,8 @@ class _VariantResolver:
return None
return cached
# session_factory returns Flask-SQLAlchemy's scoped_session, which is
# request-bound and not a context manager; use it directly.
# Flask-SQLAlchemy's scoped_session is request-bound and not a
# context manager; use it directly.
session = self._parent.session_factory()
row = self._load_from_db(session, token_hash)
if row is None:
@ -301,23 +386,87 @@ class _VariantResolver:
)
def _row_to_cache(row: ResolvedRow) -> dict:
return {
"subject_email": row.subject_email,
"subject_issuer": row.subject_issuer,
"account_id": str(row.account_id) if row.account_id else None,
"token_id": str(row.token_id),
"expires_at": row.expires_at.isoformat() if row.expires_at else None,
}
# ============================================================================
# Layer 0 — workspace membership cache + helper
# ============================================================================
def _row_from_cache(data: dict) -> ResolvedRow:
return ResolvedRow(
subject_email=data["subject_email"],
subject_issuer=data["subject_issuer"],
account_id=uuid.UUID(data["account_id"]) if data["account_id"] else None,
token_id=uuid.UUID(data["token_id"]),
expires_at=datetime.fromisoformat(data["expires_at"]) if data["expires_at"] else None,
def record_layer0_verdict(token_hash: str, tenant_id: str, verdict: bool) -> None:
"""Merge a Layer-0 membership verdict into the AuthContext cache entry at
`auth:token:{hash}`. No-op if entry missing/expired/invalid next request
rebuilds via authenticate() and re-runs Layer 0.
"""
cache_key = TOKEN_CACHE_KEY_FMT.format(hash=token_hash)
raw = redis_client.get(cache_key)
if raw is None:
return
text = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
if text == "invalid":
return
try:
data = json.loads(text)
except (ValueError, KeyError):
return
ttl = redis_client.ttl(cache_key)
if ttl <= 0:
return
data.setdefault("verified_tenants", {})[tenant_id] = verdict
redis_client.setex(cache_key, ttl, json.dumps(data))
def check_workspace_membership(
*,
account_id: uuid.UUID | str,
tenant_id: str,
token_hash: str,
cached_verdicts: dict[str, bool],
) -> None:
"""Layer-0 enforcement core. Raises `Forbidden` on deny, returns on allow.
Shared by the pipeline step (`WorkspaceMembershipCheck`) and the
inline helper (`require_workspace_member`). Caller is responsible for
short-circuiting on EE / SSO subjects before invoking this function
runs the membership + active-status checks unconditionally.
"""
cached = cached_verdicts.get(tenant_id)
if cached is True:
return
if cached is False:
raise Forbidden("workspace_membership_revoked")
join = db.session.execute(
select(TenantAccountJoin.id).where(
TenantAccountJoin.account_id == account_id,
TenantAccountJoin.tenant_id == tenant_id,
)
).scalar_one_or_none()
if join is None:
record_layer0_verdict(token_hash, tenant_id, False)
raise Forbidden("workspace_membership_revoked")
status = db.session.execute(select(Account.status).where(Account.id == account_id)).scalar_one_or_none()
if status != "active":
record_layer0_verdict(token_hash, tenant_id, False)
raise Forbidden("workspace_membership_revoked")
record_layer0_verdict(token_hash, tenant_id, True)
def require_workspace_member(ctx: AuthContext, tenant_id: str) -> None:
"""AuthContext-flavoured wrapper around `check_workspace_membership`.
No-op on EE (gateway RBAC owns tenant isolation) and for SSO subjects
(no `tenant_account_joins` row by definition).
"""
if dify_config.ENTERPRISE_ENABLED:
return
if ctx.subject_type != SubjectType.ACCOUNT or ctx.account_id is None:
return
check_workspace_membership(
account_id=ctx.account_id,
tenant_id=tenant_id,
token_hash=ctx.token_hash,
cached_verdicts=ctx.verified_tenants,
)
@ -326,20 +475,6 @@ def _row_from_cache(data: dict) -> ResolvedRow:
# ============================================================================
class Accepts(StrEnum):
USER_ACCOUNT = "user_account"
USER_EXT_SSO = "user_ext_sso"
ACCEPT_USER_ANY: frozenset[Accepts] = frozenset({Accepts.USER_ACCOUNT, Accepts.USER_EXT_SSO})
_SUBJECT_TO_ACCEPT: dict[SubjectType, Accepts] = {
SubjectType.ACCOUNT: Accepts.USER_ACCOUNT,
SubjectType.EXTERNAL_SSO: Accepts.USER_EXT_SSO,
}
_authenticator: BearerAuthenticator | None = None
@ -414,17 +549,10 @@ def bearer_feature_required[**P, R](fn: Callable[P, R]) -> Callable[P, R]:
return inner
# "full" is the catch-all scope carried by dfoa_ tokens; any scope check
# passes when the bearer holds it. dfoe_ ships with apps:run and a few
# narrower scopes; PATs (future) carry only what the user requested at
# mint time.
SCOPE_FULL = "full"
def require_scope(scope: str) -> Callable:
def require_scope(scope: Scope) -> Callable:
"""Route-level scope gate — must run AFTER validate_bearer so that
g.auth_ctx is set. Raises Forbidden('insufficient_scope: <scope>')
when the bearer lacks both the requested scope and the catch-all.
when the bearer lacks both the requested scope and `Scope.FULL`.
"""
def wrap(fn: Callable) -> Callable:
@ -435,7 +563,7 @@ def require_scope(scope: str) -> Callable:
raise RuntimeError(
"require_scope used without validate_bearer; stack @validate_bearer above @require_scope"
)
if SCOPE_FULL not in ctx.scopes and scope not in ctx.scopes:
if Scope.FULL not in ctx.scopes and scope not in ctx.scopes:
raise Forbidden(f"insufficient_scope: {scope}")
return fn(*args, **kwargs)
@ -456,14 +584,14 @@ def build_registry(session_factory, redis_client) -> TokenKindRegistry:
TokenKind(
prefix="dfoa_",
subject_type=SubjectType.ACCOUNT,
scopes=frozenset({"full"}),
scopes=frozenset({Scope.FULL}),
source="oauth_account",
resolver=oauth.for_account(),
),
TokenKind(
prefix="dfoe_",
subject_type=SubjectType.EXTERNAL_SSO,
scopes=frozenset({"apps:run"}),
scopes=frozenset({Scope.APPS_RUN}),
source="oauth_external_sso",
resolver=oauth.for_external_sso(),
),

View File

@ -2,7 +2,7 @@
window Redis ZSET). Apply after auth decorators so scopes can read
``g.auth_ctx``. Use :func:`enforce` when the bucket key is computed
in-handler. RFC-8628 ``slow_down`` is inline its response shape isn't
generic 429. Spec: docs/specs/v1.0/server/security.md.
generic 429.
"""
from __future__ import annotations
@ -14,9 +14,10 @@ from enum import StrEnum
from functools import wraps
from typing import ParamSpec, TypeVar
from flask import g, request, session
from flask import g, jsonify, make_response, request, session
from werkzeug.exceptions import TooManyRequests
from configs import dify_config
from libs.helper import RateLimiter, extract_remote_ip
@ -42,6 +43,11 @@ LIMIT_APPROVE_CONSOLE = RateLimit(10, timedelta(hours=1), (RateLimitScope.SESSIO
LIMIT_LOOKUP_PUBLIC = RateLimit(60, timedelta(minutes=5), (RateLimitScope.IP,))
LIMIT_ME_PER_ACCOUNT = RateLimit(60, timedelta(minutes=1), (RateLimitScope.ACCOUNT,))
LIMIT_ME_PER_EMAIL = RateLimit(60, timedelta(minutes=1), (RateLimitScope.SUBJECT_EMAIL,))
LIMIT_BEARER_PER_TOKEN = RateLimit(
limit=dify_config.OPENAPI_RATE_LIMIT_PER_TOKEN,
window=timedelta(minutes=1),
scopes=(RateLimitScope.TOKEN_ID,), # bucket key composed by caller from sha256(token)
)
def _one_key(scope: RateLimitScope) -> str:
@ -113,3 +119,22 @@ def enforce(spec: RateLimit, *, key: str) -> None:
if limiter.is_rate_limited(key):
raise TooManyRequests("rate_limited")
limiter.increment_rate_limit(key)
def enforce_bearer_rate_limit(token_hash: str) -> None:
"""Per-token rate limit on /openapi/v1/* bearer-authed routes.
Bucket key = ``token:<sha256_hex>`` so the same token shares one
bucket across api replicas (Redis-backed sliding window).
"""
limiter = _build_limiter(LIMIT_BEARER_PER_TOKEN)
key = f"token:{token_hash}"
if limiter.is_rate_limited(key):
retry_after = limiter.seconds_until_available(key)
response = make_response(
jsonify({"error": "rate_limited", "retry_after_ms": retry_after * 1000}),
429,
)
response.headers["Retry-After"] = str(retry_after)
raise TooManyRequests(response=response)
limiter.increment_rate_limit(key)

View File

@ -1,6 +1,6 @@
from unittest.mock import patch
from controllers.openapi.auth.composition import APP_PIPELINE, _resolve_app_authz_strategy
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE, _resolve_app_authz_strategy
from controllers.openapi.auth.pipeline import Pipeline
from controllers.openapi.auth.steps import (
AppAuthzCheck,
@ -8,6 +8,7 @@ from controllers.openapi.auth.steps import (
BearerCheck,
CallerMount,
ScopeCheck,
WorkspaceMembershipCheck,
)
from controllers.openapi.auth.strategies import (
AccountMounter,
@ -17,21 +18,25 @@ from controllers.openapi.auth.strategies import (
)
def test_app_pipeline_is_composed():
assert isinstance(APP_PIPELINE, Pipeline)
def test_pipeline_is_composed():
assert isinstance(OAUTH_BEARER_PIPELINE, Pipeline)
def test_app_pipeline_step_order():
steps = APP_PIPELINE._steps
def test_pipeline_step_order():
"""BearerCheck → ScopeCheck → AppResolver → WorkspaceMembershipCheck →
AppAuthzCheck CallerMount. Rate-limit is enforced inside
`BearerAuthenticator.authenticate`, not as a separate pipeline step."""
steps = OAUTH_BEARER_PIPELINE._steps
assert isinstance(steps[0], BearerCheck)
assert isinstance(steps[1], ScopeCheck)
assert isinstance(steps[2], AppResolver)
assert isinstance(steps[3], AppAuthzCheck)
assert isinstance(steps[4], CallerMount)
assert isinstance(steps[3], WorkspaceMembershipCheck)
assert isinstance(steps[4], AppAuthzCheck)
assert isinstance(steps[5], CallerMount)
def test_caller_mount_has_both_mounters():
cm = APP_PIPELINE._steps[4]
cm = OAUTH_BEARER_PIPELINE._steps[5]
kinds = {type(m) for m in cm._mounters}
assert AccountMounter in kinds
assert EndUserMounter in kinds

View File

@ -1,6 +1,5 @@
import uuid
from datetime import UTC, datetime
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
@ -8,7 +7,7 @@ from werkzeug.exceptions import Unauthorized
from controllers.openapi.auth.context import Context
from controllers.openapi.auth.steps import BearerCheck
from libs.oauth_bearer import ResolvedRow, SubjectType
from libs.oauth_bearer import AuthContext, InvalidBearerError, Scope, SubjectType
def _ctx(headers):
@ -22,37 +21,36 @@ def test_bearer_check_rejects_missing_header():
BearerCheck()(_ctx({}))
@patch("controllers.openapi.auth.steps._registry")
def test_bearer_check_rejects_unknown_prefix(reg):
reg.return_value.find.return_value = None
@patch("controllers.openapi.auth.steps.get_authenticator")
def test_bearer_check_rejects_unknown_prefix(get_auth):
get_auth.return_value.authenticate.side_effect = InvalidBearerError("unknown token prefix")
with pytest.raises(Unauthorized):
BearerCheck()(_ctx({"Authorization": "Bearer xxx_abc"}))
@patch("controllers.openapi.auth.steps._registry")
def test_bearer_check_populates_context(reg):
@patch("controllers.openapi.auth.steps.get_authenticator")
def test_bearer_check_populates_context(get_auth):
tok_id = uuid.uuid4()
fake_resolver = MagicMock()
fake_resolver.resolve.return_value = ResolvedRow(
authn = AuthContext(
subject_type=SubjectType.ACCOUNT,
subject_email="a@x.com",
subject_issuer=None,
account_id=None,
scopes=frozenset({Scope.FULL}),
token_id=tok_id,
expires_at=datetime.now(UTC),
)
fake_kind = SimpleNamespace(
subject_type=SubjectType.ACCOUNT,
scopes=frozenset({"full"}),
source="oauth-account",
resolver=fake_resolver,
expires_at=datetime.now(UTC),
token_hash="hash-1",
verified_tenants={},
)
reg.return_value.find.return_value = fake_kind
get_auth.return_value.authenticate.return_value = authn
ctx = _ctx({"Authorization": "Bearer dfoa_abc"})
BearerCheck()(ctx)
assert ctx.subject_type == SubjectType.ACCOUNT
assert ctx.subject_email == "a@x.com"
assert ctx.scopes == frozenset({"full"})
assert ctx.scopes == frozenset({Scope.FULL})
assert ctx.source == "oauth-account"
assert ctx.token_id == tok_id
assert ctx.token_hash == "hash-1"

View File

@ -0,0 +1,157 @@
"""Unit tests for WorkspaceMembershipCheck (Layer 0)."""
from __future__ import annotations
import uuid
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import Forbidden
from controllers.openapi.auth.context import Context
from controllers.openapi.auth.steps import WorkspaceMembershipCheck
from libs.oauth_bearer import SubjectType
def _ctx(*, subject_type, account_id, tenant_id, cached_verified_tenants=None, token_hash=None) -> Context:
c = Context(request=MagicMock(), required_scope="apps:read")
c.subject_type = subject_type
c.account_id = account_id
c.tenant = SimpleNamespace(id=tenant_id) if tenant_id else None
c.cached_verified_tenants = cached_verified_tenants
c.token_hash = token_hash
return c
@pytest.fixture
def step():
return WorkspaceMembershipCheck()
@patch("controllers.openapi.auth.steps.dify_config")
@patch("libs.oauth_bearer.record_layer0_verdict")
@patch("libs.oauth_bearer.db")
def test_skips_when_enterprise_enabled(mock_db, mock_record, mock_cfg, step):
mock_cfg.ENTERPRISE_ENABLED = True
ctx = _ctx(
subject_type=SubjectType.ACCOUNT,
account_id=str(uuid.uuid4()),
tenant_id=str(uuid.uuid4()),
cached_verified_tenants={},
token_hash="hash-1",
)
step(ctx) # no raise
mock_db.session.execute.assert_not_called()
mock_record.assert_not_called()
@patch("controllers.openapi.auth.steps.dify_config")
@patch("libs.oauth_bearer.record_layer0_verdict")
@patch("libs.oauth_bearer.db")
def test_skips_for_external_sso(mock_db, mock_record, mock_cfg, step):
mock_cfg.ENTERPRISE_ENABLED = False
ctx = _ctx(
subject_type=SubjectType.EXTERNAL_SSO,
account_id=None,
tenant_id=str(uuid.uuid4()),
cached_verified_tenants={},
token_hash="hash-1",
)
step(ctx) # no raise
mock_db.session.execute.assert_not_called()
mock_record.assert_not_called()
@patch("controllers.openapi.auth.steps.dify_config")
@patch("libs.oauth_bearer.record_layer0_verdict")
@patch("libs.oauth_bearer.db")
def test_uses_cached_ok(mock_db, mock_record, mock_cfg, step):
mock_cfg.ENTERPRISE_ENABLED = False
ctx = _ctx(
subject_type=SubjectType.ACCOUNT,
account_id="a1",
tenant_id="t1",
cached_verified_tenants={"t1": True},
token_hash="hash-1",
)
step(ctx)
mock_db.session.execute.assert_not_called()
mock_record.assert_not_called()
@patch("controllers.openapi.auth.steps.dify_config")
@patch("libs.oauth_bearer.record_layer0_verdict")
@patch("libs.oauth_bearer.db")
def test_uses_cached_denied(mock_db, mock_record, mock_cfg, step):
mock_cfg.ENTERPRISE_ENABLED = False
ctx = _ctx(
subject_type=SubjectType.ACCOUNT,
account_id="a1",
tenant_id="t1",
cached_verified_tenants={"t1": False},
token_hash="hash-1",
)
with pytest.raises(Forbidden, match="workspace_membership_revoked"):
step(ctx)
mock_db.session.execute.assert_not_called()
mock_record.assert_not_called()
@patch("controllers.openapi.auth.steps.dify_config")
@patch("libs.oauth_bearer.record_layer0_verdict")
@patch("libs.oauth_bearer.db")
def test_denies_when_no_membership(mock_db, mock_record, mock_cfg, step):
mock_cfg.ENTERPRISE_ENABLED = False
mock_db.session.execute.return_value.scalar_one_or_none.return_value = None
ctx = _ctx(
subject_type=SubjectType.ACCOUNT,
account_id="a1",
tenant_id="t1",
cached_verified_tenants={},
token_hash="hash-1",
)
with pytest.raises(Forbidden, match="workspace_membership_revoked"):
step(ctx)
mock_record.assert_called_once_with("hash-1", "t1", False)
@patch("controllers.openapi.auth.steps.dify_config")
@patch("libs.oauth_bearer.record_layer0_verdict")
@patch("libs.oauth_bearer.db")
def test_denies_when_account_inactive(mock_db, mock_record, mock_cfg, step):
mock_cfg.ENTERPRISE_ENABLED = False
mock_db.session.execute.side_effect = [
MagicMock(scalar_one_or_none=MagicMock(return_value="join-id")),
MagicMock(scalar_one_or_none=MagicMock(return_value="banned")),
]
ctx = _ctx(
subject_type=SubjectType.ACCOUNT,
account_id="a1",
tenant_id="t1",
cached_verified_tenants={},
token_hash="hash-1",
)
with pytest.raises(Forbidden, match="workspace_membership_revoked"):
step(ctx)
mock_record.assert_called_once_with("hash-1", "t1", False)
@patch("controllers.openapi.auth.steps.dify_config")
@patch("libs.oauth_bearer.record_layer0_verdict")
@patch("libs.oauth_bearer.db")
def test_allows_active_member(mock_db, mock_record, mock_cfg, step):
mock_cfg.ENTERPRISE_ENABLED = False
mock_db.session.execute.side_effect = [
MagicMock(scalar_one_or_none=MagicMock(return_value="join-id")),
MagicMock(scalar_one_or_none=MagicMock(return_value="active")),
]
ctx = _ctx(
subject_type=SubjectType.ACCOUNT,
account_id="a1",
tenant_id="t1",
cached_verified_tenants={},
token_hash="hash-1",
)
step(ctx) # no raise
mock_record.assert_called_once_with("hash-1", "t1", True)

View File

@ -7,8 +7,9 @@ from controllers.openapi.auth.pipeline import Pipeline
def bypass_pipeline(monkeypatch):
"""Stub Pipeline.run so endpoint decoration does not invoke real auth.
Module-level @APP_PIPELINE.guard(...) captures the real APP_PIPELINE at
import time; mocking the module attribute does not undo that. Patching
Pipeline.run on the class is the bypass that actually works.
Module-level @OAUTH_BEARER_PIPELINE.guard(...) captures the real
pipeline at import time; mocking the module attribute does not undo
that. Patching Pipeline.run on the class is the bypass that actually
works.
"""
monkeypatch.setattr(Pipeline, "run", lambda self, ctx: None)

View File

@ -0,0 +1,94 @@
"""Unit tests for record_layer0_verdict — merge L0 verdict into AuthContext cache."""
from __future__ import annotations
import json
from unittest.mock import MagicMock, patch
import pytest
from libs.oauth_bearer import record_layer0_verdict
@pytest.fixture
def mock_redis():
return MagicMock()
@patch("libs.oauth_bearer.redis_client")
def test_no_op_when_cache_entry_missing(mock_redis):
mock_redis.get.return_value = None
record_layer0_verdict("h1", "t1", True)
mock_redis.setex.assert_not_called()
@patch("libs.oauth_bearer.redis_client")
def test_no_op_when_cache_entry_invalid_marker(mock_redis):
mock_redis.get.return_value = b"invalid"
record_layer0_verdict("h1", "t1", True)
mock_redis.setex.assert_not_called()
@patch("libs.oauth_bearer.redis_client")
def test_no_op_when_json_malformed(mock_redis):
mock_redis.get.return_value = b"not json"
record_layer0_verdict("h1", "t1", True)
mock_redis.setex.assert_not_called()
@patch("libs.oauth_bearer.redis_client")
def test_no_op_when_ttl_expired(mock_redis):
mock_redis.get.return_value = json.dumps(
{
"subject_email": "e",
"subject_issuer": None,
"account_id": None,
"token_id": "tid",
"expires_at": None,
}
).encode()
mock_redis.ttl.return_value = -1
record_layer0_verdict("h1", "t1", True)
mock_redis.setex.assert_not_called()
@patch("libs.oauth_bearer.redis_client")
def test_merges_new_tenant_verdict(mock_redis):
mock_redis.get.return_value = json.dumps(
{
"subject_email": "e",
"subject_issuer": None,
"account_id": None,
"token_id": "tid",
"expires_at": None,
"verified_tenants": {"t0": True},
}
).encode()
mock_redis.ttl.return_value = 42
record_layer0_verdict("h1", "t1", False)
mock_redis.setex.assert_called_once()
args = mock_redis.setex.call_args
assert args.args[0] == "auth:token:h1"
assert args.args[1] == 42 # remaining TTL preserved
written = json.loads(args.args[2])
assert written["verified_tenants"] == {"t0": True, "t1": False}
@patch("libs.oauth_bearer.redis_client")
def test_merges_when_field_absent_from_legacy_entry(mock_redis):
"""Backward compat: legacy cache entry without verified_tenants field."""
mock_redis.get.return_value = json.dumps(
{
"subject_email": "e",
"subject_issuer": None,
"account_id": None,
"token_id": "tid",
"expires_at": None,
}
).encode()
mock_redis.ttl.return_value = 42
record_layer0_verdict("h1", "t1", True)
written = json.loads(mock_redis.setex.call_args.args[2])
assert written["verified_tenants"] == {"t1": True}

View File

@ -12,8 +12,8 @@ from flask import Flask, g
from werkzeug.exceptions import Forbidden
from libs.oauth_bearer import (
SCOPE_FULL,
AuthContext,
Scope,
SubjectType,
require_scope,
)
@ -26,7 +26,7 @@ def app() -> Flask:
return app
def _ctx(scopes: frozenset[str]) -> AuthContext:
def _ctx(scopes) -> AuthContext:
return AuthContext(
subject_type=SubjectType.ACCOUNT,
subject_email="user@example.com",
@ -36,6 +36,8 @@ def _ctx(scopes: frozenset[str]) -> AuthContext:
token_id=uuid.uuid4(),
source="oauth_account",
expires_at=None,
token_hash="h1",
verified_tenants={},
)
@ -67,7 +69,7 @@ def test_require_scope_full_passes_any_check(app: Flask):
return "ok"
with app.test_request_context():
g.auth_ctx = _ctx(frozenset({SCOPE_FULL}))
g.auth_ctx = _ctx(frozenset({Scope.FULL}))
assert view() == "ok"

View File

@ -0,0 +1,74 @@
"""Unit tests for the per-token bearer rate limit primitive."""
from __future__ import annotations
from datetime import timedelta
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import TooManyRequests
from libs.helper import RateLimiter
from libs.rate_limit import (
LIMIT_BEARER_PER_TOKEN,
enforce_bearer_rate_limit,
)
@pytest.fixture
def mock_redis():
return MagicMock()
def test_limit_bearer_per_token_uses_60_per_minute_default():
assert LIMIT_BEARER_PER_TOKEN.limit == 60
assert LIMIT_BEARER_PER_TOKEN.window == timedelta(minutes=1)
def test_seconds_until_available_returns_remaining_window(mock_redis):
"""ZSET oldest entry score = 100; window = 60s; now = 130s → remaining = 30s."""
rl = RateLimiter("rl:bearer:token", max_attempts=60, time_window=60, redis_client=mock_redis)
mock_redis.zrange.return_value = [(b"member-1", 100.0)]
with patch("libs.helper.time.time", return_value=130):
assert rl.seconds_until_available("k1") == 30
def test_seconds_until_available_floor_one_second(mock_redis):
"""Even when math says <1s remaining, return at least 1 so client backs off measurably."""
rl = RateLimiter("rl:bearer:token", max_attempts=60, time_window=60, redis_client=mock_redis)
mock_redis.zrange.return_value = [(b"member-1", 119.5)]
with patch("libs.helper.time.time", return_value=180):
# window expired (180 > 119.5+60=179.5 by 0.5s) — bucket is actually free now
# but this method only called when is_rate_limited() == True; defensive floor.
assert rl.seconds_until_available("k1") >= 1
def test_seconds_until_available_empty_bucket(mock_redis):
"""No entries → 1s sentinel (defensive; should not be reached when limited)."""
rl = RateLimiter("rl:bearer:token", max_attempts=60, time_window=60, redis_client=mock_redis)
mock_redis.zrange.return_value = []
assert rl.seconds_until_available("k1") == 1
@patch("libs.rate_limit._build_limiter")
def test_enforce_bearer_rate_limit_passes_under_limit(mock_build):
limiter = MagicMock()
limiter.is_rate_limited.return_value = False
mock_build.return_value = limiter
enforce_bearer_rate_limit("hash-1")
limiter.increment_rate_limit.assert_called_once_with("token:hash-1")
@patch("libs.rate_limit._build_limiter")
def test_enforce_bearer_rate_limit_raises_429_with_retry_after(mock_build):
limiter = MagicMock()
limiter.is_rate_limited.return_value = True
limiter.seconds_until_available.return_value = 23
mock_build.return_value = limiter
with pytest.raises(TooManyRequests) as exc:
enforce_bearer_rate_limit("hash-1")
headers = dict(exc.value.get_response().headers)
assert headers.get("Retry-After") == "23"
body = exc.value.get_response().get_json() or {}
assert body.get("error") == "rate_limited"
assert body.get("retry_after_ms") == 23000

View File

@ -0,0 +1,93 @@
"""Unit tests for require_workspace_member."""
from __future__ import annotations
import uuid
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import Forbidden
from libs.oauth_bearer import AuthContext, Scope, SubjectType, require_workspace_member
def _ctx(verified: dict[str, bool] | None = None, *, account: bool = True) -> AuthContext:
return AuthContext(
subject_type=SubjectType.ACCOUNT if account else SubjectType.EXTERNAL_SSO,
subject_email="e@example.com",
subject_issuer=None,
account_id=uuid.uuid4() if account else None,
scopes=frozenset({Scope.FULL}),
token_id=uuid.uuid4(),
source="oauth_account",
expires_at=None,
token_hash="h1",
verified_tenants=dict(verified or {}),
)
@patch("libs.oauth_bearer.dify_config")
def test_skips_when_enterprise_enabled(mock_cfg):
mock_cfg.ENTERPRISE_ENABLED = True
require_workspace_member(_ctx(), "t1")
@patch("libs.oauth_bearer.dify_config")
def test_skips_for_external_sso(mock_cfg):
mock_cfg.ENTERPRISE_ENABLED = False
require_workspace_member(_ctx(account=False), "t1")
@patch("libs.oauth_bearer.db")
@patch("libs.oauth_bearer.dify_config")
def test_uses_cached_ok_no_db_access(mock_cfg, mock_db):
mock_cfg.ENTERPRISE_ENABLED = False
require_workspace_member(_ctx({"t1": True}), "t1")
mock_db.session.execute.assert_not_called()
@patch("libs.oauth_bearer.db")
@patch("libs.oauth_bearer.dify_config")
def test_uses_cached_denied(mock_cfg, mock_db):
mock_cfg.ENTERPRISE_ENABLED = False
with pytest.raises(Forbidden, match="workspace_membership_revoked"):
require_workspace_member(_ctx({"t1": False}), "t1")
mock_db.session.execute.assert_not_called()
@patch("libs.oauth_bearer.record_layer0_verdict")
@patch("libs.oauth_bearer.db")
@patch("libs.oauth_bearer.dify_config")
def test_denies_when_no_membership(mock_cfg, mock_db, mock_record):
mock_cfg.ENTERPRISE_ENABLED = False
mock_db.session.execute.return_value.scalar_one_or_none.return_value = None
with pytest.raises(Forbidden, match="workspace_membership_revoked"):
require_workspace_member(_ctx({}), "t1")
mock_record.assert_called_once_with("h1", "t1", False)
@patch("libs.oauth_bearer.record_layer0_verdict")
@patch("libs.oauth_bearer.db")
@patch("libs.oauth_bearer.dify_config")
def test_denies_when_account_inactive(mock_cfg, mock_db, mock_record):
mock_cfg.ENTERPRISE_ENABLED = False
mock_db.session.execute.side_effect = [
MagicMock(scalar_one_or_none=MagicMock(return_value="join-id")),
MagicMock(scalar_one_or_none=MagicMock(return_value="banned")),
]
with pytest.raises(Forbidden, match="workspace_membership_revoked"):
require_workspace_member(_ctx({}), "t1")
mock_record.assert_called_once_with("h1", "t1", False)
@patch("libs.oauth_bearer.record_layer0_verdict")
@patch("libs.oauth_bearer.db")
@patch("libs.oauth_bearer.dify_config")
def test_allows_active_member(mock_cfg, mock_db, mock_record):
mock_cfg.ENTERPRISE_ENABLED = False
mock_db.session.execute.side_effect = [
MagicMock(scalar_one_or_none=MagicMock(return_value="join-id")),
MagicMock(scalar_one_or_none=MagicMock(return_value="active")),
]
require_workspace_member(_ctx({}), "t1")
mock_record.assert_called_once_with("h1", "t1", True)