From afd8989150c366e984407b4177338615d04253be Mon Sep 17 00:00:00 2001 From: Harry Date: Tue, 2 Sep 2025 12:06:27 +0800 Subject: [PATCH] feat(trigger): introduce subscription builder and enhance trigger management - Refactor trigger provider classes to improve naming consistency, including renaming classes for subscription management - Implement new TriggerSubscriptionBuilderService for creating and verifying subscription builders - Update API endpoints to support subscription builder creation and verification - Enhance data models to include new attributes for subscription builders - Remove the deprecated TriggerSubscriptionValidationService to streamline the codebase Co-authored-by: Claude --- .../console/workspace/trigger_providers.py | 106 ++++++-- api/controllers/trigger/trigger.py | 4 +- api/core/plugin/entities/plugin_daemon.py | 5 + api/core/plugin/entities/request.py | 3 + api/core/plugin/impl/trigger.py | 17 +- api/core/trigger/entities/api_entities.py | 24 -- api/core/trigger/entities/entities.py | 53 +++- api/core/trigger/provider.py | 34 ++- api/core/trigger/trigger_manager.py | 27 +- api/core/trigger/utils/endpoint.py | 5 + ...ptimize_trigger_provider_endpoint_index.py | 42 +++ api/models/trigger.py | 18 +- .../trigger/trigger_provider_service.py | 179 ++----------- .../trigger_subscription_builder_service.py | 240 ++++++++++++++++++ ...trigger_subscription_validation_service.py | 48 ---- api/services/trigger_service.py | 189 +++----------- .../core/plugin/utils/test_http_parser.py | 26 +- 17 files changed, 544 insertions(+), 476 deletions(-) create mode 100644 api/core/trigger/utils/endpoint.py create mode 100644 api/migrations/versions/2025_09_01_1242-9d83760807c5_optimize_trigger_provider_endpoint_index.py create mode 100644 api/services/trigger/trigger_subscription_builder_service.py delete mode 100644 api/services/trigger/trigger_subscription_validation_service.py diff --git a/api/controllers/console/workspace/trigger_providers.py b/api/controllers/console/workspace/trigger_providers.py index aced8b75f3..781e0bc0f6 100644 --- a/api/controllers/console/workspace/trigger_providers.py +++ b/api/controllers/console/workspace/trigger_providers.py @@ -15,6 +15,7 @@ from libs.login import current_user, login_required from models.account import Account from services.plugin.oauth_service import OAuthProxyService from services.trigger.trigger_provider_service import TriggerProviderService +from services.trigger.trigger_subscription_builder_service import TriggerSubscriptionBuilderService logger = logging.getLogger(__name__) @@ -31,7 +32,7 @@ class TriggerProviderListApi(Resource): return jsonable_encoder(TriggerProviderService.list_trigger_providers(user.current_tenant_id)) -class TriggerSubscriptionListApi(Resource): +class TriggerProviderSubscriptionListApi(Resource): @setup_required @login_required @account_initialization_required @@ -54,7 +55,7 @@ class TriggerSubscriptionListApi(Resource): raise -class TriggerSubscriptionsAddApi(Resource): +class TriggerSubscriptionBuilderCreateApi(Resource): @setup_required @login_required @account_initialization_required @@ -67,31 +68,23 @@ class TriggerSubscriptionsAddApi(Resource): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument("credential_type", type=str, required=True, nullable=False, location="json") - parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") parser.add_argument("name", type=str, required=False, nullable=True, location="json") - parser.add_argument("expires_at", type=int, required=False, nullable=True, location="json", default=-1) + parser.add_argument("credentials", type=dict, required=False, nullable=False, location="json") args = parser.parse_args() try: - # Parse credential type - try: - credential_type = CredentialType(args["credential_type"]) - except ValueError: - raise BadRequest(f"Invalid credential_type. Must be one of: {[t.value for t in CredentialType]}") - - result = TriggerProviderService.add_trigger_provider( + credentials = args.get("credentials", {}) + credential_type = CredentialType.API_KEY if credentials else CredentialType.UNAUTHORIZED + subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder( tenant_id=user.current_tenant_id, user_id=user.id, provider_id=TriggerProviderID(provider), + credentials=credentials, credential_type=credential_type, - credentials=args["credentials"], - name=args.get("name"), - expires_at=args.get("expires_at", -1), + credential_expires_at=-1, + expires_at=-1, ) - - return result - + return jsonable_encoder({"subscription_builder": subscription_builder}) except ValueError as e: raise BadRequest(str(e)) except Exception as e: @@ -99,6 +92,58 @@ class TriggerSubscriptionsAddApi(Resource): raise +class TriggerSubscriptionBuilderVerifyApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, provider, subscription_builder_id): + """Verify a subscription instance for a trigger provider""" + user = current_user + assert isinstance(user, Account) + assert user.current_tenant_id is not None + if not user.is_admin_or_owner: + raise Forbidden() + + try: + TriggerSubscriptionBuilderService.verify_trigger_subscription_builder( + tenant_id=user.current_tenant_id, + user_id=user.id, + provider_id=TriggerProviderID(provider), + subscription_builder_id=subscription_builder_id, + ) + return 200 + except Exception as e: + logger.exception("Error verifying provider credential", exc_info=e) + raise + + +class TriggerSubscriptionBuilderBuildApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, provider, subscription_builder_id): + """Build a subscription instance for a trigger provider""" + user = current_user + assert isinstance(user, Account) + assert user.current_tenant_id is not None + if not user.is_admin_or_owner: + raise Forbidden() + + try: + TriggerSubscriptionBuilderService.build_trigger_subscription_builder( + tenant_id=user.current_tenant_id, + user_id=user.id, + provider_id=TriggerProviderID(provider), + subscription_builder_id=subscription_builder_id, + ) + return 200 + except ValueError as e: + raise BadRequest(str(e)) + except Exception as e: + logger.exception("Error building provider credential", exc_info=e) + raise + + class TriggerSubscriptionsDeleteApi(Resource): @setup_required @login_required @@ -239,17 +284,17 @@ class TriggerOAuthCallbackApi(Resource): raise Exception("Failed to get OAuth credentials") # Save OAuth credentials to database - TriggerProviderService.add_trigger_provider( + subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder( tenant_id=tenant_id, user_id=user_id, provider_id=provider_id, + credentials=credentials, credential_type=CredentialType.OAUTH2, - credentials=dict(credentials), + credential_expires_at=expires_at, expires_at=expires_at, ) - # Redirect to OAuth callback page - return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") + return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback?subscription_id={subscription_builder.id}") class TriggerOAuthRefreshTokenApi(Resource): @@ -380,11 +425,18 @@ class TriggerOAuthClientManageApi(Resource): # Trigger provider endpoints api.add_resource(TriggerProviderListApi, "/workspaces/current/trigger-providers") +api.add_resource(TriggerProviderSubscriptionListApi, "/workspaces/current/trigger-provider//list") api.add_resource( - TriggerSubscriptionListApi, "/workspaces/current/trigger-provider/subscriptions//list" + TriggerSubscriptionBuilderCreateApi, + "/workspaces/current/trigger-provider/subscriptions//create-builder", ) api.add_resource( - TriggerSubscriptionsAddApi, "/workspaces/current/trigger-provider/subscriptions//add" + TriggerSubscriptionBuilderVerifyApi, + "/workspaces/current/trigger-provider/subscriptions//verify/", +) +api.add_resource( + TriggerSubscriptionBuilderBuildApi, + "/workspaces/current/trigger-provider/subscriptions//build/", ) api.add_resource( TriggerSubscriptionsDeleteApi, @@ -393,13 +445,11 @@ api.add_resource( # OAuth api.add_resource( - TriggerOAuthAuthorizeApi, "/workspaces/current/trigger-provider//oauth/authorize" + TriggerOAuthAuthorizeApi, "/workspaces/current/trigger-provider/subscriptions//oauth/authorize" ) api.add_resource(TriggerOAuthCallbackApi, "/oauth/plugin//trigger/callback") api.add_resource( TriggerOAuthRefreshTokenApi, "/workspaces/current/trigger-provider/subscriptions//oauth/refresh", ) -api.add_resource( - TriggerOAuthClientManageApi, "/workspaces/current/trigger-provider//oauth/client" -) +api.add_resource(TriggerOAuthClientManageApi, "/workspaces/current/trigger-provider//oauth/client") diff --git a/api/controllers/trigger/trigger.py b/api/controllers/trigger/trigger.py index 12e4ef9523..9597e72b14 100644 --- a/api/controllers/trigger/trigger.py +++ b/api/controllers/trigger/trigger.py @@ -5,7 +5,7 @@ from flask import jsonify, request from werkzeug.exceptions import NotFound from controllers.trigger import bp -from services.trigger.trigger_subscription_validation_service import TriggerSubscriptionValidationService +from services.trigger.trigger_subscription_builder_service import TriggerSubscriptionBuilderService from services.trigger_service import TriggerService logger = logging.getLogger(__name__) @@ -29,7 +29,7 @@ def trigger_endpoint(endpoint_id: str): raise NotFound("Invalid endpoint ID") handling_chain = [ TriggerService.process_endpoint, - TriggerSubscriptionValidationService.process_validating_endpoint, + TriggerSubscriptionBuilderService.process_builder_validation_endpoint, ] try: for handler in handling_chain: diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index 87f441edd9..d5b5aa4045 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -210,12 +210,15 @@ class PluginTriggerProviderEntity(BaseModel): class CredentialType(enum.StrEnum): API_KEY = "api-key" OAUTH2 = "oauth2" + UNAUTHORIZED = "unauthorized" def get_name(self): if self == CredentialType.API_KEY: return "API KEY" elif self == CredentialType.OAUTH2: return "AUTH" + elif self == CredentialType.UNAUTHORIZED: + return "UNAUTHORIZED" else: return self.value.replace("-", " ").upper() @@ -236,5 +239,7 @@ class CredentialType(enum.StrEnum): return cls.API_KEY elif type_name == "oauth2": return cls.OAUTH2 + elif type_name == "unauthorized": + return cls.UNAUTHORIZED else: raise ValueError(f"Invalid credential type: {credential_type}") diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index 3a573ef472..da0f2b8419 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -248,14 +248,17 @@ class PluginTriggerDispatchResponse(BaseModel): triggers: list[str] raw_http_response: str + class TriggerSubscriptionResponse(BaseModel): subscription: dict[str, Any] + class TriggerValidateProviderCredentialsResponse(BaseModel): valid: bool message: str error: str + class TriggerDispatchResponse: triggers: list[str] response: Response diff --git a/api/core/plugin/impl/trigger.py b/api/core/plugin/impl/trigger.py index d56db80588..ee643c9d76 100644 --- a/api/core/plugin/impl/trigger.py +++ b/api/core/plugin/impl/trigger.py @@ -1,4 +1,5 @@ import binascii +from collections.abc import Mapping from typing import Any from flask import Request @@ -84,10 +85,10 @@ class PluginTriggerManager(BasePluginClient): user_id: str, provider: str, trigger: str, - credentials: dict[str, Any], + credentials: Mapping[str, str], credential_type: CredentialType, request: Request, - parameters: dict[str, Any], + parameters: Mapping[str, Any], ) -> TriggerInvokeResponse: """ Invoke a trigger with the given parameters. @@ -121,7 +122,7 @@ class PluginTriggerManager(BasePluginClient): raise ValueError("No response received from plugin daemon for invoke trigger") def validate_provider_credentials( - self, tenant_id: str, user_id: str, provider: str, credentials: dict[str, Any] + self, tenant_id: str, user_id: str, provider: str, credentials: Mapping[str, str] ) -> TriggerValidateProviderCredentialsResponse: """ Validate the credentials of the trigger provider. @@ -155,7 +156,7 @@ class PluginTriggerManager(BasePluginClient): tenant_id: str, user_id: str, provider: str, - subscription: dict[str, Any], + subscription: Mapping[str, Any], request: Request, ) -> TriggerDispatchResponse: """ @@ -194,9 +195,9 @@ class PluginTriggerManager(BasePluginClient): tenant_id: str, user_id: str, provider: str, - credentials: dict[str, Any], + credentials: Mapping[str, str], endpoint: str, - parameters: dict[str, Any], + parameters: Mapping[str, Any], ) -> TriggerSubscriptionResponse: """ Subscribe to a trigger. @@ -233,7 +234,7 @@ class PluginTriggerManager(BasePluginClient): user_id: str, provider: str, subscription: Subscription, - credentials: dict[str, Any], + credentials: Mapping[str, str], ) -> TriggerSubscriptionResponse: """ Unsubscribe from a trigger. @@ -269,7 +270,7 @@ class PluginTriggerManager(BasePluginClient): user_id: str, provider: str, subscription: Subscription, - credentials: dict[str, Any], + credentials: Mapping[str, str], ) -> TriggerSubscriptionResponse: """ Refresh a trigger subscription. diff --git a/api/core/trigger/entities/api_entities.py b/api/core/trigger/entities/api_entities.py index 20687cfeb8..5f551ff034 100644 --- a/api/core/trigger/entities/api_entities.py +++ b/api/core/trigger/entities/api_entities.py @@ -7,7 +7,6 @@ from core.entities.provider_entities import ProviderConfig from core.plugin.entities.plugin_daemon import CredentialType from core.trigger.entities.entities import ( OAuthSchema, - Subscription, SubscriptionSchema, TriggerDescription, TriggerEntity, @@ -40,27 +39,4 @@ class TriggerApiEntity(BaseModel): parameters: list[TriggerParameter] = Field(description="The parameters of the trigger") output_schema: Optional[Mapping[str, Any]] = Field(description="The output schema of the trigger") -class SubscriptionValidation(BaseModel): - id: str - name: str - tenant_id: str - user_id: str - provider_id: str - endpoint: str - parameters: dict - properties: dict - credentials: dict - credential_type: str - credential_expires_at: int - expires_at: int - - def to_subscription(self) -> Subscription: - return Subscription( - expires_at=self.expires_at, - endpoint=self.endpoint, - parameters=self.parameters, - properties=self.properties, - ) - - __all__ = ["TriggerApiEntity", "TriggerProviderApiEntity", "TriggerProviderSubscriptionApiEntity"] diff --git a/api/core/trigger/entities/entities.py b/api/core/trigger/entities/entities.py index 99163e8c6f..359deec71c 100644 --- a/api/core/trigger/entities/entities.py +++ b/api/core/trigger/entities/entities.py @@ -115,6 +115,18 @@ class SubscriptionSchema(BaseModel): description="The configuration schema stored in the subscription entity", ) + def get_default_parameters(self) -> Mapping[str, Any]: + """Get the default parameters from the parameters schema""" + if not self.parameters_schema: + return {} + return {param.name: param.default for param in self.parameters_schema if param.default} + + def get_default_properties(self) -> Mapping[str, Any]: + """Get the default properties from the properties schema""" + if not self.properties_schema: + return {} + return {prop.name: prop.default for prop in self.properties_schema if prop.default} + class TriggerProviderEntity(BaseModel): """ @@ -148,13 +160,7 @@ class Subscription(BaseModel): ) endpoint: str = Field(..., description="The webhook endpoint URL allocated by Dify for receiving events") - - parameters: dict[str, Any] | None = Field( - default=None, - description="""The parameters of the subscription, this is the creation parameters. - Only available when creating a new subscription by credentials(auto subscription), not manual subscription""", - ) - properties: dict[str, Any] = Field( + properties: Mapping[str, Any] = Field( ..., description="Subscription data containing all properties and provider-specific information" ) @@ -177,10 +183,43 @@ class Unsubscription(BaseModel): ) +class RequestLog(BaseModel): + id: str + endpoint: str + request: dict + response: dict + created_at: str + + +class SubscriptionBuilder(BaseModel): + id: str + name: str | None = None + tenant_id: str + user_id: str + provider_id: str + endpoint_id: str + parameters: Mapping[str, Any] + properties: Mapping[str, Any] + credentials: Mapping[str, str] + credential_type: str | None = None + credential_expires_at: int | None = None + expires_at: int + + def to_subscription(self) -> Subscription: + return Subscription( + expires_at=self.expires_at, + endpoint=self.endpoint_id, + parameters=self.parameters, + properties=self.properties, + ) + + # Export all entities __all__ = [ "OAuthSchema", + "RequestLog", "Subscription", + "SubscriptionBuilder", "TriggerDescription", "TriggerEntity", "TriggerIdentity", diff --git a/api/core/trigger/provider.py b/api/core/trigger/provider.py index e03f64da27..6dae530184 100644 --- a/api/core/trigger/provider.py +++ b/api/core/trigger/provider.py @@ -3,7 +3,8 @@ Trigger Provider Controller for managing trigger providers """ import logging -from typing import Optional +from collections.abc import Mapping +from typing import Any, Optional from flask import Request @@ -20,6 +21,7 @@ from core.trigger.entities.api_entities import TriggerProviderApiEntity from core.trigger.entities.entities import ( ProviderConfig, Subscription, + SubscriptionSchema, TriggerEntity, TriggerProviderEntity, TriggerProviderIdentity, @@ -91,18 +93,17 @@ class PluginTriggerProviderController: return trigger return None - def get_subscription_schema(self) -> list[ProviderConfig]: + def get_subscription_schema(self) -> SubscriptionSchema: """ Get subscription schema for this provider :return: List of subscription config schemas """ - # Return the parameters schema from the subscription schema - if self.entity.subscription_schema and self.entity.subscription_schema.parameters_schema: - return self.entity.subscription_schema.parameters_schema - return [] + return self.entity.subscription_schema - def validate_credentials(self, credentials: dict) -> TriggerValidateProviderCredentialsResponse: + def validate_credentials( + self, user_id: str, credentials: Mapping[str, str] + ) -> TriggerValidateProviderCredentialsResponse: """ Validate credentials against schema @@ -123,7 +124,7 @@ class PluginTriggerProviderController: provider_id = self.get_provider_id() return manager.validate_provider_credentials( tenant_id=self.tenant_id, - user_id="system", # System validation + user_id=user_id, provider=str(provider_id), credentials=credentials, ) @@ -153,6 +154,7 @@ class PluginTriggerProviderController: return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else [] if credential_type == CredentialType.API_KEY: return self.entity.credentials_schema.copy() if self.entity.credentials_schema else [] + raise ValueError(f"Invalid credential type: {credential_type}") def get_credential_schema_config(self, credential_type: CredentialType | str) -> list[BasicProviderConfig]: """ @@ -168,7 +170,7 @@ class PluginTriggerProviderController: """ return self.entity.oauth_schema.client_schema.copy() if self.entity.oauth_schema else [] - def dispatch(self,user_id: str, request: Request, subscription: Subscription) -> TriggerDispatchResponse: + def dispatch(self, user_id: str, request: Request, subscription: Subscription) -> TriggerDispatchResponse: """ Dispatch a trigger through plugin runtime @@ -193,8 +195,8 @@ class PluginTriggerProviderController: self, user_id: str, trigger_name: str, - parameters: dict, - credentials: dict, + parameters: Mapping[str, Any], + credentials: Mapping[str, str], credential_type: CredentialType, request: Request, ) -> TriggerInvokeResponse: @@ -223,7 +225,9 @@ class PluginTriggerProviderController: parameters=parameters, ) - def subscribe_trigger(self, user_id: str, endpoint: str, parameters: dict, credentials: dict) -> Subscription: + def subscribe_trigger( + self, user_id: str, endpoint: str, parameters: Mapping[str, Any], credentials: Mapping[str, str] + ) -> Subscription: """ Subscribe to a trigger through plugin runtime @@ -247,7 +251,9 @@ class PluginTriggerProviderController: return Subscription.model_validate(response.subscription) - def unsubscribe_trigger(self, user_id: str, subscription: Subscription, credentials: dict) -> Unsubscription: + def unsubscribe_trigger( + self, user_id: str, subscription: Subscription, credentials: Mapping[str, str] + ) -> Unsubscription: """ Unsubscribe from a trigger through plugin runtime @@ -269,7 +275,7 @@ class PluginTriggerProviderController: return Unsubscription.model_validate(response.subscription) - def refresh_trigger(self, subscription: Subscription, credentials: dict) -> Subscription: + def refresh_trigger(self, subscription: Subscription, credentials: Mapping[str, str]) -> Subscription: """ Refresh a trigger subscription through plugin runtime diff --git a/api/core/trigger/trigger_manager.py b/api/core/trigger/trigger_manager.py index 0016675f57..b20b1fdd79 100644 --- a/api/core/trigger/trigger_manager.py +++ b/api/core/trigger/trigger_manager.py @@ -3,7 +3,8 @@ Trigger Manager for loading and managing trigger providers and triggers """ import logging -from typing import Optional +from collections.abc import Mapping +from typing import Any, Optional from flask import Request @@ -12,8 +13,8 @@ from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.entities.request import TriggerInvokeResponse from core.plugin.impl.trigger import PluginTriggerManager from core.trigger.entities.entities import ( - ProviderConfig, Subscription, + SubscriptionSchema, TriggerEntity, Unsubscription, ) @@ -116,19 +117,20 @@ class TriggerManager: @classmethod def validate_trigger_credentials( - cls, tenant_id: str, provider_id: TriggerProviderID, credentials: dict + cls, tenant_id: str, provider_id: TriggerProviderID, user_id: str, credentials: Mapping[str, str] ) -> tuple[bool, str]: """ Validate trigger provider credentials :param tenant_id: Tenant ID :param provider_id: Provider ID + :param user_id: User ID :param credentials: Credentials to validate :return: Tuple of (is_valid, error_message) """ try: provider = cls.get_trigger_provider(tenant_id, provider_id) - validation_result = provider.validate_credentials(credentials) + validation_result = provider.validate_credentials(user_id, credentials) return validation_result.valid, validation_result.message if not validation_result.valid else "" except Exception as e: return False, str(e) @@ -140,8 +142,8 @@ class TriggerManager: user_id: str, provider_id: TriggerProviderID, trigger_name: str, - parameters: dict, - credentials: dict, + parameters: Mapping[str, Any], + credentials: Mapping[str, str], credential_type: CredentialType, request: Request, ) -> TriggerInvokeResponse: @@ -171,8 +173,8 @@ class TriggerManager: user_id: str, provider_id: TriggerProviderID, endpoint: str, - parameters: dict, - credentials: dict, + parameters: Mapping[str, Any], + credentials: Mapping[str, str], ) -> Subscription: """ Subscribe to a trigger (e.g., register webhook) @@ -197,7 +199,7 @@ class TriggerManager: user_id: str, provider_id: TriggerProviderID, subscription: Subscription, - credentials: dict, + credentials: Mapping[str, str], ) -> Unsubscription: """ Unsubscribe from a trigger @@ -213,7 +215,7 @@ class TriggerManager: return provider.unsubscribe_trigger(user_id=user_id, subscription=subscription, credentials=credentials) @classmethod - def get_provider_subscription_schema(cls, tenant_id: str, provider_id: TriggerProviderID) -> list[ProviderConfig]: + def get_provider_subscription_schema(cls, tenant_id: str, provider_id: TriggerProviderID) -> SubscriptionSchema: """ Get provider subscription schema @@ -228,9 +230,8 @@ class TriggerManager: cls, tenant_id: str, provider_id: TriggerProviderID, - trigger_name: str, subscription: Subscription, - credentials: dict, + credentials: Mapping[str, str], ) -> Subscription: """ Refresh a trigger subscription @@ -242,7 +243,7 @@ class TriggerManager: :param credentials: Provider credentials :return: Refreshed subscription result """ - return cls.get_trigger_provider(tenant_id, provider_id).refresh_trigger(trigger_name, subscription, credentials) + return cls.get_trigger_provider(tenant_id, provider_id).refresh_trigger(subscription, credentials) # Export diff --git a/api/core/trigger/utils/endpoint.py b/api/core/trigger/utils/endpoint.py new file mode 100644 index 0000000000..242075acac --- /dev/null +++ b/api/core/trigger/utils/endpoint.py @@ -0,0 +1,5 @@ +from configs import dify_config + + +def parse_endpoint_id(endpoint_id: str) -> str: + return f"{dify_config.CONSOLE_API_URL}/console/api/trigger/endpoint/{endpoint_id}" diff --git a/api/migrations/versions/2025_09_01_1242-9d83760807c5_optimize_trigger_provider_endpoint_index.py b/api/migrations/versions/2025_09_01_1242-9d83760807c5_optimize_trigger_provider_endpoint_index.py new file mode 100644 index 0000000000..2b6c4113ce --- /dev/null +++ b/api/migrations/versions/2025_09_01_1242-9d83760807c5_optimize_trigger_provider_endpoint_index.py @@ -0,0 +1,42 @@ +"""Optimize trigger provider endpoint index + +Revision ID: 9d83760807c5 +Revises: 9ee7d347f4c1 +Create Date: 2025-09-01 12:42:00.000000 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '9d83760807c5' +down_revision = '9ee7d347f4c1' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + # Drop the old unique constraint on endpoint + with op.batch_alter_table('trigger_providers', schema=None) as batch_op: + batch_op.drop_constraint('unique_trigger_provider_endpoint', type_='unique') + + # Create a new unique index on endpoint for O(1) lookup + batch_op.create_index('idx_trigger_providers_endpoint', ['endpoint'], unique=True) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + # Drop the new index + with op.batch_alter_table('trigger_providers', schema=None) as batch_op: + batch_op.drop_index('idx_trigger_providers_endpoint') + + # Recreate the old unique constraint + batch_op.create_unique_constraint('unique_trigger_provider_endpoint', ['endpoint']) + + # ### end Alembic commands ### diff --git a/api/models/trigger.py b/api/models/trigger.py index 2d40f6f78b..bd107c541a 100644 --- a/api/models/trigger.py +++ b/api/models/trigger.py @@ -10,6 +10,7 @@ from sqlalchemy.orm import Mapped, mapped_column from core.plugin.entities.plugin_daemon import CredentialType from core.trigger.entities.api_entities import TriggerProviderSubscriptionApiEntity from core.trigger.entities.entities import Subscription +from core.trigger.utils.endpoint import parse_endpoint_id from models.base import Base from models.types import StringUUID @@ -22,10 +23,13 @@ class TriggerSubscription(Base): __tablename__ = "trigger_subscriptions" __table_args__ = ( - sa.PrimaryKeyConstraint("id", name="trigger_subscription_pkey"), - Index("idx_trigger_subscriptions_tenant_provider", "tenant_id", "provider_id"), - UniqueConstraint("tenant_id", "provider_id", "name", name="unique_trigger_subscription"), - UniqueConstraint("endpoint", name="unique_trigger_subscription_endpoint"), + sa.PrimaryKeyConstraint("id", name="trigger_provider_pkey"), + Index("idx_trigger_providers_tenant_provider", "tenant_id", "provider_id"), + # Primary index for O(1) lookup by endpoint + Index("idx_trigger_providers_endpoint", "endpoint_id", unique=True), + # Composite index for tenant-specific queries (optional, kept for compatibility) + Index("idx_trigger_providers_tenant_endpoint", "tenant_id", "endpoint_id"), + UniqueConstraint("tenant_id", "provider_id", "name", name="unique_trigger_provider"), ) id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) @@ -35,7 +39,7 @@ class TriggerSubscription(Base): provider_id: Mapped[str] = mapped_column( String(255), nullable=False, comment="Provider identifier (e.g., plugin_id/provider_name)" ) - endpoint: Mapped[str] = mapped_column(String(255), nullable=False, comment="Subscription endpoint") + endpoint_id: Mapped[str] = mapped_column(String(255), nullable=False, comment="Subscription endpoint") parameters: Mapped[dict] = mapped_column(sa.JSON, nullable=False, comment="Subscription parameters JSON") properties: Mapped[dict] = mapped_column(sa.JSON, nullable=False, comment="Subscription properties JSON") @@ -63,8 +67,7 @@ class TriggerSubscription(Base): def to_entity(self) -> Subscription: return Subscription( expires_at=self.expires_at, - endpoint=self.endpoint, - parameters=self.parameters, + endpoint=parse_endpoint_id(self.endpoint_id), properties=self.properties, ) @@ -77,6 +80,7 @@ class TriggerSubscription(Base): credentials=self.credentials, ) + # system level trigger oauth client params class TriggerOAuthSystemClient(Base): __tablename__ = "trigger_oauth_system_clients" diff --git a/api/services/trigger/trigger_provider_service.py b/api/services/trigger/trigger_provider_service.py index 8f8053b6bd..2b1f75c63b 100644 --- a/api/services/trigger/trigger_provider_service.py +++ b/api/services/trigger/trigger_provider_service.py @@ -1,6 +1,5 @@ import json import logging -import re from collections.abc import Mapping from typing import Any, Optional @@ -16,7 +15,6 @@ from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params from core.trigger.entities.api_entities import ( - SubscriptionValidation, TriggerProviderApiEntity, TriggerProviderSubscriptionApiEntity, ) @@ -36,6 +34,9 @@ logger = logging.getLogger(__name__) class TriggerProviderService: """Service for managing trigger providers and credentials""" + ########################## + # Trigger provider + ########################## __MAX_TRIGGER_PROVIDER_COUNT__ = 10 @classmethod @@ -73,10 +74,14 @@ class TriggerProviderService: cls, tenant_id: str, user_id: str, + name: str, provider_id: TriggerProviderID, + endpoint_id: str, credential_type: CredentialType, - credentials: dict, - name: Optional[str] = None, + parameters: Mapping[str, Any], + properties: Mapping[str, Any], + credentials: Mapping[str, str], + credential_expires_at: int = -1, expires_at: int = -1, ) -> dict: """ @@ -93,7 +98,7 @@ class TriggerProviderService: """ try: provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id) - with Session(db.engine) as session: + with Session(db.engine, autoflush=False) as session: # Use distributed lock to prevent race conditions lock_key = f"trigger_provider_create_lock:{tenant_id}_{provider_id}" with redis_client.lock(lock_key, timeout=20): @@ -110,23 +115,14 @@ class TriggerProviderService: f"reached for {provider_id}" ) - # Generate name if not provided - if not name: - name = cls._generate_provider_name( - session=session, - tenant_id=tenant_id, - provider_id=provider_id, - credential_type=credential_type, - ) - else: - # Check if name already exists - existing = ( - session.query(TriggerSubscription) - .filter_by(tenant_id=tenant_id, provider_id=provider_id, name=name) - .first() - ) - if existing: - raise ValueError(f"Credential name '{name}' already exists for this provider") + # Check if name already exists + existing = ( + session.query(TriggerSubscription) + .filter_by(tenant_id=tenant_id, provider_id=provider_id, name=name) + .first() + ) + if existing: + raise ValueError(f"Credential name '{name}' already exists for this provider") encrypter, _ = create_provider_encrypter( tenant_id=tenant_id, @@ -138,10 +134,14 @@ class TriggerProviderService: db_provider = TriggerSubscription( tenant_id=tenant_id, user_id=user_id, - provider_id=provider_id, - credential_type=credential_type.value, - credentials=encrypter.encrypt(credentials), name=name, + endpoint_id=endpoint_id, + provider_id=provider_id, + parameters=parameters, + properties=properties, + credentials=encrypter.encrypt(dict(credentials)), + credential_type=credential_type.value, + credential_expires_at=credential_expires_at, expires_at=expires_at, ) @@ -154,70 +154,6 @@ class TriggerProviderService: logger.exception("Failed to add trigger provider") raise ValueError(str(e)) - @classmethod - def update_trigger_provider( - cls, - tenant_id: str, - subscription_id: str, - credentials: Optional[dict] = None, - name: Optional[str] = None, - ) -> dict: - """ - Update an existing trigger provider's credentials or name. - - :param tenant_id: Tenant ID - :param subscription_id: Subscription instance ID - :param credentials: New credentials (optional) - :param name: New name (optional) - :return: Success response - """ - with Session(db.engine) as session: - db_provider = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first() - if not db_provider: - raise ValueError(f"Trigger provider subscription {subscription_id} not found") - - try: - provider_controller = TriggerManager.get_trigger_provider( - tenant_id, TriggerProviderID(db_provider.provider_id) - ) - - if credentials: - encrypter, cache = create_trigger_provider_encrypter_for_subscription( - tenant_id=tenant_id, - controller=provider_controller, - subscription=db_provider, - ) - # Handle hidden values - original_credentials = encrypter.decrypt(db_provider.credentials) - new_credentials = { - key: value if value != HIDDEN_VALUE else original_credentials.get(key, UNKNOWN_VALUE) - for key, value in credentials.items() - } - - db_provider.credentials = encrypter.encrypt(new_credentials) - cache.delete() - - # Update name if provided - if name and name != db_provider.name: - # Check if name already exists - existing = ( - session.query(TriggerSubscription) - .filter_by(tenant_id=tenant_id, provider_id=db_provider.provider_id, name=name) - .filter(TriggerSubscription.id != subscription_id) - .first() - ) - if existing: - raise ValueError(f"Credential name '{name}' already exists") - - db_provider.name = name - - session.commit() - return {"result": "success"} - - except Exception as e: - session.rollback() - raise ValueError(str(e)) - @classmethod def delete_trigger_provider(cls, tenant_id: str, subscription_id: str) -> dict: """ @@ -505,59 +441,6 @@ class TriggerProviderService: ) return custom_client is not None - @classmethod - def _generate_provider_name( - cls, - session: Session, - tenant_id: str, - provider_id: TriggerProviderID, - credential_type: CredentialType, - ) -> str: - """ - Generate a unique name for a provider credential instance. - - :param session: Database session - :param tenant_id: Tenant ID - :param provider: Provider identifier - :param credential_type: Credential type - :return: Generated name - """ - try: - db_providers = ( - session.query(TriggerSubscription) - .filter_by( - tenant_id=tenant_id, - provider_id=provider_id, - credential_type=credential_type.value, - ) - .order_by(desc(TriggerSubscription.created_at)) - .all() - ) - - # Get base name - base_name = credential_type.get_name() - - # Find existing numbered names - pattern = rf"^{re.escape(base_name)}\s+(\d+)$" - numbers = [] - - for db_provider in db_providers: - if db_provider.name: - match = re.match(pattern, db_provider.name.strip()) - if match: - numbers.append(int(match.group(1))) - - # Generate next number - if not numbers: - return f"{base_name} 1" - - max_number = max(numbers) - return f"{base_name} {max_number + 1}" - - except Exception as e: - logger.warning("Error generating provider name") - return f"{credential_type.get_name()} 1" - @classmethod def get_subscription_by_endpoint(cls, endpoint_id: str) -> TriggerSubscription | None: """ @@ -566,15 +449,3 @@ class TriggerProviderService: with Session(db.engine, autoflush=False) as session: subscription = session.query(TriggerSubscription).filter_by(endpoint=endpoint_id).first() return subscription - - @classmethod - def get_subscription_validation(cls, endpoint_id: str) -> SubscriptionValidation | None: - """ - Get a trigger subscription by the endpoint ID. - """ - cache_key = f"trigger:subscription:validation:endpoint:{endpoint_id}" - subscription_cache = redis_client.get(cache_key) - if subscription_cache: - return SubscriptionValidation.model_validate(json.loads(subscription_cache)) - - return None \ No newline at end of file diff --git a/api/services/trigger/trigger_subscription_builder_service.py b/api/services/trigger/trigger_subscription_builder_service.py new file mode 100644 index 0000000000..cf7f564d68 --- /dev/null +++ b/api/services/trigger/trigger_subscription_builder_service.py @@ -0,0 +1,240 @@ +import json +import logging +import uuid +from collections.abc import Mapping + +from flask import Request, Response + +from core.plugin.entities.plugin import TriggerProviderID +from core.plugin.entities.plugin_daemon import CredentialType +from core.trigger.entities.entities import ( + RequestLog, + SubscriptionBuilder, +) +from core.trigger.trigger_manager import TriggerManager +from extensions.ext_redis import redis_client +from services.trigger.trigger_provider_service import TriggerProviderService + +logger = logging.getLogger(__name__) + + +class TriggerSubscriptionBuilderService: + """Service for managing trigger providers and credentials""" + + ########################## + # Trigger provider + ########################## + __MAX_TRIGGER_PROVIDER_COUNT__ = 10 + + ########################## + # Validation endpoint + ########################## + __VALIDATION_REQUEST_CACHE_COUNT__ = 10 + __VALIDATION_REQUEST_CACHE_EXPIRE_MS__ = 30 * 60 * 1000 + + @classmethod + def encode_cache_key(cls, subscription_id: str) -> str: + return f"trigger:subscription:validation:{subscription_id}" + + @classmethod + def verify_trigger_subscription_builder( + cls, + tenant_id: str, + user_id: str, + provider_id: TriggerProviderID, + subscription_builder_id: str, + ) -> None: + """Verify a trigger subscription builder""" + provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id) + if not provider_controller: + raise ValueError(f"Provider {provider_id} not found") + + subscription_builder = cls.get_subscription_builder(subscription_builder_id) + if not subscription_builder: + raise ValueError(f"Subscription builder {subscription_builder_id} not found") + + provider_controller.validate_credentials(user_id, subscription_builder.credentials) + + @classmethod + def build_trigger_subscription_builder( + cls, + tenant_id: str, + user_id: str, + provider_id: TriggerProviderID, + subscription_builder_id: str, + ) -> None: + """Build a trigger subscription builder""" + provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id) + if not provider_controller: + raise ValueError(f"Provider {provider_id} not found") + + subscription_builder = cls.get_subscription_builder(subscription_builder_id) + if not subscription_builder: + raise ValueError(f"Subscription builder {subscription_builder_id} not found") + + if subscription_builder.name is None: + raise ValueError("Subscription builder name is required") + + credential_type = CredentialType.of(subscription_builder.credential_type or CredentialType.UNAUTHORIZED.value) + if credential_type == CredentialType.UNAUTHORIZED: + # manually create + TriggerProviderService.add_trigger_provider( + tenant_id=tenant_id, + user_id=user_id, + name=subscription_builder.name, + provider_id=provider_id, + endpoint_id=subscription_builder.endpoint_id, + parameters=subscription_builder.parameters, + properties=subscription_builder.properties, + credential_expires_at=subscription_builder.credential_expires_at or -1, + expires_at=subscription_builder.expires_at, + credentials=subscription_builder.credentials, + credential_type=credential_type, + ) + else: + # automatically create + subscription = TriggerManager.subscribe_trigger( + tenant_id=tenant_id, + user_id=user_id, + provider_id=provider_id, + endpoint=subscription_builder.endpoint_id, + parameters=subscription_builder.parameters, + credentials=subscription_builder.credentials, + ) + + TriggerProviderService.add_trigger_provider( + tenant_id=tenant_id, + user_id=user_id, + name=subscription_builder.name, + provider_id=provider_id, + endpoint_id=subscription_builder.endpoint_id, + parameters=subscription_builder.parameters, + properties=subscription.properties, + credentials=subscription_builder.credentials, + credential_type=credential_type, + credential_expires_at=subscription_builder.credential_expires_at or -1, + expires_at=subscription_builder.expires_at, + ) + + cls.delete_trigger_subscription_builder(subscription_builder_id) + + @classmethod + def create_trigger_subscription_builder( + cls, + tenant_id: str, + user_id: str, + provider_id: TriggerProviderID, + credentials: Mapping[str, str], + credential_type: CredentialType, + credential_expires_at: int, + expires_at: int, + ) -> SubscriptionBuilder: + """ + Add a new trigger subscription validation. + """ + provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id) + if not provider_controller: + raise ValueError(f"Provider {provider_id} not found") + + subscription_schema = provider_controller.get_subscription_schema() + subscription_id = str(uuid.uuid4()) + subscription_builder = SubscriptionBuilder( + id=subscription_id, + name="", + endpoint_id=subscription_id, + tenant_id=tenant_id, + user_id=user_id, + provider_id=str(provider_id), + parameters=subscription_schema.get_default_parameters(), + properties=subscription_schema.get_default_properties(), + credentials=credentials, + credential_type=credential_type, + credential_expires_at=credential_expires_at, + expires_at=expires_at, + ) + cache_key = cls.encode_cache_key(subscription_id) + redis_client.setex( + cache_key, cls.__VALIDATION_REQUEST_CACHE_EXPIRE_MS__, subscription_builder.model_dump_json() + ) + return subscription_builder + + @classmethod + def update_trigger_subscription_builder( + cls, + tenant_id: str, + subscription_builder: SubscriptionBuilder, + ) -> SubscriptionBuilder: + """ + Update a trigger subscription validation. + """ + subscription_id = subscription_builder.id + cache_key = cls.encode_cache_key(subscription_id) + subscription_builder_cache = cls.get_subscription_builder(subscription_id) + if not subscription_builder_cache or subscription_builder_cache.tenant_id != tenant_id: + raise ValueError(f"Subscription {subscription_id} not found") + + redis_client.setex( + cache_key, cls.__VALIDATION_REQUEST_CACHE_EXPIRE_MS__, subscription_builder.model_dump_json() + ) + return subscription_builder + + @classmethod + def delete_trigger_subscription_builder(cls, subscription_id: str) -> None: + """ + Delete a trigger subscription validation. + """ + cache_key = cls.encode_cache_key(subscription_id) + redis_client.delete(cache_key) + + @classmethod + def get_subscription_builder(cls, endpoint_id: str) -> SubscriptionBuilder | None: + """ + Get a trigger subscription by the endpoint ID. + """ + cache_key = cls.encode_cache_key(endpoint_id) + subscription_cache = redis_client.get(cache_key) + if subscription_cache: + return SubscriptionBuilder.model_validate(json.loads(subscription_cache)) + + return None + + @classmethod + def append_request_log(cls, endpoint_id: str, request: Request, response: Response) -> None: + """ + Append the validation request log to Redis. + """ + pass + + @classmethod + def list_request_logs(cls, endpoint_id: str) -> list[RequestLog]: + """ + List the request logs for a validation endpoint. + """ + return [] + + @classmethod + def process_builder_validation_endpoint(cls, endpoint_id: str, request: Request) -> Response | None: + """ + Process a temporary endpoint request. + + :param endpoint_id: The endpoint identifier + :param request: The Flask request object + :return: The Flask response object + """ + # check if validation endpoint exists + subscription_builder = cls.get_subscription_builder(endpoint_id) + if not subscription_builder: + return None + + # response to validation endpoint + controller = TriggerManager.get_trigger_provider( + subscription_builder.tenant_id, TriggerProviderID(subscription_builder.provider_id) + ) + response = controller.dispatch( + user_id=subscription_builder.user_id, + request=request, + subscription=subscription_builder.to_subscription(), + ) + # append the request log + cls.append_request_log(endpoint_id, request, response.response) + return response.response diff --git a/api/services/trigger/trigger_subscription_validation_service.py b/api/services/trigger/trigger_subscription_validation_service.py deleted file mode 100644 index 0a2b8ce1e3..0000000000 --- a/api/services/trigger/trigger_subscription_validation_service.py +++ /dev/null @@ -1,48 +0,0 @@ -import logging - -from flask import Request, Response - -from core.plugin.entities.plugin import TriggerProviderID -from core.trigger.trigger_manager import TriggerManager -from services.trigger.trigger_provider_service import TriggerProviderService - -logger = logging.getLogger(__name__) - - -class TriggerSubscriptionValidationService: - __VALIDATION_REQUEST_CACHE_COUNT__ = 10 - __VALIDATION_REQUEST_CACHE_EXPIRE_MS__ = 5 * 60 * 1000 - - @classmethod - def append_validation_request_log(cls, endpoint_id: str, request: Request, response: Response) -> None: - """ - Append the validation request log to Redis. - """ - - - @classmethod - def process_validating_endpoint(cls, endpoint_id: str, request: Request) -> Response | None: - """ - Process a temporary endpoint request. - - :param endpoint_id: The endpoint identifier - :param request: The Flask request object - :return: The Flask response object - """ - # check if validation endpoint exists - subscription_validation = TriggerProviderService.get_subscription_validation(endpoint_id) - if not subscription_validation: - return None - - # response to validation endpoint - controller = TriggerManager.get_trigger_provider( - subscription_validation.tenant_id, TriggerProviderID(subscription_validation.provider_id) - ) - response = controller.dispatch( - user_id=subscription_validation.user_id, - request=request, - subscription=subscription_validation.to_subscription(), - ) - # append the request log - cls.append_validation_request_log(endpoint_id, request, response.response) - return response.response diff --git a/api/services/trigger_service.py b/api/services/trigger_service.py index 249e9fe33b..4b2bfaab2b 100644 --- a/api/services/trigger_service.py +++ b/api/services/trigger_service.py @@ -1,14 +1,11 @@ -import json import logging -import time -import uuid -from typing import Any, Optional from flask import Request, Response from core.plugin.entities.plugin import TriggerProviderID +from core.trigger.entities.entities import TriggerEntity from core.trigger.trigger_manager import TriggerManager -from extensions.ext_redis import redis_client +from models.trigger import TriggerSubscription from services.trigger.trigger_provider_service import TriggerProviderService logger = logging.getLogger(__name__) @@ -18,34 +15,27 @@ class TriggerService: __TEMPORARY_ENDPOINT_EXPIRE_MS__ = 5 * 60 * 1000 __ENDPOINT_REQUEST_CACHE_COUNT__ = 10 __ENDPOINT_REQUEST_CACHE_EXPIRE_MS__ = 5 * 60 * 1000 - # Lua script for atomic write with time & count based cleanup - __LUA_SCRIPT__ = """ - -- KEYS[1] = zset key - -- ARGV[1] = max_count (maximum number of entries to keep) - -- ARGV[2] = min_ts_ms (minimum timestamp to keep = now_ms - ttl_ms) - -- ARGV[3] = now_ms (current timestamp in milliseconds) - -- ARGV[4] = member (log entry JSON) - local key = KEYS[1] - local maxCount = tonumber(ARGV[1]) - local minTs = tonumber(ARGV[2]) - local nowMs = tonumber(ARGV[3]) - local member = ARGV[4] + @classmethod + def process_triggered_workflows(cls, subscription: TriggerSubscription, trigger: TriggerEntity, request: Request) -> None: + """Process triggered workflows.""" + - -- 1) Add new entry with timestamp as score - redis.call('ZADD', key, nowMs, member) - - -- 2) Remove entries older than minTs (time-based cleanup) - redis.call('ZREMRANGEBYSCORE', key, '-inf', minTs) - - -- 3) Remove oldest entries if count exceeds maxCount (count-based cleanup) - local n = redis.call('ZCARD', key) - if n > maxCount then - redis.call('ZREMRANGEBYRANK', key, 0, n - maxCount - 1) -- 0 is oldest - end - - return n - """ + @classmethod + def select_triggers(cls, controller, dispatch_response, provider_id, subscription) -> list[TriggerEntity]: + triggers = [] + for trigger_name in dispatch_response.triggers: + trigger = controller.get_trigger(trigger_name) + if trigger is None: + logger.error( + "Trigger '%s' not found in provider '%s' for tenant '%s'", + trigger_name, + provider_id, + subscription.tenant_id, + ) + raise ValueError(f"Trigger '{trigger_name}' not found") + triggers.append(trigger) + return triggers @classmethod def process_endpoint(cls, endpoint_id: str, request: Request) -> Response | None: @@ -53,140 +43,23 @@ class TriggerService: subscription = TriggerProviderService.get_subscription_by_endpoint(endpoint_id) if not subscription: return None - + provider_id = TriggerProviderID(subscription.provider_id) controller = TriggerManager.get_trigger_provider(subscription.tenant_id, provider_id) if not controller: return None - + dispatch_response = controller.dispatch( user_id=subscription.user_id, request=request, subscription=subscription.to_entity() ) # TODO invoke triggers - # dispatch_response.triggers - + if dispatch_response.triggers: + triggers = cls.select_triggers(controller, dispatch_response, provider_id, subscription) + for trigger in triggers: + cls.process_triggered_workflows( + subscription=subscription, + trigger=trigger, + request=request, + ) return dispatch_response.response - - @classmethod - def log_endpoint_request(cls, endpoint_id: str, request: Request) -> int: - """ - Log the endpoint request to Redis using ZSET for rolling log with time & count based retention. - - Args: - endpoint_id: The endpoint identifier - request: The Flask request object - - Returns: - The current number of logged requests for this endpoint - """ - try: - # Prepare timestamp - now_ms = int(time.time() * 1000) - min_ts = now_ms - cls.__ENDPOINT_REQUEST_CACHE_EXPIRE_MS__ - - # Extract request data - request_data = { - "id": str(uuid.uuid4()), - "timestamp": now_ms, - "method": request.method, - "path": request.path, - "headers": dict(request.headers), - "query_params": request.args.to_dict(flat=False) if request.args else {}, - "body": None, - "remote_addr": request.remote_addr, - } - - # Try to get request body if it exists - if request.is_json: - try: - request_data["body"] = request.get_json(force=True) - except Exception: - request_data["body"] = request.get_data(as_text=True) - elif request.data: - request_data["body"] = request.get_data(as_text=True) - - # Serialize to JSON - member = json.dumps(request_data, separators=(",", ":")) - - # Execute Lua script atomically - key = f"trigger:endpoint_requests:{endpoint_id}" - count = redis_client.eval( - cls.__LUA_SCRIPT__, - 1, # number of keys - key, # KEYS[1] - str(cls.__ENDPOINT_REQUEST_CACHE_COUNT__), # ARGV[1] - max count - str(min_ts), # ARGV[2] - minimum timestamp - str(now_ms), # ARGV[3] - current timestamp - member, # ARGV[4] - log entry - ) - - logger.debug("Logged request for endpoint %s, current count: %s", endpoint_id, count) - return count - - except Exception as e: - logger.exception("Failed to log endpoint request for %s", endpoint_id, exc_info=e) - # Don't fail the main request processing if logging fails - return 0 - - @classmethod - def get_recent_endpoint_requests( - cls, endpoint_id: str, limit: int = 100, start_time_ms: Optional[int] = None, end_time_ms: Optional[int] = None - ) -> list[dict[str, Any]]: - """ - Retrieve recent logged requests for an endpoint. - - Args: - endpoint_id: The endpoint identifier - limit: Maximum number of entries to return - start_time_ms: Start timestamp in milliseconds (optional) - end_time_ms: End timestamp in milliseconds (optional, defaults to now) - - Returns: - List of request log entries, newest first - """ - try: - key = f"trigger:endpoint_requests:{endpoint_id}" - - # Set time bounds - if end_time_ms is None: - end_time_ms = int(time.time() * 1000) - if start_time_ms is None: - start_time_ms = end_time_ms - cls.__ENDPOINT_REQUEST_CACHE_EXPIRE_MS__ - - # Get entries in reverse order (newest first) - entries = redis_client.zrevrangebyscore(key, max=end_time_ms, min=start_time_ms, start=0, num=limit) - - # Parse JSON entries - requests = [] - for entry in entries: - try: - requests.append(json.loads(entry)) - except json.JSONDecodeError: - logger.warning("Failed to parse log entry: %s", entry) - - return requests - - except Exception as e: - logger.exception("Failed to retrieve endpoint requests for %s", endpoint_id, exc_info=e) - return [] - - @classmethod - def clear_endpoint_requests(cls, endpoint_id: str) -> bool: - """ - Clear all logged requests for an endpoint. - - Args: - endpoint_id: The endpoint identifier - - Returns: - True if successful, False otherwise - """ - try: - key = f"trigger:endpoint_requests:{endpoint_id}" - redis_client.delete(key) - logger.info("Cleared request logs for endpoint %s", endpoint_id) - return True - except Exception as e: - logger.exception("Failed to clear endpoint requests for %s", endpoint_id, exc_info=e) - return False diff --git a/api/tests/unit_tests/core/plugin/utils/test_http_parser.py b/api/tests/unit_tests/core/plugin/utils/test_http_parser.py index 934331e074..1c2e0c96f8 100644 --- a/api/tests/unit_tests/core/plugin/utils/test_http_parser.py +++ b/api/tests/unit_tests/core/plugin/utils/test_http_parser.py @@ -420,7 +420,7 @@ class TestFileUploads: assert b"POST /api/upload HTTP/1.1\r\n" in raw_data assert f"Content-Type: multipart/form-data; boundary={boundary}".encode() in raw_data - assert b"Content-Disposition: form-data; name=\"file\"; filename=\"test.txt\"" in raw_data + assert b'Content-Disposition: form-data; name="file"; filename="test.txt"' in raw_data assert text_content.encode() in raw_data def test_deserialize_request_with_text_file_upload(self): @@ -465,7 +465,7 @@ class TestFileUploads: boundary = "----BoundaryString123" # Simulate a small PNG file header binary_content = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x10\x00\x00\x00\x10" - + # Build multipart body body_parts = [] body_parts.append(f"------{boundary}".encode()) @@ -478,7 +478,7 @@ class TestFileUploads: body_parts.append(b"") body_parts.append(b"Test image") body_parts.append(f"------{boundary}--".encode()) - + body = b"\r\n".join(body_parts) environ = { @@ -499,16 +499,16 @@ class TestFileUploads: assert b"POST /api/images HTTP/1.1\r\n" in raw_data assert f"Content-Type: multipart/form-data; boundary={boundary}".encode() in raw_data - assert b"filename=\"test.png\"" in raw_data + assert b'filename="test.png"' in raw_data assert b"Content-Type: image/png" in raw_data assert binary_content in raw_data def test_deserialize_request_with_binary_file_upload(self): # Test deserializing multipart/form-data request with binary file boundary = "----BoundaryABC123" - # Simulate a small JPEG file header + # Simulate a small JPEG file header binary_content = b"\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00" - + body_parts = [] body_parts.append(f"------{boundary}".encode()) body_parts.append(b'Content-Disposition: form-data; name="photo"; filename="photo.jpg"') @@ -520,7 +520,7 @@ class TestFileUploads: body_parts.append(b"") body_parts.append(b"Vacation 2024") body_parts.append(f"------{boundary}--".encode()) - + body = b"\r\n".join(body_parts) raw_data = ( @@ -538,7 +538,7 @@ class TestFileUploads: assert request.path == "/api/photos" assert "multipart/form-data" in request.content_type assert request.headers.get("Accept") == "application/json" - + # Verify the binary content is preserved request_body = request.get_data() assert b"photo.jpg" in request_body @@ -553,7 +553,7 @@ class TestFileUploads: boundary = "----MultiFilesBoundary" text_file = b"Text file contents" binary_file = b"\x00\x01\x02\x03\x04\x05" - + body_parts = [] # First file (text) body_parts.append(f"------{boundary}".encode()) @@ -573,7 +573,7 @@ class TestFileUploads: body_parts.append(b"") body_parts.append(b"uploads/2024") body_parts.append(f"------{boundary}--".encode()) - + body = b"\r\n".join(body_parts) environ = { @@ -606,7 +606,7 @@ class TestFileUploads: boundary = "----RoundTripBoundary" file_content = b"This is my file content with special chars: \xf0\x9f\x98\x80" - + body_parts = [] body_parts.append(f"------{boundary}".encode()) body_parts.append(b'Content-Disposition: form-data; name="upload"; filename="emoji.txt"') @@ -618,7 +618,7 @@ class TestFileUploads: body_parts.append(b"") body_parts.append(b'{"encoding": "utf-8", "size": 42}') body_parts.append(f"------{boundary}--".encode()) - + body = b"\r\n".join(body_parts) environ = { @@ -647,7 +647,7 @@ class TestFileUploads: assert restored_request.query_string == b"version=2" assert "multipart/form-data" in restored_request.content_type assert boundary in restored_request.content_type - + # Verify file content is preserved restored_body = restored_request.get_data() assert b"emoji.txt" in restored_body