diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 5e97e73e35..803d4404c7 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -1039,7 +1039,7 @@ class CeleryScheduleTasksConfig(BaseSettings): ) TRIGGER_PROVIDER_SUBSCRIPTION_THRESHOLD_SECONDS: int = Field( description="Proactive subscription refresh threshold in seconds", - default=300, + default=60 * 60, ) diff --git a/api/schedule/trigger_provider_refresh_task.py b/api/schedule/trigger_provider_refresh_task.py index f91a849a71..e56ad7da7a 100644 --- a/api/schedule/trigger_provider_refresh_task.py +++ b/api/schedule/trigger_provider_refresh_task.py @@ -1,7 +1,18 @@ import logging +import math import time +from collections.abc import Iterable, Sequence + +from sqlalchemy import ColumnElement, and_, func, or_, select +from sqlalchemy.engine.row import Row +from sqlalchemy.orm import Session import app +from configs import dify_config +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.trigger import TriggerSubscription +from tasks.trigger_subscription_refresh_tasks import trigger_subscription_refresh logger = logging.getLogger(__name__) @@ -10,12 +21,78 @@ def _now_ts() -> int: return int(time.time()) -@app.celery.task(queue="trigger_refresh") +def _build_due_filter(now_ts: int): + """Build SQLAlchemy filter for due credential or subscription refresh.""" + credential_due: ColumnElement[bool] = and_( + TriggerSubscription.credential_expires_at != -1, + TriggerSubscription.credential_expires_at + <= now_ts + int(dify_config.TRIGGER_PROVIDER_CREDENTIAL_THRESHOLD_SECONDS), + ) + subscription_due: ColumnElement[bool] = and_( + TriggerSubscription.expires_at != -1, + TriggerSubscription.expires_at <= now_ts + int(dify_config.TRIGGER_PROVIDER_SUBSCRIPTION_THRESHOLD_SECONDS), + ) + return or_(credential_due, subscription_due) + + +def _lock_keys(rows: Sequence[tuple[str, str]]) -> list[str]: + """Generate redis lock keys for rows as (tenant_id, subscription_id).""" + return [f"trigger_provider_refresh_lock:{tenant_id}_{sid}" for tenant_id, sid in rows] + + +def _acquire_locks(keys: Iterable[str], ttl_seconds: int) -> list[bool]: + """Attempt to acquire locks in a single pipelined round-trip. + + Returns a list of booleans indicating which locks were acquired. + """ + pipe = redis_client.pipeline(transaction=False) + for key in keys: + pipe.set(key, b"1", ex=ttl_seconds, nx=True) + results = pipe.execute() + return [bool(r) for r in results] + + +@app.celery.task(queue="trigger_refresh_publisher") def trigger_provider_refresh() -> None: """ - Simple trigger provider refresh task. - - Scans due trigger subscriptions in small batches - - Refreshes OAuth credentials if needed - - Refreshes subscription metadata if needed + Scan due trigger subscriptions and enqueue refresh tasks with in-flight locks. """ - pass \ No newline at end of file + now = _now_ts() + + batch_size = int(dify_config.TRIGGER_PROVIDER_REFRESH_BATCH_SIZE) + lock_ttl = max( + 300, + int(dify_config.TRIGGER_PROVIDER_SUBSCRIPTION_THRESHOLD_SECONDS), + ) + + with Session(db.engine, expire_on_commit=False) as session: + filter: ColumnElement[bool] = _build_due_filter(now_ts=now) + total_due: int = session.scalar(statement=select(func.count()).where(filter)) or 0 + if total_due == 0: + return + + pages: int = math.ceil(total_due / batch_size) + for page in range(pages): + offset: int = page * batch_size + subscription_rows: Sequence[Row[tuple[str, str]]] = session.execute( + select(TriggerSubscription.tenant_id, TriggerSubscription.id) + .where(filter) + .order_by(TriggerSubscription.updated_at.asc()) + .offset(offset) + .limit(batch_size) + ).all() + if not subscription_rows: + continue + + subscriptions: list[tuple[str, str]] = [ + (str(tenant_id), str(subscription_id)) for tenant_id, subscription_id in subscription_rows + ] + lock_keys: list[str] = _lock_keys(subscriptions) + acquired: list[bool] = _acquire_locks(keys=lock_keys, ttl_seconds=lock_ttl) + + for (tenant_id, subscription_id), is_locked in zip(subscriptions, acquired): + if not is_locked: + continue + trigger_subscription_refresh.delay(tenant_id=tenant_id, subscription_id=subscription_id) + + logger.info("Trigger provider refresh queued for due subscriptions: %d", total_due) diff --git a/api/tasks/trigger_subscription_refresh_tasks.py b/api/tasks/trigger_subscription_refresh_tasks.py new file mode 100644 index 0000000000..e0d6d195db --- /dev/null +++ b/api/tasks/trigger_subscription_refresh_tasks.py @@ -0,0 +1,98 @@ +import logging +import time + +from celery import shared_task +from sqlalchemy.orm import Session + +from core.plugin.entities.plugin_daemon import CredentialType +from core.trigger.provider import PluginTriggerProviderController +from core.trigger.trigger_manager import TriggerManager +from core.trigger.utils.encryption import create_trigger_provider_encrypter_for_properties +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.provider_ids import TriggerProviderID +from models.trigger import TriggerSubscription +from services.trigger.trigger_provider_service import TriggerProviderService + +logger = logging.getLogger(__name__) + + +def _now_ts() -> int: + return int(time.time()) + + +@shared_task(queue="trigger_refresh_executor") +def trigger_subscription_refresh(tenant_id: str, subscription_id: str) -> None: + """Refresh a trigger subscription if needed, guarded by a Redis in-flight lock.""" + lock_key = f"trigger_provider_refresh_lock:{tenant_id}_{subscription_id}" + if not redis_client.get(lock_key): # Lock missing means job already timed out/handled + logger.debug("Refresh lock missing, skip: %s", lock_key) + return + + try: + now: int = _now_ts() + with Session(db.engine) as session: + subscription: TriggerSubscription | None = ( + session.query(TriggerSubscription) + .filter_by(tenant_id=tenant_id, id=subscription_id) + .first() + ) + + if not subscription: + logger.warning("Subscription not found: tenant=%s id=%s", tenant_id, subscription_id) + return + + # Refresh OAuth token if already expired + if ( + subscription.credential_expires_at != -1 + and int(subscription.credential_expires_at) <= now + and CredentialType.of(subscription.credential_type) == CredentialType.OAUTH2 + ): + try: + TriggerProviderService.refresh_oauth_token(tenant_id, subscription.id) + except Exception: + logger.exception("OAuth refresh failed for %s/%s", tenant_id, subscription.id) + # proceed to subscription refresh; provider may still accept late refresh + + # Only refresh subscription when it's actually expired + if subscription.expires_at != -1 and int(subscription.expires_at) <= now: + # Load decrypted subscription and properties + loaded = TriggerProviderService.get_subscription_by_id( + tenant_id=tenant_id, subscription_id=subscription.id + ) + if not loaded: + logger.warning("Subscription vanished during refresh: tenant=%s id=%s", tenant_id, subscription_id) + return + + controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider( + tenant_id, TriggerProviderID(loaded.provider_id) + ) + refreshed = controller.refresh_trigger( + subscription=loaded.to_entity(), + credentials=loaded.credentials, + credential_type=CredentialType.of(loaded.credential_type), + ) + + # Persist refreshed properties/expires_at with encryption + properties_encrypter, properties_cache = create_trigger_provider_encrypter_for_properties( + tenant_id=tenant_id, + controller=controller, + subscription=loaded, + ) + + db_sub: TriggerSubscription | None = ( + session.query(TriggerSubscription) + .filter_by(tenant_id=tenant_id, id=subscription.id) + .first() + ) + if db_sub is not None: + db_sub.properties = dict(properties_encrypter.encrypt(dict(refreshed.properties))) + db_sub.expires_at = int(refreshed.expires_at) + session.commit() + properties_cache.delete() + finally: + try: + redis_client.delete(lock_key) + except Exception: + # Best-effort lock cleanup + pass