diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 5e97e73e35..803d4404c7 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -1039,7 +1039,7 @@ class CeleryScheduleTasksConfig(BaseSettings): ) TRIGGER_PROVIDER_SUBSCRIPTION_THRESHOLD_SECONDS: int = Field( description="Proactive subscription refresh threshold in seconds", - default=300, + default=60 * 60, ) diff --git a/api/core/trigger/utils/locks.py b/api/core/trigger/utils/locks.py new file mode 100644 index 0000000000..46833396e0 --- /dev/null +++ b/api/core/trigger/utils/locks.py @@ -0,0 +1,12 @@ +from collections.abc import Sequence +from itertools import starmap + + +def build_trigger_refresh_lock_key(tenant_id: str, subscription_id: str) -> str: + """Build the Redis lock key for trigger subscription refresh in-flight protection.""" + return f"trigger_provider_refresh_lock:{tenant_id}_{subscription_id}" + + +def build_trigger_refresh_lock_keys(pairs: Sequence[tuple[str, str]]) -> list[str]: + """Build Redis lock keys for a sequence of (tenant_id, subscription_id) pairs.""" + return list(starmap(build_trigger_refresh_lock_key, pairs)) diff --git a/api/schedule/trigger_provider_refresh_task.py b/api/schedule/trigger_provider_refresh_task.py index f91a849a71..3b3e478793 100644 --- a/api/schedule/trigger_provider_refresh_task.py +++ b/api/schedule/trigger_provider_refresh_task.py @@ -1,7 +1,19 @@ import logging +import math import time +from collections.abc import Iterable, Sequence + +from sqlalchemy import ColumnElement, and_, func, or_, select +from sqlalchemy.engine.row import Row +from sqlalchemy.orm import Session import app +from configs import dify_config +from core.trigger.utils.locks import build_trigger_refresh_lock_keys +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.trigger import TriggerSubscription +from tasks.trigger_subscription_refresh_tasks import trigger_subscription_refresh logger = logging.getLogger(__name__) @@ -10,12 +22,83 @@ def _now_ts() -> int: return int(time.time()) -@app.celery.task(queue="trigger_refresh") +def _build_due_filter(now_ts: int): + """Build SQLAlchemy filter for due credential or subscription refresh.""" + credential_due: ColumnElement[bool] = and_( + TriggerSubscription.credential_expires_at != -1, + TriggerSubscription.credential_expires_at + <= now_ts + int(dify_config.TRIGGER_PROVIDER_CREDENTIAL_THRESHOLD_SECONDS), + ) + subscription_due: ColumnElement[bool] = and_( + TriggerSubscription.expires_at != -1, + TriggerSubscription.expires_at <= now_ts + int(dify_config.TRIGGER_PROVIDER_SUBSCRIPTION_THRESHOLD_SECONDS), + ) + return or_(credential_due, subscription_due) + + +def _acquire_locks(keys: Iterable[str], ttl_seconds: int) -> list[bool]: + """Attempt to acquire locks in a single pipelined round-trip. + + Returns a list of booleans indicating which locks were acquired. + """ + pipe = redis_client.pipeline(transaction=False) + for key in keys: + pipe.set(key, b"1", ex=ttl_seconds, nx=True) + results = pipe.execute() + return [bool(r) for r in results] + + +@app.celery.task(queue="trigger_refresh_publisher") def trigger_provider_refresh() -> None: """ - Simple trigger provider refresh task. - - Scans due trigger subscriptions in small batches - - Refreshes OAuth credentials if needed - - Refreshes subscription metadata if needed + Scan due trigger subscriptions and enqueue refresh tasks with in-flight locks. """ - pass \ No newline at end of file + now: int = _now_ts() + + batch_size: int = int(dify_config.TRIGGER_PROVIDER_REFRESH_BATCH_SIZE) + lock_ttl: int = max(300, int(dify_config.TRIGGER_PROVIDER_SUBSCRIPTION_THRESHOLD_SECONDS)) + + with Session(db.engine, expire_on_commit=False) as session: + filter: ColumnElement[bool] = _build_due_filter(now_ts=now) + total_due: int = int(session.scalar(statement=select(func.count()).where(filter)) or 0) + logger.info("Trigger refresh scan start: due=%d", total_due) + if total_due == 0: + return + + pages: int = math.ceil(total_due / batch_size) + for page in range(pages): + offset: int = page * batch_size + subscription_rows: Sequence[Row[tuple[str, str]]] = session.execute( + select(TriggerSubscription.tenant_id, TriggerSubscription.id) + .where(filter) + .order_by(TriggerSubscription.updated_at.asc()) + .offset(offset) + .limit(batch_size) + ).all() + if not subscription_rows: + logger.debug("Trigger refresh page %d/%d empty", page + 1, pages) + continue + + subscriptions: list[tuple[str, str]] = [ + (str(tenant_id), str(subscription_id)) for tenant_id, subscription_id in subscription_rows + ] + lock_keys: list[str] = build_trigger_refresh_lock_keys(subscriptions) + acquired: list[bool] = _acquire_locks(keys=lock_keys, ttl_seconds=lock_ttl) + + enqueued: int = 0 + for (tenant_id, subscription_id), is_locked in zip(subscriptions, acquired): + if not is_locked: + continue + trigger_subscription_refresh.delay(tenant_id=tenant_id, subscription_id=subscription_id) + enqueued += 1 + + logger.info( + "Trigger refresh page %d/%d: scanned=%d locks_acquired=%d enqueued=%d", + page + 1, + pages, + len(subscriptions), + sum(1 for x in acquired if x), + enqueued, + ) + + logger.info("Trigger refresh scan done: due=%d", total_due) diff --git a/api/services/trigger/trigger_provider_service.py b/api/services/trigger/trigger_provider_service.py index 4cc19b485a..1c474b2c48 100644 --- a/api/services/trigger/trigger_provider_service.py +++ b/api/services/trigger/trigger_provider_service.py @@ -1,5 +1,6 @@ import json import logging +import time as _time import uuid from collections.abc import Mapping from typing import Any, Optional @@ -18,6 +19,7 @@ from core.trigger.entities.api_entities import ( TriggerProviderApiEntity, TriggerProviderSubscriptionApiEntity, ) +from core.trigger.entities.entities import Subscription as TriggerSubscriptionEntity from core.trigger.provider import PluginTriggerProviderController from core.trigger.trigger_manager import TriggerManager from core.trigger.utils.encryption import ( @@ -25,6 +27,7 @@ from core.trigger.utils.encryption import ( create_trigger_provider_encrypter_for_subscription, delete_cache_for_subscription, ) +from core.trigger.utils.endpoint import generate_plugin_trigger_endpoint_url from extensions.ext_database import db from extensions.ext_redis import redis_client from models.provider_ids import TriggerProviderID @@ -347,6 +350,91 @@ class TriggerProviderService: "expires_at": refreshed_credentials.expires_at, } + @classmethod + def refresh_subscription( + cls, + tenant_id: str, + subscription_id: str, + now: int | None = None, + ) -> Mapping[str, Any]: + """ + Refresh trigger subscription if expired. + + Args: + tenant_id: Tenant ID + subscription_id: Subscription instance ID + now: Current timestamp, defaults to `int(time.time())` + + Returns: + Mapping with keys: `result` ("success"|"skipped") and `expires_at` (new or existing value) + """ + now_ts: int = int(now if now is not None else _time.time()) + + with Session(db.engine) as session: + subscription: TriggerSubscription | None = ( + session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first() + ) + if subscription is None: + raise ValueError(f"Trigger provider subscription {subscription_id} not found") + + if subscription.expires_at == -1 or int(subscription.expires_at) > now_ts: + logger.debug( + "Subscription not due for refresh: tenant=%s id=%s expires_at=%s now=%s", + tenant_id, + subscription_id, + subscription.expires_at, + now_ts, + ) + return {"result": "skipped", "expires_at": int(subscription.expires_at)} + + provider_id = TriggerProviderID(subscription.provider_id) + controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider( + tenant_id=tenant_id, provider_id=provider_id + ) + + # Decrypt credentials and properties for runtime + credential_encrypter, _ = create_trigger_provider_encrypter_for_subscription( + tenant_id=tenant_id, + controller=controller, + subscription=subscription, + ) + properties_encrypter, properties_cache = create_trigger_provider_encrypter_for_properties( + tenant_id=tenant_id, + controller=controller, + subscription=subscription, + ) + + decrypted_credentials = credential_encrypter.decrypt(subscription.credentials) + decrypted_properties = properties_encrypter.decrypt(subscription.properties) + + sub_entity: TriggerSubscriptionEntity = TriggerSubscriptionEntity( + expires_at=int(subscription.expires_at), + endpoint=generate_plugin_trigger_endpoint_url(subscription.endpoint_id), + parameters=subscription.parameters, + properties=decrypted_properties, + ) + + refreshed: TriggerSubscriptionEntity = controller.refresh_trigger( + subscription=sub_entity, + credentials=decrypted_credentials, + credential_type=CredentialType.of(subscription.credential_type), + ) + + # Persist refreshed properties and expires_at + subscription.properties = dict(properties_encrypter.encrypt(dict(refreshed.properties))) + subscription.expires_at = int(refreshed.expires_at) + session.commit() + properties_cache.delete() + + logger.info( + "Subscription refreshed (service): tenant=%s id=%s new_expires_at=%s", + tenant_id, + subscription_id, + subscription.expires_at, + ) + + return {"result": "success", "expires_at": int(refreshed.expires_at)} + @classmethod def get_oauth_client(cls, tenant_id: str, provider_id: TriggerProviderID) -> Optional[Mapping[str, Any]]: """ diff --git a/api/tasks/trigger_subscription_refresh_tasks.py b/api/tasks/trigger_subscription_refresh_tasks.py new file mode 100644 index 0000000000..11324df881 --- /dev/null +++ b/api/tasks/trigger_subscription_refresh_tasks.py @@ -0,0 +1,115 @@ +import logging +import time +from collections.abc import Mapping +from typing import Any + +from celery import shared_task +from sqlalchemy.orm import Session + +from core.plugin.entities.plugin_daemon import CredentialType +from core.trigger.utils.locks import build_trigger_refresh_lock_key +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.trigger import TriggerSubscription +from services.trigger.trigger_provider_service import TriggerProviderService + +logger = logging.getLogger(__name__) + + +def _now_ts() -> int: + return int(time.time()) + + +def _load_subscription(session: Session, tenant_id: str, subscription_id: str) -> TriggerSubscription | None: + return session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first() + + +def _refresh_oauth_if_expired(tenant_id: str, subscription: TriggerSubscription, now: int) -> None: + if ( + subscription.credential_expires_at != -1 + and int(subscription.credential_expires_at) <= now + and CredentialType.of(subscription.credential_type) == CredentialType.OAUTH2 + ): + logger.info( + "Refreshing OAuth token: tenant=%s subscription_id=%s expires_at=%s now=%s", + tenant_id, + subscription.id, + subscription.credential_expires_at, + now, + ) + try: + result: Mapping[str, Any] = TriggerProviderService.refresh_oauth_token( + tenant_id=tenant_id, subscription_id=subscription.id + ) + logger.info( + "OAuth token refreshed: tenant=%s subscription_id=%s result=%s", tenant_id, subscription.id, result + ) + except Exception: + logger.exception("OAuth refresh failed: tenant=%s subscription_id=%s", tenant_id, subscription.id) + + +def _refresh_subscription_if_expired( + tenant_id: str, + subscription: TriggerSubscription, + now: int, +) -> None: + if subscription.expires_at == -1 or int(subscription.expires_at) > now: + logger.debug( + "Subscription not due: tenant=%s subscription_id=%s expires_at=%s now=%s", + tenant_id, + subscription.id, + subscription.expires_at, + now, + ) + return + + try: + result: Mapping[str, Any] = TriggerProviderService.refresh_subscription( + tenant_id=tenant_id, subscription_id=subscription.id, now=now + ) + logger.info( + "Subscription refreshed: tenant=%s subscription_id=%s result=%s", + tenant_id, + subscription.id, + result.get("result"), + ) + except Exception: + logger.exception("Subscription refresh failed: tenant=%s id=%s", tenant_id, subscription.id) + + +@shared_task(queue="trigger_refresh_executor") +def trigger_subscription_refresh(tenant_id: str, subscription_id: str) -> None: + """Refresh a trigger subscription if needed, guarded by a Redis in-flight lock.""" + lock_key: str = build_trigger_refresh_lock_key(tenant_id, subscription_id) + if not redis_client.get(lock_key): + logger.debug("Refresh lock missing, skip: %s", lock_key) + return + + logger.info("Begin subscription refresh: tenant=%s id=%s", tenant_id, subscription_id) + try: + now: int = _now_ts() + with Session(db.engine) as session: + subscription: TriggerSubscription | None = _load_subscription(session, tenant_id, subscription_id) + + if not subscription: + logger.warning("Subscription not found: tenant=%s id=%s", tenant_id, subscription_id) + return + + logger.debug( + "Loaded subscription: tenant=%s id=%s cred_exp=%s sub_exp=%s now=%s", + tenant_id, + subscription.id, + subscription.credential_expires_at, + subscription.expires_at, + now, + ) + + _refresh_oauth_if_expired(tenant_id=tenant_id, subscription=subscription, now=now) + _refresh_subscription_if_expired(tenant_id=tenant_id, subscription=subscription, now=now) + finally: + try: + redis_client.delete(lock_key) + logger.debug("Lock released: %s", lock_key) + except Exception: + # Best-effort lock cleanup + logger.warning("Failed to release lock: %s", lock_key, exc_info=True)