mirror of
https://github.com/langgenius/dify.git
synced 2026-05-12 15:58:19 +08:00
141 lines
4.8 KiB
Python
141 lines
4.8 KiB
Python
"""Typed rate-limit decorator over ``libs.helper.RateLimiter`` (sliding-
|
|
window Redis ZSET). Apply after auth decorators so scopes can read
|
|
``g.auth_ctx``. Use :func:`enforce` when the bucket key is computed
|
|
in-handler. RFC-8628 ``slow_down`` is inline — its response shape isn't
|
|
generic 429.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Callable
|
|
from dataclasses import dataclass
|
|
from datetime import timedelta
|
|
from enum import StrEnum
|
|
from functools import wraps
|
|
from typing import ParamSpec, TypeVar
|
|
|
|
from flask import g, jsonify, make_response, request, session
|
|
from werkzeug.exceptions import TooManyRequests
|
|
|
|
from configs import dify_config
|
|
from libs.helper import RateLimiter, extract_remote_ip
|
|
|
|
|
|
class RateLimitScope(StrEnum):
|
|
IP = "ip"
|
|
SESSION = "session"
|
|
ACCOUNT = "account"
|
|
SUBJECT_EMAIL = "subject_email"
|
|
TOKEN_ID = "token_id"
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class RateLimit:
|
|
limit: int
|
|
window: timedelta
|
|
scopes: tuple[RateLimitScope, ...]
|
|
|
|
|
|
LIMIT_DEVICE_CODE_PER_IP = RateLimit(60, timedelta(hours=1), (RateLimitScope.IP,))
|
|
LIMIT_SSO_INITIATE_PER_IP = RateLimit(60, timedelta(hours=1), (RateLimitScope.IP,))
|
|
LIMIT_APPROVE_EXT_PER_EMAIL = RateLimit(10, timedelta(hours=1), (RateLimitScope.SUBJECT_EMAIL,))
|
|
LIMIT_APPROVE_CONSOLE = RateLimit(10, timedelta(hours=1), (RateLimitScope.SESSION,))
|
|
LIMIT_LOOKUP_PUBLIC = RateLimit(60, timedelta(minutes=5), (RateLimitScope.IP,))
|
|
LIMIT_ME_PER_ACCOUNT = RateLimit(60, timedelta(minutes=1), (RateLimitScope.ACCOUNT,))
|
|
LIMIT_ME_PER_EMAIL = RateLimit(60, timedelta(minutes=1), (RateLimitScope.SUBJECT_EMAIL,))
|
|
LIMIT_BEARER_PER_TOKEN = RateLimit(
|
|
limit=dify_config.OPENAPI_RATE_LIMIT_PER_TOKEN,
|
|
window=timedelta(minutes=1),
|
|
scopes=(RateLimitScope.TOKEN_ID,), # bucket key composed by caller from sha256(token)
|
|
)
|
|
|
|
|
|
def _one_key(scope: RateLimitScope) -> str:
|
|
match scope:
|
|
case RateLimitScope.IP:
|
|
return f"ip:{extract_remote_ip(request) or 'unknown'}"
|
|
case RateLimitScope.SESSION:
|
|
return f"session:{session.get('_id', 'anon')}"
|
|
case RateLimitScope.ACCOUNT:
|
|
ctx = getattr(g, "auth_ctx", None)
|
|
if ctx and ctx.account_id:
|
|
return f"account:{ctx.account_id}"
|
|
return "account:anon"
|
|
case RateLimitScope.SUBJECT_EMAIL:
|
|
ctx = getattr(g, "auth_ctx", None)
|
|
if ctx and ctx.subject_email:
|
|
return f"subject:{ctx.subject_email}"
|
|
return "subject:anon"
|
|
case RateLimitScope.TOKEN_ID:
|
|
ctx = getattr(g, "auth_ctx", None)
|
|
if ctx and ctx.token_id:
|
|
return f"token:{ctx.token_id}"
|
|
return "token:anon"
|
|
|
|
|
|
def _composite_key(scopes: tuple[RateLimitScope, ...]) -> str:
|
|
return "|".join(_one_key(s) for s in scopes)
|
|
|
|
|
|
def _limiter_prefix(scopes: tuple[RateLimitScope, ...]) -> str:
|
|
return "rl:" + "+".join(s.value for s in scopes)
|
|
|
|
|
|
def _build_limiter(spec: RateLimit) -> RateLimiter:
|
|
return RateLimiter(
|
|
prefix=_limiter_prefix(spec.scopes),
|
|
max_attempts=spec.limit,
|
|
time_window=int(spec.window.total_seconds()),
|
|
)
|
|
|
|
|
|
_P = ParamSpec("_P")
|
|
_R = TypeVar("_R")
|
|
|
|
|
|
def rate_limit(spec: RateLimit) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
|
|
"""Apply after auth decorators that the scopes read from."""
|
|
limiter = _build_limiter(spec)
|
|
|
|
def wrap(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
|
@wraps(fn)
|
|
def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
|
key = _composite_key(spec.scopes)
|
|
if limiter.is_rate_limited(key):
|
|
raise TooManyRequests("rate_limited")
|
|
limiter.increment_rate_limit(key)
|
|
return fn(*args, **kwargs)
|
|
|
|
return inner
|
|
|
|
return wrap
|
|
|
|
|
|
def enforce(spec: RateLimit, *, key: str) -> None:
|
|
"""Imperative form — caller composes the bucket key to match scope
|
|
semantics (the key is opaque here).
|
|
"""
|
|
limiter = _build_limiter(spec)
|
|
if limiter.is_rate_limited(key):
|
|
raise TooManyRequests("rate_limited")
|
|
limiter.increment_rate_limit(key)
|
|
|
|
|
|
def enforce_bearer_rate_limit(token_hash: str) -> None:
|
|
"""Per-token rate limit on /openapi/v1/* bearer-authed routes.
|
|
|
|
Bucket key = ``token:<sha256_hex>`` so the same token shares one
|
|
bucket across api replicas (Redis-backed sliding window).
|
|
"""
|
|
limiter = _build_limiter(LIMIT_BEARER_PER_TOKEN)
|
|
key = f"token:{token_hash}"
|
|
if limiter.is_rate_limited(key):
|
|
retry_after = limiter.seconds_until_available(key)
|
|
response = make_response(
|
|
jsonify({"error": "rate_limited", "retry_after_ms": retry_after * 1000}),
|
|
429,
|
|
)
|
|
response.headers["Retry-After"] = str(retry_after)
|
|
raise TooManyRequests(response=response)
|
|
limiter.increment_rate_limit(key)
|