diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index e036003ca3..77cf59ce68 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -908,6 +908,12 @@ class AuthConfig(BaseSettings): default=True, ) + OPENAPI_RATE_LIMIT_PER_TOKEN: PositiveInt = Field( + description="Per-token rate limit on /openapi/v1/* (requests per minute). " + "Bucket keyed on sha256(token), shared across api replicas via Redis.", + default=60, + ) + class ModerationConfig(BaseSettings): """ diff --git a/api/controllers/openapi/app_info.py b/api/controllers/openapi/app_info.py index b9a805d015..5c9a714fee 100644 --- a/api/controllers/openapi/app_info.py +++ b/api/controllers/openapi/app_info.py @@ -2,36 +2,34 @@ from __future__ import annotations +from flask import g from flask_restx import Resource -from pydantic import BaseModel +from werkzeug.exceptions import NotFound from controllers.openapi import openapi_ns -from controllers.openapi.auth.composition import APP_PIPELINE - - -class AppInfoResponse(BaseModel): - id: str - name: str - description: str | None = None - mode: str - author_name: str | None = None - tags: list[str] = [] - - -def _unpack_app(app_model): - return app_model +from controllers.openapi.apps import account_or_404, app_info_payload +from extensions.ext_database import db +from libs.oauth_bearer import ( + ACCEPT_USER_ANY, + Scope, + require_scope, + require_workspace_member, + validate_bearer, +) +from models import App @openapi_ns.route("/apps//info") class AppInfoApi(Resource): - @APP_PIPELINE.guard(scope="apps:run") - def get(self, app_id, app_model, caller, caller_kind): - app = _unpack_app(app_model) - return AppInfoResponse( - id=app.id, - name=app.name, - description=app.description, - mode=app.mode, - author_name=app.author_name, - tags=[t.name for t in app.tags], - ).model_dump(mode="json") + @validate_bearer(accept=ACCEPT_USER_ANY) + @require_scope(Scope.APPS_READ) # type: ignore[reportUntypedFunctionDecorator] + def get(self, app_id: str): + ctx = g.auth_ctx + account_or_404(ctx) + + app = db.session.get(App, app_id) + if not app or app.status != "normal": + raise NotFound("app not found") + + require_workspace_member(ctx, str(app.tenant_id)) + return app_info_payload(app), 200 diff --git a/api/controllers/openapi/auth/__init__.py b/api/controllers/openapi/auth/__init__.py index ef255d2491..17ac5493d0 100644 --- a/api/controllers/openapi/auth/__init__.py +++ b/api/controllers/openapi/auth/__init__.py @@ -1,3 +1,3 @@ -from controllers.openapi.auth.composition import APP_PIPELINE +from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE -__all__ = ["APP_PIPELINE"] +__all__ = ["OAUTH_BEARER_PIPELINE"] diff --git a/api/controllers/openapi/auth/composition.py b/api/controllers/openapi/auth/composition.py index a8da919f29..3fc155e9c1 100644 --- a/api/controllers/openapi/auth/composition.py +++ b/api/controllers/openapi/auth/composition.py @@ -1,6 +1,9 @@ -"""APP_PIPELINE — the only auth scheme for openapi app endpoints. +"""`OAUTH_BEARER_PIPELINE` — the auth scheme for openapi `/run` endpoints. -Endpoints attach via @APP_PIPELINE.guard(scope=…). No alternative paths. +Endpoints attach via `@OAUTH_BEARER_PIPELINE.guard(scope=…)`. No alternative +paths. Read endpoints (`/apps`, `/info`, `/parameters`, `/describe`) skip +the pipeline and use `validate_bearer + require_scope + require_workspace_member` +inline — they don't need `AppAuthzCheck`/`CallerMount`. """ from __future__ import annotations @@ -12,6 +15,7 @@ from controllers.openapi.auth.steps import ( BearerCheck, CallerMount, ScopeCheck, + WorkspaceMembershipCheck, ) from controllers.openapi.auth.strategies import ( AccountMounter, @@ -29,10 +33,11 @@ def _resolve_app_authz_strategy() -> AppAuthzStrategy: return MembershipStrategy() -APP_PIPELINE = Pipeline( +OAUTH_BEARER_PIPELINE = Pipeline( BearerCheck(), ScopeCheck(), AppResolver(), + WorkspaceMembershipCheck(), AppAuthzCheck(_resolve_app_authz_strategy), CallerMount(AccountMounter(), EndUserMounter()), ) diff --git a/api/controllers/openapi/auth/context.py b/api/controllers/openapi/auth/context.py index df43a5183d..48a6fd6aeb 100644 --- a/api/controllers/openapi/auth/context.py +++ b/api/controllers/openapi/auth/context.py @@ -7,29 +7,35 @@ read populated values via the decorator's kwargs unpacking. from __future__ import annotations +import uuid from dataclasses import dataclass, field from datetime import datetime -from typing import Literal, Protocol +from typing import TYPE_CHECKING, Literal, Protocol from flask import Request -from libs.oauth_bearer import SubjectType +from libs.oauth_bearer import Scope, SubjectType + +if TYPE_CHECKING: + from models import App, Tenant @dataclass class Context: request: Request - required_scope: str + required_scope: Scope subject_type: SubjectType | None = None subject_email: str | None = None subject_issuer: str | None = None - account_id: str | None = None - scopes: frozenset[str] = field(default_factory=frozenset) - token_id: str | None = None + account_id: uuid.UUID | None = None + scopes: frozenset[Scope] = field(default_factory=frozenset) + token_id: uuid.UUID | None = None + token_hash: str | None = None + cached_verified_tenants: dict[str, bool] | None = None source: str | None = None expires_at: datetime | None = None - app: object | None = None - tenant: object | None = None + app: App | None = None + tenant: Tenant | None = None caller: object | None = None caller_kind: Literal["account", "end_user"] | None = None diff --git a/api/controllers/openapi/auth/pipeline.py b/api/controllers/openapi/auth/pipeline.py index b4ca1e793b..1dbcfab9b2 100644 --- a/api/controllers/openapi/auth/pipeline.py +++ b/api/controllers/openapi/auth/pipeline.py @@ -12,6 +12,7 @@ from functools import wraps from flask import request from controllers.openapi.auth.context import Context, Step +from libs.oauth_bearer import Scope class Pipeline: @@ -22,7 +23,7 @@ class Pipeline: for step in self._steps: step(ctx) - def guard(self, *, scope: str): + def guard(self, *, scope: Scope): def decorator(view): @wraps(view) def decorated(*args, **kwargs): diff --git a/api/controllers/openapi/auth/steps.py b/api/controllers/openapi/auth/steps.py index c671fec21f..f23512c485 100644 --- a/api/controllers/openapi/auth/steps.py +++ b/api/controllers/openapi/auth/steps.py @@ -1,7 +1,7 @@ """Pipeline steps. Each is one responsibility. -BearerCheck is the only step that touches the token registry; downstream -steps see only the populated Context. +`BearerCheck` is the only step that touches the token registry; downstream +steps see only the populated `Context`. """ from __future__ import annotations @@ -10,62 +10,52 @@ from collections.abc import Callable from werkzeug.exceptions import BadRequest, Forbidden, NotFound, Unauthorized +from configs import dify_config from controllers.openapi.auth.context import Context from controllers.openapi.auth.strategies import AppAuthzStrategy, CallerMounter from extensions.ext_database import db -from libs.oauth_bearer import TokenExpiredError, get_authenticator, sha256_hex +from libs.oauth_bearer import ( + InvalidBearerError, + Scope, + SubjectType, + _extract_bearer, # type: ignore[attr-defined] + check_workspace_membership, + get_authenticator, +) from models import App, Tenant, TenantStatus -def _registry(): - return get_authenticator().registry - - -def _extract_bearer(req) -> str | None: - auth = req.headers.get("Authorization") - if not auth or not auth.lower().startswith("bearer "): - return None - return auth.split(None, 1)[1].strip() or None - - -def _hash_token(token: str) -> str: - return sha256_hex(token) - - class BearerCheck: - """Resolve bearer → populate identity fields.""" + """Resolve bearer → populate identity fields. Rate-limit is enforced + inside `BearerAuthenticator.authenticate`, so no separate step here.""" def __call__(self, ctx: Context) -> None: token = _extract_bearer(ctx.request) if not token: raise Unauthorized("bearer required") - kind = _registry().find(token) - if kind is None: - raise Unauthorized("invalid bearer prefix") - try: - row = kind.resolver.resolve(_hash_token(token)) - except TokenExpiredError: - raise Unauthorized("token expired") - if row is None: - raise Unauthorized("invalid bearer") + authn = get_authenticator().authenticate(token) + except InvalidBearerError as e: + raise Unauthorized(str(e)) - ctx.subject_type = kind.subject_type - ctx.subject_email = row.subject_email - ctx.subject_issuer = row.subject_issuer - ctx.account_id = row.account_id - ctx.scopes = kind.scopes - ctx.source = kind.source - ctx.token_id = row.token_id - ctx.expires_at = row.expires_at + ctx.subject_type = authn.subject_type + ctx.subject_email = authn.subject_email + ctx.subject_issuer = authn.subject_issuer + ctx.account_id = authn.account_id + ctx.scopes = frozenset(authn.scopes) + ctx.source = authn.source + ctx.token_id = authn.token_id + ctx.expires_at = authn.expires_at + ctx.token_hash = authn.token_hash + ctx.cached_verified_tenants = dict(authn.verified_tenants) class ScopeCheck: """Verify ctx.scopes (already populated by BearerCheck) covers required.""" def __call__(self, ctx: Context) -> None: - if "full" in ctx.scopes or ctx.required_scope in ctx.scopes: + if Scope.FULL in ctx.scopes or ctx.required_scope in ctx.scopes: return raise Forbidden("insufficient_scope") @@ -73,8 +63,9 @@ class ScopeCheck: class AppResolver: """Read app_id from request.view_args, populate ctx.app + ctx.tenant. - Every endpoint using APP_PIPELINE must declare ```` in - its route — that is the design lock-in (no body / header coupling). + Every endpoint using the OAuth bearer pipeline must declare + ```` in its route — that is the design lock-in (no body / + header coupling). """ def __call__(self, ctx: Context) -> None: @@ -92,6 +83,31 @@ class AppResolver: ctx.app, ctx.tenant = app, tenant +class WorkspaceMembershipCheck: + """Layer 0 — workspace membership gate. + + CE-only (skipped when ENTERPRISE_ENABLED). Account-subject bearers + (dfoa_) only — SSO subjects skip. + """ + + def __call__(self, ctx: Context) -> None: + if dify_config.ENTERPRISE_ENABLED: + return + if ctx.subject_type != SubjectType.ACCOUNT: + return + if ctx.account_id is None or ctx.tenant is None: + raise Unauthorized("account_id or tenant unset — BearerCheck or AppResolver did not run") + if ctx.token_hash is None: + raise Unauthorized("token_hash unset — BearerCheck did not run") + + check_workspace_membership( + account_id=ctx.account_id, + tenant_id=ctx.tenant.id, + token_hash=ctx.token_hash, + cached_verdicts=ctx.cached_verified_tenants or {}, + ) + + class AppAuthzCheck: def __init__(self, resolve_strategy: Callable[[], AppAuthzStrategy]) -> None: self._resolve = resolve_strategy diff --git a/api/controllers/openapi/auth/strategies.py b/api/controllers/openapi/auth/strategies.py index d8c02f7881..1ab4c650ac 100644 --- a/api/controllers/openapi/auth/strategies.py +++ b/api/controllers/openapi/auth/strategies.py @@ -7,6 +7,7 @@ composition stays a flat list. from __future__ import annotations +import uuid from typing import Protocol from flask import current_app @@ -38,7 +39,7 @@ class AclStrategy: return False return EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp( user_id=ctx.subject_email, - app_id=ctx.app.id, # type: ignore[attr-defined] + app_id=ctx.app.id, ) @@ -55,10 +56,10 @@ class MembershipStrategy: return False if ctx.tenant is None: return False - return _has_tenant_membership(ctx.account_id, ctx.tenant.id) # type: ignore[attr-defined] + return _has_tenant_membership(ctx.account_id, ctx.tenant.id) -def _has_tenant_membership(account_id: str | None, tenant_id: str) -> bool: +def _has_tenant_membership(account_id: uuid.UUID | str | None, tenant_id: str) -> bool: if not account_id: return False row = db.session.execute( @@ -92,7 +93,7 @@ class AccountMounter: account = db.session.get(Account, ctx.account_id) if account is None: raise RuntimeError("AccountMounter: account row missing for resolved bearer") - account.current_tenant = ctx.tenant # type: ignore[assignment] + account.current_tenant = ctx.tenant _login_as(account) ctx.caller, ctx.caller_kind = account, "account" @@ -106,8 +107,8 @@ class EndUserMounter: raise RuntimeError("EndUserMounter: tenant/app/subject_email unset — earlier steps did not run") end_user = EndUserService.get_or_create_end_user_by_type( InvokeFrom.OPENAPI, - tenant_id=ctx.tenant.id, # type: ignore[attr-defined] - app_id=ctx.app.id, # type: ignore[attr-defined] + tenant_id=ctx.tenant.id, + app_id=ctx.app.id, user_id=ctx.subject_email, ) _login_as(end_user) diff --git a/api/controllers/openapi/chat_messages.py b/api/controllers/openapi/chat_messages.py index f746edc7da..bbf28b57ff 100644 --- a/api/controllers/openapi/chat_messages.py +++ b/api/controllers/openapi/chat_messages.py @@ -3,7 +3,7 @@ service_api/app/completion.py:ChatApi. Differences from service_api: - App is in URL path, not header. -- One decorator: @APP_PIPELINE.guard(scope="apps:run"). +- One decorator: @OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN). - Request body has no `user` field (Model 2: identity is the bearer). - Typed Request and Response models. - invoke_from = InvokeFrom.OPENAPI. @@ -25,7 +25,7 @@ import services from controllers.openapi import openapi_ns from controllers.openapi._audit import emit_app_run from controllers.openapi._models import MessageMetadata -from controllers.openapi.auth.composition import APP_PIPELINE +from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE from controllers.service_api.app.error import ( AppUnavailableError, CompletionRequestError, @@ -45,6 +45,7 @@ from core.errors.error import ( from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import UUIDStrOrEmpty +from libs.oauth_bearer import Scope from models.model import App, AppMode from services.app_generate_service import AppGenerateService from services.errors.app import ( @@ -101,7 +102,7 @@ def _unpack_caller(caller): @openapi_ns.route("/apps//chat-messages") class ChatMessagesApi(Resource): - @APP_PIPELINE.guard(scope="apps:run") + @OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN) def post(self, app_id: str, app_model: App, caller, caller_kind: str): app = _unpack_app(app_model) if AppMode.value_of(app.mode) not in { diff --git a/api/controllers/openapi/completion_messages.py b/api/controllers/openapi/completion_messages.py index 1180d43113..bd3df450fb 100644 --- a/api/controllers/openapi/completion_messages.py +++ b/api/controllers/openapi/completion_messages.py @@ -16,7 +16,7 @@ import services from controllers.openapi import openapi_ns from controllers.openapi._audit import emit_app_run from controllers.openapi._models import MessageMetadata -from controllers.openapi.auth.composition import APP_PIPELINE +from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE from controllers.service_api.app.error import ( AppUnavailableError, CompletionRequestError, @@ -33,6 +33,7 @@ from core.errors.error import ( ) from graphon.model_runtime.errors.invoke import InvokeError from libs import helper +from libs.oauth_bearer import Scope from models.model import App, AppMode from services.app_generate_service import AppGenerateService @@ -67,7 +68,7 @@ def _unpack_caller(caller): @openapi_ns.route("/apps//completion-messages") class CompletionMessagesApi(Resource): - @APP_PIPELINE.guard(scope="apps:run") + @OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN) def post(self, app_id: str, app_model: App, caller, caller_kind: str): app = _unpack_app(app_model) if AppMode.value_of(app.mode) != AppMode.COMPLETION: diff --git a/api/controllers/openapi/workflow_run.py b/api/controllers/openapi/workflow_run.py index c71e9ab529..ce11b1ded2 100644 --- a/api/controllers/openapi/workflow_run.py +++ b/api/controllers/openapi/workflow_run.py @@ -15,7 +15,7 @@ from werkzeug.exceptions import BadRequest, InternalServerError, NotFound from controllers.common.controller_schemas import WorkflowRunPayload as WorkflowRunPayloadBase from controllers.openapi import openapi_ns from controllers.openapi._audit import emit_app_run -from controllers.openapi.auth.composition import APP_PIPELINE +from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE from controllers.service_api.app.error import ( CompletionRequestError, NotWorkflowAppError, @@ -32,6 +32,7 @@ from core.errors.error import ( ) from graphon.model_runtime.errors.invoke import InvokeError from libs import helper +from libs.oauth_bearer import Scope from models.model import App, AppMode from services.app_generate_service import AppGenerateService from services.errors.app import ( @@ -77,7 +78,7 @@ def _unpack_caller(caller): @openapi_ns.route("/apps//workflows/run") class WorkflowRunApi(Resource): - @APP_PIPELINE.guard(scope="apps:run") + @OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN) def post(self, app_id: str, app_model: App, caller, caller_kind: str): app = _unpack_app(app_model) if AppMode.value_of(app.mode) != AppMode.WORKFLOW: diff --git a/api/libs/helper.py b/api/libs/helper.py index ac69a11084..47472c17da 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -542,3 +542,18 @@ class RateLimiter: self._redis_client.zadd(key, {member: current_time}) self._redis_client.expire(key, self.time_window * 2) + + def seconds_until_available(self, email: str) -> int: + """Seconds until the oldest in-window entry expires, freeing a slot. + + Defensive floor of 1 second. Caller should only invoke this after + is_rate_limited() returned True. + """ + key = self._get_key(email) + oldest = cast(Any, self._redis_client).zrange(key, 0, 0, withscores=True) + if not oldest: + return 1 + _member, score = oldest[0] + free_at = int(score) + self.time_window + remaining = free_at - int(time.time()) + return max(remaining, 1) diff --git a/api/libs/oauth_bearer.py b/api/libs/oauth_bearer.py index f524c0b0b4..24e7028b2a 100644 --- a/api/libs/oauth_bearer.py +++ b/api/libs/oauth_bearer.py @@ -12,19 +12,22 @@ import json import logging import uuid from collections.abc import Callable, Iterable -from dataclasses import dataclass +from dataclasses import dataclass, field from datetime import UTC, datetime from enum import StrEnum from functools import wraps from typing import Literal, ParamSpec, Protocol, TypeVar from flask import g, request -from sqlalchemy import update +from sqlalchemy import select, update from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, ServiceUnavailable, Unauthorized from configs import dify_config -from models import OAuthAccessToken +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from libs.rate_limit import enforce_bearer_rate_limit +from models import Account, OAuthAccessToken, TenantAccountJoin logger = logging.getLogger(__name__) @@ -39,20 +42,54 @@ class SubjectType(StrEnum): EXTERNAL_SSO = "external_sso" +class Scope(StrEnum): + """Catalog of bearer scopes recognised by the openapi surface. + + `FULL` is the catch-all carried by `dfoa_` account tokens — it satisfies + any per-route `require_scope`. `dfoe_` tokens carry the per-feature scopes + (`APPS_RUN` today). + """ + + FULL = "full" + APPS_READ = "apps:read" + APPS_RUN = "apps:run" + + +class Accepts(StrEnum): + """Subject types a route is willing to accept as caller.""" + + USER_ACCOUNT = "user_account" + USER_EXT_SSO = "user_ext_sso" + + +ACCEPT_USER_ANY: frozenset[Accepts] = frozenset({Accepts.USER_ACCOUNT, Accepts.USER_EXT_SSO}) + +_SUBJECT_TO_ACCEPT: dict[SubjectType, Accepts] = { + SubjectType.ACCOUNT: Accepts.USER_ACCOUNT, + SubjectType.EXTERNAL_SSO: Accepts.USER_EXT_SSO, +} + + @dataclass(frozen=True, slots=True) class AuthContext: """Attached to ``g.auth_ctx``. ``scopes`` / ``subject_type`` / ``source`` come from the TokenKind, not the DB — corrupt rows can't elevate scope. + + `verified_tenants` is a snapshot of the Layer-0 verdict cache at + authenticate time. Per-request mutations write through to Redis via + `record_layer0_verdict`; this snapshot is not updated in place (frozen). """ subject_type: SubjectType subject_email: str | None subject_issuer: str | None account_id: uuid.UUID | None - scopes: frozenset[str] + scopes: frozenset[Scope] token_id: uuid.UUID source: str expires_at: datetime | None + token_hash: str + verified_tenants: dict[str, bool] = field(default_factory=dict) @dataclass(frozen=True, slots=True) @@ -62,6 +99,47 @@ class ResolvedRow: account_id: uuid.UUID | None token_id: uuid.UUID expires_at: datetime | None + verified_tenants: dict[str, bool] = field(default_factory=dict) + + def to_cache(self) -> dict: + return { + "subject_email": self.subject_email, + "subject_issuer": self.subject_issuer, + "account_id": str(self.account_id) if self.account_id else None, + "token_id": str(self.token_id), + "expires_at": self.expires_at.isoformat() if self.expires_at else None, + "verified_tenants": dict(self.verified_tenants), + } + + @classmethod + def from_cache(cls, data: dict) -> ResolvedRow: + return cls( + subject_email=data["subject_email"], + subject_issuer=data["subject_issuer"], + account_id=uuid.UUID(data["account_id"]) if data["account_id"] else None, + token_id=uuid.UUID(data["token_id"]), + expires_at=datetime.fromisoformat(data["expires_at"]) if data["expires_at"] else None, + verified_tenants=_coerce_verified_tenants(data.get("verified_tenants")), + ) + + +def _coerce_verified_tenants(raw: object) -> dict[str, bool]: + """Tolerate legacy entries that stored 'ok'/'denied' string verdicts. + + TODO(post-v1.0): remove once the AuthContext cache TTL has fully cycled + on all live deployments (60s TTL → safe to drop one release after rollout). + """ + if not isinstance(raw, dict): + return {} + out: dict[str, bool] = {} + for k, v in raw.items(): + if isinstance(v, bool): + out[k] = v + elif v == "ok": + out[k] = True + elif v == "denied": + out[k] = False + return out class Resolver(Protocol): @@ -73,7 +151,7 @@ class Resolver(Protocol): class TokenKind: prefix: str subject_type: SubjectType - scopes: frozenset[str] + scopes: frozenset[Scope] source: str resolver: Resolver @@ -129,12 +207,20 @@ class BearerAuthenticator: return self._registry def authenticate(self, token: str) -> AuthContext: + """Identity + per-token rate limit (single source). + + Both the openapi pipeline (`BearerCheck`) and the decorator + (`validate_bearer`) call this — rate-limit fires exactly once per + request regardless of which path hosts the route. + """ kind = self._registry.find(token) if kind is None: raise InvalidBearerError("unknown token prefix") - row = kind.resolver.resolve(sha256_hex(token)) + token_hash = sha256_hex(token) + row = kind.resolver.resolve(token_hash) if row is None: raise InvalidBearerError("token unknown or revoked") + enforce_bearer_rate_limit(token_hash) return AuthContext( subject_type=kind.subject_type, subject_email=row.subject_email, @@ -144,6 +230,8 @@ class BearerAuthenticator: token_id=row.token_id, source=kind.source, expires_at=row.expires_at, + token_hash=token_hash, + verified_tenants=dict(row.verified_tenants), ) @@ -171,8 +259,6 @@ class OAuthAccessTokenResolver: positive_ttl: int = POSITIVE_TTL_SECONDS, negative_ttl: int = NEGATIVE_TTL_SECONDS, ) -> None: - # session_factory and the cache helpers below are friend-API for - # _VariantResolver in this module — kept public-named on purpose. self.session_factory = session_factory self._redis = redis_client self._positive_ttl = positive_ttl @@ -195,8 +281,7 @@ class OAuthAccessTokenResolver: if text == "invalid": return "invalid" try: - data = json.loads(text) - return _row_from_cache(data) + return ResolvedRow.from_cache(json.loads(text)) except (ValueError, KeyError): logger.warning("auth:token cache entry malformed; treating as miss") return None @@ -205,7 +290,7 @@ class OAuthAccessTokenResolver: self._redis.setex( self._cache_key(token_hash), self._positive_ttl, - json.dumps(_row_to_cache(row)), + json.dumps(row.to_cache()), ) def cache_set_negative(self, token_hash: str) -> None: @@ -213,7 +298,7 @@ class OAuthAccessTokenResolver: def hard_expire(self, session: Session, row_id: uuid.UUID | str, token_hash: str) -> None: """Atomic CAS — only the worker that flips revoked_at emits audit; - replays are idempotent. Spec: tokens.md §Detection + hard-expire. + replays are idempotent. """ stmt = ( update(OAuthAccessToken) @@ -247,8 +332,8 @@ class _VariantResolver: return None return cached - # session_factory returns Flask-SQLAlchemy's scoped_session, which is - # request-bound and not a context manager; use it directly. + # Flask-SQLAlchemy's scoped_session is request-bound and not a + # context manager; use it directly. session = self._parent.session_factory() row = self._load_from_db(session, token_hash) if row is None: @@ -301,23 +386,87 @@ class _VariantResolver: ) -def _row_to_cache(row: ResolvedRow) -> dict: - return { - "subject_email": row.subject_email, - "subject_issuer": row.subject_issuer, - "account_id": str(row.account_id) if row.account_id else None, - "token_id": str(row.token_id), - "expires_at": row.expires_at.isoformat() if row.expires_at else None, - } +# ============================================================================ +# Layer 0 — workspace membership cache + helper +# ============================================================================ -def _row_from_cache(data: dict) -> ResolvedRow: - return ResolvedRow( - subject_email=data["subject_email"], - subject_issuer=data["subject_issuer"], - account_id=uuid.UUID(data["account_id"]) if data["account_id"] else None, - token_id=uuid.UUID(data["token_id"]), - expires_at=datetime.fromisoformat(data["expires_at"]) if data["expires_at"] else None, +def record_layer0_verdict(token_hash: str, tenant_id: str, verdict: bool) -> None: + """Merge a Layer-0 membership verdict into the AuthContext cache entry at + `auth:token:{hash}`. No-op if entry missing/expired/invalid — next request + rebuilds via authenticate() and re-runs Layer 0. + """ + cache_key = TOKEN_CACHE_KEY_FMT.format(hash=token_hash) + raw = redis_client.get(cache_key) + if raw is None: + return + text = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw + if text == "invalid": + return + try: + data = json.loads(text) + except (ValueError, KeyError): + return + ttl = redis_client.ttl(cache_key) + if ttl <= 0: + return + data.setdefault("verified_tenants", {})[tenant_id] = verdict + redis_client.setex(cache_key, ttl, json.dumps(data)) + + +def check_workspace_membership( + *, + account_id: uuid.UUID | str, + tenant_id: str, + token_hash: str, + cached_verdicts: dict[str, bool], +) -> None: + """Layer-0 enforcement core. Raises `Forbidden` on deny, returns on allow. + + Shared by the pipeline step (`WorkspaceMembershipCheck`) and the + inline helper (`require_workspace_member`). Caller is responsible for + short-circuiting on EE / SSO subjects before invoking — this function + runs the membership + active-status checks unconditionally. + """ + cached = cached_verdicts.get(tenant_id) + if cached is True: + return + if cached is False: + raise Forbidden("workspace_membership_revoked") + + join = db.session.execute( + select(TenantAccountJoin.id).where( + TenantAccountJoin.account_id == account_id, + TenantAccountJoin.tenant_id == tenant_id, + ) + ).scalar_one_or_none() + if join is None: + record_layer0_verdict(token_hash, tenant_id, False) + raise Forbidden("workspace_membership_revoked") + + status = db.session.execute(select(Account.status).where(Account.id == account_id)).scalar_one_or_none() + if status != "active": + record_layer0_verdict(token_hash, tenant_id, False) + raise Forbidden("workspace_membership_revoked") + + record_layer0_verdict(token_hash, tenant_id, True) + + +def require_workspace_member(ctx: AuthContext, tenant_id: str) -> None: + """AuthContext-flavoured wrapper around `check_workspace_membership`. + + No-op on EE (gateway RBAC owns tenant isolation) and for SSO subjects + (no `tenant_account_joins` row by definition). + """ + if dify_config.ENTERPRISE_ENABLED: + return + if ctx.subject_type != SubjectType.ACCOUNT or ctx.account_id is None: + return + check_workspace_membership( + account_id=ctx.account_id, + tenant_id=tenant_id, + token_hash=ctx.token_hash, + cached_verdicts=ctx.verified_tenants, ) @@ -326,20 +475,6 @@ def _row_from_cache(data: dict) -> ResolvedRow: # ============================================================================ -class Accepts(StrEnum): - USER_ACCOUNT = "user_account" - USER_EXT_SSO = "user_ext_sso" - - -ACCEPT_USER_ANY: frozenset[Accepts] = frozenset({Accepts.USER_ACCOUNT, Accepts.USER_EXT_SSO}) - - -_SUBJECT_TO_ACCEPT: dict[SubjectType, Accepts] = { - SubjectType.ACCOUNT: Accepts.USER_ACCOUNT, - SubjectType.EXTERNAL_SSO: Accepts.USER_EXT_SSO, -} - - _authenticator: BearerAuthenticator | None = None @@ -414,17 +549,10 @@ def bearer_feature_required[**P, R](fn: Callable[P, R]) -> Callable[P, R]: return inner -# "full" is the catch-all scope carried by dfoa_ tokens; any scope check -# passes when the bearer holds it. dfoe_ ships with apps:run and a few -# narrower scopes; PATs (future) carry only what the user requested at -# mint time. -SCOPE_FULL = "full" - - -def require_scope(scope: str) -> Callable: +def require_scope(scope: Scope) -> Callable: """Route-level scope gate — must run AFTER validate_bearer so that g.auth_ctx is set. Raises Forbidden('insufficient_scope: ') - when the bearer lacks both the requested scope and the catch-all. + when the bearer lacks both the requested scope and `Scope.FULL`. """ def wrap(fn: Callable) -> Callable: @@ -435,7 +563,7 @@ def require_scope(scope: str) -> Callable: raise RuntimeError( "require_scope used without validate_bearer; stack @validate_bearer above @require_scope" ) - if SCOPE_FULL not in ctx.scopes and scope not in ctx.scopes: + if Scope.FULL not in ctx.scopes and scope not in ctx.scopes: raise Forbidden(f"insufficient_scope: {scope}") return fn(*args, **kwargs) @@ -456,14 +584,14 @@ def build_registry(session_factory, redis_client) -> TokenKindRegistry: TokenKind( prefix="dfoa_", subject_type=SubjectType.ACCOUNT, - scopes=frozenset({"full"}), + scopes=frozenset({Scope.FULL}), source="oauth_account", resolver=oauth.for_account(), ), TokenKind( prefix="dfoe_", subject_type=SubjectType.EXTERNAL_SSO, - scopes=frozenset({"apps:run"}), + scopes=frozenset({Scope.APPS_RUN}), source="oauth_external_sso", resolver=oauth.for_external_sso(), ), diff --git a/api/libs/rate_limit.py b/api/libs/rate_limit.py index 8f43f1b312..2818898789 100644 --- a/api/libs/rate_limit.py +++ b/api/libs/rate_limit.py @@ -2,7 +2,7 @@ window Redis ZSET). Apply after auth decorators so scopes can read ``g.auth_ctx``. Use :func:`enforce` when the bucket key is computed in-handler. RFC-8628 ``slow_down`` is inline — its response shape isn't -generic 429. Spec: docs/specs/v1.0/server/security.md. +generic 429. """ from __future__ import annotations @@ -14,9 +14,10 @@ from enum import StrEnum from functools import wraps from typing import ParamSpec, TypeVar -from flask import g, request, session +from flask import g, jsonify, make_response, request, session from werkzeug.exceptions import TooManyRequests +from configs import dify_config from libs.helper import RateLimiter, extract_remote_ip @@ -42,6 +43,11 @@ LIMIT_APPROVE_CONSOLE = RateLimit(10, timedelta(hours=1), (RateLimitScope.SESSIO LIMIT_LOOKUP_PUBLIC = RateLimit(60, timedelta(minutes=5), (RateLimitScope.IP,)) LIMIT_ME_PER_ACCOUNT = RateLimit(60, timedelta(minutes=1), (RateLimitScope.ACCOUNT,)) LIMIT_ME_PER_EMAIL = RateLimit(60, timedelta(minutes=1), (RateLimitScope.SUBJECT_EMAIL,)) +LIMIT_BEARER_PER_TOKEN = RateLimit( + limit=dify_config.OPENAPI_RATE_LIMIT_PER_TOKEN, + window=timedelta(minutes=1), + scopes=(RateLimitScope.TOKEN_ID,), # bucket key composed by caller from sha256(token) +) def _one_key(scope: RateLimitScope) -> str: @@ -113,3 +119,22 @@ def enforce(spec: RateLimit, *, key: str) -> None: if limiter.is_rate_limited(key): raise TooManyRequests("rate_limited") limiter.increment_rate_limit(key) + + +def enforce_bearer_rate_limit(token_hash: str) -> None: + """Per-token rate limit on /openapi/v1/* bearer-authed routes. + + Bucket key = ``token:`` so the same token shares one + bucket across api replicas (Redis-backed sliding window). + """ + limiter = _build_limiter(LIMIT_BEARER_PER_TOKEN) + key = f"token:{token_hash}" + if limiter.is_rate_limited(key): + retry_after = limiter.seconds_until_available(key) + response = make_response( + jsonify({"error": "rate_limited", "retry_after_ms": retry_after * 1000}), + 429, + ) + response.headers["Retry-After"] = str(retry_after) + raise TooManyRequests(response=response) + limiter.increment_rate_limit(key) diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_composition.py b/api/tests/unit_tests/controllers/openapi/auth/test_composition.py index 48fe5fd6aa..f8ff2b540c 100644 --- a/api/tests/unit_tests/controllers/openapi/auth/test_composition.py +++ b/api/tests/unit_tests/controllers/openapi/auth/test_composition.py @@ -1,6 +1,6 @@ from unittest.mock import patch -from controllers.openapi.auth.composition import APP_PIPELINE, _resolve_app_authz_strategy +from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE, _resolve_app_authz_strategy from controllers.openapi.auth.pipeline import Pipeline from controllers.openapi.auth.steps import ( AppAuthzCheck, @@ -8,6 +8,7 @@ from controllers.openapi.auth.steps import ( BearerCheck, CallerMount, ScopeCheck, + WorkspaceMembershipCheck, ) from controllers.openapi.auth.strategies import ( AccountMounter, @@ -17,21 +18,25 @@ from controllers.openapi.auth.strategies import ( ) -def test_app_pipeline_is_composed(): - assert isinstance(APP_PIPELINE, Pipeline) +def test_pipeline_is_composed(): + assert isinstance(OAUTH_BEARER_PIPELINE, Pipeline) -def test_app_pipeline_step_order(): - steps = APP_PIPELINE._steps +def test_pipeline_step_order(): + """BearerCheck → ScopeCheck → AppResolver → WorkspaceMembershipCheck → + AppAuthzCheck → CallerMount. Rate-limit is enforced inside + `BearerAuthenticator.authenticate`, not as a separate pipeline step.""" + steps = OAUTH_BEARER_PIPELINE._steps assert isinstance(steps[0], BearerCheck) assert isinstance(steps[1], ScopeCheck) assert isinstance(steps[2], AppResolver) - assert isinstance(steps[3], AppAuthzCheck) - assert isinstance(steps[4], CallerMount) + assert isinstance(steps[3], WorkspaceMembershipCheck) + assert isinstance(steps[4], AppAuthzCheck) + assert isinstance(steps[5], CallerMount) def test_caller_mount_has_both_mounters(): - cm = APP_PIPELINE._steps[4] + cm = OAUTH_BEARER_PIPELINE._steps[5] kinds = {type(m) for m in cm._mounters} assert AccountMounter in kinds assert EndUserMounter in kinds diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_step_bearer.py b/api/tests/unit_tests/controllers/openapi/auth/test_step_bearer.py index f59120686b..49f456cbdd 100644 --- a/api/tests/unit_tests/controllers/openapi/auth/test_step_bearer.py +++ b/api/tests/unit_tests/controllers/openapi/auth/test_step_bearer.py @@ -1,6 +1,5 @@ import uuid from datetime import UTC, datetime -from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest @@ -8,7 +7,7 @@ from werkzeug.exceptions import Unauthorized from controllers.openapi.auth.context import Context from controllers.openapi.auth.steps import BearerCheck -from libs.oauth_bearer import ResolvedRow, SubjectType +from libs.oauth_bearer import AuthContext, InvalidBearerError, Scope, SubjectType def _ctx(headers): @@ -22,37 +21,36 @@ def test_bearer_check_rejects_missing_header(): BearerCheck()(_ctx({})) -@patch("controllers.openapi.auth.steps._registry") -def test_bearer_check_rejects_unknown_prefix(reg): - reg.return_value.find.return_value = None +@patch("controllers.openapi.auth.steps.get_authenticator") +def test_bearer_check_rejects_unknown_prefix(get_auth): + get_auth.return_value.authenticate.side_effect = InvalidBearerError("unknown token prefix") with pytest.raises(Unauthorized): BearerCheck()(_ctx({"Authorization": "Bearer xxx_abc"})) -@patch("controllers.openapi.auth.steps._registry") -def test_bearer_check_populates_context(reg): +@patch("controllers.openapi.auth.steps.get_authenticator") +def test_bearer_check_populates_context(get_auth): tok_id = uuid.uuid4() - fake_resolver = MagicMock() - fake_resolver.resolve.return_value = ResolvedRow( + authn = AuthContext( + subject_type=SubjectType.ACCOUNT, subject_email="a@x.com", subject_issuer=None, account_id=None, + scopes=frozenset({Scope.FULL}), token_id=tok_id, - expires_at=datetime.now(UTC), - ) - fake_kind = SimpleNamespace( - subject_type=SubjectType.ACCOUNT, - scopes=frozenset({"full"}), source="oauth-account", - resolver=fake_resolver, + expires_at=datetime.now(UTC), + token_hash="hash-1", + verified_tenants={}, ) - reg.return_value.find.return_value = fake_kind + get_auth.return_value.authenticate.return_value = authn ctx = _ctx({"Authorization": "Bearer dfoa_abc"}) BearerCheck()(ctx) assert ctx.subject_type == SubjectType.ACCOUNT assert ctx.subject_email == "a@x.com" - assert ctx.scopes == frozenset({"full"}) + assert ctx.scopes == frozenset({Scope.FULL}) assert ctx.source == "oauth-account" assert ctx.token_id == tok_id + assert ctx.token_hash == "hash-1" diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_step_layer0.py b/api/tests/unit_tests/controllers/openapi/auth/test_step_layer0.py new file mode 100644 index 0000000000..4ae8f90246 --- /dev/null +++ b/api/tests/unit_tests/controllers/openapi/auth/test_step_layer0.py @@ -0,0 +1,157 @@ +"""Unit tests for WorkspaceMembershipCheck (Layer 0).""" + +from __future__ import annotations + +import uuid +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import Forbidden + +from controllers.openapi.auth.context import Context +from controllers.openapi.auth.steps import WorkspaceMembershipCheck +from libs.oauth_bearer import SubjectType + + +def _ctx(*, subject_type, account_id, tenant_id, cached_verified_tenants=None, token_hash=None) -> Context: + c = Context(request=MagicMock(), required_scope="apps:read") + c.subject_type = subject_type + c.account_id = account_id + c.tenant = SimpleNamespace(id=tenant_id) if tenant_id else None + c.cached_verified_tenants = cached_verified_tenants + c.token_hash = token_hash + return c + + +@pytest.fixture +def step(): + return WorkspaceMembershipCheck() + + +@patch("controllers.openapi.auth.steps.dify_config") +@patch("libs.oauth_bearer.record_layer0_verdict") +@patch("libs.oauth_bearer.db") +def test_skips_when_enterprise_enabled(mock_db, mock_record, mock_cfg, step): + mock_cfg.ENTERPRISE_ENABLED = True + ctx = _ctx( + subject_type=SubjectType.ACCOUNT, + account_id=str(uuid.uuid4()), + tenant_id=str(uuid.uuid4()), + cached_verified_tenants={}, + token_hash="hash-1", + ) + step(ctx) # no raise + mock_db.session.execute.assert_not_called() + mock_record.assert_not_called() + + +@patch("controllers.openapi.auth.steps.dify_config") +@patch("libs.oauth_bearer.record_layer0_verdict") +@patch("libs.oauth_bearer.db") +def test_skips_for_external_sso(mock_db, mock_record, mock_cfg, step): + mock_cfg.ENTERPRISE_ENABLED = False + ctx = _ctx( + subject_type=SubjectType.EXTERNAL_SSO, + account_id=None, + tenant_id=str(uuid.uuid4()), + cached_verified_tenants={}, + token_hash="hash-1", + ) + step(ctx) # no raise + mock_db.session.execute.assert_not_called() + mock_record.assert_not_called() + + +@patch("controllers.openapi.auth.steps.dify_config") +@patch("libs.oauth_bearer.record_layer0_verdict") +@patch("libs.oauth_bearer.db") +def test_uses_cached_ok(mock_db, mock_record, mock_cfg, step): + mock_cfg.ENTERPRISE_ENABLED = False + ctx = _ctx( + subject_type=SubjectType.ACCOUNT, + account_id="a1", + tenant_id="t1", + cached_verified_tenants={"t1": True}, + token_hash="hash-1", + ) + step(ctx) + mock_db.session.execute.assert_not_called() + mock_record.assert_not_called() + + +@patch("controllers.openapi.auth.steps.dify_config") +@patch("libs.oauth_bearer.record_layer0_verdict") +@patch("libs.oauth_bearer.db") +def test_uses_cached_denied(mock_db, mock_record, mock_cfg, step): + mock_cfg.ENTERPRISE_ENABLED = False + ctx = _ctx( + subject_type=SubjectType.ACCOUNT, + account_id="a1", + tenant_id="t1", + cached_verified_tenants={"t1": False}, + token_hash="hash-1", + ) + with pytest.raises(Forbidden, match="workspace_membership_revoked"): + step(ctx) + mock_db.session.execute.assert_not_called() + mock_record.assert_not_called() + + +@patch("controllers.openapi.auth.steps.dify_config") +@patch("libs.oauth_bearer.record_layer0_verdict") +@patch("libs.oauth_bearer.db") +def test_denies_when_no_membership(mock_db, mock_record, mock_cfg, step): + mock_cfg.ENTERPRISE_ENABLED = False + mock_db.session.execute.return_value.scalar_one_or_none.return_value = None + ctx = _ctx( + subject_type=SubjectType.ACCOUNT, + account_id="a1", + tenant_id="t1", + cached_verified_tenants={}, + token_hash="hash-1", + ) + with pytest.raises(Forbidden, match="workspace_membership_revoked"): + step(ctx) + mock_record.assert_called_once_with("hash-1", "t1", False) + + +@patch("controllers.openapi.auth.steps.dify_config") +@patch("libs.oauth_bearer.record_layer0_verdict") +@patch("libs.oauth_bearer.db") +def test_denies_when_account_inactive(mock_db, mock_record, mock_cfg, step): + mock_cfg.ENTERPRISE_ENABLED = False + mock_db.session.execute.side_effect = [ + MagicMock(scalar_one_or_none=MagicMock(return_value="join-id")), + MagicMock(scalar_one_or_none=MagicMock(return_value="banned")), + ] + ctx = _ctx( + subject_type=SubjectType.ACCOUNT, + account_id="a1", + tenant_id="t1", + cached_verified_tenants={}, + token_hash="hash-1", + ) + with pytest.raises(Forbidden, match="workspace_membership_revoked"): + step(ctx) + mock_record.assert_called_once_with("hash-1", "t1", False) + + +@patch("controllers.openapi.auth.steps.dify_config") +@patch("libs.oauth_bearer.record_layer0_verdict") +@patch("libs.oauth_bearer.db") +def test_allows_active_member(mock_db, mock_record, mock_cfg, step): + mock_cfg.ENTERPRISE_ENABLED = False + mock_db.session.execute.side_effect = [ + MagicMock(scalar_one_or_none=MagicMock(return_value="join-id")), + MagicMock(scalar_one_or_none=MagicMock(return_value="active")), + ] + ctx = _ctx( + subject_type=SubjectType.ACCOUNT, + account_id="a1", + tenant_id="t1", + cached_verified_tenants={}, + token_hash="hash-1", + ) + step(ctx) # no raise + mock_record.assert_called_once_with("hash-1", "t1", True) diff --git a/api/tests/unit_tests/controllers/openapi/conftest.py b/api/tests/unit_tests/controllers/openapi/conftest.py index 42e3768a18..9486ff6e94 100644 --- a/api/tests/unit_tests/controllers/openapi/conftest.py +++ b/api/tests/unit_tests/controllers/openapi/conftest.py @@ -7,8 +7,9 @@ from controllers.openapi.auth.pipeline import Pipeline def bypass_pipeline(monkeypatch): """Stub Pipeline.run so endpoint decoration does not invoke real auth. - Module-level @APP_PIPELINE.guard(...) captures the real APP_PIPELINE at - import time; mocking the module attribute does not undo that. Patching - Pipeline.run on the class is the bypass that actually works. + Module-level @OAUTH_BEARER_PIPELINE.guard(...) captures the real + pipeline at import time; mocking the module attribute does not undo + that. Patching Pipeline.run on the class is the bypass that actually + works. """ monkeypatch.setattr(Pipeline, "run", lambda self, ctx: None) diff --git a/api/tests/unit_tests/libs/test_oauth_bearer_layer0_cache.py b/api/tests/unit_tests/libs/test_oauth_bearer_layer0_cache.py new file mode 100644 index 0000000000..0023f17119 --- /dev/null +++ b/api/tests/unit_tests/libs/test_oauth_bearer_layer0_cache.py @@ -0,0 +1,94 @@ +"""Unit tests for record_layer0_verdict — merge L0 verdict into AuthContext cache.""" + +from __future__ import annotations + +import json +from unittest.mock import MagicMock, patch + +import pytest + +from libs.oauth_bearer import record_layer0_verdict + + +@pytest.fixture +def mock_redis(): + return MagicMock() + + +@patch("libs.oauth_bearer.redis_client") +def test_no_op_when_cache_entry_missing(mock_redis): + mock_redis.get.return_value = None + record_layer0_verdict("h1", "t1", True) + mock_redis.setex.assert_not_called() + + +@patch("libs.oauth_bearer.redis_client") +def test_no_op_when_cache_entry_invalid_marker(mock_redis): + mock_redis.get.return_value = b"invalid" + record_layer0_verdict("h1", "t1", True) + mock_redis.setex.assert_not_called() + + +@patch("libs.oauth_bearer.redis_client") +def test_no_op_when_json_malformed(mock_redis): + mock_redis.get.return_value = b"not json" + record_layer0_verdict("h1", "t1", True) + mock_redis.setex.assert_not_called() + + +@patch("libs.oauth_bearer.redis_client") +def test_no_op_when_ttl_expired(mock_redis): + mock_redis.get.return_value = json.dumps( + { + "subject_email": "e", + "subject_issuer": None, + "account_id": None, + "token_id": "tid", + "expires_at": None, + } + ).encode() + mock_redis.ttl.return_value = -1 + record_layer0_verdict("h1", "t1", True) + mock_redis.setex.assert_not_called() + + +@patch("libs.oauth_bearer.redis_client") +def test_merges_new_tenant_verdict(mock_redis): + mock_redis.get.return_value = json.dumps( + { + "subject_email": "e", + "subject_issuer": None, + "account_id": None, + "token_id": "tid", + "expires_at": None, + "verified_tenants": {"t0": True}, + } + ).encode() + mock_redis.ttl.return_value = 42 + + record_layer0_verdict("h1", "t1", False) + + mock_redis.setex.assert_called_once() + args = mock_redis.setex.call_args + assert args.args[0] == "auth:token:h1" + assert args.args[1] == 42 # remaining TTL preserved + written = json.loads(args.args[2]) + assert written["verified_tenants"] == {"t0": True, "t1": False} + + +@patch("libs.oauth_bearer.redis_client") +def test_merges_when_field_absent_from_legacy_entry(mock_redis): + """Backward compat: legacy cache entry without verified_tenants field.""" + mock_redis.get.return_value = json.dumps( + { + "subject_email": "e", + "subject_issuer": None, + "account_id": None, + "token_id": "tid", + "expires_at": None, + } + ).encode() + mock_redis.ttl.return_value = 42 + record_layer0_verdict("h1", "t1", True) + written = json.loads(mock_redis.setex.call_args.args[2]) + assert written["verified_tenants"] == {"t1": True} diff --git a/api/tests/unit_tests/libs/test_oauth_bearer_require_scope.py b/api/tests/unit_tests/libs/test_oauth_bearer_require_scope.py index e9f26e59ea..dfb8735702 100644 --- a/api/tests/unit_tests/libs/test_oauth_bearer_require_scope.py +++ b/api/tests/unit_tests/libs/test_oauth_bearer_require_scope.py @@ -12,8 +12,8 @@ from flask import Flask, g from werkzeug.exceptions import Forbidden from libs.oauth_bearer import ( - SCOPE_FULL, AuthContext, + Scope, SubjectType, require_scope, ) @@ -26,7 +26,7 @@ def app() -> Flask: return app -def _ctx(scopes: frozenset[str]) -> AuthContext: +def _ctx(scopes) -> AuthContext: return AuthContext( subject_type=SubjectType.ACCOUNT, subject_email="user@example.com", @@ -36,6 +36,8 @@ def _ctx(scopes: frozenset[str]) -> AuthContext: token_id=uuid.uuid4(), source="oauth_account", expires_at=None, + token_hash="h1", + verified_tenants={}, ) @@ -67,7 +69,7 @@ def test_require_scope_full_passes_any_check(app: Flask): return "ok" with app.test_request_context(): - g.auth_ctx = _ctx(frozenset({SCOPE_FULL})) + g.auth_ctx = _ctx(frozenset({Scope.FULL})) assert view() == "ok" diff --git a/api/tests/unit_tests/libs/test_rate_limit_bearer.py b/api/tests/unit_tests/libs/test_rate_limit_bearer.py new file mode 100644 index 0000000000..b204575ccb --- /dev/null +++ b/api/tests/unit_tests/libs/test_rate_limit_bearer.py @@ -0,0 +1,74 @@ +"""Unit tests for the per-token bearer rate limit primitive.""" + +from __future__ import annotations + +from datetime import timedelta +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import TooManyRequests + +from libs.helper import RateLimiter +from libs.rate_limit import ( + LIMIT_BEARER_PER_TOKEN, + enforce_bearer_rate_limit, +) + + +@pytest.fixture +def mock_redis(): + return MagicMock() + + +def test_limit_bearer_per_token_uses_60_per_minute_default(): + assert LIMIT_BEARER_PER_TOKEN.limit == 60 + assert LIMIT_BEARER_PER_TOKEN.window == timedelta(minutes=1) + + +def test_seconds_until_available_returns_remaining_window(mock_redis): + """ZSET oldest entry score = 100; window = 60s; now = 130s → remaining = 30s.""" + rl = RateLimiter("rl:bearer:token", max_attempts=60, time_window=60, redis_client=mock_redis) + mock_redis.zrange.return_value = [(b"member-1", 100.0)] + with patch("libs.helper.time.time", return_value=130): + assert rl.seconds_until_available("k1") == 30 + + +def test_seconds_until_available_floor_one_second(mock_redis): + """Even when math says <1s remaining, return at least 1 so client backs off measurably.""" + rl = RateLimiter("rl:bearer:token", max_attempts=60, time_window=60, redis_client=mock_redis) + mock_redis.zrange.return_value = [(b"member-1", 119.5)] + with patch("libs.helper.time.time", return_value=180): + # window expired (180 > 119.5+60=179.5 by 0.5s) — bucket is actually free now + # but this method only called when is_rate_limited() == True; defensive floor. + assert rl.seconds_until_available("k1") >= 1 + + +def test_seconds_until_available_empty_bucket(mock_redis): + """No entries → 1s sentinel (defensive; should not be reached when limited).""" + rl = RateLimiter("rl:bearer:token", max_attempts=60, time_window=60, redis_client=mock_redis) + mock_redis.zrange.return_value = [] + assert rl.seconds_until_available("k1") == 1 + + +@patch("libs.rate_limit._build_limiter") +def test_enforce_bearer_rate_limit_passes_under_limit(mock_build): + limiter = MagicMock() + limiter.is_rate_limited.return_value = False + mock_build.return_value = limiter + enforce_bearer_rate_limit("hash-1") + limiter.increment_rate_limit.assert_called_once_with("token:hash-1") + + +@patch("libs.rate_limit._build_limiter") +def test_enforce_bearer_rate_limit_raises_429_with_retry_after(mock_build): + limiter = MagicMock() + limiter.is_rate_limited.return_value = True + limiter.seconds_until_available.return_value = 23 + mock_build.return_value = limiter + with pytest.raises(TooManyRequests) as exc: + enforce_bearer_rate_limit("hash-1") + headers = dict(exc.value.get_response().headers) + assert headers.get("Retry-After") == "23" + body = exc.value.get_response().get_json() or {} + assert body.get("error") == "rate_limited" + assert body.get("retry_after_ms") == 23000 diff --git a/api/tests/unit_tests/libs/test_workspace_member_helper.py b/api/tests/unit_tests/libs/test_workspace_member_helper.py new file mode 100644 index 0000000000..5f24a728a0 --- /dev/null +++ b/api/tests/unit_tests/libs/test_workspace_member_helper.py @@ -0,0 +1,93 @@ +"""Unit tests for require_workspace_member.""" + +from __future__ import annotations + +import uuid +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import Forbidden + +from libs.oauth_bearer import AuthContext, Scope, SubjectType, require_workspace_member + + +def _ctx(verified: dict[str, bool] | None = None, *, account: bool = True) -> AuthContext: + return AuthContext( + subject_type=SubjectType.ACCOUNT if account else SubjectType.EXTERNAL_SSO, + subject_email="e@example.com", + subject_issuer=None, + account_id=uuid.uuid4() if account else None, + scopes=frozenset({Scope.FULL}), + token_id=uuid.uuid4(), + source="oauth_account", + expires_at=None, + token_hash="h1", + verified_tenants=dict(verified or {}), + ) + + +@patch("libs.oauth_bearer.dify_config") +def test_skips_when_enterprise_enabled(mock_cfg): + mock_cfg.ENTERPRISE_ENABLED = True + require_workspace_member(_ctx(), "t1") + + +@patch("libs.oauth_bearer.dify_config") +def test_skips_for_external_sso(mock_cfg): + mock_cfg.ENTERPRISE_ENABLED = False + require_workspace_member(_ctx(account=False), "t1") + + +@patch("libs.oauth_bearer.db") +@patch("libs.oauth_bearer.dify_config") +def test_uses_cached_ok_no_db_access(mock_cfg, mock_db): + mock_cfg.ENTERPRISE_ENABLED = False + require_workspace_member(_ctx({"t1": True}), "t1") + mock_db.session.execute.assert_not_called() + + +@patch("libs.oauth_bearer.db") +@patch("libs.oauth_bearer.dify_config") +def test_uses_cached_denied(mock_cfg, mock_db): + mock_cfg.ENTERPRISE_ENABLED = False + with pytest.raises(Forbidden, match="workspace_membership_revoked"): + require_workspace_member(_ctx({"t1": False}), "t1") + mock_db.session.execute.assert_not_called() + + +@patch("libs.oauth_bearer.record_layer0_verdict") +@patch("libs.oauth_bearer.db") +@patch("libs.oauth_bearer.dify_config") +def test_denies_when_no_membership(mock_cfg, mock_db, mock_record): + mock_cfg.ENTERPRISE_ENABLED = False + mock_db.session.execute.return_value.scalar_one_or_none.return_value = None + with pytest.raises(Forbidden, match="workspace_membership_revoked"): + require_workspace_member(_ctx({}), "t1") + mock_record.assert_called_once_with("h1", "t1", False) + + +@patch("libs.oauth_bearer.record_layer0_verdict") +@patch("libs.oauth_bearer.db") +@patch("libs.oauth_bearer.dify_config") +def test_denies_when_account_inactive(mock_cfg, mock_db, mock_record): + mock_cfg.ENTERPRISE_ENABLED = False + mock_db.session.execute.side_effect = [ + MagicMock(scalar_one_or_none=MagicMock(return_value="join-id")), + MagicMock(scalar_one_or_none=MagicMock(return_value="banned")), + ] + with pytest.raises(Forbidden, match="workspace_membership_revoked"): + require_workspace_member(_ctx({}), "t1") + mock_record.assert_called_once_with("h1", "t1", False) + + +@patch("libs.oauth_bearer.record_layer0_verdict") +@patch("libs.oauth_bearer.db") +@patch("libs.oauth_bearer.dify_config") +def test_allows_active_member(mock_cfg, mock_db, mock_record): + mock_cfg.ENTERPRISE_ENABLED = False + mock_db.session.execute.side_effect = [ + MagicMock(scalar_one_or_none=MagicMock(return_value="join-id")), + MagicMock(scalar_one_or_none=MagicMock(return_value="active")), + ] + require_workspace_member(_ctx({}), "t1") + mock_record.assert_called_once_with("h1", "t1", True)