From f4517d667bb5800d74ba0a4f7fb161760941d0e6 Mon Sep 17 00:00:00 2001 From: Harry Date: Tue, 21 Oct 2025 11:37:11 +0800 Subject: [PATCH 1/2] feat(trigger): enhance trigger provider refresh task with locking mechanism and due filter logic --- api/configs/feature/__init__.py | 2 +- api/schedule/trigger_provider_refresh_task.py | 89 +++++++++++++++-- .../trigger_subscription_refresh_tasks.py | 98 +++++++++++++++++++ 3 files changed, 182 insertions(+), 7 deletions(-) create mode 100644 api/tasks/trigger_subscription_refresh_tasks.py diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 5e97e73e35..803d4404c7 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -1039,7 +1039,7 @@ class CeleryScheduleTasksConfig(BaseSettings): ) TRIGGER_PROVIDER_SUBSCRIPTION_THRESHOLD_SECONDS: int = Field( description="Proactive subscription refresh threshold in seconds", - default=300, + default=60 * 60, ) diff --git a/api/schedule/trigger_provider_refresh_task.py b/api/schedule/trigger_provider_refresh_task.py index f91a849a71..e56ad7da7a 100644 --- a/api/schedule/trigger_provider_refresh_task.py +++ b/api/schedule/trigger_provider_refresh_task.py @@ -1,7 +1,18 @@ import logging +import math import time +from collections.abc import Iterable, Sequence + +from sqlalchemy import ColumnElement, and_, func, or_, select +from sqlalchemy.engine.row import Row +from sqlalchemy.orm import Session import app +from configs import dify_config +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.trigger import TriggerSubscription +from tasks.trigger_subscription_refresh_tasks import trigger_subscription_refresh logger = logging.getLogger(__name__) @@ -10,12 +21,78 @@ def _now_ts() -> int: return int(time.time()) -@app.celery.task(queue="trigger_refresh") +def _build_due_filter(now_ts: int): + """Build SQLAlchemy filter for due credential or subscription refresh.""" + credential_due: ColumnElement[bool] = and_( + TriggerSubscription.credential_expires_at != -1, + TriggerSubscription.credential_expires_at + <= now_ts + int(dify_config.TRIGGER_PROVIDER_CREDENTIAL_THRESHOLD_SECONDS), + ) + subscription_due: ColumnElement[bool] = and_( + TriggerSubscription.expires_at != -1, + TriggerSubscription.expires_at <= now_ts + int(dify_config.TRIGGER_PROVIDER_SUBSCRIPTION_THRESHOLD_SECONDS), + ) + return or_(credential_due, subscription_due) + + +def _lock_keys(rows: Sequence[tuple[str, str]]) -> list[str]: + """Generate redis lock keys for rows as (tenant_id, subscription_id).""" + return [f"trigger_provider_refresh_lock:{tenant_id}_{sid}" for tenant_id, sid in rows] + + +def _acquire_locks(keys: Iterable[str], ttl_seconds: int) -> list[bool]: + """Attempt to acquire locks in a single pipelined round-trip. + + Returns a list of booleans indicating which locks were acquired. + """ + pipe = redis_client.pipeline(transaction=False) + for key in keys: + pipe.set(key, b"1", ex=ttl_seconds, nx=True) + results = pipe.execute() + return [bool(r) for r in results] + + +@app.celery.task(queue="trigger_refresh_publisher") def trigger_provider_refresh() -> None: """ - Simple trigger provider refresh task. - - Scans due trigger subscriptions in small batches - - Refreshes OAuth credentials if needed - - Refreshes subscription metadata if needed + Scan due trigger subscriptions and enqueue refresh tasks with in-flight locks. """ - pass \ No newline at end of file + now = _now_ts() + + batch_size = int(dify_config.TRIGGER_PROVIDER_REFRESH_BATCH_SIZE) + lock_ttl = max( + 300, + int(dify_config.TRIGGER_PROVIDER_SUBSCRIPTION_THRESHOLD_SECONDS), + ) + + with Session(db.engine, expire_on_commit=False) as session: + filter: ColumnElement[bool] = _build_due_filter(now_ts=now) + total_due: int = session.scalar(statement=select(func.count()).where(filter)) or 0 + if total_due == 0: + return + + pages: int = math.ceil(total_due / batch_size) + for page in range(pages): + offset: int = page * batch_size + subscription_rows: Sequence[Row[tuple[str, str]]] = session.execute( + select(TriggerSubscription.tenant_id, TriggerSubscription.id) + .where(filter) + .order_by(TriggerSubscription.updated_at.asc()) + .offset(offset) + .limit(batch_size) + ).all() + if not subscription_rows: + continue + + subscriptions: list[tuple[str, str]] = [ + (str(tenant_id), str(subscription_id)) for tenant_id, subscription_id in subscription_rows + ] + lock_keys: list[str] = _lock_keys(subscriptions) + acquired: list[bool] = _acquire_locks(keys=lock_keys, ttl_seconds=lock_ttl) + + for (tenant_id, subscription_id), is_locked in zip(subscriptions, acquired): + if not is_locked: + continue + trigger_subscription_refresh.delay(tenant_id=tenant_id, subscription_id=subscription_id) + + logger.info("Trigger provider refresh queued for due subscriptions: %d", total_due) diff --git a/api/tasks/trigger_subscription_refresh_tasks.py b/api/tasks/trigger_subscription_refresh_tasks.py new file mode 100644 index 0000000000..e0d6d195db --- /dev/null +++ b/api/tasks/trigger_subscription_refresh_tasks.py @@ -0,0 +1,98 @@ +import logging +import time + +from celery import shared_task +from sqlalchemy.orm import Session + +from core.plugin.entities.plugin_daemon import CredentialType +from core.trigger.provider import PluginTriggerProviderController +from core.trigger.trigger_manager import TriggerManager +from core.trigger.utils.encryption import create_trigger_provider_encrypter_for_properties +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.provider_ids import TriggerProviderID +from models.trigger import TriggerSubscription +from services.trigger.trigger_provider_service import TriggerProviderService + +logger = logging.getLogger(__name__) + + +def _now_ts() -> int: + return int(time.time()) + + +@shared_task(queue="trigger_refresh_executor") +def trigger_subscription_refresh(tenant_id: str, subscription_id: str) -> None: + """Refresh a trigger subscription if needed, guarded by a Redis in-flight lock.""" + lock_key = f"trigger_provider_refresh_lock:{tenant_id}_{subscription_id}" + if not redis_client.get(lock_key): # Lock missing means job already timed out/handled + logger.debug("Refresh lock missing, skip: %s", lock_key) + return + + try: + now: int = _now_ts() + with Session(db.engine) as session: + subscription: TriggerSubscription | None = ( + session.query(TriggerSubscription) + .filter_by(tenant_id=tenant_id, id=subscription_id) + .first() + ) + + if not subscription: + logger.warning("Subscription not found: tenant=%s id=%s", tenant_id, subscription_id) + return + + # Refresh OAuth token if already expired + if ( + subscription.credential_expires_at != -1 + and int(subscription.credential_expires_at) <= now + and CredentialType.of(subscription.credential_type) == CredentialType.OAUTH2 + ): + try: + TriggerProviderService.refresh_oauth_token(tenant_id, subscription.id) + except Exception: + logger.exception("OAuth refresh failed for %s/%s", tenant_id, subscription.id) + # proceed to subscription refresh; provider may still accept late refresh + + # Only refresh subscription when it's actually expired + if subscription.expires_at != -1 and int(subscription.expires_at) <= now: + # Load decrypted subscription and properties + loaded = TriggerProviderService.get_subscription_by_id( + tenant_id=tenant_id, subscription_id=subscription.id + ) + if not loaded: + logger.warning("Subscription vanished during refresh: tenant=%s id=%s", tenant_id, subscription_id) + return + + controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider( + tenant_id, TriggerProviderID(loaded.provider_id) + ) + refreshed = controller.refresh_trigger( + subscription=loaded.to_entity(), + credentials=loaded.credentials, + credential_type=CredentialType.of(loaded.credential_type), + ) + + # Persist refreshed properties/expires_at with encryption + properties_encrypter, properties_cache = create_trigger_provider_encrypter_for_properties( + tenant_id=tenant_id, + controller=controller, + subscription=loaded, + ) + + db_sub: TriggerSubscription | None = ( + session.query(TriggerSubscription) + .filter_by(tenant_id=tenant_id, id=subscription.id) + .first() + ) + if db_sub is not None: + db_sub.properties = dict(properties_encrypter.encrypt(dict(refreshed.properties))) + db_sub.expires_at = int(refreshed.expires_at) + session.commit() + properties_cache.delete() + finally: + try: + redis_client.delete(lock_key) + except Exception: + # Best-effort lock cleanup + pass From 8ac25c29ee401f24ff3e5b04c15a1f02220ebbc8 Mon Sep 17 00:00:00 2001 From: Harry Date: Tue, 21 Oct 2025 12:09:12 +0800 Subject: [PATCH 2/2] feat(trigger): implement subscription refresh logic with enhanced error handling and logging --- api/core/trigger/utils/locks.py | 12 ++ api/schedule/trigger_provider_refresh_task.py | 34 +++-- .../trigger/trigger_provider_service.py | 88 +++++++++++ .../trigger_subscription_refresh_tasks.py | 143 ++++++++++-------- 4 files changed, 200 insertions(+), 77 deletions(-) create mode 100644 api/core/trigger/utils/locks.py diff --git a/api/core/trigger/utils/locks.py b/api/core/trigger/utils/locks.py new file mode 100644 index 0000000000..46833396e0 --- /dev/null +++ b/api/core/trigger/utils/locks.py @@ -0,0 +1,12 @@ +from collections.abc import Sequence +from itertools import starmap + + +def build_trigger_refresh_lock_key(tenant_id: str, subscription_id: str) -> str: + """Build the Redis lock key for trigger subscription refresh in-flight protection.""" + return f"trigger_provider_refresh_lock:{tenant_id}_{subscription_id}" + + +def build_trigger_refresh_lock_keys(pairs: Sequence[tuple[str, str]]) -> list[str]: + """Build Redis lock keys for a sequence of (tenant_id, subscription_id) pairs.""" + return list(starmap(build_trigger_refresh_lock_key, pairs)) diff --git a/api/schedule/trigger_provider_refresh_task.py b/api/schedule/trigger_provider_refresh_task.py index e56ad7da7a..3b3e478793 100644 --- a/api/schedule/trigger_provider_refresh_task.py +++ b/api/schedule/trigger_provider_refresh_task.py @@ -9,6 +9,7 @@ from sqlalchemy.orm import Session import app from configs import dify_config +from core.trigger.utils.locks import build_trigger_refresh_lock_keys from extensions.ext_database import db from extensions.ext_redis import redis_client from models.trigger import TriggerSubscription @@ -35,11 +36,6 @@ def _build_due_filter(now_ts: int): return or_(credential_due, subscription_due) -def _lock_keys(rows: Sequence[tuple[str, str]]) -> list[str]: - """Generate redis lock keys for rows as (tenant_id, subscription_id).""" - return [f"trigger_provider_refresh_lock:{tenant_id}_{sid}" for tenant_id, sid in rows] - - def _acquire_locks(keys: Iterable[str], ttl_seconds: int) -> list[bool]: """Attempt to acquire locks in a single pipelined round-trip. @@ -57,17 +53,15 @@ def trigger_provider_refresh() -> None: """ Scan due trigger subscriptions and enqueue refresh tasks with in-flight locks. """ - now = _now_ts() + now: int = _now_ts() - batch_size = int(dify_config.TRIGGER_PROVIDER_REFRESH_BATCH_SIZE) - lock_ttl = max( - 300, - int(dify_config.TRIGGER_PROVIDER_SUBSCRIPTION_THRESHOLD_SECONDS), - ) + batch_size: int = int(dify_config.TRIGGER_PROVIDER_REFRESH_BATCH_SIZE) + lock_ttl: int = max(300, int(dify_config.TRIGGER_PROVIDER_SUBSCRIPTION_THRESHOLD_SECONDS)) with Session(db.engine, expire_on_commit=False) as session: filter: ColumnElement[bool] = _build_due_filter(now_ts=now) - total_due: int = session.scalar(statement=select(func.count()).where(filter)) or 0 + total_due: int = int(session.scalar(statement=select(func.count()).where(filter)) or 0) + logger.info("Trigger refresh scan start: due=%d", total_due) if total_due == 0: return @@ -82,17 +76,29 @@ def trigger_provider_refresh() -> None: .limit(batch_size) ).all() if not subscription_rows: + logger.debug("Trigger refresh page %d/%d empty", page + 1, pages) continue subscriptions: list[tuple[str, str]] = [ (str(tenant_id), str(subscription_id)) for tenant_id, subscription_id in subscription_rows ] - lock_keys: list[str] = _lock_keys(subscriptions) + lock_keys: list[str] = build_trigger_refresh_lock_keys(subscriptions) acquired: list[bool] = _acquire_locks(keys=lock_keys, ttl_seconds=lock_ttl) + enqueued: int = 0 for (tenant_id, subscription_id), is_locked in zip(subscriptions, acquired): if not is_locked: continue trigger_subscription_refresh.delay(tenant_id=tenant_id, subscription_id=subscription_id) + enqueued += 1 - logger.info("Trigger provider refresh queued for due subscriptions: %d", total_due) + logger.info( + "Trigger refresh page %d/%d: scanned=%d locks_acquired=%d enqueued=%d", + page + 1, + pages, + len(subscriptions), + sum(1 for x in acquired if x), + enqueued, + ) + + logger.info("Trigger refresh scan done: due=%d", total_due) diff --git a/api/services/trigger/trigger_provider_service.py b/api/services/trigger/trigger_provider_service.py index 4cc19b485a..1c474b2c48 100644 --- a/api/services/trigger/trigger_provider_service.py +++ b/api/services/trigger/trigger_provider_service.py @@ -1,5 +1,6 @@ import json import logging +import time as _time import uuid from collections.abc import Mapping from typing import Any, Optional @@ -18,6 +19,7 @@ from core.trigger.entities.api_entities import ( TriggerProviderApiEntity, TriggerProviderSubscriptionApiEntity, ) +from core.trigger.entities.entities import Subscription as TriggerSubscriptionEntity from core.trigger.provider import PluginTriggerProviderController from core.trigger.trigger_manager import TriggerManager from core.trigger.utils.encryption import ( @@ -25,6 +27,7 @@ from core.trigger.utils.encryption import ( create_trigger_provider_encrypter_for_subscription, delete_cache_for_subscription, ) +from core.trigger.utils.endpoint import generate_plugin_trigger_endpoint_url from extensions.ext_database import db from extensions.ext_redis import redis_client from models.provider_ids import TriggerProviderID @@ -347,6 +350,91 @@ class TriggerProviderService: "expires_at": refreshed_credentials.expires_at, } + @classmethod + def refresh_subscription( + cls, + tenant_id: str, + subscription_id: str, + now: int | None = None, + ) -> Mapping[str, Any]: + """ + Refresh trigger subscription if expired. + + Args: + tenant_id: Tenant ID + subscription_id: Subscription instance ID + now: Current timestamp, defaults to `int(time.time())` + + Returns: + Mapping with keys: `result` ("success"|"skipped") and `expires_at` (new or existing value) + """ + now_ts: int = int(now if now is not None else _time.time()) + + with Session(db.engine) as session: + subscription: TriggerSubscription | None = ( + session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first() + ) + if subscription is None: + raise ValueError(f"Trigger provider subscription {subscription_id} not found") + + if subscription.expires_at == -1 or int(subscription.expires_at) > now_ts: + logger.debug( + "Subscription not due for refresh: tenant=%s id=%s expires_at=%s now=%s", + tenant_id, + subscription_id, + subscription.expires_at, + now_ts, + ) + return {"result": "skipped", "expires_at": int(subscription.expires_at)} + + provider_id = TriggerProviderID(subscription.provider_id) + controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider( + tenant_id=tenant_id, provider_id=provider_id + ) + + # Decrypt credentials and properties for runtime + credential_encrypter, _ = create_trigger_provider_encrypter_for_subscription( + tenant_id=tenant_id, + controller=controller, + subscription=subscription, + ) + properties_encrypter, properties_cache = create_trigger_provider_encrypter_for_properties( + tenant_id=tenant_id, + controller=controller, + subscription=subscription, + ) + + decrypted_credentials = credential_encrypter.decrypt(subscription.credentials) + decrypted_properties = properties_encrypter.decrypt(subscription.properties) + + sub_entity: TriggerSubscriptionEntity = TriggerSubscriptionEntity( + expires_at=int(subscription.expires_at), + endpoint=generate_plugin_trigger_endpoint_url(subscription.endpoint_id), + parameters=subscription.parameters, + properties=decrypted_properties, + ) + + refreshed: TriggerSubscriptionEntity = controller.refresh_trigger( + subscription=sub_entity, + credentials=decrypted_credentials, + credential_type=CredentialType.of(subscription.credential_type), + ) + + # Persist refreshed properties and expires_at + subscription.properties = dict(properties_encrypter.encrypt(dict(refreshed.properties))) + subscription.expires_at = int(refreshed.expires_at) + session.commit() + properties_cache.delete() + + logger.info( + "Subscription refreshed (service): tenant=%s id=%s new_expires_at=%s", + tenant_id, + subscription_id, + subscription.expires_at, + ) + + return {"result": "success", "expires_at": int(refreshed.expires_at)} + @classmethod def get_oauth_client(cls, tenant_id: str, provider_id: TriggerProviderID) -> Optional[Mapping[str, Any]]: """ diff --git a/api/tasks/trigger_subscription_refresh_tasks.py b/api/tasks/trigger_subscription_refresh_tasks.py index e0d6d195db..11324df881 100644 --- a/api/tasks/trigger_subscription_refresh_tasks.py +++ b/api/tasks/trigger_subscription_refresh_tasks.py @@ -1,16 +1,15 @@ import logging import time +from collections.abc import Mapping +from typing import Any from celery import shared_task from sqlalchemy.orm import Session from core.plugin.entities.plugin_daemon import CredentialType -from core.trigger.provider import PluginTriggerProviderController -from core.trigger.trigger_manager import TriggerManager -from core.trigger.utils.encryption import create_trigger_provider_encrypter_for_properties +from core.trigger.utils.locks import build_trigger_refresh_lock_key from extensions.ext_database import db from extensions.ext_redis import redis_client -from models.provider_ids import TriggerProviderID from models.trigger import TriggerSubscription from services.trigger.trigger_provider_service import TriggerProviderService @@ -21,78 +20,96 @@ def _now_ts() -> int: return int(time.time()) -@shared_task(queue="trigger_refresh_executor") -def trigger_subscription_refresh(tenant_id: str, subscription_id: str) -> None: - """Refresh a trigger subscription if needed, guarded by a Redis in-flight lock.""" - lock_key = f"trigger_provider_refresh_lock:{tenant_id}_{subscription_id}" - if not redis_client.get(lock_key): # Lock missing means job already timed out/handled - logger.debug("Refresh lock missing, skip: %s", lock_key) +def _load_subscription(session: Session, tenant_id: str, subscription_id: str) -> TriggerSubscription | None: + return session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first() + + +def _refresh_oauth_if_expired(tenant_id: str, subscription: TriggerSubscription, now: int) -> None: + if ( + subscription.credential_expires_at != -1 + and int(subscription.credential_expires_at) <= now + and CredentialType.of(subscription.credential_type) == CredentialType.OAUTH2 + ): + logger.info( + "Refreshing OAuth token: tenant=%s subscription_id=%s expires_at=%s now=%s", + tenant_id, + subscription.id, + subscription.credential_expires_at, + now, + ) + try: + result: Mapping[str, Any] = TriggerProviderService.refresh_oauth_token( + tenant_id=tenant_id, subscription_id=subscription.id + ) + logger.info( + "OAuth token refreshed: tenant=%s subscription_id=%s result=%s", tenant_id, subscription.id, result + ) + except Exception: + logger.exception("OAuth refresh failed: tenant=%s subscription_id=%s", tenant_id, subscription.id) + + +def _refresh_subscription_if_expired( + tenant_id: str, + subscription: TriggerSubscription, + now: int, +) -> None: + if subscription.expires_at == -1 or int(subscription.expires_at) > now: + logger.debug( + "Subscription not due: tenant=%s subscription_id=%s expires_at=%s now=%s", + tenant_id, + subscription.id, + subscription.expires_at, + now, + ) return + try: + result: Mapping[str, Any] = TriggerProviderService.refresh_subscription( + tenant_id=tenant_id, subscription_id=subscription.id, now=now + ) + logger.info( + "Subscription refreshed: tenant=%s subscription_id=%s result=%s", + tenant_id, + subscription.id, + result.get("result"), + ) + except Exception: + logger.exception("Subscription refresh failed: tenant=%s id=%s", tenant_id, subscription.id) + + +@shared_task(queue="trigger_refresh_executor") +def trigger_subscription_refresh(tenant_id: str, subscription_id: str) -> None: + """Refresh a trigger subscription if needed, guarded by a Redis in-flight lock.""" + lock_key: str = build_trigger_refresh_lock_key(tenant_id, subscription_id) + if not redis_client.get(lock_key): + logger.debug("Refresh lock missing, skip: %s", lock_key) + return + + logger.info("Begin subscription refresh: tenant=%s id=%s", tenant_id, subscription_id) try: now: int = _now_ts() with Session(db.engine) as session: - subscription: TriggerSubscription | None = ( - session.query(TriggerSubscription) - .filter_by(tenant_id=tenant_id, id=subscription_id) - .first() - ) + subscription: TriggerSubscription | None = _load_subscription(session, tenant_id, subscription_id) if not subscription: logger.warning("Subscription not found: tenant=%s id=%s", tenant_id, subscription_id) return - # Refresh OAuth token if already expired - if ( - subscription.credential_expires_at != -1 - and int(subscription.credential_expires_at) <= now - and CredentialType.of(subscription.credential_type) == CredentialType.OAUTH2 - ): - try: - TriggerProviderService.refresh_oauth_token(tenant_id, subscription.id) - except Exception: - logger.exception("OAuth refresh failed for %s/%s", tenant_id, subscription.id) - # proceed to subscription refresh; provider may still accept late refresh + logger.debug( + "Loaded subscription: tenant=%s id=%s cred_exp=%s sub_exp=%s now=%s", + tenant_id, + subscription.id, + subscription.credential_expires_at, + subscription.expires_at, + now, + ) - # Only refresh subscription when it's actually expired - if subscription.expires_at != -1 and int(subscription.expires_at) <= now: - # Load decrypted subscription and properties - loaded = TriggerProviderService.get_subscription_by_id( - tenant_id=tenant_id, subscription_id=subscription.id - ) - if not loaded: - logger.warning("Subscription vanished during refresh: tenant=%s id=%s", tenant_id, subscription_id) - return - - controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider( - tenant_id, TriggerProviderID(loaded.provider_id) - ) - refreshed = controller.refresh_trigger( - subscription=loaded.to_entity(), - credentials=loaded.credentials, - credential_type=CredentialType.of(loaded.credential_type), - ) - - # Persist refreshed properties/expires_at with encryption - properties_encrypter, properties_cache = create_trigger_provider_encrypter_for_properties( - tenant_id=tenant_id, - controller=controller, - subscription=loaded, - ) - - db_sub: TriggerSubscription | None = ( - session.query(TriggerSubscription) - .filter_by(tenant_id=tenant_id, id=subscription.id) - .first() - ) - if db_sub is not None: - db_sub.properties = dict(properties_encrypter.encrypt(dict(refreshed.properties))) - db_sub.expires_at = int(refreshed.expires_at) - session.commit() - properties_cache.delete() + _refresh_oauth_if_expired(tenant_id=tenant_id, subscription=subscription, now=now) + _refresh_subscription_if_expired(tenant_id=tenant_id, subscription=subscription, now=now) finally: try: redis_client.delete(lock_key) + logger.debug("Lock released: %s", lock_key) except Exception: # Best-effort lock cleanup - pass + logger.warning("Failed to release lock: %s", lock_key, exc_info=True)