mirror of https://github.com/langgenius/dify.git
feat(trigger): implement atomic update and verification for subscription builders
- Introduced atomic operations for updating and verifying subscription builders to prevent race conditions. - Added distributed locking mechanism to ensure data consistency during concurrent updates and builds. - Refactored existing methods to utilize the new atomic update and verification logic, enhancing the reliability of trigger subscription handling.
This commit is contained in:
parent
beff639c3d
commit
cca48f07aa
|
|
@ -133,20 +133,16 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
|
|||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
|
||||
# Use atomic update_and_verify to prevent race conditions
|
||||
return TriggerSubscriptionBuilderService.update_and_verify_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||
credentials=args.get("credentials", None),
|
||||
),
|
||||
)
|
||||
return TriggerSubscriptionBuilderService.verify_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error verifying provider credential", exc_info=e)
|
||||
raise ValueError(str(e)) from e
|
||||
|
|
@ -232,8 +228,10 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
|
|||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
|
||||
# Use atomic update_and_build to prevent race conditions
|
||||
TriggerSubscriptionBuilderService.update_and_build_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||
|
|
@ -242,12 +240,6 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
|
|||
properties=args.get("properties", None),
|
||||
),
|
||||
)
|
||||
TriggerSubscriptionBuilderService.build_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
)
|
||||
return 200
|
||||
except Exception as e:
|
||||
logger.exception("Error building provider credential", exc_info=e)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import json
|
|||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -44,10 +45,31 @@ class TriggerSubscriptionBuilderService:
|
|||
__VALIDATION_REQUEST_CACHE_COUNT__ = 10
|
||||
__VALIDATION_REQUEST_CACHE_EXPIRE_SECONDS__ = 30 * 60
|
||||
|
||||
##########################
|
||||
# Distributed lock
|
||||
##########################
|
||||
__LOCK_EXPIRE_SECONDS__ = 30
|
||||
|
||||
@classmethod
|
||||
def encode_cache_key(cls, subscription_id: str) -> str:
|
||||
return f"trigger:subscription:builder:{subscription_id}"
|
||||
|
||||
@classmethod
|
||||
def encode_lock_key(cls, subscription_id: str) -> str:
|
||||
return f"trigger:subscription:builder:lock:{subscription_id}"
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def acquire_builder_lock(cls, subscription_id: str):
|
||||
"""
|
||||
Acquire a distributed lock for a subscription builder.
|
||||
|
||||
:param subscription_id: The subscription builder ID
|
||||
"""
|
||||
lock_key = cls.encode_lock_key(subscription_id)
|
||||
with redis_client.lock(lock_key, timeout=cls.__LOCK_EXPIRE_SECONDS__):
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
def verify_trigger_subscription_builder(
|
||||
cls,
|
||||
|
|
@ -87,58 +109,64 @@ class TriggerSubscriptionBuilderService:
|
|||
if not provider_controller:
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
||||
subscription_builder = cls.get_subscription_builder(subscription_builder_id)
|
||||
if not subscription_builder:
|
||||
raise ValueError(f"Subscription builder {subscription_builder_id} not found")
|
||||
# Acquire lock to prevent concurrent build operations
|
||||
with cls.acquire_builder_lock(subscription_builder_id):
|
||||
subscription_builder = cls.get_subscription_builder(subscription_builder_id)
|
||||
if not subscription_builder:
|
||||
raise ValueError(f"Subscription builder {subscription_builder_id} not found")
|
||||
|
||||
if not subscription_builder.name:
|
||||
raise ValueError("Subscription builder name is required")
|
||||
if not subscription_builder.name:
|
||||
raise ValueError("Subscription builder name is required")
|
||||
|
||||
credential_type = CredentialType.of(subscription_builder.credential_type or CredentialType.UNAUTHORIZED.value)
|
||||
if credential_type == CredentialType.UNAUTHORIZED:
|
||||
# manually create
|
||||
TriggerProviderService.add_trigger_subscription(
|
||||
subscription_id=subscription_builder.id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
name=subscription_builder.name,
|
||||
provider_id=provider_id,
|
||||
endpoint_id=subscription_builder.endpoint_id,
|
||||
parameters=subscription_builder.parameters,
|
||||
properties=subscription_builder.properties,
|
||||
credential_expires_at=subscription_builder.credential_expires_at or -1,
|
||||
expires_at=subscription_builder.expires_at,
|
||||
credentials=subscription_builder.credentials,
|
||||
credential_type=credential_type,
|
||||
)
|
||||
else:
|
||||
# automatically create
|
||||
subscription: Subscription = TriggerManager.subscribe_trigger(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=provider_id,
|
||||
endpoint=parse_endpoint_id(subscription_builder.endpoint_id),
|
||||
parameters=subscription_builder.parameters,
|
||||
credentials=subscription_builder.credentials,
|
||||
credential_type=credential_type,
|
||||
credential_type = CredentialType.of(
|
||||
subscription_builder.credential_type or CredentialType.UNAUTHORIZED.value
|
||||
)
|
||||
if credential_type == CredentialType.UNAUTHORIZED:
|
||||
# manually create
|
||||
TriggerProviderService.add_trigger_subscription(
|
||||
subscription_id=subscription_builder.id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
name=subscription_builder.name,
|
||||
provider_id=provider_id,
|
||||
endpoint_id=subscription_builder.endpoint_id,
|
||||
parameters=subscription_builder.parameters,
|
||||
properties=subscription_builder.properties,
|
||||
credential_expires_at=subscription_builder.credential_expires_at or -1,
|
||||
expires_at=subscription_builder.expires_at,
|
||||
credentials=subscription_builder.credentials,
|
||||
credential_type=credential_type,
|
||||
)
|
||||
else:
|
||||
# automatically create
|
||||
subscription: Subscription = TriggerManager.subscribe_trigger(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=provider_id,
|
||||
endpoint=parse_endpoint_id(subscription_builder.endpoint_id),
|
||||
parameters=subscription_builder.parameters,
|
||||
credentials=subscription_builder.credentials,
|
||||
credential_type=credential_type,
|
||||
)
|
||||
|
||||
TriggerProviderService.add_trigger_subscription(
|
||||
subscription_id=subscription_builder.id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
name=subscription_builder.name,
|
||||
provider_id=provider_id,
|
||||
endpoint_id=subscription_builder.endpoint_id,
|
||||
parameters=subscription_builder.parameters,
|
||||
properties=subscription.properties,
|
||||
credentials=subscription_builder.credentials,
|
||||
credential_type=credential_type,
|
||||
credential_expires_at=subscription_builder.credential_expires_at or -1,
|
||||
expires_at=subscription_builder.expires_at,
|
||||
)
|
||||
TriggerProviderService.add_trigger_subscription(
|
||||
subscription_id=subscription_builder.id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
name=subscription_builder.name,
|
||||
provider_id=provider_id,
|
||||
endpoint_id=subscription_builder.endpoint_id,
|
||||
parameters=subscription_builder.parameters,
|
||||
properties=subscription.properties,
|
||||
credentials=subscription_builder.credentials,
|
||||
credential_type=credential_type,
|
||||
credential_expires_at=subscription_builder.credential_expires_at or -1,
|
||||
expires_at=subscription_builder.expires_at,
|
||||
)
|
||||
|
||||
cls.delete_trigger_subscription_builder(subscription_builder_id)
|
||||
# Delete the builder after successful subscription creation
|
||||
cache_key = cls.encode_cache_key(subscription_builder_id)
|
||||
redis_client.delete(cache_key)
|
||||
|
||||
@classmethod
|
||||
def create_trigger_subscription_builder(
|
||||
|
|
@ -191,17 +219,154 @@ class TriggerSubscriptionBuilderService:
|
|||
if not provider_controller:
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
||||
cache_key = cls.encode_cache_key(subscription_id)
|
||||
subscription_builder_cache = cls.get_subscription_builder(subscription_builder_id)
|
||||
if not subscription_builder_cache or subscription_builder_cache.tenant_id != tenant_id:
|
||||
raise ValueError(f"Subscription {subscription_id} expired or not found")
|
||||
# Acquire lock to prevent concurrent updates
|
||||
with cls.acquire_builder_lock(subscription_id):
|
||||
cache_key = cls.encode_cache_key(subscription_id)
|
||||
subscription_builder_cache = cls.get_subscription_builder(subscription_builder_id)
|
||||
if not subscription_builder_cache or subscription_builder_cache.tenant_id != tenant_id:
|
||||
raise ValueError(f"Subscription {subscription_id} expired or not found")
|
||||
|
||||
subscription_builder_updater.update(subscription_builder_cache)
|
||||
subscription_builder_updater.update(subscription_builder_cache)
|
||||
|
||||
redis_client.setex(
|
||||
cache_key, cls.__BUILDER_CACHE_EXPIRE_SECONDS__, subscription_builder_cache.model_dump_json()
|
||||
)
|
||||
return cls.builder_to_api_entity(controller=provider_controller, entity=subscription_builder_cache)
|
||||
redis_client.setex(
|
||||
cache_key, cls.__BUILDER_CACHE_EXPIRE_SECONDS__, subscription_builder_cache.model_dump_json()
|
||||
)
|
||||
return cls.builder_to_api_entity(controller=provider_controller, entity=subscription_builder_cache)
|
||||
|
||||
@classmethod
|
||||
def update_and_verify_builder(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
subscription_builder_id: str,
|
||||
subscription_builder_updater: SubscriptionBuilderUpdater,
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Atomically update and verify a subscription builder.
|
||||
This ensures the verification is done on the exact data that was just updated.
|
||||
"""
|
||||
subscription_id = subscription_builder_id
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
if not provider_controller:
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
||||
# Acquire lock for the entire update + verify operation
|
||||
with cls.acquire_builder_lock(subscription_id):
|
||||
cache_key = cls.encode_cache_key(subscription_id)
|
||||
subscription_builder_cache = cls.get_subscription_builder(subscription_builder_id)
|
||||
if not subscription_builder_cache or subscription_builder_cache.tenant_id != tenant_id:
|
||||
raise ValueError(f"Subscription {subscription_id} expired or not found")
|
||||
|
||||
# Update
|
||||
subscription_builder_updater.update(subscription_builder_cache)
|
||||
redis_client.setex(
|
||||
cache_key, cls.__BUILDER_CACHE_EXPIRE_SECONDS__, subscription_builder_cache.model_dump_json()
|
||||
)
|
||||
|
||||
# Verify (using the just-updated data)
|
||||
if subscription_builder_cache.credential_type == CredentialType.OAUTH2:
|
||||
return {"verified": bool(subscription_builder_cache.credentials)}
|
||||
|
||||
if subscription_builder_cache.credential_type == CredentialType.API_KEY:
|
||||
credentials_to_validate = subscription_builder_cache.credentials
|
||||
try:
|
||||
provider_controller.validate_credentials(user_id, credentials_to_validate)
|
||||
except ToolProviderCredentialValidationError as e:
|
||||
raise ValueError(f"Invalid credentials: {e}")
|
||||
return {"verified": True}
|
||||
|
||||
return {"verified": True}
|
||||
|
||||
@classmethod
|
||||
def update_and_build_builder(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
subscription_builder_id: str,
|
||||
subscription_builder_updater: SubscriptionBuilderUpdater,
|
||||
) -> None:
|
||||
"""
|
||||
Atomically update and build a subscription builder.
|
||||
This ensures the build uses the exact data that was just updated.
|
||||
"""
|
||||
subscription_id = subscription_builder_id
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
if not provider_controller:
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
||||
# Acquire lock for the entire update + build operation
|
||||
with cls.acquire_builder_lock(subscription_id):
|
||||
cache_key = cls.encode_cache_key(subscription_id)
|
||||
subscription_builder_cache = cls.get_subscription_builder(subscription_builder_id)
|
||||
if not subscription_builder_cache or subscription_builder_cache.tenant_id != tenant_id:
|
||||
raise ValueError(f"Subscription {subscription_id} expired or not found")
|
||||
|
||||
# Update
|
||||
subscription_builder_updater.update(subscription_builder_cache)
|
||||
redis_client.setex(
|
||||
cache_key, cls.__BUILDER_CACHE_EXPIRE_SECONDS__, subscription_builder_cache.model_dump_json()
|
||||
)
|
||||
|
||||
# Re-fetch to ensure we have the latest data
|
||||
subscription_builder = cls.get_subscription_builder(subscription_builder_id)
|
||||
if not subscription_builder:
|
||||
raise ValueError(f"Subscription builder {subscription_builder_id} not found")
|
||||
|
||||
if not subscription_builder.name:
|
||||
raise ValueError("Subscription builder name is required")
|
||||
|
||||
# Build
|
||||
credential_type = CredentialType.of(
|
||||
subscription_builder.credential_type or CredentialType.UNAUTHORIZED.value
|
||||
)
|
||||
if credential_type == CredentialType.UNAUTHORIZED:
|
||||
# manually create
|
||||
TriggerProviderService.add_trigger_subscription(
|
||||
subscription_id=subscription_builder.id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
name=subscription_builder.name,
|
||||
provider_id=provider_id,
|
||||
endpoint_id=subscription_builder.endpoint_id,
|
||||
parameters=subscription_builder.parameters,
|
||||
properties=subscription_builder.properties,
|
||||
credential_expires_at=subscription_builder.credential_expires_at or -1,
|
||||
expires_at=subscription_builder.expires_at,
|
||||
credentials=subscription_builder.credentials,
|
||||
credential_type=credential_type,
|
||||
)
|
||||
else:
|
||||
# automatically create
|
||||
subscription: Subscription = TriggerManager.subscribe_trigger(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=provider_id,
|
||||
endpoint=parse_endpoint_id(subscription_builder.endpoint_id),
|
||||
parameters=subscription_builder.parameters,
|
||||
credentials=subscription_builder.credentials,
|
||||
credential_type=credential_type,
|
||||
)
|
||||
|
||||
TriggerProviderService.add_trigger_subscription(
|
||||
subscription_id=subscription_builder.id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
name=subscription_builder.name,
|
||||
provider_id=provider_id,
|
||||
endpoint_id=subscription_builder.endpoint_id,
|
||||
parameters=subscription_builder.parameters,
|
||||
properties=subscription.properties,
|
||||
credentials=subscription_builder.credentials,
|
||||
credential_type=credential_type,
|
||||
credential_expires_at=subscription_builder.credential_expires_at or -1,
|
||||
expires_at=subscription_builder.expires_at,
|
||||
)
|
||||
|
||||
# Delete the builder after successful subscription creation
|
||||
cache_key = cls.encode_cache_key(subscription_builder_id)
|
||||
redis_client.delete(cache_key)
|
||||
|
||||
@classmethod
|
||||
def builder_to_api_entity(
|
||||
|
|
@ -222,14 +387,6 @@ class TriggerSubscriptionBuilderService:
|
|||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def delete_trigger_subscription_builder(cls, subscription_id: str) -> None:
|
||||
"""
|
||||
Delete a trigger subscription validation.
|
||||
"""
|
||||
cache_key = cls.encode_cache_key(subscription_id)
|
||||
redis_client.delete(cache_key)
|
||||
|
||||
@classmethod
|
||||
def get_subscription_builder(cls, endpoint_id: str) -> SubscriptionBuilder | None:
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in New Issue