feat(trigger): enhance trigger subscription handling with credential support

- Added `credentials` and `credential_type` parameters to various methods in `PluginTriggerManager`, `PluginTriggerProviderController`, and `TriggerManager` to support improved credential management for trigger subscriptions.
- Updated the `Subscription` model to include `parameters` for better subscription data handling.
- Refactored related services to accommodate the new credential handling, ensuring consistency across the trigger workflow.
This commit is contained in:
Harry 2025-10-11 21:12:09 +08:00
parent 4f65cc312d
commit 42f75b6602
8 changed files with 155 additions and 46 deletions

View File

@ -158,6 +158,8 @@ class PluginTriggerManager(BasePluginClient):
provider: str,
subscription: Mapping[str, Any],
request: Request,
credentials: Mapping[str, str],
credential_type: CredentialType,
) -> TriggerDispatchResponse:
"""
Dispatch an event to triggers.
@ -173,6 +175,8 @@ class PluginTriggerManager(BasePluginClient):
"data": {
"provider": provider_id.provider_name,
"subscription": subscription,
"credentials": credentials,
"credential_type": credential_type,
"raw_http_request": binascii.hexlify(serialize_request(request)).decode(),
},
},
@ -197,6 +201,7 @@ class PluginTriggerManager(BasePluginClient):
user_id: str,
provider: str,
credentials: Mapping[str, str],
credential_type: CredentialType,
endpoint: str,
parameters: Mapping[str, Any],
) -> TriggerSubscriptionResponse:
@ -213,6 +218,7 @@ class PluginTriggerManager(BasePluginClient):
"data": {
"provider": provider_id.provider_name,
"credentials": credentials,
"credential_type": credential_type,
"endpoint": endpoint,
"parameters": parameters,
},
@ -235,6 +241,7 @@ class PluginTriggerManager(BasePluginClient):
provider: str,
subscription: Subscription,
credentials: Mapping[str, str],
credential_type: CredentialType,
) -> TriggerSubscriptionResponse:
"""
Unsubscribe from a trigger.
@ -250,6 +257,7 @@ class PluginTriggerManager(BasePluginClient):
"provider": provider_id.provider_name,
"subscription": subscription.model_dump(),
"credentials": credentials,
"credential_type": credential_type,
},
},
headers={
@ -270,6 +278,7 @@ class PluginTriggerManager(BasePluginClient):
provider: str,
subscription: Subscription,
credentials: Mapping[str, str],
credential_type: CredentialType,
) -> TriggerSubscriptionResponse:
"""
Refresh a trigger subscription.
@ -285,6 +294,7 @@ class PluginTriggerManager(BasePluginClient):
"provider": provider_id.provider_name,
"subscription": subscription.model_dump(),
"credentials": credentials,
"credential_type": credential_type,
},
},
headers={

View File

@ -173,6 +173,7 @@ class Subscription(BaseModel):
)
endpoint: str = Field(..., description="The webhook endpoint URL allocated by Dify for receiving events")
parameters: Mapping[str, Any] = Field(default={}, description="The parameters of the subscription constructor")
properties: Mapping[str, Any] = Field(
..., description="Subscription data containing all properties and provider-specific information"
)

View File

@ -13,6 +13,7 @@ from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.entities.request import (
TriggerDispatchResponse,
TriggerInvokeEventResponse,
TriggerSubscriptionResponse,
)
from core.plugin.impl.trigger import PluginTriggerManager
from core.trigger.entities.api_entities import TriggerApiEntity, TriggerProviderApiEntity
@ -249,13 +250,22 @@ class PluginTriggerProviderController:
else []
)
def dispatch(self, user_id: str, request: Request, subscription: Subscription) -> TriggerDispatchResponse:
def dispatch(
self,
user_id: str,
request: Request,
subscription: Subscription,
credentials: Mapping[str, str],
credential_type: CredentialType,
) -> TriggerDispatchResponse:
"""
Dispatch a trigger through plugin runtime
:param user_id: User ID
:param request: Flask request object
:param subscription: Subscription
:param credentials: Provider credentials
:param credential_type: Credential type
:return: Dispatch response with triggers and raw HTTP response
"""
manager = PluginTriggerManager()
@ -267,6 +277,8 @@ class PluginTriggerProviderController:
provider=str(provider_id),
subscription=subscription.model_dump(),
request=request,
credentials=credentials,
credential_type=credential_type,
)
return response
@ -306,7 +318,12 @@ class PluginTriggerProviderController:
)
def subscribe_trigger(
self, user_id: str, endpoint: str, parameters: Mapping[str, Any], credentials: Mapping[str, str]
self,
user_id: str,
endpoint: str,
parameters: Mapping[str, Any],
credentials: Mapping[str, str],
credential_type: CredentialType,
) -> Subscription:
"""
Subscribe to a trigger through plugin runtime
@ -315,24 +332,26 @@ class PluginTriggerProviderController:
:param endpoint: Subscription endpoint
:param subscription_params: Subscription parameters
:param credentials: Provider credentials
:param credential_type: Credential type
:return: Subscription result
"""
manager = PluginTriggerManager()
provider_id = self.get_provider_id()
provider_id: TriggerProviderID = self.get_provider_id()
response = manager.subscribe(
response: TriggerSubscriptionResponse = manager.subscribe(
tenant_id=self.tenant_id,
user_id=user_id,
provider=str(provider_id),
credentials=credentials,
endpoint=endpoint,
parameters=parameters,
credentials=credentials,
credential_type=credential_type,
)
return Subscription.model_validate(response.subscription)
def unsubscribe_trigger(
self, user_id: str, subscription: Subscription, credentials: Mapping[str, str]
self, user_id: str, subscription: Subscription, credentials: Mapping[str, str], credential_type: CredentialType
) -> Unsubscription:
"""
Unsubscribe from a trigger through plugin runtime
@ -340,22 +359,26 @@ class PluginTriggerProviderController:
:param user_id: User ID
:param subscription: Subscription metadata
:param credentials: Provider credentials
:param credential_type: Credential type
:return: Unsubscription result
"""
manager = PluginTriggerManager()
provider_id = self.get_provider_id()
provider_id: TriggerProviderID = self.get_provider_id()
response = manager.unsubscribe(
response: TriggerSubscriptionResponse = manager.unsubscribe(
tenant_id=self.tenant_id,
user_id=user_id,
provider=str(provider_id),
subscription=subscription,
credentials=credentials,
credential_type=credential_type,
)
return Unsubscription.model_validate(response.subscription)
def refresh_trigger(self, subscription: Subscription, credentials: Mapping[str, str]) -> Subscription:
def refresh_trigger(
self, subscription: Subscription, credentials: Mapping[str, str], credential_type: CredentialType
) -> Subscription:
"""
Refresh a trigger subscription through plugin runtime
@ -364,14 +387,15 @@ class PluginTriggerProviderController:
:return: Refreshed subscription result
"""
manager = PluginTriggerManager()
provider_id = self.get_provider_id()
provider_id: TriggerProviderID = self.get_provider_id()
response = manager.refresh(
response: TriggerSubscriptionResponse = manager.refresh(
tenant_id=self.tenant_id,
user_id="system", # System refresh
provider=str(provider_id),
subscription=subscription,
credentials=credentials,
credential_type=credential_type,
)
return Subscription.model_validate(response.subscription)

View File

@ -180,6 +180,7 @@ class TriggerManager:
endpoint: str,
parameters: Mapping[str, Any],
credentials: Mapping[str, str],
credential_type: CredentialType,
) -> Subscription:
"""
Subscribe to a trigger (e.g., register webhook)
@ -190,11 +191,18 @@ class TriggerManager:
:param endpoint: Subscription endpoint
:param parameters: Subscription parameters
:param credentials: Provider credentials
:param credential_type: Credential type
:return: Subscription result
"""
provider = cls.get_trigger_provider(tenant_id, provider_id)
provider: PluginTriggerProviderController = cls.get_trigger_provider(
tenant_id=tenant_id, provider_id=provider_id
)
return provider.subscribe_trigger(
user_id=user_id, endpoint=endpoint, parameters=parameters, credentials=credentials
user_id=user_id,
endpoint=endpoint,
parameters=parameters,
credentials=credentials,
credential_type=credential_type,
)
@classmethod
@ -205,6 +213,7 @@ class TriggerManager:
provider_id: TriggerProviderID,
subscription: Subscription,
credentials: Mapping[str, str],
credential_type: CredentialType,
) -> Unsubscription:
"""
Unsubscribe from a trigger
@ -214,10 +223,18 @@ class TriggerManager:
:param provider_id: Provider ID
:param subscription: Subscription metadata from subscribe operation
:param credentials: Provider credentials
:param credential_type: Credential type
:return: Unsubscription result
"""
provider = cls.get_trigger_provider(tenant_id, provider_id)
return provider.unsubscribe_trigger(user_id=user_id, subscription=subscription, credentials=credentials)
provider: PluginTriggerProviderController = cls.get_trigger_provider(
tenant_id=tenant_id, provider_id=provider_id
)
return provider.unsubscribe_trigger(
user_id=user_id,
subscription=subscription,
credentials=credentials,
credential_type=credential_type,
)
@classmethod
def refresh_trigger(
@ -226,6 +243,7 @@ class TriggerManager:
provider_id: TriggerProviderID,
subscription: Subscription,
credentials: Mapping[str, str],
credential_type: CredentialType,
) -> Subscription:
"""
Refresh a trigger subscription
@ -234,9 +252,14 @@ class TriggerManager:
:param provider_id: Provider ID
:param subscription: Subscription metadata from subscribe operation
:param credentials: Provider credentials
:param credential_type: Credential type
:return: Refreshed subscription result
"""
return cls.get_trigger_provider(tenant_id, provider_id).refresh_trigger(subscription, credentials)
# TODO you should update the subscription using the return value of the refresh_trigger
return cls.get_trigger_provider(tenant_id=tenant_id, provider_id=provider_id).refresh_trigger(
subscription=subscription, credentials=credentials, credential_type=credential_type
)
# Export

View File

@ -71,6 +71,7 @@ class TriggerSubscription(Base):
return Subscription(
expires_at=self.expires_at,
endpoint=parse_endpoint_id(self.endpoint_id),
parameters=self.parameters,
properties=self.properties,
)

View File

@ -18,6 +18,7 @@ from core.trigger.entities.api_entities import (
TriggerProviderApiEntity,
TriggerProviderSubscriptionApiEntity,
)
from core.trigger.provider import PluginTriggerProviderController
from core.trigger.trigger_manager import TriggerManager
from core.trigger.utils.encryption import (
create_trigger_provider_encrypter_for_properties,
@ -166,7 +167,7 @@ class TriggerProviderService:
)
# Create provider record
db_provider = TriggerSubscription(
subscription = TriggerSubscription(
id=subscription_id or str(uuid.uuid4()),
tenant_id=tenant_id,
user_id=user_id,
@ -181,10 +182,10 @@ class TriggerProviderService:
expires_at=expires_at,
)
session.add(db_provider)
session.add(subscription)
session.commit()
return {"result": "success", "id": str(db_provider.id)}
return {"result": "success", "id": str(subscription.id)}
except Exception as e:
logger.exception("Failed to add trigger provider")
@ -228,16 +229,42 @@ class TriggerProviderService:
:param subscription_id: Subscription instance ID
:return: Success response
"""
db_provider = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
if not db_provider:
subscription: TriggerSubscription | None = (
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
)
if not subscription:
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
credential_type: CredentialType = CredentialType.of(subscription.credential_type)
is_auto_created: bool = credential_type in [CredentialType.OAUTH2, CredentialType.API_KEY]
if is_auto_created:
provider_id = TriggerProviderID(subscription.provider_id)
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
tenant_id=tenant_id, provider_id=provider_id
)
encrypter, _ = create_trigger_provider_encrypter_for_subscription(
tenant_id=tenant_id,
controller=provider_controller,
subscription=subscription,
)
try:
TriggerManager.unsubscribe_trigger(
tenant_id=tenant_id,
user_id=subscription.user_id,
provider_id=provider_id,
subscription=subscription.to_entity(),
credentials=encrypter.decrypt(subscription.credentials),
credential_type=credential_type,
)
except Exception as e:
logger.exception("Error unsubscribing trigger", exc_info=e)
# Clear cache
session.delete(db_provider)
session.delete(subscription)
delete_cache_for_subscription(
tenant_id=tenant_id,
provider_id=db_provider.provider_id,
subscription_id=db_provider.id,
provider_id=subscription.provider_id,
subscription_id=subscription.id,
)
@classmethod
@ -254,16 +281,18 @@ class TriggerProviderService:
:return: New token info
"""
with Session(db.engine) as session:
db_provider = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
subscription = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
if not db_provider:
if not subscription:
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
if db_provider.credential_type != CredentialType.OAUTH2.value:
if subscription.credential_type != CredentialType.OAUTH2.value:
raise ValueError("Only OAuth credentials can be refreshed")
provider_id = TriggerProviderID(db_provider.provider_id)
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
provider_id = TriggerProviderID(subscription.provider_id)
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
tenant_id=tenant_id, provider_id=provider_id
)
# Create encrypter
encrypter, cache = create_provider_encrypter(
tenant_id=tenant_id,
@ -272,11 +301,11 @@ class TriggerProviderService:
)
# Decrypt current credentials
current_credentials = encrypter.decrypt(db_provider.credentials)
current_credentials = encrypter.decrypt(subscription.credentials)
# Get OAuth client configuration
redirect_uri = (
f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{db_provider.provider_id}/trigger/callback"
f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{subscription.provider_id}/trigger/callback"
)
system_credentials = cls.get_oauth_client(tenant_id, provider_id)
@ -284,7 +313,7 @@ class TriggerProviderService:
oauth_handler = OAuthHandler()
refreshed_credentials = oauth_handler.refresh_credentials(
tenant_id=tenant_id,
user_id=db_provider.user_id,
user_id=subscription.user_id,
plugin_id=provider_id.plugin_id,
provider=provider_id.provider_name,
redirect_uri=redirect_uri,
@ -293,8 +322,8 @@ class TriggerProviderService:
)
# Update credentials
db_provider.credentials = encrypter.encrypt(dict(refreshed_credentials.credentials))
db_provider.expires_at = refreshed_credentials.expires_at
subscription.credentials = encrypter.encrypt(dict(refreshed_credentials.credentials))
subscription.expires_at = refreshed_credentials.expires_at
session.commit()
# Clear cache
@ -315,7 +344,9 @@ class TriggerProviderService:
:param provider_id: Provider identifier
:return: OAuth client configuration or None
"""
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
tenant_id=tenant_id, provider_id=provider_id
)
with Session(db.engine, expire_on_commit=False) as session:
tenant_client: TriggerOAuthTenantClient | None = (
session.query(TriggerOAuthTenantClient)
@ -378,7 +409,9 @@ class TriggerProviderService:
return {"result": "success"}
# Get provider controller to access schema
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
tenant_id=tenant_id, provider_id=provider_id
)
with Session(db.engine) as session:
# Find existing custom client params
@ -450,7 +483,9 @@ class TriggerProviderService:
return {}
# Get provider controller to access schema
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
tenant_id=tenant_id, provider_id=provider_id
)
# Create encrypter to decrypt and mask values
encrypter, _ = create_provider_encrypter(
@ -511,8 +546,8 @@ class TriggerProviderService:
subscription = session.query(TriggerSubscription).filter_by(endpoint_id=endpoint_id).first()
if not subscription:
return None
provider_controller = TriggerManager.get_trigger_provider(
subscription.tenant_id, TriggerProviderID(subscription.provider_id)
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
tenant_id=subscription.tenant_id, provider_id=TriggerProviderID(subscription.provider_id)
)
credential_encrypter, _ = create_trigger_provider_encrypter_for_subscription(
tenant_id=subscription.tenant_id,

View File

@ -13,6 +13,7 @@ from core.plugin.utils.http_parser import deserialize_request, serialize_request
from core.trigger.entities.entities import EventEntity
from core.trigger.provider import PluginTriggerProviderController
from core.trigger.trigger_manager import TriggerManager
from core.trigger.utils.encryption import create_trigger_provider_encrypter_for_subscription
from core.workflow.enums import NodeType
from core.workflow.nodes.trigger_schedule.exc import TenantOwnerNotFoundError
from extensions.ext_database import db
@ -150,7 +151,7 @@ class TriggerService:
trigger_type=WorkflowRunTriggeredFrom.PLUGIN,
plugin_id=subscription.provider_id,
endpoint_id=subscription.endpoint_id,
inputs=invoke_response.variables.variables,
inputs=invoke_response.variables,
)
# Trigger async workflow
@ -191,8 +192,17 @@ class TriggerService:
if not controller:
return None
encrypter, _ = create_trigger_provider_encrypter_for_subscription(
tenant_id=subscription.tenant_id,
controller=controller,
subscription=subscription,
)
dispatch_response: TriggerDispatchResponse = controller.dispatch(
user_id=subscription.user_id, request=request, subscription=subscription.to_entity()
user_id=subscription.user_id,
request=request,
subscription=subscription.to_entity(),
credentials=encrypter.decrypt(subscription.credentials),
credential_type=CredentialType.of(subscription.credential_type),
)
if dispatch_response.events:

View File

@ -8,10 +8,12 @@ from typing import Any
from flask import Request, Response
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.entities.request import TriggerDispatchResponse
from core.tools.errors import ToolProviderCredentialValidationError
from core.trigger.entities.api_entities import SubscriptionBuilderApiEntity
from core.trigger.entities.entities import (
RequestLog,
Subscription,
SubscriptionBuilder,
SubscriptionBuilderUpdater,
)
@ -111,13 +113,14 @@ class TriggerSubscriptionBuilderService:
)
else:
# automatically create
subscription = TriggerManager.subscribe_trigger(
subscription: Subscription = TriggerManager.subscribe_trigger(
tenant_id=tenant_id,
user_id=user_id,
provider_id=provider_id,
endpoint=parse_endpoint_id(subscription_builder.endpoint_id),
parameters=subscription_builder.parameters,
credentials=subscription_builder.credentials,
credential_type=credential_type,
)
TriggerProviderService.add_trigger_subscription(
@ -286,18 +289,20 @@ class TriggerSubscriptionBuilderService:
:return: The Flask response object
"""
# check if validation endpoint exists
subscription_builder = cls.get_subscription_builder(endpoint_id)
subscription_builder: SubscriptionBuilder | None = cls.get_subscription_builder(endpoint_id)
if not subscription_builder:
return None
# response to validation endpoint
controller = TriggerManager.get_trigger_provider(
subscription_builder.tenant_id, TriggerProviderID(subscription_builder.provider_id)
controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
tenant_id=subscription_builder.tenant_id, provider_id=TriggerProviderID(subscription_builder.provider_id)
)
response = controller.dispatch(
response: TriggerDispatchResponse = controller.dispatch(
user_id=subscription_builder.user_id,
request=request,
subscription=subscription_builder.to_subscription(),
credentials={},
credential_type=CredentialType.UNAUTHORIZED,
)
# append the request log
cls.append_log(endpoint_id, request, response.response)