mirror of https://github.com/langgenius/dify.git
feat(trigger): implement subscription refresh logic with enhanced error handling and logging
This commit is contained in:
parent
f4517d667b
commit
8ac25c29ee
|
|
@ -0,0 +1,12 @@
|
|||
from collections.abc import Sequence
|
||||
from itertools import starmap
|
||||
|
||||
|
||||
def build_trigger_refresh_lock_key(tenant_id: str, subscription_id: str) -> str:
|
||||
"""Build the Redis lock key for trigger subscription refresh in-flight protection."""
|
||||
return f"trigger_provider_refresh_lock:{tenant_id}_{subscription_id}"
|
||||
|
||||
|
||||
def build_trigger_refresh_lock_keys(pairs: Sequence[tuple[str, str]]) -> list[str]:
|
||||
"""Build Redis lock keys for a sequence of (tenant_id, subscription_id) pairs."""
|
||||
return list(starmap(build_trigger_refresh_lock_key, pairs))
|
||||
|
|
@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
|
|||
|
||||
import app
|
||||
from configs import dify_config
|
||||
from core.trigger.utils.locks import build_trigger_refresh_lock_keys
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.trigger import TriggerSubscription
|
||||
|
|
@ -35,11 +36,6 @@ def _build_due_filter(now_ts: int):
|
|||
return or_(credential_due, subscription_due)
|
||||
|
||||
|
||||
def _lock_keys(rows: Sequence[tuple[str, str]]) -> list[str]:
|
||||
"""Generate redis lock keys for rows as (tenant_id, subscription_id)."""
|
||||
return [f"trigger_provider_refresh_lock:{tenant_id}_{sid}" for tenant_id, sid in rows]
|
||||
|
||||
|
||||
def _acquire_locks(keys: Iterable[str], ttl_seconds: int) -> list[bool]:
|
||||
"""Attempt to acquire locks in a single pipelined round-trip.
|
||||
|
||||
|
|
@ -57,17 +53,15 @@ def trigger_provider_refresh() -> None:
|
|||
"""
|
||||
Scan due trigger subscriptions and enqueue refresh tasks with in-flight locks.
|
||||
"""
|
||||
now = _now_ts()
|
||||
now: int = _now_ts()
|
||||
|
||||
batch_size = int(dify_config.TRIGGER_PROVIDER_REFRESH_BATCH_SIZE)
|
||||
lock_ttl = max(
|
||||
300,
|
||||
int(dify_config.TRIGGER_PROVIDER_SUBSCRIPTION_THRESHOLD_SECONDS),
|
||||
)
|
||||
batch_size: int = int(dify_config.TRIGGER_PROVIDER_REFRESH_BATCH_SIZE)
|
||||
lock_ttl: int = max(300, int(dify_config.TRIGGER_PROVIDER_SUBSCRIPTION_THRESHOLD_SECONDS))
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
filter: ColumnElement[bool] = _build_due_filter(now_ts=now)
|
||||
total_due: int = session.scalar(statement=select(func.count()).where(filter)) or 0
|
||||
total_due: int = int(session.scalar(statement=select(func.count()).where(filter)) or 0)
|
||||
logger.info("Trigger refresh scan start: due=%d", total_due)
|
||||
if total_due == 0:
|
||||
return
|
||||
|
||||
|
|
@ -82,17 +76,29 @@ def trigger_provider_refresh() -> None:
|
|||
.limit(batch_size)
|
||||
).all()
|
||||
if not subscription_rows:
|
||||
logger.debug("Trigger refresh page %d/%d empty", page + 1, pages)
|
||||
continue
|
||||
|
||||
subscriptions: list[tuple[str, str]] = [
|
||||
(str(tenant_id), str(subscription_id)) for tenant_id, subscription_id in subscription_rows
|
||||
]
|
||||
lock_keys: list[str] = _lock_keys(subscriptions)
|
||||
lock_keys: list[str] = build_trigger_refresh_lock_keys(subscriptions)
|
||||
acquired: list[bool] = _acquire_locks(keys=lock_keys, ttl_seconds=lock_ttl)
|
||||
|
||||
enqueued: int = 0
|
||||
for (tenant_id, subscription_id), is_locked in zip(subscriptions, acquired):
|
||||
if not is_locked:
|
||||
continue
|
||||
trigger_subscription_refresh.delay(tenant_id=tenant_id, subscription_id=subscription_id)
|
||||
enqueued += 1
|
||||
|
||||
logger.info("Trigger provider refresh queued for due subscriptions: %d", total_due)
|
||||
logger.info(
|
||||
"Trigger refresh page %d/%d: scanned=%d locks_acquired=%d enqueued=%d",
|
||||
page + 1,
|
||||
pages,
|
||||
len(subscriptions),
|
||||
sum(1 for x in acquired if x),
|
||||
enqueued,
|
||||
)
|
||||
|
||||
logger.info("Trigger refresh scan done: due=%d", total_due)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import json
|
||||
import logging
|
||||
import time as _time
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
|
@ -18,6 +19,7 @@ from core.trigger.entities.api_entities import (
|
|||
TriggerProviderApiEntity,
|
||||
TriggerProviderSubscriptionApiEntity,
|
||||
)
|
||||
from core.trigger.entities.entities import Subscription as TriggerSubscriptionEntity
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.trigger.utils.encryption import (
|
||||
|
|
@ -25,6 +27,7 @@ from core.trigger.utils.encryption import (
|
|||
create_trigger_provider_encrypter_for_subscription,
|
||||
delete_cache_for_subscription,
|
||||
)
|
||||
from core.trigger.utils.endpoint import generate_plugin_trigger_endpoint_url
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.provider_ids import TriggerProviderID
|
||||
|
|
@ -347,6 +350,91 @@ class TriggerProviderService:
|
|||
"expires_at": refreshed_credentials.expires_at,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def refresh_subscription(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
subscription_id: str,
|
||||
now: int | None = None,
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Refresh trigger subscription if expired.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
subscription_id: Subscription instance ID
|
||||
now: Current timestamp, defaults to `int(time.time())`
|
||||
|
||||
Returns:
|
||||
Mapping with keys: `result` ("success"|"skipped") and `expires_at` (new or existing value)
|
||||
"""
|
||||
now_ts: int = int(now if now is not None else _time.time())
|
||||
|
||||
with Session(db.engine) as session:
|
||||
subscription: TriggerSubscription | None = (
|
||||
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||
)
|
||||
if subscription is None:
|
||||
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
|
||||
|
||||
if subscription.expires_at == -1 or int(subscription.expires_at) > now_ts:
|
||||
logger.debug(
|
||||
"Subscription not due for refresh: tenant=%s id=%s expires_at=%s now=%s",
|
||||
tenant_id,
|
||||
subscription_id,
|
||||
subscription.expires_at,
|
||||
now_ts,
|
||||
)
|
||||
return {"result": "skipped", "expires_at": int(subscription.expires_at)}
|
||||
|
||||
provider_id = TriggerProviderID(subscription.provider_id)
|
||||
controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
|
||||
tenant_id=tenant_id, provider_id=provider_id
|
||||
)
|
||||
|
||||
# Decrypt credentials and properties for runtime
|
||||
credential_encrypter, _ = create_trigger_provider_encrypter_for_subscription(
|
||||
tenant_id=tenant_id,
|
||||
controller=controller,
|
||||
subscription=subscription,
|
||||
)
|
||||
properties_encrypter, properties_cache = create_trigger_provider_encrypter_for_properties(
|
||||
tenant_id=tenant_id,
|
||||
controller=controller,
|
||||
subscription=subscription,
|
||||
)
|
||||
|
||||
decrypted_credentials = credential_encrypter.decrypt(subscription.credentials)
|
||||
decrypted_properties = properties_encrypter.decrypt(subscription.properties)
|
||||
|
||||
sub_entity: TriggerSubscriptionEntity = TriggerSubscriptionEntity(
|
||||
expires_at=int(subscription.expires_at),
|
||||
endpoint=generate_plugin_trigger_endpoint_url(subscription.endpoint_id),
|
||||
parameters=subscription.parameters,
|
||||
properties=decrypted_properties,
|
||||
)
|
||||
|
||||
refreshed: TriggerSubscriptionEntity = controller.refresh_trigger(
|
||||
subscription=sub_entity,
|
||||
credentials=decrypted_credentials,
|
||||
credential_type=CredentialType.of(subscription.credential_type),
|
||||
)
|
||||
|
||||
# Persist refreshed properties and expires_at
|
||||
subscription.properties = dict(properties_encrypter.encrypt(dict(refreshed.properties)))
|
||||
subscription.expires_at = int(refreshed.expires_at)
|
||||
session.commit()
|
||||
properties_cache.delete()
|
||||
|
||||
logger.info(
|
||||
"Subscription refreshed (service): tenant=%s id=%s new_expires_at=%s",
|
||||
tenant_id,
|
||||
subscription_id,
|
||||
subscription.expires_at,
|
||||
)
|
||||
|
||||
return {"result": "success", "expires_at": int(refreshed.expires_at)}
|
||||
|
||||
@classmethod
|
||||
def get_oauth_client(cls, tenant_id: str, provider_id: TriggerProviderID) -> Optional[Mapping[str, Any]]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,16 +1,15 @@
|
|||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from celery import shared_task
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.trigger.utils.encryption import create_trigger_provider_encrypter_for_properties
|
||||
from core.trigger.utils.locks import build_trigger_refresh_lock_key
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.provider_ids import TriggerProviderID
|
||||
from models.trigger import TriggerSubscription
|
||||
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
|
||||
|
|
@ -21,78 +20,96 @@ def _now_ts() -> int:
|
|||
return int(time.time())
|
||||
|
||||
|
||||
@shared_task(queue="trigger_refresh_executor")
|
||||
def trigger_subscription_refresh(tenant_id: str, subscription_id: str) -> None:
|
||||
"""Refresh a trigger subscription if needed, guarded by a Redis in-flight lock."""
|
||||
lock_key = f"trigger_provider_refresh_lock:{tenant_id}_{subscription_id}"
|
||||
if not redis_client.get(lock_key): # Lock missing means job already timed out/handled
|
||||
logger.debug("Refresh lock missing, skip: %s", lock_key)
|
||||
def _load_subscription(session: Session, tenant_id: str, subscription_id: str) -> TriggerSubscription | None:
|
||||
return session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||
|
||||
|
||||
def _refresh_oauth_if_expired(tenant_id: str, subscription: TriggerSubscription, now: int) -> None:
|
||||
if (
|
||||
subscription.credential_expires_at != -1
|
||||
and int(subscription.credential_expires_at) <= now
|
||||
and CredentialType.of(subscription.credential_type) == CredentialType.OAUTH2
|
||||
):
|
||||
logger.info(
|
||||
"Refreshing OAuth token: tenant=%s subscription_id=%s expires_at=%s now=%s",
|
||||
tenant_id,
|
||||
subscription.id,
|
||||
subscription.credential_expires_at,
|
||||
now,
|
||||
)
|
||||
try:
|
||||
result: Mapping[str, Any] = TriggerProviderService.refresh_oauth_token(
|
||||
tenant_id=tenant_id, subscription_id=subscription.id
|
||||
)
|
||||
logger.info(
|
||||
"OAuth token refreshed: tenant=%s subscription_id=%s result=%s", tenant_id, subscription.id, result
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("OAuth refresh failed: tenant=%s subscription_id=%s", tenant_id, subscription.id)
|
||||
|
||||
|
||||
def _refresh_subscription_if_expired(
|
||||
tenant_id: str,
|
||||
subscription: TriggerSubscription,
|
||||
now: int,
|
||||
) -> None:
|
||||
if subscription.expires_at == -1 or int(subscription.expires_at) > now:
|
||||
logger.debug(
|
||||
"Subscription not due: tenant=%s subscription_id=%s expires_at=%s now=%s",
|
||||
tenant_id,
|
||||
subscription.id,
|
||||
subscription.expires_at,
|
||||
now,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
result: Mapping[str, Any] = TriggerProviderService.refresh_subscription(
|
||||
tenant_id=tenant_id, subscription_id=subscription.id, now=now
|
||||
)
|
||||
logger.info(
|
||||
"Subscription refreshed: tenant=%s subscription_id=%s result=%s",
|
||||
tenant_id,
|
||||
subscription.id,
|
||||
result.get("result"),
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Subscription refresh failed: tenant=%s id=%s", tenant_id, subscription.id)
|
||||
|
||||
|
||||
@shared_task(queue="trigger_refresh_executor")
|
||||
def trigger_subscription_refresh(tenant_id: str, subscription_id: str) -> None:
|
||||
"""Refresh a trigger subscription if needed, guarded by a Redis in-flight lock."""
|
||||
lock_key: str = build_trigger_refresh_lock_key(tenant_id, subscription_id)
|
||||
if not redis_client.get(lock_key):
|
||||
logger.debug("Refresh lock missing, skip: %s", lock_key)
|
||||
return
|
||||
|
||||
logger.info("Begin subscription refresh: tenant=%s id=%s", tenant_id, subscription_id)
|
||||
try:
|
||||
now: int = _now_ts()
|
||||
with Session(db.engine) as session:
|
||||
subscription: TriggerSubscription | None = (
|
||||
session.query(TriggerSubscription)
|
||||
.filter_by(tenant_id=tenant_id, id=subscription_id)
|
||||
.first()
|
||||
)
|
||||
subscription: TriggerSubscription | None = _load_subscription(session, tenant_id, subscription_id)
|
||||
|
||||
if not subscription:
|
||||
logger.warning("Subscription not found: tenant=%s id=%s", tenant_id, subscription_id)
|
||||
return
|
||||
|
||||
# Refresh OAuth token if already expired
|
||||
if (
|
||||
subscription.credential_expires_at != -1
|
||||
and int(subscription.credential_expires_at) <= now
|
||||
and CredentialType.of(subscription.credential_type) == CredentialType.OAUTH2
|
||||
):
|
||||
try:
|
||||
TriggerProviderService.refresh_oauth_token(tenant_id, subscription.id)
|
||||
except Exception:
|
||||
logger.exception("OAuth refresh failed for %s/%s", tenant_id, subscription.id)
|
||||
# proceed to subscription refresh; provider may still accept late refresh
|
||||
logger.debug(
|
||||
"Loaded subscription: tenant=%s id=%s cred_exp=%s sub_exp=%s now=%s",
|
||||
tenant_id,
|
||||
subscription.id,
|
||||
subscription.credential_expires_at,
|
||||
subscription.expires_at,
|
||||
now,
|
||||
)
|
||||
|
||||
# Only refresh subscription when it's actually expired
|
||||
if subscription.expires_at != -1 and int(subscription.expires_at) <= now:
|
||||
# Load decrypted subscription and properties
|
||||
loaded = TriggerProviderService.get_subscription_by_id(
|
||||
tenant_id=tenant_id, subscription_id=subscription.id
|
||||
)
|
||||
if not loaded:
|
||||
logger.warning("Subscription vanished during refresh: tenant=%s id=%s", tenant_id, subscription_id)
|
||||
return
|
||||
|
||||
controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
|
||||
tenant_id, TriggerProviderID(loaded.provider_id)
|
||||
)
|
||||
refreshed = controller.refresh_trigger(
|
||||
subscription=loaded.to_entity(),
|
||||
credentials=loaded.credentials,
|
||||
credential_type=CredentialType.of(loaded.credential_type),
|
||||
)
|
||||
|
||||
# Persist refreshed properties/expires_at with encryption
|
||||
properties_encrypter, properties_cache = create_trigger_provider_encrypter_for_properties(
|
||||
tenant_id=tenant_id,
|
||||
controller=controller,
|
||||
subscription=loaded,
|
||||
)
|
||||
|
||||
db_sub: TriggerSubscription | None = (
|
||||
session.query(TriggerSubscription)
|
||||
.filter_by(tenant_id=tenant_id, id=subscription.id)
|
||||
.first()
|
||||
)
|
||||
if db_sub is not None:
|
||||
db_sub.properties = dict(properties_encrypter.encrypt(dict(refreshed.properties)))
|
||||
db_sub.expires_at = int(refreshed.expires_at)
|
||||
session.commit()
|
||||
properties_cache.delete()
|
||||
_refresh_oauth_if_expired(tenant_id=tenant_id, subscription=subscription, now=now)
|
||||
_refresh_subscription_if_expired(tenant_id=tenant_id, subscription=subscription, now=now)
|
||||
finally:
|
||||
try:
|
||||
redis_client.delete(lock_key)
|
||||
logger.debug("Lock released: %s", lock_key)
|
||||
except Exception:
|
||||
# Best-effort lock cleanup
|
||||
pass
|
||||
logger.warning("Failed to release lock: %s", lock_key, exc_info=True)
|
||||
|
|
|
|||
Loading…
Reference in New Issue