This commit is contained in:
Joel 2025-10-21 14:44:24 +08:00
commit 25e4203cb1
5 changed files with 305 additions and 7 deletions

View File

@ -1039,7 +1039,7 @@ class CeleryScheduleTasksConfig(BaseSettings):
)
TRIGGER_PROVIDER_SUBSCRIPTION_THRESHOLD_SECONDS: int = Field(
description="Proactive subscription refresh threshold in seconds",
default=300,
default=60 * 60,
)

View File

@ -0,0 +1,12 @@
from collections.abc import Sequence
from itertools import starmap
def build_trigger_refresh_lock_key(tenant_id: str, subscription_id: str) -> str:
"""Build the Redis lock key for trigger subscription refresh in-flight protection."""
return f"trigger_provider_refresh_lock:{tenant_id}_{subscription_id}"
def build_trigger_refresh_lock_keys(pairs: Sequence[tuple[str, str]]) -> list[str]:
"""Build Redis lock keys for a sequence of (tenant_id, subscription_id) pairs."""
return list(starmap(build_trigger_refresh_lock_key, pairs))

View File

@ -1,7 +1,19 @@
import logging
import math
import time
from collections.abc import Iterable, Sequence
from sqlalchemy import ColumnElement, and_, func, or_, select
from sqlalchemy.engine.row import Row
from sqlalchemy.orm import Session
import app
from configs import dify_config
from core.trigger.utils.locks import build_trigger_refresh_lock_keys
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.trigger import TriggerSubscription
from tasks.trigger_subscription_refresh_tasks import trigger_subscription_refresh
logger = logging.getLogger(__name__)
@ -10,12 +22,83 @@ def _now_ts() -> int:
return int(time.time())
@app.celery.task(queue="trigger_refresh")
def _build_due_filter(now_ts: int):
"""Build SQLAlchemy filter for due credential or subscription refresh."""
credential_due: ColumnElement[bool] = and_(
TriggerSubscription.credential_expires_at != -1,
TriggerSubscription.credential_expires_at
<= now_ts + int(dify_config.TRIGGER_PROVIDER_CREDENTIAL_THRESHOLD_SECONDS),
)
subscription_due: ColumnElement[bool] = and_(
TriggerSubscription.expires_at != -1,
TriggerSubscription.expires_at <= now_ts + int(dify_config.TRIGGER_PROVIDER_SUBSCRIPTION_THRESHOLD_SECONDS),
)
return or_(credential_due, subscription_due)
def _acquire_locks(keys: Iterable[str], ttl_seconds: int) -> list[bool]:
"""Attempt to acquire locks in a single pipelined round-trip.
Returns a list of booleans indicating which locks were acquired.
"""
pipe = redis_client.pipeline(transaction=False)
for key in keys:
pipe.set(key, b"1", ex=ttl_seconds, nx=True)
results = pipe.execute()
return [bool(r) for r in results]
@app.celery.task(queue="trigger_refresh_publisher")
def trigger_provider_refresh() -> None:
"""
Simple trigger provider refresh task.
- Scans due trigger subscriptions in small batches
- Refreshes OAuth credentials if needed
- Refreshes subscription metadata if needed
Scan due trigger subscriptions and enqueue refresh tasks with in-flight locks.
"""
pass
now: int = _now_ts()
batch_size: int = int(dify_config.TRIGGER_PROVIDER_REFRESH_BATCH_SIZE)
lock_ttl: int = max(300, int(dify_config.TRIGGER_PROVIDER_SUBSCRIPTION_THRESHOLD_SECONDS))
with Session(db.engine, expire_on_commit=False) as session:
filter: ColumnElement[bool] = _build_due_filter(now_ts=now)
total_due: int = int(session.scalar(statement=select(func.count()).where(filter)) or 0)
logger.info("Trigger refresh scan start: due=%d", total_due)
if total_due == 0:
return
pages: int = math.ceil(total_due / batch_size)
for page in range(pages):
offset: int = page * batch_size
subscription_rows: Sequence[Row[tuple[str, str]]] = session.execute(
select(TriggerSubscription.tenant_id, TriggerSubscription.id)
.where(filter)
.order_by(TriggerSubscription.updated_at.asc())
.offset(offset)
.limit(batch_size)
).all()
if not subscription_rows:
logger.debug("Trigger refresh page %d/%d empty", page + 1, pages)
continue
subscriptions: list[tuple[str, str]] = [
(str(tenant_id), str(subscription_id)) for tenant_id, subscription_id in subscription_rows
]
lock_keys: list[str] = build_trigger_refresh_lock_keys(subscriptions)
acquired: list[bool] = _acquire_locks(keys=lock_keys, ttl_seconds=lock_ttl)
enqueued: int = 0
for (tenant_id, subscription_id), is_locked in zip(subscriptions, acquired):
if not is_locked:
continue
trigger_subscription_refresh.delay(tenant_id=tenant_id, subscription_id=subscription_id)
enqueued += 1
logger.info(
"Trigger refresh page %d/%d: scanned=%d locks_acquired=%d enqueued=%d",
page + 1,
pages,
len(subscriptions),
sum(1 for x in acquired if x),
enqueued,
)
logger.info("Trigger refresh scan done: due=%d", total_due)

View File

@ -1,5 +1,6 @@
import json
import logging
import time as _time
import uuid
from collections.abc import Mapping
from typing import Any, Optional
@ -18,6 +19,7 @@ from core.trigger.entities.api_entities import (
TriggerProviderApiEntity,
TriggerProviderSubscriptionApiEntity,
)
from core.trigger.entities.entities import Subscription as TriggerSubscriptionEntity
from core.trigger.provider import PluginTriggerProviderController
from core.trigger.trigger_manager import TriggerManager
from core.trigger.utils.encryption import (
@ -25,6 +27,7 @@ from core.trigger.utils.encryption import (
create_trigger_provider_encrypter_for_subscription,
delete_cache_for_subscription,
)
from core.trigger.utils.endpoint import generate_plugin_trigger_endpoint_url
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.provider_ids import TriggerProviderID
@ -347,6 +350,91 @@ class TriggerProviderService:
"expires_at": refreshed_credentials.expires_at,
}
@classmethod
def refresh_subscription(
cls,
tenant_id: str,
subscription_id: str,
now: int | None = None,
) -> Mapping[str, Any]:
"""
Refresh trigger subscription if expired.
Args:
tenant_id: Tenant ID
subscription_id: Subscription instance ID
now: Current timestamp, defaults to `int(time.time())`
Returns:
Mapping with keys: `result` ("success"|"skipped") and `expires_at` (new or existing value)
"""
now_ts: int = int(now if now is not None else _time.time())
with Session(db.engine) as session:
subscription: TriggerSubscription | None = (
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
)
if subscription is None:
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
if subscription.expires_at == -1 or int(subscription.expires_at) > now_ts:
logger.debug(
"Subscription not due for refresh: tenant=%s id=%s expires_at=%s now=%s",
tenant_id,
subscription_id,
subscription.expires_at,
now_ts,
)
return {"result": "skipped", "expires_at": int(subscription.expires_at)}
provider_id = TriggerProviderID(subscription.provider_id)
controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
tenant_id=tenant_id, provider_id=provider_id
)
# Decrypt credentials and properties for runtime
credential_encrypter, _ = create_trigger_provider_encrypter_for_subscription(
tenant_id=tenant_id,
controller=controller,
subscription=subscription,
)
properties_encrypter, properties_cache = create_trigger_provider_encrypter_for_properties(
tenant_id=tenant_id,
controller=controller,
subscription=subscription,
)
decrypted_credentials = credential_encrypter.decrypt(subscription.credentials)
decrypted_properties = properties_encrypter.decrypt(subscription.properties)
sub_entity: TriggerSubscriptionEntity = TriggerSubscriptionEntity(
expires_at=int(subscription.expires_at),
endpoint=generate_plugin_trigger_endpoint_url(subscription.endpoint_id),
parameters=subscription.parameters,
properties=decrypted_properties,
)
refreshed: TriggerSubscriptionEntity = controller.refresh_trigger(
subscription=sub_entity,
credentials=decrypted_credentials,
credential_type=CredentialType.of(subscription.credential_type),
)
# Persist refreshed properties and expires_at
subscription.properties = dict(properties_encrypter.encrypt(dict(refreshed.properties)))
subscription.expires_at = int(refreshed.expires_at)
session.commit()
properties_cache.delete()
logger.info(
"Subscription refreshed (service): tenant=%s id=%s new_expires_at=%s",
tenant_id,
subscription_id,
subscription.expires_at,
)
return {"result": "success", "expires_at": int(refreshed.expires_at)}
@classmethod
def get_oauth_client(cls, tenant_id: str, provider_id: TriggerProviderID) -> Optional[Mapping[str, Any]]:
"""

View File

@ -0,0 +1,115 @@
import logging
import time
from collections.abc import Mapping
from typing import Any
from celery import shared_task
from sqlalchemy.orm import Session
from core.plugin.entities.plugin_daemon import CredentialType
from core.trigger.utils.locks import build_trigger_refresh_lock_key
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.trigger import TriggerSubscription
from services.trigger.trigger_provider_service import TriggerProviderService
logger = logging.getLogger(__name__)
def _now_ts() -> int:
return int(time.time())
def _load_subscription(session: Session, tenant_id: str, subscription_id: str) -> TriggerSubscription | None:
return session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
def _refresh_oauth_if_expired(tenant_id: str, subscription: TriggerSubscription, now: int) -> None:
if (
subscription.credential_expires_at != -1
and int(subscription.credential_expires_at) <= now
and CredentialType.of(subscription.credential_type) == CredentialType.OAUTH2
):
logger.info(
"Refreshing OAuth token: tenant=%s subscription_id=%s expires_at=%s now=%s",
tenant_id,
subscription.id,
subscription.credential_expires_at,
now,
)
try:
result: Mapping[str, Any] = TriggerProviderService.refresh_oauth_token(
tenant_id=tenant_id, subscription_id=subscription.id
)
logger.info(
"OAuth token refreshed: tenant=%s subscription_id=%s result=%s", tenant_id, subscription.id, result
)
except Exception:
logger.exception("OAuth refresh failed: tenant=%s subscription_id=%s", tenant_id, subscription.id)
def _refresh_subscription_if_expired(
tenant_id: str,
subscription: TriggerSubscription,
now: int,
) -> None:
if subscription.expires_at == -1 or int(subscription.expires_at) > now:
logger.debug(
"Subscription not due: tenant=%s subscription_id=%s expires_at=%s now=%s",
tenant_id,
subscription.id,
subscription.expires_at,
now,
)
return
try:
result: Mapping[str, Any] = TriggerProviderService.refresh_subscription(
tenant_id=tenant_id, subscription_id=subscription.id, now=now
)
logger.info(
"Subscription refreshed: tenant=%s subscription_id=%s result=%s",
tenant_id,
subscription.id,
result.get("result"),
)
except Exception:
logger.exception("Subscription refresh failed: tenant=%s id=%s", tenant_id, subscription.id)
@shared_task(queue="trigger_refresh_executor")
def trigger_subscription_refresh(tenant_id: str, subscription_id: str) -> None:
"""Refresh a trigger subscription if needed, guarded by a Redis in-flight lock."""
lock_key: str = build_trigger_refresh_lock_key(tenant_id, subscription_id)
if not redis_client.get(lock_key):
logger.debug("Refresh lock missing, skip: %s", lock_key)
return
logger.info("Begin subscription refresh: tenant=%s id=%s", tenant_id, subscription_id)
try:
now: int = _now_ts()
with Session(db.engine) as session:
subscription: TriggerSubscription | None = _load_subscription(session, tenant_id, subscription_id)
if not subscription:
logger.warning("Subscription not found: tenant=%s id=%s", tenant_id, subscription_id)
return
logger.debug(
"Loaded subscription: tenant=%s id=%s cred_exp=%s sub_exp=%s now=%s",
tenant_id,
subscription.id,
subscription.credential_expires_at,
subscription.expires_at,
now,
)
_refresh_oauth_if_expired(tenant_id=tenant_id, subscription=subscription, now=now)
_refresh_subscription_if_expired(tenant_id=tenant_id, subscription=subscription, now=now)
finally:
try:
redis_client.delete(lock_key)
logger.debug("Lock released: %s", lock_key)
except Exception:
# Best-effort lock cleanup
logger.warning("Failed to release lock: %s", lock_key, exc_info=True)