From 2c9430313dba35ce948cf342a28f1bed98b93b1e Mon Sep 17 00:00:00 2001 From: zyssyz123 <916125788@qq.com> Date: Fri, 6 Feb 2026 16:25:27 +0800 Subject: [PATCH] fix: redis for api token (#31861) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: hj24 --- api/README.md | 2 +- api/configs/feature/__init__.py | 10 + api/controllers/console/apikey.py | 6 + api/controllers/console/datasets/datasets.py | 6 + api/controllers/service_api/wraps.py | 49 +-- api/docker/entrypoint.sh | 4 +- api/extensions/ext_celery.py | 8 + .../update_api_token_last_used_task.py | 114 ++++++ api/services/api_token_service.py | 330 +++++++++++++++ api/tasks/remove_app_and_related_data_task.py | 7 + .../libs/test_api_token_cache_integration.py | 375 ++++++++++++++++++ .../unit_tests/extensions/test_celery_ssl.py | 2 + .../unit_tests/libs/test_api_token_cache.py | 250 ++++++++++++ 13 files changed, 1132 insertions(+), 31 deletions(-) create mode 100644 api/schedule/update_api_token_last_used_task.py create mode 100644 api/services/api_token_service.py create mode 100644 api/tests/integration_tests/libs/test_api_token_cache_integration.py create mode 100644 api/tests/unit_tests/libs/test_api_token_cache.py diff --git a/api/README.md b/api/README.md index 9d89b490b0..b23edeab72 100644 --- a/api/README.md +++ b/api/README.md @@ -122,7 +122,7 @@ These commands assume you start from the repository root. ```bash cd api - uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention + uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention ``` 1. Optional: start Celery Beat (scheduled tasks, in a new terminal). diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index d97e9a0440..e8c1b522de 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -1155,6 +1155,16 @@ class CeleryScheduleTasksConfig(BaseSettings): default=0, ) + # API token last_used_at batch update + ENABLE_API_TOKEN_LAST_USED_UPDATE_TASK: bool = Field( + description="Enable periodic batch update of API token last_used_at timestamps", + default=True, + ) + API_TOKEN_LAST_USED_UPDATE_INTERVAL: int = Field( + description="Interval in minutes for batch updating API token last_used_at (default 30)", + default=30, + ) + # Trigger provider refresh (simple version) ENABLE_TRIGGER_PROVIDER_REFRESH_TASK: bool = Field( description="Enable trigger provider refresh poller", diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index c81709e985..b6d1df319e 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -10,6 +10,7 @@ from libs.helper import TimestampField from libs.login import current_account_with_tenant, login_required from models.dataset import Dataset from models.model import ApiToken, App +from services.api_token_service import ApiTokenCache from . import console_ns from .wraps import account_initialization_required, edit_permission_required, setup_required @@ -131,6 +132,11 @@ class BaseApiKeyResource(Resource): if key is None: flask_restx.abort(HTTPStatus.NOT_FOUND, message="API key not found") + # Invalidate cache before deleting from database + # Type assertion: key is guaranteed to be non-None here because abort() raises + assert key is not None # nosec - for type checker only + ApiTokenCache.delete(key.token, key.type) + db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() db.session.commit() diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 30e4ed1119..a06b872846 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -55,6 +55,7 @@ from libs.login import current_account_with_tenant, login_required from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile from models.dataset import DatasetPermissionEnum from models.provider_ids import ModelProviderID +from services.api_token_service import ApiTokenCache from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService # Register models for flask_restx to avoid dict type issues in Swagger @@ -820,6 +821,11 @@ class DatasetApiDeleteApi(Resource): if key is None: console_ns.abort(404, message="API key not found") + # Invalidate cache before deleting from database + # Type assertion: key is guaranteed to be non-None here because abort() raises + assert key is not None # nosec - for type checker only + ApiTokenCache.delete(key.token, key.type) + db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() db.session.commit() diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index e597a72fc0..b80735914d 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -1,27 +1,24 @@ import logging import time from collections.abc import Callable -from datetime import timedelta from enum import StrEnum, auto from functools import wraps -from typing import Concatenate, ParamSpec, TypeVar +from typing import Concatenate, ParamSpec, TypeVar, cast from flask import current_app, request from flask_login import user_logged_in from flask_restx import Resource from pydantic import BaseModel -from sqlalchemy import select, update -from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, NotFound, Unauthorized from enums.cloud_plan import CloudPlan from extensions.ext_database import db from extensions.ext_redis import redis_client -from libs.datetime_utils import naive_utc_now from libs.login import current_user from models import Account, Tenant, TenantAccountJoin, TenantStatus from models.dataset import Dataset, RateLimitLog from models.model import ApiToken, App +from services.api_token_service import ApiTokenCache, fetch_token_with_single_flight, record_token_usage from services.end_user_service import EndUserService from services.feature_service import FeatureService @@ -296,7 +293,14 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None): def validate_and_get_api_token(scope: str | None = None): """ - Validate and get API token. + Validate and get API token with Redis caching. + + This function uses a two-tier approach: + 1. First checks Redis cache for the token + 2. If not cached, queries database and caches the result + + The last_used_at field is updated asynchronously via Celery task + to avoid blocking the request. """ auth_header = request.headers.get("Authorization") if auth_header is None or " " not in auth_header: @@ -308,29 +312,18 @@ def validate_and_get_api_token(scope: str | None = None): if auth_scheme != "bearer": raise Unauthorized("Authorization scheme must be 'Bearer'") - current_time = naive_utc_now() - cutoff_time = current_time - timedelta(minutes=1) - with Session(db.engine, expire_on_commit=False) as session: - update_stmt = ( - update(ApiToken) - .where( - ApiToken.token == auth_token, - (ApiToken.last_used_at.is_(None) | (ApiToken.last_used_at < cutoff_time)), - ApiToken.type == scope, - ) - .values(last_used_at=current_time) - ) - stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope) - result = session.execute(update_stmt) - api_token = session.scalar(stmt) + # Try to get token from cache first + # Returns a CachedApiToken (plain Python object), not a SQLAlchemy model + cached_token = ApiTokenCache.get(auth_token, scope) + if cached_token is not None: + logger.debug("Token validation served from cache for scope: %s", scope) + # Record usage in Redis for later batch update (no Celery task per request) + record_token_usage(auth_token, scope) + return cast(ApiToken, cached_token) - if hasattr(result, "rowcount") and result.rowcount > 0: - session.commit() - - if not api_token: - raise Unauthorized("Access token is invalid") - - return api_token + # Cache miss - use Redis lock for single-flight mode + # This ensures only one request queries DB for the same token concurrently + return fetch_token_with_single_flight(auth_token, scope) class DatasetApiResource(Resource): diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index c0279f893b..b0863f0a2c 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -35,10 +35,10 @@ if [[ "${MODE}" == "worker" ]]; then if [[ -z "${CELERY_QUEUES}" ]]; then if [[ "${EDITION}" == "CLOUD" ]]; then # Cloud edition: separate queues for dataset and trigger tasks - DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention" + DEFAULT_QUEUES="api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention" else # Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues - DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention" + DEFAULT_QUEUES="api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention" fi else DEFAULT_QUEUES="${CELERY_QUEUES}" diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index af983f6d87..dea214163c 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -184,6 +184,14 @@ def init_app(app: DifyApp) -> Celery: "task": "schedule.trigger_provider_refresh_task.trigger_provider_refresh", "schedule": timedelta(minutes=dify_config.TRIGGER_PROVIDER_REFRESH_INTERVAL), } + + if dify_config.ENABLE_API_TOKEN_LAST_USED_UPDATE_TASK: + imports.append("schedule.update_api_token_last_used_task") + beat_schedule["batch_update_api_token_last_used"] = { + "task": "schedule.update_api_token_last_used_task.batch_update_api_token_last_used", + "schedule": timedelta(minutes=dify_config.API_TOKEN_LAST_USED_UPDATE_INTERVAL), + } + celery_app.conf.update(beat_schedule=beat_schedule, imports=imports) return celery_app diff --git a/api/schedule/update_api_token_last_used_task.py b/api/schedule/update_api_token_last_used_task.py new file mode 100644 index 0000000000..f0f304a671 --- /dev/null +++ b/api/schedule/update_api_token_last_used_task.py @@ -0,0 +1,114 @@ +""" +Scheduled task to batch-update API token last_used_at timestamps. + +Instead of updating the database on every request, token usage is recorded +in Redis as lightweight SET keys (api_token_active:{scope}:{token}). +This task runs periodically (default every 30 minutes) to flush those +records into the database in a single batch operation. +""" + +import logging +import time +from datetime import datetime + +import click +from sqlalchemy import update +from sqlalchemy.orm import Session + +import app +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.model import ApiToken +from services.api_token_service import ACTIVE_TOKEN_KEY_PREFIX + +logger = logging.getLogger(__name__) + + +@app.celery.task(queue="api_token") +def batch_update_api_token_last_used(): + """ + Batch update last_used_at for all recently active API tokens. + + Scans Redis for api_token_active:* keys, parses the token and scope + from each key, and performs a batch database update. + """ + click.echo(click.style("batch_update_api_token_last_used: start.", fg="green")) + start_at = time.perf_counter() + + updated_count = 0 + scanned_count = 0 + + try: + # Collect all active token keys and their values (the actual usage timestamps) + token_entries: list[tuple[str, str | None, datetime]] = [] # (token, scope, usage_time) + keys_to_delete: list[str | bytes] = [] + + for key in redis_client.scan_iter(match=f"{ACTIVE_TOKEN_KEY_PREFIX}*", count=200): + if isinstance(key, bytes): + key = key.decode("utf-8") + scanned_count += 1 + + # Read the value (ISO timestamp recorded at actual request time) + value = redis_client.get(key) + if not value: + keys_to_delete.append(key) + continue + + if isinstance(value, bytes): + value = value.decode("utf-8") + + try: + usage_time = datetime.fromisoformat(value) + except (ValueError, TypeError): + logger.warning("Invalid timestamp in key %s: %s", key, value) + keys_to_delete.append(key) + continue + + # Parse token info from key: api_token_active:{scope}:{token} + suffix = key[len(ACTIVE_TOKEN_KEY_PREFIX) :] + parts = suffix.split(":", 1) + if len(parts) == 2: + scope_str, token = parts + scope = None if scope_str == "None" else scope_str + token_entries.append((token, scope, usage_time)) + keys_to_delete.append(key) + + if not token_entries: + click.echo(click.style("batch_update_api_token_last_used: no active tokens found.", fg="yellow")) + # Still clean up any invalid keys + if keys_to_delete: + redis_client.delete(*keys_to_delete) + return + + # Update each token in its own short transaction to avoid long transactions + for token, scope, usage_time in token_entries: + with Session(db.engine, expire_on_commit=False) as session, session.begin(): + stmt = ( + update(ApiToken) + .where( + ApiToken.token == token, + ApiToken.type == scope, + (ApiToken.last_used_at.is_(None) | (ApiToken.last_used_at < usage_time)), + ) + .values(last_used_at=usage_time) + ) + result = session.execute(stmt) + rowcount = getattr(result, "rowcount", 0) + if rowcount > 0: + updated_count += 1 + + # Delete processed keys from Redis + if keys_to_delete: + redis_client.delete(*keys_to_delete) + + except Exception: + logger.exception("batch_update_api_token_last_used failed") + + elapsed = time.perf_counter() - start_at + click.echo( + click.style( + f"batch_update_api_token_last_used: done. " + f"scanned={scanned_count}, updated={updated_count}, elapsed={elapsed:.2f}s", + fg="green", + ) + ) diff --git a/api/services/api_token_service.py b/api/services/api_token_service.py new file mode 100644 index 0000000000..98cb5c0620 --- /dev/null +++ b/api/services/api_token_service.py @@ -0,0 +1,330 @@ +""" +API Token Service + +Handles all API token caching, validation, and usage recording. +Includes Redis cache operations, database queries, and single-flight concurrency control. +""" + +import logging +from datetime import datetime +from typing import Any + +from pydantic import BaseModel +from sqlalchemy import select +from sqlalchemy.orm import Session +from werkzeug.exceptions import Unauthorized + +from extensions.ext_database import db +from extensions.ext_redis import redis_client, redis_fallback +from libs.datetime_utils import naive_utc_now +from models.model import ApiToken + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------- +# Pydantic DTO +# --------------------------------------------------------------------- + + +class CachedApiToken(BaseModel): + """ + Pydantic model for cached API token data. + + This is NOT a SQLAlchemy model instance, but a plain Pydantic model + that mimics the ApiToken model interface for read-only access. + """ + + id: str + app_id: str | None + tenant_id: str | None + type: str + token: str + last_used_at: datetime | None + created_at: datetime | None + + def __repr__(self) -> str: + return f"" + + +# --------------------------------------------------------------------- +# Cache configuration +# --------------------------------------------------------------------- + +CACHE_KEY_PREFIX = "api_token" +CACHE_TTL_SECONDS = 600 # 10 minutes +CACHE_NULL_TTL_SECONDS = 60 # 1 minute for non-existent tokens +ACTIVE_TOKEN_KEY_PREFIX = "api_token_active:" + + +# --------------------------------------------------------------------- +# Cache class +# --------------------------------------------------------------------- + + +class ApiTokenCache: + """ + Redis cache wrapper for API tokens. + Handles serialization, deserialization, and cache invalidation. + """ + + @staticmethod + def make_active_key(token: str, scope: str | None = None) -> str: + """Generate Redis key for recording token usage.""" + return f"{ACTIVE_TOKEN_KEY_PREFIX}{scope}:{token}" + + @staticmethod + def _make_tenant_index_key(tenant_id: str) -> str: + """Generate Redis key for tenant token index.""" + return f"tenant_tokens:{tenant_id}" + + @staticmethod + def _make_cache_key(token: str, scope: str | None = None) -> str: + """Generate cache key for the given token and scope.""" + scope_str = scope or "any" + return f"{CACHE_KEY_PREFIX}:{scope_str}:{token}" + + @staticmethod + def _serialize_token(api_token: Any) -> bytes: + """Serialize ApiToken object to JSON bytes.""" + if isinstance(api_token, CachedApiToken): + return api_token.model_dump_json().encode("utf-8") + + cached = CachedApiToken( + id=str(api_token.id), + app_id=str(api_token.app_id) if api_token.app_id else None, + tenant_id=str(api_token.tenant_id) if api_token.tenant_id else None, + type=api_token.type, + token=api_token.token, + last_used_at=api_token.last_used_at, + created_at=api_token.created_at, + ) + return cached.model_dump_json().encode("utf-8") + + @staticmethod + def _deserialize_token(cached_data: bytes | str) -> Any: + """Deserialize JSON bytes/string back to a CachedApiToken Pydantic model.""" + if cached_data in {b"null", "null"}: + return None + + try: + if isinstance(cached_data, bytes): + cached_data = cached_data.decode("utf-8") + return CachedApiToken.model_validate_json(cached_data) + except (ValueError, Exception) as e: + logger.warning("Failed to deserialize token from cache: %s", e) + return None + + @staticmethod + @redis_fallback(default_return=None) + def get(token: str, scope: str | None) -> Any | None: + """Get API token from cache.""" + cache_key = ApiTokenCache._make_cache_key(token, scope) + cached_data = redis_client.get(cache_key) + + if cached_data is None: + logger.debug("Cache miss for token key: %s", cache_key) + return None + + logger.debug("Cache hit for token key: %s", cache_key) + return ApiTokenCache._deserialize_token(cached_data) + + @staticmethod + def _add_to_tenant_index(tenant_id: str | None, cache_key: str) -> None: + """Add cache key to tenant index for efficient invalidation.""" + if not tenant_id: + return + + try: + index_key = ApiTokenCache._make_tenant_index_key(tenant_id) + redis_client.sadd(index_key, cache_key) + redis_client.expire(index_key, CACHE_TTL_SECONDS + 60) + except Exception as e: + logger.warning("Failed to update tenant index: %s", e) + + @staticmethod + def _remove_from_tenant_index(tenant_id: str | None, cache_key: str) -> None: + """Remove cache key from tenant index.""" + if not tenant_id: + return + + try: + index_key = ApiTokenCache._make_tenant_index_key(tenant_id) + redis_client.srem(index_key, cache_key) + except Exception as e: + logger.warning("Failed to remove from tenant index: %s", e) + + @staticmethod + @redis_fallback(default_return=False) + def set(token: str, scope: str | None, api_token: Any | None, ttl: int = CACHE_TTL_SECONDS) -> bool: + """Set API token in cache.""" + cache_key = ApiTokenCache._make_cache_key(token, scope) + + if api_token is None: + cached_value = b"null" + ttl = CACHE_NULL_TTL_SECONDS + else: + cached_value = ApiTokenCache._serialize_token(api_token) + + try: + redis_client.setex(cache_key, ttl, cached_value) + + if api_token is not None and hasattr(api_token, "tenant_id"): + ApiTokenCache._add_to_tenant_index(api_token.tenant_id, cache_key) + + logger.debug("Cached token with key: %s, ttl: %ss", cache_key, ttl) + return True + except Exception as e: + logger.warning("Failed to cache token: %s", e) + return False + + @staticmethod + @redis_fallback(default_return=False) + def delete(token: str, scope: str | None = None) -> bool: + """Delete API token from cache.""" + if scope is None: + pattern = f"{CACHE_KEY_PREFIX}:*:{token}" + try: + keys_to_delete = list(redis_client.scan_iter(match=pattern)) + if keys_to_delete: + redis_client.delete(*keys_to_delete) + logger.info("Deleted %d cache entries for token", len(keys_to_delete)) + return True + except Exception as e: + logger.warning("Failed to delete token cache with pattern: %s", e) + return False + else: + cache_key = ApiTokenCache._make_cache_key(token, scope) + try: + tenant_id = None + try: + cached_data = redis_client.get(cache_key) + if cached_data and cached_data != b"null": + cached_token = ApiTokenCache._deserialize_token(cached_data) + if cached_token: + tenant_id = cached_token.tenant_id + except Exception as e: + logger.debug("Failed to get tenant_id for cache cleanup: %s", e) + + redis_client.delete(cache_key) + + if tenant_id: + ApiTokenCache._remove_from_tenant_index(tenant_id, cache_key) + + logger.info("Deleted cache for key: %s", cache_key) + return True + except Exception as e: + logger.warning("Failed to delete token cache: %s", e) + return False + + @staticmethod + @redis_fallback(default_return=False) + def invalidate_by_tenant(tenant_id: str) -> bool: + """Invalidate all API token caches for a specific tenant via tenant index.""" + try: + index_key = ApiTokenCache._make_tenant_index_key(tenant_id) + cache_keys = redis_client.smembers(index_key) + + if cache_keys: + deleted_count = 0 + for cache_key in cache_keys: + if isinstance(cache_key, bytes): + cache_key = cache_key.decode("utf-8") + redis_client.delete(cache_key) + deleted_count += 1 + + redis_client.delete(index_key) + + logger.info( + "Invalidated %d token cache entries for tenant: %s", + deleted_count, + tenant_id, + ) + else: + logger.info( + "No tenant index found for %s, relying on TTL expiration", + tenant_id, + ) + + return True + + except Exception as e: + logger.warning("Failed to invalidate tenant token cache: %s", e) + return False + + +# --------------------------------------------------------------------- +# Token usage recording (for batch update) +# --------------------------------------------------------------------- + + +def record_token_usage(auth_token: str, scope: str | None) -> None: + """ + Record token usage in Redis for later batch update by a scheduled job. + + Instead of dispatching a Celery task per request, we simply SET a key in Redis. + A Celery Beat scheduled task will periodically scan these keys and batch-update + last_used_at in the database. + """ + try: + key = ApiTokenCache.make_active_key(auth_token, scope) + redis_client.set(key, naive_utc_now().isoformat(), ex=3600) + except Exception as e: + logger.warning("Failed to record token usage: %s", e) + + +# --------------------------------------------------------------------- +# Database query + single-flight +# --------------------------------------------------------------------- + + +def query_token_from_db(auth_token: str, scope: str | None) -> ApiToken: + """ + Query API token from database and cache the result. + + Raises Unauthorized if token is invalid. + """ + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope) + api_token = session.scalar(stmt) + + if not api_token: + ApiTokenCache.set(auth_token, scope, None) + raise Unauthorized("Access token is invalid") + + ApiTokenCache.set(auth_token, scope, api_token) + record_token_usage(auth_token, scope) + return api_token + + +def fetch_token_with_single_flight(auth_token: str, scope: str | None) -> ApiToken | Any: + """ + Fetch token from DB with single-flight pattern using Redis lock. + + Ensures only one concurrent request queries the database for the same token. + Falls back to direct query if lock acquisition fails. + """ + logger.debug("Token cache miss, attempting to acquire query lock for scope: %s", scope) + + lock_key = f"api_token_query_lock:{scope}:{auth_token}" + lock = redis_client.lock(lock_key, timeout=10, blocking_timeout=5) + + try: + if lock.acquire(blocking=True): + try: + cached_token = ApiTokenCache.get(auth_token, scope) + if cached_token is not None: + logger.debug("Token cached by concurrent request, using cached version") + return cached_token + + return query_token_from_db(auth_token, scope) + finally: + lock.release() + else: + logger.warning("Lock timeout for token: %s, proceeding with direct query", auth_token[:10]) + return query_token_from_db(auth_token, scope) + except Unauthorized: + raise + except Exception as e: + logger.warning("Redis lock failed for token query: %s, proceeding anyway", e) + return query_token_from_db(auth_token, scope) diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index 6240f2200f..b1840662ff 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -48,6 +48,7 @@ from models.workflow import ( WorkflowArchiveLog, ) from repositories.factory import DifyAPIRepositoryFactory +from services.api_token_service import ApiTokenCache logger = logging.getLogger(__name__) @@ -134,6 +135,12 @@ def _delete_app_mcp_servers(tenant_id: str, app_id: str): def _delete_app_api_tokens(tenant_id: str, app_id: str): def del_api_token(session, api_token_id: str): + # Fetch token details for cache invalidation + token_obj = session.query(ApiToken).where(ApiToken.id == api_token_id).first() + if token_obj: + # Invalidate cache before deletion + ApiTokenCache.delete(token_obj.token, token_obj.type) + session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False) _delete_records( diff --git a/api/tests/integration_tests/libs/test_api_token_cache_integration.py b/api/tests/integration_tests/libs/test_api_token_cache_integration.py new file mode 100644 index 0000000000..166fcb515f --- /dev/null +++ b/api/tests/integration_tests/libs/test_api_token_cache_integration.py @@ -0,0 +1,375 @@ +""" +Integration tests for API Token Cache with Redis. + +These tests require: +- Redis server running +- Test database configured +""" + +import time +from datetime import datetime, timedelta +from unittest.mock import patch + +import pytest + +from extensions.ext_redis import redis_client +from models.model import ApiToken +from services.api_token_service import ApiTokenCache, CachedApiToken + + +class TestApiTokenCacheRedisIntegration: + """Integration tests with real Redis.""" + + def setup_method(self): + """Setup test fixtures and clean Redis.""" + self.test_token = "test-integration-token-123" + self.test_scope = "app" + self.cache_key = f"api_token:{self.test_scope}:{self.test_token}" + + # Clean up any existing test data + self._cleanup() + + def teardown_method(self): + """Cleanup test data from Redis.""" + self._cleanup() + + def _cleanup(self): + """Remove test data from Redis.""" + try: + redis_client.delete(self.cache_key) + redis_client.delete(ApiTokenCache._make_tenant_index_key("test-tenant-id")) + redis_client.delete(ApiTokenCache.make_active_key(self.test_token, self.test_scope)) + except Exception: + pass # Ignore cleanup errors + + def test_cache_set_and_get_with_real_redis(self): + """Test cache set and get operations with real Redis.""" + from unittest.mock import MagicMock + + mock_token = MagicMock() + mock_token.id = "test-id-123" + mock_token.app_id = "test-app-456" + mock_token.tenant_id = "test-tenant-789" + mock_token.type = "app" + mock_token.token = self.test_token + mock_token.last_used_at = datetime.now() + mock_token.created_at = datetime.now() - timedelta(days=30) + + # Set in cache + result = ApiTokenCache.set(self.test_token, self.test_scope, mock_token) + assert result is True + + # Verify in Redis + cached_data = redis_client.get(self.cache_key) + assert cached_data is not None + + # Get from cache + cached_token = ApiTokenCache.get(self.test_token, self.test_scope) + assert cached_token is not None + assert isinstance(cached_token, CachedApiToken) + assert cached_token.id == "test-id-123" + assert cached_token.app_id == "test-app-456" + assert cached_token.tenant_id == "test-tenant-789" + assert cached_token.type == "app" + assert cached_token.token == self.test_token + + def test_cache_ttl_with_real_redis(self): + """Test cache TTL is set correctly.""" + from unittest.mock import MagicMock + + mock_token = MagicMock() + mock_token.id = "test-id" + mock_token.app_id = "test-app" + mock_token.tenant_id = "test-tenant" + mock_token.type = "app" + mock_token.token = self.test_token + mock_token.last_used_at = None + mock_token.created_at = datetime.now() + + ApiTokenCache.set(self.test_token, self.test_scope, mock_token) + + ttl = redis_client.ttl(self.cache_key) + assert 595 <= ttl <= 600 # Should be around 600 seconds (10 minutes) + + def test_cache_null_value_for_invalid_token(self): + """Test caching null value for invalid tokens.""" + result = ApiTokenCache.set(self.test_token, self.test_scope, None) + assert result is True + + cached_data = redis_client.get(self.cache_key) + assert cached_data == b"null" + + cached_token = ApiTokenCache.get(self.test_token, self.test_scope) + assert cached_token is None + + ttl = redis_client.ttl(self.cache_key) + assert 55 <= ttl <= 60 + + def test_cache_delete_with_real_redis(self): + """Test cache deletion with real Redis.""" + from unittest.mock import MagicMock + + mock_token = MagicMock() + mock_token.id = "test-id" + mock_token.app_id = "test-app" + mock_token.tenant_id = "test-tenant" + mock_token.type = "app" + mock_token.token = self.test_token + mock_token.last_used_at = None + mock_token.created_at = datetime.now() + + ApiTokenCache.set(self.test_token, self.test_scope, mock_token) + assert redis_client.exists(self.cache_key) == 1 + + result = ApiTokenCache.delete(self.test_token, self.test_scope) + assert result is True + assert redis_client.exists(self.cache_key) == 0 + + def test_tenant_index_creation(self): + """Test tenant index is created when caching token.""" + from unittest.mock import MagicMock + + tenant_id = "test-tenant-id" + mock_token = MagicMock() + mock_token.id = "test-id" + mock_token.app_id = "test-app" + mock_token.tenant_id = tenant_id + mock_token.type = "app" + mock_token.token = self.test_token + mock_token.last_used_at = None + mock_token.created_at = datetime.now() + + ApiTokenCache.set(self.test_token, self.test_scope, mock_token) + + index_key = ApiTokenCache._make_tenant_index_key(tenant_id) + assert redis_client.exists(index_key) == 1 + + members = redis_client.smembers(index_key) + cache_keys = [m.decode("utf-8") if isinstance(m, bytes) else m for m in members] + assert self.cache_key in cache_keys + + def test_invalidate_by_tenant_via_index(self): + """Test tenant-wide cache invalidation using index (fast path).""" + from unittest.mock import MagicMock + + tenant_id = "test-tenant-id" + + for i in range(3): + token_value = f"test-token-{i}" + mock_token = MagicMock() + mock_token.id = f"test-id-{i}" + mock_token.app_id = "test-app" + mock_token.tenant_id = tenant_id + mock_token.type = "app" + mock_token.token = token_value + mock_token.last_used_at = None + mock_token.created_at = datetime.now() + + ApiTokenCache.set(token_value, "app", mock_token) + + for i in range(3): + key = f"api_token:app:test-token-{i}" + assert redis_client.exists(key) == 1 + + result = ApiTokenCache.invalidate_by_tenant(tenant_id) + assert result is True + + for i in range(3): + key = f"api_token:app:test-token-{i}" + assert redis_client.exists(key) == 0 + + assert redis_client.exists(ApiTokenCache._make_tenant_index_key(tenant_id)) == 0 + + def test_concurrent_cache_access(self): + """Test concurrent cache access doesn't cause issues.""" + import concurrent.futures + from unittest.mock import MagicMock + + mock_token = MagicMock() + mock_token.id = "test-id" + mock_token.app_id = "test-app" + mock_token.tenant_id = "test-tenant" + mock_token.type = "app" + mock_token.token = self.test_token + mock_token.last_used_at = None + mock_token.created_at = datetime.now() + + ApiTokenCache.set(self.test_token, self.test_scope, mock_token) + + def get_from_cache(): + return ApiTokenCache.get(self.test_token, self.test_scope) + + with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: + futures = [executor.submit(get_from_cache) for _ in range(50)] + results = [f.result() for f in concurrent.futures.as_completed(futures)] + + assert len(results) == 50 + assert all(r is not None for r in results) + assert all(isinstance(r, CachedApiToken) for r in results) + + +class TestTokenUsageRecording: + """Tests for recording token usage in Redis (batch update approach).""" + + def setup_method(self): + self.test_token = "test-usage-token" + self.test_scope = "app" + self.active_key = ApiTokenCache.make_active_key(self.test_token, self.test_scope) + + def teardown_method(self): + try: + redis_client.delete(self.active_key) + except Exception: + pass + + def test_record_token_usage_sets_redis_key(self): + """Test that record_token_usage writes an active key to Redis.""" + from services.api_token_service import record_token_usage + + record_token_usage(self.test_token, self.test_scope) + + # Key should exist + assert redis_client.exists(self.active_key) == 1 + + # Value should be an ISO timestamp + value = redis_client.get(self.active_key) + if isinstance(value, bytes): + value = value.decode("utf-8") + datetime.fromisoformat(value) # Should not raise + + def test_record_token_usage_has_ttl(self): + """Test that active keys have a TTL as safety net.""" + from services.api_token_service import record_token_usage + + record_token_usage(self.test_token, self.test_scope) + + ttl = redis_client.ttl(self.active_key) + assert 3595 <= ttl <= 3600 # ~1 hour + + def test_record_token_usage_overwrites(self): + """Test that repeated calls overwrite the same key (no accumulation).""" + from services.api_token_service import record_token_usage + + record_token_usage(self.test_token, self.test_scope) + first_value = redis_client.get(self.active_key) + + time.sleep(0.01) # Tiny delay so timestamp differs + + record_token_usage(self.test_token, self.test_scope) + second_value = redis_client.get(self.active_key) + + # Key count should still be 1 (overwritten, not accumulated) + assert redis_client.exists(self.active_key) == 1 + + +class TestEndToEndCacheFlow: + """End-to-end integration test for complete cache flow.""" + + @pytest.mark.usefixtures("db_session") + def test_complete_flow_cache_miss_then_hit(self, db_session): + """ + Test complete flow: + 1. First request (cache miss) -> query DB -> cache result + 2. Second request (cache hit) -> return from cache + 3. Verify Redis state + """ + test_token_value = "test-e2e-token" + test_scope = "app" + + test_token = ApiToken() + test_token.id = "test-e2e-id" + test_token.token = test_token_value + test_token.type = test_scope + test_token.app_id = "test-app" + test_token.tenant_id = "test-tenant" + test_token.last_used_at = None + test_token.created_at = datetime.now() + + db_session.add(test_token) + db_session.commit() + + try: + # Step 1: Cache miss - set token in cache + ApiTokenCache.set(test_token_value, test_scope, test_token) + + cache_key = f"api_token:{test_scope}:{test_token_value}" + assert redis_client.exists(cache_key) == 1 + + # Step 2: Cache hit - get from cache + cached_token = ApiTokenCache.get(test_token_value, test_scope) + assert cached_token is not None + assert cached_token.id == test_token.id + assert cached_token.token == test_token_value + + # Step 3: Verify tenant index + index_key = ApiTokenCache._make_tenant_index_key(test_token.tenant_id) + assert redis_client.exists(index_key) == 1 + assert cache_key.encode() in redis_client.smembers(index_key) + + # Step 4: Delete and verify cleanup + ApiTokenCache.delete(test_token_value, test_scope) + assert redis_client.exists(cache_key) == 0 + assert cache_key.encode() not in redis_client.smembers(index_key) + + finally: + db_session.delete(test_token) + db_session.commit() + redis_client.delete(f"api_token:{test_scope}:{test_token_value}") + redis_client.delete(ApiTokenCache._make_tenant_index_key(test_token.tenant_id)) + + def test_high_concurrency_simulation(self): + """Simulate high concurrency access to cache.""" + import concurrent.futures + from unittest.mock import MagicMock + + test_token_value = "test-concurrent-token" + test_scope = "app" + + mock_token = MagicMock() + mock_token.id = "concurrent-id" + mock_token.app_id = "test-app" + mock_token.tenant_id = "test-tenant" + mock_token.type = test_scope + mock_token.token = test_token_value + mock_token.last_used_at = datetime.now() + mock_token.created_at = datetime.now() + + ApiTokenCache.set(test_token_value, test_scope, mock_token) + + try: + + def read_cache(): + return ApiTokenCache.get(test_token_value, test_scope) + + start_time = time.time() + with concurrent.futures.ThreadPoolExecutor(max_workers=20) as executor: + futures = [executor.submit(read_cache) for _ in range(100)] + results = [f.result() for f in concurrent.futures.as_completed(futures)] + elapsed = time.time() - start_time + + assert len(results) == 100 + assert all(r is not None for r in results) + + assert elapsed < 1.0, f"Too slow: {elapsed}s for 100 cache reads" + + finally: + ApiTokenCache.delete(test_token_value, test_scope) + redis_client.delete(ApiTokenCache._make_tenant_index_key(mock_token.tenant_id)) + + +class TestRedisFailover: + """Test behavior when Redis is unavailable.""" + + @patch("services.api_token_service.redis_client") + def test_graceful_degradation_when_redis_fails(self, mock_redis): + """Test system degrades gracefully when Redis is unavailable.""" + from redis import RedisError + + mock_redis.get.side_effect = RedisError("Connection failed") + mock_redis.setex.side_effect = RedisError("Connection failed") + + result_get = ApiTokenCache.get("test-token", "app") + assert result_get is None + + result_set = ApiTokenCache.set("test-token", "app", None) + assert result_set is False diff --git a/api/tests/unit_tests/extensions/test_celery_ssl.py b/api/tests/unit_tests/extensions/test_celery_ssl.py index d3a4d69f07..34d48fa94e 100644 --- a/api/tests/unit_tests/extensions/test_celery_ssl.py +++ b/api/tests/unit_tests/extensions/test_celery_ssl.py @@ -132,6 +132,8 @@ class TestCelerySSLConfiguration: mock_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK = 0 mock_config.ENABLE_TRIGGER_PROVIDER_REFRESH_TASK = False mock_config.TRIGGER_PROVIDER_REFRESH_INTERVAL = 15 + mock_config.ENABLE_API_TOKEN_LAST_USED_UPDATE_TASK = False + mock_config.API_TOKEN_LAST_USED_UPDATE_INTERVAL = 30 with patch("extensions.ext_celery.dify_config", mock_config): from dify_app import DifyApp diff --git a/api/tests/unit_tests/libs/test_api_token_cache.py b/api/tests/unit_tests/libs/test_api_token_cache.py new file mode 100644 index 0000000000..fa4c5e77a7 --- /dev/null +++ b/api/tests/unit_tests/libs/test_api_token_cache.py @@ -0,0 +1,250 @@ +""" +Unit tests for API Token Cache module. +""" + +import json +from datetime import datetime +from unittest.mock import MagicMock, patch + +from services.api_token_service import ( + CACHE_KEY_PREFIX, + CACHE_NULL_TTL_SECONDS, + CACHE_TTL_SECONDS, + ApiTokenCache, + CachedApiToken, +) + + +class TestApiTokenCache: + """Test cases for ApiTokenCache class.""" + + def setup_method(self): + """Setup test fixtures.""" + self.mock_token = MagicMock() + self.mock_token.id = "test-token-id-123" + self.mock_token.app_id = "test-app-id-456" + self.mock_token.tenant_id = "test-tenant-id-789" + self.mock_token.type = "app" + self.mock_token.token = "test-token-value-abc" + self.mock_token.last_used_at = datetime(2026, 2, 3, 10, 0, 0) + self.mock_token.created_at = datetime(2026, 1, 1, 0, 0, 0) + + def test_make_cache_key(self): + """Test cache key generation.""" + # Test with scope + key = ApiTokenCache._make_cache_key("my-token", "app") + assert key == f"{CACHE_KEY_PREFIX}:app:my-token" + + # Test without scope + key = ApiTokenCache._make_cache_key("my-token", None) + assert key == f"{CACHE_KEY_PREFIX}:any:my-token" + + def test_serialize_token(self): + """Test token serialization.""" + serialized = ApiTokenCache._serialize_token(self.mock_token) + data = json.loads(serialized) + + assert data["id"] == "test-token-id-123" + assert data["app_id"] == "test-app-id-456" + assert data["tenant_id"] == "test-tenant-id-789" + assert data["type"] == "app" + assert data["token"] == "test-token-value-abc" + assert data["last_used_at"] == "2026-02-03T10:00:00" + assert data["created_at"] == "2026-01-01T00:00:00" + + def test_serialize_token_with_nulls(self): + """Test token serialization with None values.""" + mock_token = MagicMock() + mock_token.id = "test-id" + mock_token.app_id = None + mock_token.tenant_id = None + mock_token.type = "dataset" + mock_token.token = "test-token" + mock_token.last_used_at = None + mock_token.created_at = datetime(2026, 1, 1, 0, 0, 0) + + serialized = ApiTokenCache._serialize_token(mock_token) + data = json.loads(serialized) + + assert data["app_id"] is None + assert data["tenant_id"] is None + assert data["last_used_at"] is None + + def test_deserialize_token(self): + """Test token deserialization.""" + cached_data = json.dumps( + { + "id": "test-id", + "app_id": "test-app", + "tenant_id": "test-tenant", + "type": "app", + "token": "test-token", + "last_used_at": "2026-02-03T10:00:00", + "created_at": "2026-01-01T00:00:00", + } + ) + + result = ApiTokenCache._deserialize_token(cached_data) + + assert isinstance(result, CachedApiToken) + assert result.id == "test-id" + assert result.app_id == "test-app" + assert result.tenant_id == "test-tenant" + assert result.type == "app" + assert result.token == "test-token" + assert result.last_used_at == datetime(2026, 2, 3, 10, 0, 0) + assert result.created_at == datetime(2026, 1, 1, 0, 0, 0) + + def test_deserialize_null_token(self): + """Test deserialization of null token (cached miss).""" + result = ApiTokenCache._deserialize_token("null") + assert result is None + + def test_deserialize_invalid_json(self): + """Test deserialization with invalid JSON.""" + result = ApiTokenCache._deserialize_token("invalid-json{") + assert result is None + + @patch("services.api_token_service.redis_client") + def test_get_cache_hit(self, mock_redis): + """Test cache hit scenario.""" + cached_data = json.dumps( + { + "id": "test-id", + "app_id": "test-app", + "tenant_id": "test-tenant", + "type": "app", + "token": "test-token", + "last_used_at": "2026-02-03T10:00:00", + "created_at": "2026-01-01T00:00:00", + } + ).encode("utf-8") + mock_redis.get.return_value = cached_data + + result = ApiTokenCache.get("test-token", "app") + + assert result is not None + assert isinstance(result, CachedApiToken) + assert result.app_id == "test-app" + mock_redis.get.assert_called_once_with(f"{CACHE_KEY_PREFIX}:app:test-token") + + @patch("services.api_token_service.redis_client") + def test_get_cache_miss(self, mock_redis): + """Test cache miss scenario.""" + mock_redis.get.return_value = None + + result = ApiTokenCache.get("test-token", "app") + + assert result is None + mock_redis.get.assert_called_once() + + @patch("services.api_token_service.redis_client") + def test_set_valid_token(self, mock_redis): + """Test setting a valid token in cache.""" + result = ApiTokenCache.set("test-token", "app", self.mock_token) + + assert result is True + mock_redis.setex.assert_called_once() + args = mock_redis.setex.call_args[0] + assert args[0] == f"{CACHE_KEY_PREFIX}:app:test-token" + assert args[1] == CACHE_TTL_SECONDS + + @patch("services.api_token_service.redis_client") + def test_set_null_token(self, mock_redis): + """Test setting a null token (cache penetration prevention).""" + result = ApiTokenCache.set("invalid-token", "app", None) + + assert result is True + mock_redis.setex.assert_called_once() + args = mock_redis.setex.call_args[0] + assert args[0] == f"{CACHE_KEY_PREFIX}:app:invalid-token" + assert args[1] == CACHE_NULL_TTL_SECONDS + assert args[2] == b"null" + + @patch("services.api_token_service.redis_client") + def test_delete_with_scope(self, mock_redis): + """Test deleting token cache with specific scope.""" + result = ApiTokenCache.delete("test-token", "app") + + assert result is True + mock_redis.delete.assert_called_once_with(f"{CACHE_KEY_PREFIX}:app:test-token") + + @patch("services.api_token_service.redis_client") + def test_delete_without_scope(self, mock_redis): + """Test deleting token cache without scope (delete all).""" + # Mock scan_iter to return an iterator of keys + mock_redis.scan_iter.return_value = iter( + [ + b"api_token:app:test-token", + b"api_token:dataset:test-token", + ] + ) + + result = ApiTokenCache.delete("test-token", None) + + assert result is True + # Verify scan_iter was called with the correct pattern + mock_redis.scan_iter.assert_called_once() + call_args = mock_redis.scan_iter.call_args + assert call_args[1]["match"] == f"{CACHE_KEY_PREFIX}:*:test-token" + + # Verify delete was called with all matched keys + mock_redis.delete.assert_called_once_with( + b"api_token:app:test-token", + b"api_token:dataset:test-token", + ) + + @patch("services.api_token_service.redis_client") + def test_redis_fallback_on_exception(self, mock_redis): + """Test Redis fallback when Redis is unavailable.""" + from redis import RedisError + + mock_redis.get.side_effect = RedisError("Connection failed") + + result = ApiTokenCache.get("test-token", "app") + + # Should return None (fallback) instead of raising exception + assert result is None + + +class TestApiTokenCacheIntegration: + """Integration test scenarios.""" + + @patch("services.api_token_service.redis_client") + def test_full_cache_lifecycle(self, mock_redis): + """Test complete cache lifecycle: set -> get -> delete.""" + # Setup mock token + mock_token = MagicMock() + mock_token.id = "id-123" + mock_token.app_id = "app-456" + mock_token.tenant_id = "tenant-789" + mock_token.type = "app" + mock_token.token = "token-abc" + mock_token.last_used_at = datetime(2026, 2, 3, 10, 0, 0) + mock_token.created_at = datetime(2026, 1, 1, 0, 0, 0) + + # 1. Set token in cache + ApiTokenCache.set("token-abc", "app", mock_token) + assert mock_redis.setex.called + + # 2. Simulate cache hit + cached_data = ApiTokenCache._serialize_token(mock_token) + mock_redis.get.return_value = cached_data # bytes from model_dump_json().encode() + + retrieved = ApiTokenCache.get("token-abc", "app") + assert retrieved is not None + assert isinstance(retrieved, CachedApiToken) + + # 3. Delete from cache + ApiTokenCache.delete("token-abc", "app") + assert mock_redis.delete.called + + @patch("services.api_token_service.redis_client") + def test_cache_penetration_prevention(self, mock_redis): + """Test that non-existent tokens are cached as null.""" + # Set null token (cache miss) + ApiTokenCache.set("non-existent-token", "app", None) + + args = mock_redis.setex.call_args[0] + assert args[2] == b"null" + assert args[1] == CACHE_NULL_TTL_SECONDS # Shorter TTL for null values