From 42f75b660270a2400e252acbe38daa6b232cb096 Mon Sep 17 00:00:00 2001 From: Harry Date: Sat, 11 Oct 2025 21:12:09 +0800 Subject: [PATCH] feat(trigger): enhance trigger subscription handling with credential support - Added `credentials` and `credential_type` parameters to various methods in `PluginTriggerManager`, `PluginTriggerProviderController`, and `TriggerManager` to support improved credential management for trigger subscriptions. - Updated the `Subscription` model to include `parameters` for better subscription data handling. - Refactored related services to accommodate the new credential handling, ensuring consistency across the trigger workflow. --- api/core/plugin/impl/trigger.py | 10 +++ api/core/trigger/entities/entities.py | 1 + api/core/trigger/provider.py | 46 ++++++++--- api/core/trigger/trigger_manager.py | 33 ++++++-- api/models/trigger.py | 1 + .../trigger/trigger_provider_service.py | 81 +++++++++++++------ api/services/trigger/trigger_service.py | 14 +++- .../trigger_subscription_builder_service.py | 15 ++-- 8 files changed, 155 insertions(+), 46 deletions(-) diff --git a/api/core/plugin/impl/trigger.py b/api/core/plugin/impl/trigger.py index 29ecfd9c76..798714d4f4 100644 --- a/api/core/plugin/impl/trigger.py +++ b/api/core/plugin/impl/trigger.py @@ -158,6 +158,8 @@ class PluginTriggerManager(BasePluginClient): provider: str, subscription: Mapping[str, Any], request: Request, + credentials: Mapping[str, str], + credential_type: CredentialType, ) -> TriggerDispatchResponse: """ Dispatch an event to triggers. @@ -173,6 +175,8 @@ class PluginTriggerManager(BasePluginClient): "data": { "provider": provider_id.provider_name, "subscription": subscription, + "credentials": credentials, + "credential_type": credential_type, "raw_http_request": binascii.hexlify(serialize_request(request)).decode(), }, }, @@ -197,6 +201,7 @@ class PluginTriggerManager(BasePluginClient): user_id: str, provider: str, credentials: Mapping[str, str], + credential_type: CredentialType, endpoint: str, parameters: Mapping[str, Any], ) -> TriggerSubscriptionResponse: @@ -213,6 +218,7 @@ class PluginTriggerManager(BasePluginClient): "data": { "provider": provider_id.provider_name, "credentials": credentials, + "credential_type": credential_type, "endpoint": endpoint, "parameters": parameters, }, @@ -235,6 +241,7 @@ class PluginTriggerManager(BasePluginClient): provider: str, subscription: Subscription, credentials: Mapping[str, str], + credential_type: CredentialType, ) -> TriggerSubscriptionResponse: """ Unsubscribe from a trigger. @@ -250,6 +257,7 @@ class PluginTriggerManager(BasePluginClient): "provider": provider_id.provider_name, "subscription": subscription.model_dump(), "credentials": credentials, + "credential_type": credential_type, }, }, headers={ @@ -270,6 +278,7 @@ class PluginTriggerManager(BasePluginClient): provider: str, subscription: Subscription, credentials: Mapping[str, str], + credential_type: CredentialType, ) -> TriggerSubscriptionResponse: """ Refresh a trigger subscription. @@ -285,6 +294,7 @@ class PluginTriggerManager(BasePluginClient): "provider": provider_id.provider_name, "subscription": subscription.model_dump(), "credentials": credentials, + "credential_type": credential_type, }, }, headers={ diff --git a/api/core/trigger/entities/entities.py b/api/core/trigger/entities/entities.py index 1240084fda..9fa219dd93 100644 --- a/api/core/trigger/entities/entities.py +++ b/api/core/trigger/entities/entities.py @@ -173,6 +173,7 @@ class Subscription(BaseModel): ) endpoint: str = Field(..., description="The webhook endpoint URL allocated by Dify for receiving events") + parameters: Mapping[str, Any] = Field(default={}, description="The parameters of the subscription constructor") properties: Mapping[str, Any] = Field( ..., description="Subscription data containing all properties and provider-specific information" ) diff --git a/api/core/trigger/provider.py b/api/core/trigger/provider.py index f45736cb79..3f290aec93 100644 --- a/api/core/trigger/provider.py +++ b/api/core/trigger/provider.py @@ -13,6 +13,7 @@ from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.entities.request import ( TriggerDispatchResponse, TriggerInvokeEventResponse, + TriggerSubscriptionResponse, ) from core.plugin.impl.trigger import PluginTriggerManager from core.trigger.entities.api_entities import TriggerApiEntity, TriggerProviderApiEntity @@ -249,13 +250,22 @@ class PluginTriggerProviderController: else [] ) - def dispatch(self, user_id: str, request: Request, subscription: Subscription) -> TriggerDispatchResponse: + def dispatch( + self, + user_id: str, + request: Request, + subscription: Subscription, + credentials: Mapping[str, str], + credential_type: CredentialType, + ) -> TriggerDispatchResponse: """ Dispatch a trigger through plugin runtime :param user_id: User ID :param request: Flask request object :param subscription: Subscription + :param credentials: Provider credentials + :param credential_type: Credential type :return: Dispatch response with triggers and raw HTTP response """ manager = PluginTriggerManager() @@ -267,6 +277,8 @@ class PluginTriggerProviderController: provider=str(provider_id), subscription=subscription.model_dump(), request=request, + credentials=credentials, + credential_type=credential_type, ) return response @@ -306,7 +318,12 @@ class PluginTriggerProviderController: ) def subscribe_trigger( - self, user_id: str, endpoint: str, parameters: Mapping[str, Any], credentials: Mapping[str, str] + self, + user_id: str, + endpoint: str, + parameters: Mapping[str, Any], + credentials: Mapping[str, str], + credential_type: CredentialType, ) -> Subscription: """ Subscribe to a trigger through plugin runtime @@ -315,24 +332,26 @@ class PluginTriggerProviderController: :param endpoint: Subscription endpoint :param subscription_params: Subscription parameters :param credentials: Provider credentials + :param credential_type: Credential type :return: Subscription result """ manager = PluginTriggerManager() - provider_id = self.get_provider_id() + provider_id: TriggerProviderID = self.get_provider_id() - response = manager.subscribe( + response: TriggerSubscriptionResponse = manager.subscribe( tenant_id=self.tenant_id, user_id=user_id, provider=str(provider_id), - credentials=credentials, endpoint=endpoint, parameters=parameters, + credentials=credentials, + credential_type=credential_type, ) return Subscription.model_validate(response.subscription) def unsubscribe_trigger( - self, user_id: str, subscription: Subscription, credentials: Mapping[str, str] + self, user_id: str, subscription: Subscription, credentials: Mapping[str, str], credential_type: CredentialType ) -> Unsubscription: """ Unsubscribe from a trigger through plugin runtime @@ -340,22 +359,26 @@ class PluginTriggerProviderController: :param user_id: User ID :param subscription: Subscription metadata :param credentials: Provider credentials + :param credential_type: Credential type :return: Unsubscription result """ manager = PluginTriggerManager() - provider_id = self.get_provider_id() + provider_id: TriggerProviderID = self.get_provider_id() - response = manager.unsubscribe( + response: TriggerSubscriptionResponse = manager.unsubscribe( tenant_id=self.tenant_id, user_id=user_id, provider=str(provider_id), subscription=subscription, credentials=credentials, + credential_type=credential_type, ) return Unsubscription.model_validate(response.subscription) - def refresh_trigger(self, subscription: Subscription, credentials: Mapping[str, str]) -> Subscription: + def refresh_trigger( + self, subscription: Subscription, credentials: Mapping[str, str], credential_type: CredentialType + ) -> Subscription: """ Refresh a trigger subscription through plugin runtime @@ -364,14 +387,15 @@ class PluginTriggerProviderController: :return: Refreshed subscription result """ manager = PluginTriggerManager() - provider_id = self.get_provider_id() + provider_id: TriggerProviderID = self.get_provider_id() - response = manager.refresh( + response: TriggerSubscriptionResponse = manager.refresh( tenant_id=self.tenant_id, user_id="system", # System refresh provider=str(provider_id), subscription=subscription, credentials=credentials, + credential_type=credential_type, ) return Subscription.model_validate(response.subscription) diff --git a/api/core/trigger/trigger_manager.py b/api/core/trigger/trigger_manager.py index 548543ada7..f6237bfb9f 100644 --- a/api/core/trigger/trigger_manager.py +++ b/api/core/trigger/trigger_manager.py @@ -180,6 +180,7 @@ class TriggerManager: endpoint: str, parameters: Mapping[str, Any], credentials: Mapping[str, str], + credential_type: CredentialType, ) -> Subscription: """ Subscribe to a trigger (e.g., register webhook) @@ -190,11 +191,18 @@ class TriggerManager: :param endpoint: Subscription endpoint :param parameters: Subscription parameters :param credentials: Provider credentials + :param credential_type: Credential type :return: Subscription result """ - provider = cls.get_trigger_provider(tenant_id, provider_id) + provider: PluginTriggerProviderController = cls.get_trigger_provider( + tenant_id=tenant_id, provider_id=provider_id + ) return provider.subscribe_trigger( - user_id=user_id, endpoint=endpoint, parameters=parameters, credentials=credentials + user_id=user_id, + endpoint=endpoint, + parameters=parameters, + credentials=credentials, + credential_type=credential_type, ) @classmethod @@ -205,6 +213,7 @@ class TriggerManager: provider_id: TriggerProviderID, subscription: Subscription, credentials: Mapping[str, str], + credential_type: CredentialType, ) -> Unsubscription: """ Unsubscribe from a trigger @@ -214,10 +223,18 @@ class TriggerManager: :param provider_id: Provider ID :param subscription: Subscription metadata from subscribe operation :param credentials: Provider credentials + :param credential_type: Credential type :return: Unsubscription result """ - provider = cls.get_trigger_provider(tenant_id, provider_id) - return provider.unsubscribe_trigger(user_id=user_id, subscription=subscription, credentials=credentials) + provider: PluginTriggerProviderController = cls.get_trigger_provider( + tenant_id=tenant_id, provider_id=provider_id + ) + return provider.unsubscribe_trigger( + user_id=user_id, + subscription=subscription, + credentials=credentials, + credential_type=credential_type, + ) @classmethod def refresh_trigger( @@ -226,6 +243,7 @@ class TriggerManager: provider_id: TriggerProviderID, subscription: Subscription, credentials: Mapping[str, str], + credential_type: CredentialType, ) -> Subscription: """ Refresh a trigger subscription @@ -234,9 +252,14 @@ class TriggerManager: :param provider_id: Provider ID :param subscription: Subscription metadata from subscribe operation :param credentials: Provider credentials + :param credential_type: Credential type :return: Refreshed subscription result """ - return cls.get_trigger_provider(tenant_id, provider_id).refresh_trigger(subscription, credentials) + + # TODO you should update the subscription using the return value of the refresh_trigger + return cls.get_trigger_provider(tenant_id=tenant_id, provider_id=provider_id).refresh_trigger( + subscription=subscription, credentials=credentials, credential_type=credential_type + ) # Export diff --git a/api/models/trigger.py b/api/models/trigger.py index 08dc53d82f..092fd84935 100644 --- a/api/models/trigger.py +++ b/api/models/trigger.py @@ -71,6 +71,7 @@ class TriggerSubscription(Base): return Subscription( expires_at=self.expires_at, endpoint=parse_endpoint_id(self.endpoint_id), + parameters=self.parameters, properties=self.properties, ) diff --git a/api/services/trigger/trigger_provider_service.py b/api/services/trigger/trigger_provider_service.py index fb7fdf81d1..bac0d77a4e 100644 --- a/api/services/trigger/trigger_provider_service.py +++ b/api/services/trigger/trigger_provider_service.py @@ -18,6 +18,7 @@ from core.trigger.entities.api_entities import ( TriggerProviderApiEntity, TriggerProviderSubscriptionApiEntity, ) +from core.trigger.provider import PluginTriggerProviderController from core.trigger.trigger_manager import TriggerManager from core.trigger.utils.encryption import ( create_trigger_provider_encrypter_for_properties, @@ -166,7 +167,7 @@ class TriggerProviderService: ) # Create provider record - db_provider = TriggerSubscription( + subscription = TriggerSubscription( id=subscription_id or str(uuid.uuid4()), tenant_id=tenant_id, user_id=user_id, @@ -181,10 +182,10 @@ class TriggerProviderService: expires_at=expires_at, ) - session.add(db_provider) + session.add(subscription) session.commit() - return {"result": "success", "id": str(db_provider.id)} + return {"result": "success", "id": str(subscription.id)} except Exception as e: logger.exception("Failed to add trigger provider") @@ -228,16 +229,42 @@ class TriggerProviderService: :param subscription_id: Subscription instance ID :return: Success response """ - db_provider = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first() - if not db_provider: + subscription: TriggerSubscription | None = ( + session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first() + ) + if not subscription: raise ValueError(f"Trigger provider subscription {subscription_id} not found") + credential_type: CredentialType = CredentialType.of(subscription.credential_type) + is_auto_created: bool = credential_type in [CredentialType.OAUTH2, CredentialType.API_KEY] + if is_auto_created: + provider_id = TriggerProviderID(subscription.provider_id) + provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider( + tenant_id=tenant_id, provider_id=provider_id + ) + encrypter, _ = create_trigger_provider_encrypter_for_subscription( + tenant_id=tenant_id, + controller=provider_controller, + subscription=subscription, + ) + try: + TriggerManager.unsubscribe_trigger( + tenant_id=tenant_id, + user_id=subscription.user_id, + provider_id=provider_id, + subscription=subscription.to_entity(), + credentials=encrypter.decrypt(subscription.credentials), + credential_type=credential_type, + ) + except Exception as e: + logger.exception("Error unsubscribing trigger", exc_info=e) + # Clear cache - session.delete(db_provider) + session.delete(subscription) delete_cache_for_subscription( tenant_id=tenant_id, - provider_id=db_provider.provider_id, - subscription_id=db_provider.id, + provider_id=subscription.provider_id, + subscription_id=subscription.id, ) @classmethod @@ -254,16 +281,18 @@ class TriggerProviderService: :return: New token info """ with Session(db.engine) as session: - db_provider = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first() + subscription = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first() - if not db_provider: + if not subscription: raise ValueError(f"Trigger provider subscription {subscription_id} not found") - if db_provider.credential_type != CredentialType.OAUTH2.value: + if subscription.credential_type != CredentialType.OAUTH2.value: raise ValueError("Only OAuth credentials can be refreshed") - provider_id = TriggerProviderID(db_provider.provider_id) - provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id) + provider_id = TriggerProviderID(subscription.provider_id) + provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider( + tenant_id=tenant_id, provider_id=provider_id + ) # Create encrypter encrypter, cache = create_provider_encrypter( tenant_id=tenant_id, @@ -272,11 +301,11 @@ class TriggerProviderService: ) # Decrypt current credentials - current_credentials = encrypter.decrypt(db_provider.credentials) + current_credentials = encrypter.decrypt(subscription.credentials) # Get OAuth client configuration redirect_uri = ( - f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{db_provider.provider_id}/trigger/callback" + f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{subscription.provider_id}/trigger/callback" ) system_credentials = cls.get_oauth_client(tenant_id, provider_id) @@ -284,7 +313,7 @@ class TriggerProviderService: oauth_handler = OAuthHandler() refreshed_credentials = oauth_handler.refresh_credentials( tenant_id=tenant_id, - user_id=db_provider.user_id, + user_id=subscription.user_id, plugin_id=provider_id.plugin_id, provider=provider_id.provider_name, redirect_uri=redirect_uri, @@ -293,8 +322,8 @@ class TriggerProviderService: ) # Update credentials - db_provider.credentials = encrypter.encrypt(dict(refreshed_credentials.credentials)) - db_provider.expires_at = refreshed_credentials.expires_at + subscription.credentials = encrypter.encrypt(dict(refreshed_credentials.credentials)) + subscription.expires_at = refreshed_credentials.expires_at session.commit() # Clear cache @@ -315,7 +344,9 @@ class TriggerProviderService: :param provider_id: Provider identifier :return: OAuth client configuration or None """ - provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id) + provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider( + tenant_id=tenant_id, provider_id=provider_id + ) with Session(db.engine, expire_on_commit=False) as session: tenant_client: TriggerOAuthTenantClient | None = ( session.query(TriggerOAuthTenantClient) @@ -378,7 +409,9 @@ class TriggerProviderService: return {"result": "success"} # Get provider controller to access schema - provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id) + provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider( + tenant_id=tenant_id, provider_id=provider_id + ) with Session(db.engine) as session: # Find existing custom client params @@ -450,7 +483,9 @@ class TriggerProviderService: return {} # Get provider controller to access schema - provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id) + provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider( + tenant_id=tenant_id, provider_id=provider_id + ) # Create encrypter to decrypt and mask values encrypter, _ = create_provider_encrypter( @@ -511,8 +546,8 @@ class TriggerProviderService: subscription = session.query(TriggerSubscription).filter_by(endpoint_id=endpoint_id).first() if not subscription: return None - provider_controller = TriggerManager.get_trigger_provider( - subscription.tenant_id, TriggerProviderID(subscription.provider_id) + provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider( + tenant_id=subscription.tenant_id, provider_id=TriggerProviderID(subscription.provider_id) ) credential_encrypter, _ = create_trigger_provider_encrypter_for_subscription( tenant_id=subscription.tenant_id, diff --git a/api/services/trigger/trigger_service.py b/api/services/trigger/trigger_service.py index c384349515..b4ee9e6e1e 100644 --- a/api/services/trigger/trigger_service.py +++ b/api/services/trigger/trigger_service.py @@ -13,6 +13,7 @@ from core.plugin.utils.http_parser import deserialize_request, serialize_request from core.trigger.entities.entities import EventEntity from core.trigger.provider import PluginTriggerProviderController from core.trigger.trigger_manager import TriggerManager +from core.trigger.utils.encryption import create_trigger_provider_encrypter_for_subscription from core.workflow.enums import NodeType from core.workflow.nodes.trigger_schedule.exc import TenantOwnerNotFoundError from extensions.ext_database import db @@ -150,7 +151,7 @@ class TriggerService: trigger_type=WorkflowRunTriggeredFrom.PLUGIN, plugin_id=subscription.provider_id, endpoint_id=subscription.endpoint_id, - inputs=invoke_response.variables.variables, + inputs=invoke_response.variables, ) # Trigger async workflow @@ -191,8 +192,17 @@ class TriggerService: if not controller: return None + encrypter, _ = create_trigger_provider_encrypter_for_subscription( + tenant_id=subscription.tenant_id, + controller=controller, + subscription=subscription, + ) dispatch_response: TriggerDispatchResponse = controller.dispatch( - user_id=subscription.user_id, request=request, subscription=subscription.to_entity() + user_id=subscription.user_id, + request=request, + subscription=subscription.to_entity(), + credentials=encrypter.decrypt(subscription.credentials), + credential_type=CredentialType.of(subscription.credential_type), ) if dispatch_response.events: diff --git a/api/services/trigger/trigger_subscription_builder_service.py b/api/services/trigger/trigger_subscription_builder_service.py index 65313f6a2d..8e57996812 100644 --- a/api/services/trigger/trigger_subscription_builder_service.py +++ b/api/services/trigger/trigger_subscription_builder_service.py @@ -8,10 +8,12 @@ from typing import Any from flask import Request, Response from core.plugin.entities.plugin_daemon import CredentialType +from core.plugin.entities.request import TriggerDispatchResponse from core.tools.errors import ToolProviderCredentialValidationError from core.trigger.entities.api_entities import SubscriptionBuilderApiEntity from core.trigger.entities.entities import ( RequestLog, + Subscription, SubscriptionBuilder, SubscriptionBuilderUpdater, ) @@ -111,13 +113,14 @@ class TriggerSubscriptionBuilderService: ) else: # automatically create - subscription = TriggerManager.subscribe_trigger( + subscription: Subscription = TriggerManager.subscribe_trigger( tenant_id=tenant_id, user_id=user_id, provider_id=provider_id, endpoint=parse_endpoint_id(subscription_builder.endpoint_id), parameters=subscription_builder.parameters, credentials=subscription_builder.credentials, + credential_type=credential_type, ) TriggerProviderService.add_trigger_subscription( @@ -286,18 +289,20 @@ class TriggerSubscriptionBuilderService: :return: The Flask response object """ # check if validation endpoint exists - subscription_builder = cls.get_subscription_builder(endpoint_id) + subscription_builder: SubscriptionBuilder | None = cls.get_subscription_builder(endpoint_id) if not subscription_builder: return None # response to validation endpoint - controller = TriggerManager.get_trigger_provider( - subscription_builder.tenant_id, TriggerProviderID(subscription_builder.provider_id) + controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider( + tenant_id=subscription_builder.tenant_id, provider_id=TriggerProviderID(subscription_builder.provider_id) ) - response = controller.dispatch( + response: TriggerDispatchResponse = controller.dispatch( user_id=subscription_builder.user_id, request=request, subscription=subscription_builder.to_subscription(), + credentials={}, + credential_type=CredentialType.UNAUTHORIZED, ) # append the request log cls.append_log(endpoint_id, request, response.response)