mirror of https://github.com/langgenius/dify.git
331 lines
12 KiB
Python
331 lines
12 KiB
Python
"""
|
|
API Token Service
|
|
|
|
Handles all API token caching, validation, and usage recording.
|
|
Includes Redis cache operations, database queries, and single-flight concurrency control.
|
|
"""
|
|
|
|
import logging
|
|
from datetime import datetime
|
|
from typing import Any
|
|
|
|
from pydantic import BaseModel
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import Session
|
|
from werkzeug.exceptions import Unauthorized
|
|
|
|
from extensions.ext_database import db
|
|
from extensions.ext_redis import redis_client, redis_fallback
|
|
from libs.datetime_utils import naive_utc_now
|
|
from models.model import ApiToken
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# ---------------------------------------------------------------------
|
|
# Pydantic DTO
|
|
# ---------------------------------------------------------------------
|
|
|
|
|
|
class CachedApiToken(BaseModel):
|
|
"""
|
|
Pydantic model for cached API token data.
|
|
|
|
This is NOT a SQLAlchemy model instance, but a plain Pydantic model
|
|
that mimics the ApiToken model interface for read-only access.
|
|
"""
|
|
|
|
id: str
|
|
app_id: str | None
|
|
tenant_id: str | None
|
|
type: str
|
|
token: str
|
|
last_used_at: datetime | None
|
|
created_at: datetime | None
|
|
|
|
def __repr__(self) -> str:
|
|
return f"<CachedApiToken id={self.id} type={self.type}>"
|
|
|
|
|
|
# ---------------------------------------------------------------------
|
|
# Cache configuration
|
|
# ---------------------------------------------------------------------
|
|
|
|
CACHE_KEY_PREFIX = "api_token"
|
|
CACHE_TTL_SECONDS = 600 # 10 minutes
|
|
CACHE_NULL_TTL_SECONDS = 60 # 1 minute for non-existent tokens
|
|
ACTIVE_TOKEN_KEY_PREFIX = "api_token_active:"
|
|
|
|
|
|
# ---------------------------------------------------------------------
|
|
# Cache class
|
|
# ---------------------------------------------------------------------
|
|
|
|
|
|
class ApiTokenCache:
|
|
"""
|
|
Redis cache wrapper for API tokens.
|
|
Handles serialization, deserialization, and cache invalidation.
|
|
"""
|
|
|
|
@staticmethod
|
|
def make_active_key(token: str, scope: str | None = None) -> str:
|
|
"""Generate Redis key for recording token usage."""
|
|
return f"{ACTIVE_TOKEN_KEY_PREFIX}{scope}:{token}"
|
|
|
|
@staticmethod
|
|
def _make_tenant_index_key(tenant_id: str) -> str:
|
|
"""Generate Redis key for tenant token index."""
|
|
return f"tenant_tokens:{tenant_id}"
|
|
|
|
@staticmethod
|
|
def _make_cache_key(token: str, scope: str | None = None) -> str:
|
|
"""Generate cache key for the given token and scope."""
|
|
scope_str = scope or "any"
|
|
return f"{CACHE_KEY_PREFIX}:{scope_str}:{token}"
|
|
|
|
@staticmethod
|
|
def _serialize_token(api_token: Any) -> bytes:
|
|
"""Serialize ApiToken object to JSON bytes."""
|
|
if isinstance(api_token, CachedApiToken):
|
|
return api_token.model_dump_json().encode("utf-8")
|
|
|
|
cached = CachedApiToken(
|
|
id=str(api_token.id),
|
|
app_id=str(api_token.app_id) if api_token.app_id else None,
|
|
tenant_id=str(api_token.tenant_id) if api_token.tenant_id else None,
|
|
type=api_token.type,
|
|
token=api_token.token,
|
|
last_used_at=api_token.last_used_at,
|
|
created_at=api_token.created_at,
|
|
)
|
|
return cached.model_dump_json().encode("utf-8")
|
|
|
|
@staticmethod
|
|
def _deserialize_token(cached_data: bytes | str) -> Any:
|
|
"""Deserialize JSON bytes/string back to a CachedApiToken Pydantic model."""
|
|
if cached_data in {b"null", "null"}:
|
|
return None
|
|
|
|
try:
|
|
if isinstance(cached_data, bytes):
|
|
cached_data = cached_data.decode("utf-8")
|
|
return CachedApiToken.model_validate_json(cached_data)
|
|
except (ValueError, Exception) as e:
|
|
logger.warning("Failed to deserialize token from cache: %s", e)
|
|
return None
|
|
|
|
@staticmethod
|
|
@redis_fallback(default_return=None)
|
|
def get(token: str, scope: str | None) -> Any | None:
|
|
"""Get API token from cache."""
|
|
cache_key = ApiTokenCache._make_cache_key(token, scope)
|
|
cached_data = redis_client.get(cache_key)
|
|
|
|
if cached_data is None:
|
|
logger.debug("Cache miss for token key: %s", cache_key)
|
|
return None
|
|
|
|
logger.debug("Cache hit for token key: %s", cache_key)
|
|
return ApiTokenCache._deserialize_token(cached_data)
|
|
|
|
@staticmethod
|
|
def _add_to_tenant_index(tenant_id: str | None, cache_key: str) -> None:
|
|
"""Add cache key to tenant index for efficient invalidation."""
|
|
if not tenant_id:
|
|
return
|
|
|
|
try:
|
|
index_key = ApiTokenCache._make_tenant_index_key(tenant_id)
|
|
redis_client.sadd(index_key, cache_key)
|
|
redis_client.expire(index_key, CACHE_TTL_SECONDS + 60)
|
|
except Exception as e:
|
|
logger.warning("Failed to update tenant index: %s", e)
|
|
|
|
@staticmethod
|
|
def _remove_from_tenant_index(tenant_id: str | None, cache_key: str) -> None:
|
|
"""Remove cache key from tenant index."""
|
|
if not tenant_id:
|
|
return
|
|
|
|
try:
|
|
index_key = ApiTokenCache._make_tenant_index_key(tenant_id)
|
|
redis_client.srem(index_key, cache_key)
|
|
except Exception as e:
|
|
logger.warning("Failed to remove from tenant index: %s", e)
|
|
|
|
@staticmethod
|
|
@redis_fallback(default_return=False)
|
|
def set(token: str, scope: str | None, api_token: Any | None, ttl: int = CACHE_TTL_SECONDS) -> bool:
|
|
"""Set API token in cache."""
|
|
cache_key = ApiTokenCache._make_cache_key(token, scope)
|
|
|
|
if api_token is None:
|
|
cached_value = b"null"
|
|
ttl = CACHE_NULL_TTL_SECONDS
|
|
else:
|
|
cached_value = ApiTokenCache._serialize_token(api_token)
|
|
|
|
try:
|
|
redis_client.setex(cache_key, ttl, cached_value)
|
|
|
|
if api_token is not None and hasattr(api_token, "tenant_id"):
|
|
ApiTokenCache._add_to_tenant_index(api_token.tenant_id, cache_key)
|
|
|
|
logger.debug("Cached token with key: %s, ttl: %ss", cache_key, ttl)
|
|
return True
|
|
except Exception as e:
|
|
logger.warning("Failed to cache token: %s", e)
|
|
return False
|
|
|
|
@staticmethod
|
|
@redis_fallback(default_return=False)
|
|
def delete(token: str, scope: str | None = None) -> bool:
|
|
"""Delete API token from cache."""
|
|
if scope is None:
|
|
pattern = f"{CACHE_KEY_PREFIX}:*:{token}"
|
|
try:
|
|
keys_to_delete = list(redis_client.scan_iter(match=pattern))
|
|
if keys_to_delete:
|
|
redis_client.delete(*keys_to_delete)
|
|
logger.info("Deleted %d cache entries for token", len(keys_to_delete))
|
|
return True
|
|
except Exception as e:
|
|
logger.warning("Failed to delete token cache with pattern: %s", e)
|
|
return False
|
|
else:
|
|
cache_key = ApiTokenCache._make_cache_key(token, scope)
|
|
try:
|
|
tenant_id = None
|
|
try:
|
|
cached_data = redis_client.get(cache_key)
|
|
if cached_data and cached_data != b"null":
|
|
cached_token = ApiTokenCache._deserialize_token(cached_data)
|
|
if cached_token:
|
|
tenant_id = cached_token.tenant_id
|
|
except Exception as e:
|
|
logger.debug("Failed to get tenant_id for cache cleanup: %s", e)
|
|
|
|
redis_client.delete(cache_key)
|
|
|
|
if tenant_id:
|
|
ApiTokenCache._remove_from_tenant_index(tenant_id, cache_key)
|
|
|
|
logger.info("Deleted cache for key: %s", cache_key)
|
|
return True
|
|
except Exception as e:
|
|
logger.warning("Failed to delete token cache: %s", e)
|
|
return False
|
|
|
|
@staticmethod
|
|
@redis_fallback(default_return=False)
|
|
def invalidate_by_tenant(tenant_id: str) -> bool:
|
|
"""Invalidate all API token caches for a specific tenant via tenant index."""
|
|
try:
|
|
index_key = ApiTokenCache._make_tenant_index_key(tenant_id)
|
|
cache_keys = redis_client.smembers(index_key)
|
|
|
|
if cache_keys:
|
|
deleted_count = 0
|
|
for cache_key in cache_keys:
|
|
if isinstance(cache_key, bytes):
|
|
cache_key = cache_key.decode("utf-8")
|
|
redis_client.delete(cache_key)
|
|
deleted_count += 1
|
|
|
|
redis_client.delete(index_key)
|
|
|
|
logger.info(
|
|
"Invalidated %d token cache entries for tenant: %s",
|
|
deleted_count,
|
|
tenant_id,
|
|
)
|
|
else:
|
|
logger.info(
|
|
"No tenant index found for %s, relying on TTL expiration",
|
|
tenant_id,
|
|
)
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.warning("Failed to invalidate tenant token cache: %s", e)
|
|
return False
|
|
|
|
|
|
# ---------------------------------------------------------------------
|
|
# Token usage recording (for batch update)
|
|
# ---------------------------------------------------------------------
|
|
|
|
|
|
def record_token_usage(auth_token: str, scope: str | None) -> None:
|
|
"""
|
|
Record token usage in Redis for later batch update by a scheduled job.
|
|
|
|
Instead of dispatching a Celery task per request, we simply SET a key in Redis.
|
|
A Celery Beat scheduled task will periodically scan these keys and batch-update
|
|
last_used_at in the database.
|
|
"""
|
|
try:
|
|
key = ApiTokenCache.make_active_key(auth_token, scope)
|
|
redis_client.set(key, naive_utc_now().isoformat(), ex=3600)
|
|
except Exception as e:
|
|
logger.warning("Failed to record token usage: %s", e)
|
|
|
|
|
|
# ---------------------------------------------------------------------
|
|
# Database query + single-flight
|
|
# ---------------------------------------------------------------------
|
|
|
|
|
|
def query_token_from_db(auth_token: str, scope: str | None) -> ApiToken:
|
|
"""
|
|
Query API token from database and cache the result.
|
|
|
|
Raises Unauthorized if token is invalid.
|
|
"""
|
|
with Session(db.engine, expire_on_commit=False) as session:
|
|
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
|
|
api_token = session.scalar(stmt)
|
|
|
|
if not api_token:
|
|
ApiTokenCache.set(auth_token, scope, None)
|
|
raise Unauthorized("Access token is invalid")
|
|
|
|
ApiTokenCache.set(auth_token, scope, api_token)
|
|
record_token_usage(auth_token, scope)
|
|
return api_token
|
|
|
|
|
|
def fetch_token_with_single_flight(auth_token: str, scope: str | None) -> ApiToken | Any:
|
|
"""
|
|
Fetch token from DB with single-flight pattern using Redis lock.
|
|
|
|
Ensures only one concurrent request queries the database for the same token.
|
|
Falls back to direct query if lock acquisition fails.
|
|
"""
|
|
logger.debug("Token cache miss, attempting to acquire query lock for scope: %s", scope)
|
|
|
|
lock_key = f"api_token_query_lock:{scope}:{auth_token}"
|
|
lock = redis_client.lock(lock_key, timeout=10, blocking_timeout=5)
|
|
|
|
try:
|
|
if lock.acquire(blocking=True):
|
|
try:
|
|
cached_token = ApiTokenCache.get(auth_token, scope)
|
|
if cached_token is not None:
|
|
logger.debug("Token cached by concurrent request, using cached version")
|
|
return cached_token
|
|
|
|
return query_token_from_db(auth_token, scope)
|
|
finally:
|
|
lock.release()
|
|
else:
|
|
logger.warning("Lock timeout for token: %s, proceeding with direct query", auth_token[:10])
|
|
return query_token_from_db(auth_token, scope)
|
|
except Unauthorized:
|
|
raise
|
|
except Exception as e:
|
|
logger.warning("Redis lock failed for token query: %s, proceeding anyway", e)
|
|
return query_token_from_db(auth_token, scope)
|