diff --git a/api/controllers/console/workspace/trigger_providers.py b/api/controllers/console/workspace/trigger_providers.py index 888bd77cf5..a69c6a48bb 100644 --- a/api/controllers/console/workspace/trigger_providers.py +++ b/api/controllers/console/workspace/trigger_providers.py @@ -133,20 +133,16 @@ class TriggerSubscriptionBuilderVerifyApi(Resource): args = parser.parse_args() try: - TriggerSubscriptionBuilderService.update_trigger_subscription_builder( + # Use atomic update_and_verify to prevent race conditions + return TriggerSubscriptionBuilderService.update_and_verify_builder( tenant_id=user.current_tenant_id, + user_id=user.id, provider_id=TriggerProviderID(provider), subscription_builder_id=subscription_builder_id, subscription_builder_updater=SubscriptionBuilderUpdater( credentials=args.get("credentials", None), ), ) - return TriggerSubscriptionBuilderService.verify_trigger_subscription_builder( - tenant_id=user.current_tenant_id, - user_id=user.id, - provider_id=TriggerProviderID(provider), - subscription_builder_id=subscription_builder_id, - ) except Exception as e: logger.exception("Error verifying provider credential", exc_info=e) raise ValueError(str(e)) from e @@ -232,8 +228,10 @@ class TriggerSubscriptionBuilderBuildApi(Resource): parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json") args = parser.parse_args() try: - TriggerSubscriptionBuilderService.update_trigger_subscription_builder( + # Use atomic update_and_build to prevent race conditions + TriggerSubscriptionBuilderService.update_and_build_builder( tenant_id=user.current_tenant_id, + user_id=user.id, provider_id=TriggerProviderID(provider), subscription_builder_id=subscription_builder_id, subscription_builder_updater=SubscriptionBuilderUpdater( @@ -242,12 +240,6 @@ class TriggerSubscriptionBuilderBuildApi(Resource): properties=args.get("properties", None), ), ) - TriggerSubscriptionBuilderService.build_trigger_subscription_builder( - tenant_id=user.current_tenant_id, - user_id=user.id, - provider_id=TriggerProviderID(provider), - subscription_builder_id=subscription_builder_id, - ) return 200 except Exception as e: logger.exception("Error building provider credential", exc_info=e) diff --git a/api/services/trigger/trigger_subscription_builder_service.py b/api/services/trigger/trigger_subscription_builder_service.py index 03bc4295f2..6e4a3208cf 100644 --- a/api/services/trigger/trigger_subscription_builder_service.py +++ b/api/services/trigger/trigger_subscription_builder_service.py @@ -2,6 +2,7 @@ import json import logging import uuid from collections.abc import Mapping +from contextlib import contextmanager from datetime import datetime from typing import Any @@ -44,10 +45,31 @@ class TriggerSubscriptionBuilderService: __VALIDATION_REQUEST_CACHE_COUNT__ = 10 __VALIDATION_REQUEST_CACHE_EXPIRE_SECONDS__ = 30 * 60 + ########################## + # Distributed lock + ########################## + __LOCK_EXPIRE_SECONDS__ = 30 + @classmethod def encode_cache_key(cls, subscription_id: str) -> str: return f"trigger:subscription:builder:{subscription_id}" + @classmethod + def encode_lock_key(cls, subscription_id: str) -> str: + return f"trigger:subscription:builder:lock:{subscription_id}" + + @classmethod + @contextmanager + def acquire_builder_lock(cls, subscription_id: str): + """ + Acquire a distributed lock for a subscription builder. + + :param subscription_id: The subscription builder ID + """ + lock_key = cls.encode_lock_key(subscription_id) + with redis_client.lock(lock_key, timeout=cls.__LOCK_EXPIRE_SECONDS__): + yield + @classmethod def verify_trigger_subscription_builder( cls, @@ -87,58 +109,64 @@ class TriggerSubscriptionBuilderService: if not provider_controller: raise ValueError(f"Provider {provider_id} not found") - subscription_builder = cls.get_subscription_builder(subscription_builder_id) - if not subscription_builder: - raise ValueError(f"Subscription builder {subscription_builder_id} not found") + # Acquire lock to prevent concurrent build operations + with cls.acquire_builder_lock(subscription_builder_id): + subscription_builder = cls.get_subscription_builder(subscription_builder_id) + if not subscription_builder: + raise ValueError(f"Subscription builder {subscription_builder_id} not found") - if not subscription_builder.name: - raise ValueError("Subscription builder name is required") + if not subscription_builder.name: + raise ValueError("Subscription builder name is required") - credential_type = CredentialType.of(subscription_builder.credential_type or CredentialType.UNAUTHORIZED.value) - if credential_type == CredentialType.UNAUTHORIZED: - # manually create - TriggerProviderService.add_trigger_subscription( - subscription_id=subscription_builder.id, - tenant_id=tenant_id, - user_id=user_id, - name=subscription_builder.name, - provider_id=provider_id, - endpoint_id=subscription_builder.endpoint_id, - parameters=subscription_builder.parameters, - properties=subscription_builder.properties, - credential_expires_at=subscription_builder.credential_expires_at or -1, - expires_at=subscription_builder.expires_at, - credentials=subscription_builder.credentials, - credential_type=credential_type, - ) - else: - # automatically create - subscription: Subscription = TriggerManager.subscribe_trigger( - tenant_id=tenant_id, - user_id=user_id, - provider_id=provider_id, - endpoint=parse_endpoint_id(subscription_builder.endpoint_id), - parameters=subscription_builder.parameters, - credentials=subscription_builder.credentials, - credential_type=credential_type, + credential_type = CredentialType.of( + subscription_builder.credential_type or CredentialType.UNAUTHORIZED.value ) + if credential_type == CredentialType.UNAUTHORIZED: + # manually create + TriggerProviderService.add_trigger_subscription( + subscription_id=subscription_builder.id, + tenant_id=tenant_id, + user_id=user_id, + name=subscription_builder.name, + provider_id=provider_id, + endpoint_id=subscription_builder.endpoint_id, + parameters=subscription_builder.parameters, + properties=subscription_builder.properties, + credential_expires_at=subscription_builder.credential_expires_at or -1, + expires_at=subscription_builder.expires_at, + credentials=subscription_builder.credentials, + credential_type=credential_type, + ) + else: + # automatically create + subscription: Subscription = TriggerManager.subscribe_trigger( + tenant_id=tenant_id, + user_id=user_id, + provider_id=provider_id, + endpoint=parse_endpoint_id(subscription_builder.endpoint_id), + parameters=subscription_builder.parameters, + credentials=subscription_builder.credentials, + credential_type=credential_type, + ) - TriggerProviderService.add_trigger_subscription( - subscription_id=subscription_builder.id, - tenant_id=tenant_id, - user_id=user_id, - name=subscription_builder.name, - provider_id=provider_id, - endpoint_id=subscription_builder.endpoint_id, - parameters=subscription_builder.parameters, - properties=subscription.properties, - credentials=subscription_builder.credentials, - credential_type=credential_type, - credential_expires_at=subscription_builder.credential_expires_at or -1, - expires_at=subscription_builder.expires_at, - ) + TriggerProviderService.add_trigger_subscription( + subscription_id=subscription_builder.id, + tenant_id=tenant_id, + user_id=user_id, + name=subscription_builder.name, + provider_id=provider_id, + endpoint_id=subscription_builder.endpoint_id, + parameters=subscription_builder.parameters, + properties=subscription.properties, + credentials=subscription_builder.credentials, + credential_type=credential_type, + credential_expires_at=subscription_builder.credential_expires_at or -1, + expires_at=subscription_builder.expires_at, + ) - cls.delete_trigger_subscription_builder(subscription_builder_id) + # Delete the builder after successful subscription creation + cache_key = cls.encode_cache_key(subscription_builder_id) + redis_client.delete(cache_key) @classmethod def create_trigger_subscription_builder( @@ -191,17 +219,154 @@ class TriggerSubscriptionBuilderService: if not provider_controller: raise ValueError(f"Provider {provider_id} not found") - cache_key = cls.encode_cache_key(subscription_id) - subscription_builder_cache = cls.get_subscription_builder(subscription_builder_id) - if not subscription_builder_cache or subscription_builder_cache.tenant_id != tenant_id: - raise ValueError(f"Subscription {subscription_id} expired or not found") + # Acquire lock to prevent concurrent updates + with cls.acquire_builder_lock(subscription_id): + cache_key = cls.encode_cache_key(subscription_id) + subscription_builder_cache = cls.get_subscription_builder(subscription_builder_id) + if not subscription_builder_cache or subscription_builder_cache.tenant_id != tenant_id: + raise ValueError(f"Subscription {subscription_id} expired or not found") - subscription_builder_updater.update(subscription_builder_cache) + subscription_builder_updater.update(subscription_builder_cache) - redis_client.setex( - cache_key, cls.__BUILDER_CACHE_EXPIRE_SECONDS__, subscription_builder_cache.model_dump_json() - ) - return cls.builder_to_api_entity(controller=provider_controller, entity=subscription_builder_cache) + redis_client.setex( + cache_key, cls.__BUILDER_CACHE_EXPIRE_SECONDS__, subscription_builder_cache.model_dump_json() + ) + return cls.builder_to_api_entity(controller=provider_controller, entity=subscription_builder_cache) + + @classmethod + def update_and_verify_builder( + cls, + tenant_id: str, + user_id: str, + provider_id: TriggerProviderID, + subscription_builder_id: str, + subscription_builder_updater: SubscriptionBuilderUpdater, + ) -> Mapping[str, Any]: + """ + Atomically update and verify a subscription builder. + This ensures the verification is done on the exact data that was just updated. + """ + subscription_id = subscription_builder_id + provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id) + if not provider_controller: + raise ValueError(f"Provider {provider_id} not found") + + # Acquire lock for the entire update + verify operation + with cls.acquire_builder_lock(subscription_id): + cache_key = cls.encode_cache_key(subscription_id) + subscription_builder_cache = cls.get_subscription_builder(subscription_builder_id) + if not subscription_builder_cache or subscription_builder_cache.tenant_id != tenant_id: + raise ValueError(f"Subscription {subscription_id} expired or not found") + + # Update + subscription_builder_updater.update(subscription_builder_cache) + redis_client.setex( + cache_key, cls.__BUILDER_CACHE_EXPIRE_SECONDS__, subscription_builder_cache.model_dump_json() + ) + + # Verify (using the just-updated data) + if subscription_builder_cache.credential_type == CredentialType.OAUTH2: + return {"verified": bool(subscription_builder_cache.credentials)} + + if subscription_builder_cache.credential_type == CredentialType.API_KEY: + credentials_to_validate = subscription_builder_cache.credentials + try: + provider_controller.validate_credentials(user_id, credentials_to_validate) + except ToolProviderCredentialValidationError as e: + raise ValueError(f"Invalid credentials: {e}") + return {"verified": True} + + return {"verified": True} + + @classmethod + def update_and_build_builder( + cls, + tenant_id: str, + user_id: str, + provider_id: TriggerProviderID, + subscription_builder_id: str, + subscription_builder_updater: SubscriptionBuilderUpdater, + ) -> None: + """ + Atomically update and build a subscription builder. + This ensures the build uses the exact data that was just updated. + """ + subscription_id = subscription_builder_id + provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id) + if not provider_controller: + raise ValueError(f"Provider {provider_id} not found") + + # Acquire lock for the entire update + build operation + with cls.acquire_builder_lock(subscription_id): + cache_key = cls.encode_cache_key(subscription_id) + subscription_builder_cache = cls.get_subscription_builder(subscription_builder_id) + if not subscription_builder_cache or subscription_builder_cache.tenant_id != tenant_id: + raise ValueError(f"Subscription {subscription_id} expired or not found") + + # Update + subscription_builder_updater.update(subscription_builder_cache) + redis_client.setex( + cache_key, cls.__BUILDER_CACHE_EXPIRE_SECONDS__, subscription_builder_cache.model_dump_json() + ) + + # Re-fetch to ensure we have the latest data + subscription_builder = cls.get_subscription_builder(subscription_builder_id) + if not subscription_builder: + raise ValueError(f"Subscription builder {subscription_builder_id} not found") + + if not subscription_builder.name: + raise ValueError("Subscription builder name is required") + + # Build + credential_type = CredentialType.of( + subscription_builder.credential_type or CredentialType.UNAUTHORIZED.value + ) + if credential_type == CredentialType.UNAUTHORIZED: + # manually create + TriggerProviderService.add_trigger_subscription( + subscription_id=subscription_builder.id, + tenant_id=tenant_id, + user_id=user_id, + name=subscription_builder.name, + provider_id=provider_id, + endpoint_id=subscription_builder.endpoint_id, + parameters=subscription_builder.parameters, + properties=subscription_builder.properties, + credential_expires_at=subscription_builder.credential_expires_at or -1, + expires_at=subscription_builder.expires_at, + credentials=subscription_builder.credentials, + credential_type=credential_type, + ) + else: + # automatically create + subscription: Subscription = TriggerManager.subscribe_trigger( + tenant_id=tenant_id, + user_id=user_id, + provider_id=provider_id, + endpoint=parse_endpoint_id(subscription_builder.endpoint_id), + parameters=subscription_builder.parameters, + credentials=subscription_builder.credentials, + credential_type=credential_type, + ) + + TriggerProviderService.add_trigger_subscription( + subscription_id=subscription_builder.id, + tenant_id=tenant_id, + user_id=user_id, + name=subscription_builder.name, + provider_id=provider_id, + endpoint_id=subscription_builder.endpoint_id, + parameters=subscription_builder.parameters, + properties=subscription.properties, + credentials=subscription_builder.credentials, + credential_type=credential_type, + credential_expires_at=subscription_builder.credential_expires_at or -1, + expires_at=subscription_builder.expires_at, + ) + + # Delete the builder after successful subscription creation + cache_key = cls.encode_cache_key(subscription_builder_id) + redis_client.delete(cache_key) @classmethod def builder_to_api_entity( @@ -222,14 +387,6 @@ class TriggerSubscriptionBuilderService: ), ) - @classmethod - def delete_trigger_subscription_builder(cls, subscription_id: str) -> None: - """ - Delete a trigger subscription validation. - """ - cache_key = cls.encode_cache_key(subscription_id) - redis_client.delete(cache_key) - @classmethod def get_subscription_builder(cls, endpoint_id: str) -> SubscriptionBuilder | None: """