From 2f08306695157466a27de355c316ec6777f5a087 Mon Sep 17 00:00:00 2001 From: Harry Date: Mon, 1 Sep 2025 12:08:48 +0800 Subject: [PATCH] feat(trigger): enhance trigger subscription management and processing - Refactor trigger provider classes to improve naming consistency and clarity - Introduce new methods for managing trigger subscriptions, including validation and dispatching - Update API endpoints to reflect changes in subscription handling - Implement logging and request management for endpoint interactions - Enhance data models to support subscription attributes and lifecycle management Co-authored-by: Claude --- .../console/app/workflow_trigger.py | 8 +- .../console/workspace/trigger_providers.py | 28 +- api/controllers/trigger/trigger.py | 32 +- api/core/plugin/entities/request.py | 22 + api/core/plugin/impl/trigger.py | 243 ++++++- api/core/plugin/utils/http_parser.py | 182 +++++ api/core/trigger/entities/api_entities.py | 23 + api/core/trigger/entities/entities.py | 13 +- api/core/trigger/provider.py | 162 ++++- api/core/trigger/trigger_manager.py | 80 ++- api/fields/workflow_trigger_fields.py | 2 +- api/models/trigger.py | 16 +- .../trigger/trigger_provider_service.py | 30 +- ...trigger_subscription_validation_service.py | 48 ++ api/services/trigger_service.py | 187 ++++- .../core/plugin/utils/test_http_parser.py | 655 ++++++++++++++++++ 16 files changed, 1630 insertions(+), 101 deletions(-) create mode 100644 api/core/plugin/utils/http_parser.py create mode 100644 api/services/trigger/trigger_subscription_validation_service.py create mode 100644 api/tests/unit_tests/core/plugin/utils/test_http_parser.py diff --git a/api/controllers/console/app/workflow_trigger.py b/api/controllers/console/app/workflow_trigger.py index 099dbea02f..2a099935a8 100644 --- a/api/controllers/console/app/workflow_trigger.py +++ b/api/controllers/console/app/workflow_trigger.py @@ -19,10 +19,6 @@ from models.workflow import AppTrigger, AppTriggerStatus, WorkflowWebhookTrigger logger = logging.getLogger(__name__) - - - - class WebhookTriggerApi(Resource): """Webhook Trigger API""" @@ -87,7 +83,7 @@ class WebhookTriggerApi(Resource): base_url = dify_config.SERVICE_API_URL webhook_trigger.webhook_url = f"{base_url}/triggers/webhook/{webhook_trigger.webhook_id}" webhook_trigger.webhook_debug_url = f"{base_url}/triggers/webhook-debug/{webhook_trigger.webhook_id}" - + return webhook_trigger @setup_required @@ -231,7 +227,7 @@ class AppTriggerEnableApi(Resource): trigger.icon = url_prefix + trigger.provider_name + "/icon" else: trigger.icon = "" - + return trigger diff --git a/api/controllers/console/workspace/trigger_providers.py b/api/controllers/console/workspace/trigger_providers.py index 6fab876c1c..aced8b75f3 100644 --- a/api/controllers/console/workspace/trigger_providers.py +++ b/api/controllers/console/workspace/trigger_providers.py @@ -31,7 +31,7 @@ class TriggerProviderListApi(Resource): return jsonable_encoder(TriggerProviderService.list_trigger_providers(user.current_tenant_id)) -class TriggerProviderSubscriptionListApi(Resource): +class TriggerSubscriptionListApi(Resource): @setup_required @login_required @account_initialization_required @@ -54,7 +54,7 @@ class TriggerProviderSubscriptionListApi(Resource): raise -class TriggerProviderSubscriptionsAddApi(Resource): +class TriggerSubscriptionsAddApi(Resource): @setup_required @login_required @account_initialization_required @@ -99,7 +99,7 @@ class TriggerProviderSubscriptionsAddApi(Resource): raise -class TriggerProviderSubscriptionsDeleteApi(Resource): +class TriggerSubscriptionsDeleteApi(Resource): @setup_required @login_required @account_initialization_required @@ -125,7 +125,7 @@ class TriggerProviderSubscriptionsDeleteApi(Resource): raise -class TriggerProviderOAuthAuthorizeApi(Resource): +class TriggerOAuthAuthorizeApi(Resource): @setup_required @login_required @account_initialization_required @@ -189,7 +189,7 @@ class TriggerProviderOAuthAuthorizeApi(Resource): raise -class TriggerProviderOAuthCallbackApi(Resource): +class TriggerOAuthCallbackApi(Resource): @setup_required def get(self, provider): """Handle OAuth callback for trigger provider""" @@ -252,7 +252,7 @@ class TriggerProviderOAuthCallbackApi(Resource): return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") -class TriggerProviderOAuthRefreshTokenApi(Resource): +class TriggerOAuthRefreshTokenApi(Resource): @setup_required @login_required @account_initialization_required @@ -278,7 +278,7 @@ class TriggerProviderOAuthRefreshTokenApi(Resource): raise -class TriggerProviderOAuthClientManageApi(Resource): +class TriggerOAuthClientManageApi(Resource): @setup_required @login_required @account_initialization_required @@ -381,25 +381,25 @@ class TriggerProviderOAuthClientManageApi(Resource): # Trigger provider endpoints api.add_resource(TriggerProviderListApi, "/workspaces/current/trigger-providers") api.add_resource( - TriggerProviderSubscriptionListApi, "/workspaces/current/trigger-provider/subscriptions//list" + TriggerSubscriptionListApi, "/workspaces/current/trigger-provider/subscriptions//list" ) api.add_resource( - TriggerProviderSubscriptionsAddApi, "/workspaces/current/trigger-provider/subscriptions//add" + TriggerSubscriptionsAddApi, "/workspaces/current/trigger-provider/subscriptions//add" ) api.add_resource( - TriggerProviderSubscriptionsDeleteApi, + TriggerSubscriptionsDeleteApi, "/workspaces/current/trigger-provider/subscriptions//delete", ) # OAuth api.add_resource( - TriggerProviderOAuthAuthorizeApi, "/workspaces/current/trigger-provider//oauth/authorize" + TriggerOAuthAuthorizeApi, "/workspaces/current/trigger-provider//oauth/authorize" ) -api.add_resource(TriggerProviderOAuthCallbackApi, "/oauth/plugin//trigger/callback") +api.add_resource(TriggerOAuthCallbackApi, "/oauth/plugin//trigger/callback") api.add_resource( - TriggerProviderOAuthRefreshTokenApi, + TriggerOAuthRefreshTokenApi, "/workspaces/current/trigger-provider/subscriptions//oauth/refresh", ) api.add_resource( - TriggerProviderOAuthClientManageApi, "/workspaces/current/trigger-provider//oauth/client" + TriggerOAuthClientManageApi, "/workspaces/current/trigger-provider//oauth/client" ) diff --git a/api/controllers/trigger/trigger.py b/api/controllers/trigger/trigger.py index 13f4a7e234..118e99fad4 100644 --- a/api/controllers/trigger/trigger.py +++ b/api/controllers/trigger/trigger.py @@ -1,21 +1,45 @@ import logging +import re from flask import jsonify, request from werkzeug.exceptions import NotFound from controllers.trigger import bp +from services.trigger.trigger_provider_service import TriggerProviderService +from services.trigger.trigger_subscription_validation_service import TriggerSubscriptionValidationService from services.trigger_service import TriggerService logger = logging.getLogger(__name__) +UUID_PATTERN = r"^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$" +UUID_MATCHER = re.compile(UUID_PATTERN) -@bp.route("/trigger/webhook/", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"]) -def trigger_webhook(endpoint_id: str): + +@bp.route( + "/trigger/endpoint/", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"] +) +@bp.route( + "/trigger/endpoint-debug/", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"] +) +def trigger_endpoint(endpoint_id: str): """ - Handle webhook trigger calls. + Handle endpoint trigger calls. """ + # endpoint_id must be UUID + if not UUID_MATCHER.match(endpoint_id): + raise NotFound("Invalid endpoint ID") + handling_chain = [ + TriggerService.process_endpoint, + TriggerSubscriptionValidationService.process_validating_endpoint, + ] try: - return TriggerService.process_webhook(endpoint_id, request) + for handler in handling_chain: + response = handler(endpoint_id, request) + if response: + break + if not response: + raise NotFound("Endpoint not found") + return response except ValueError as e: raise NotFound(str(e)) except Exception as e: diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index 3a783dad3e..3a573ef472 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -1,5 +1,6 @@ from typing import Any, Literal, Optional +from flask import Response from pydantic import BaseModel, ConfigDict, Field, field_validator from core.entities.provider_entities import BasicProviderConfig @@ -237,3 +238,24 @@ class RequestFetchAppInfo(BaseModel): """ app_id: str + + +class TriggerInvokeResponse(BaseModel): + event: dict[str, Any] + + +class PluginTriggerDispatchResponse(BaseModel): + triggers: list[str] + raw_http_response: str + +class TriggerSubscriptionResponse(BaseModel): + subscription: dict[str, Any] + +class TriggerValidateProviderCredentialsResponse(BaseModel): + valid: bool + message: str + error: str + +class TriggerDispatchResponse: + triggers: list[str] + response: Response diff --git a/api/core/plugin/impl/trigger.py b/api/core/plugin/impl/trigger.py index d048442e56..d56db80588 100644 --- a/api/core/plugin/impl/trigger.py +++ b/api/core/plugin/impl/trigger.py @@ -1,14 +1,26 @@ +import binascii from typing import Any -from core.plugin.entities.plugin import TriggerProviderID -from core.plugin.entities.plugin_daemon import PluginTriggerProviderEntity +from flask import Request + +from core.plugin.entities.plugin import GenericProviderID, TriggerProviderID +from core.plugin.entities.plugin_daemon import CredentialType, PluginTriggerProviderEntity +from core.plugin.entities.request import ( + PluginTriggerDispatchResponse, + TriggerDispatchResponse, + TriggerInvokeResponse, + TriggerSubscriptionResponse, + TriggerValidateProviderCredentialsResponse, +) from core.plugin.impl.base import BasePluginClient +from core.plugin.utils.http_parser import deserialize_response, serialize_request +from core.trigger.entities.entities import Subscription class PluginTriggerManager(BasePluginClient): def fetch_trigger_providers(self, tenant_id: str) -> list[PluginTriggerProviderEntity]: """ - Fetch tool providers for the given tenant. + Fetch trigger providers for the given tenant. """ def transformer(json_response: dict[str, Any]) -> dict: @@ -31,7 +43,7 @@ class PluginTriggerManager(BasePluginClient): for provider in response: provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}" - # override the provider name for each tool to plugin_id/provider_name + # override the provider name for each trigger to plugin_id/provider_name for trigger in provider.declaration.triggers: trigger.identity.provider = provider.declaration.identity.name @@ -39,7 +51,7 @@ class PluginTriggerManager(BasePluginClient): def fetch_trigger_provider(self, tenant_id: str, provider_id: TriggerProviderID) -> PluginTriggerProviderEntity: """ - Fetch tool provider for the given tenant and plugin. + Fetch trigger provider for the given tenant and plugin. """ def transformer(json_response: dict[str, Any]) -> dict: @@ -65,3 +77,224 @@ class PluginTriggerManager(BasePluginClient): trigger.identity.provider = response.declaration.identity.name return response + + def invoke_trigger( + self, + tenant_id: str, + user_id: str, + provider: str, + trigger: str, + credentials: dict[str, Any], + credential_type: CredentialType, + request: Request, + parameters: dict[str, Any], + ) -> TriggerInvokeResponse: + """ + Invoke a trigger with the given parameters. + """ + trigger_provider_id = GenericProviderID(provider) + + response = self._request_with_plugin_daemon_response_stream( + "POST", + f"plugin/{tenant_id}/dispatch/trigger/invoke", + TriggerInvokeResponse, + data={ + "user_id": user_id, + "data": { + "provider": trigger_provider_id.provider_name, + "trigger": trigger, + "credentials": credentials, + "credential_type": credential_type, + "raw_http_request": binascii.hexlify(serialize_request(request)).decode(), + "parameters": parameters, + }, + }, + headers={ + "X-Plugin-ID": trigger_provider_id.plugin_id, + "Content-Type": "application/json", + }, + ) + + for resp in response: + return TriggerInvokeResponse(event=resp.event) + + raise ValueError("No response received from plugin daemon for invoke trigger") + + def validate_provider_credentials( + self, tenant_id: str, user_id: str, provider: str, credentials: dict[str, Any] + ) -> TriggerValidateProviderCredentialsResponse: + """ + Validate the credentials of the trigger provider. + """ + trigger_provider_id = GenericProviderID(provider) + + response = self._request_with_plugin_daemon_response_stream( + "POST", + f"plugin/{tenant_id}/dispatch/trigger/validate_credentials", + TriggerValidateProviderCredentialsResponse, + data={ + "user_id": user_id, + "data": { + "provider": trigger_provider_id.provider_name, + "credentials": credentials, + }, + }, + headers={ + "X-Plugin-ID": trigger_provider_id.plugin_id, + "Content-Type": "application/json", + }, + ) + + for resp in response: + return resp + + return TriggerValidateProviderCredentialsResponse(valid=False, message="No response", error="No response") + + def dispatch_event( + self, + tenant_id: str, + user_id: str, + provider: str, + subscription: dict[str, Any], + request: Request, + ) -> TriggerDispatchResponse: + """ + Dispatch an event to triggers. + """ + trigger_provider_id = GenericProviderID(provider) + + response = self._request_with_plugin_daemon_response_stream( + "POST", + f"plugin/{tenant_id}/dispatch/trigger/dispatch_event", + PluginTriggerDispatchResponse, + data={ + "user_id": user_id, + "data": { + "provider": trigger_provider_id.provider_name, + "subscription": subscription, + "raw_http_request": binascii.hexlify(serialize_request(request)).decode(), + }, + }, + headers={ + "X-Plugin-ID": trigger_provider_id.plugin_id, + "Content-Type": "application/json", + }, + ) + + for resp in response: + return TriggerDispatchResponse( + triggers=resp.triggers, + response=deserialize_response(binascii.unhexlify(resp.raw_http_response.encode())), + ) + + raise ValueError("No response received from plugin daemon for dispatch event") + + def subscribe( + self, + tenant_id: str, + user_id: str, + provider: str, + credentials: dict[str, Any], + endpoint: str, + parameters: dict[str, Any], + ) -> TriggerSubscriptionResponse: + """ + Subscribe to a trigger. + """ + trigger_provider_id = GenericProviderID(provider) + + response = self._request_with_plugin_daemon_response_stream( + "POST", + f"plugin/{tenant_id}/dispatch/trigger/subscribe", + TriggerSubscriptionResponse, + data={ + "user_id": user_id, + "data": { + "provider": trigger_provider_id.provider_name, + "credentials": credentials, + "endpoint": endpoint, + "parameters": parameters, + }, + }, + headers={ + "X-Plugin-ID": trigger_provider_id.plugin_id, + "Content-Type": "application/json", + }, + ) + + for resp in response: + return resp + + raise ValueError("No response received from plugin daemon for subscribe") + + def unsubscribe( + self, + tenant_id: str, + user_id: str, + provider: str, + subscription: Subscription, + credentials: dict[str, Any], + ) -> TriggerSubscriptionResponse: + """ + Unsubscribe from a trigger. + """ + trigger_provider_id = GenericProviderID(provider) + + response = self._request_with_plugin_daemon_response_stream( + "POST", + f"plugin/{tenant_id}/dispatch/trigger/unsubscribe", + TriggerSubscriptionResponse, + data={ + "user_id": user_id, + "data": { + "provider": trigger_provider_id.provider_name, + "subscription": subscription.model_dump(), + "credentials": credentials, + }, + }, + headers={ + "X-Plugin-ID": trigger_provider_id.plugin_id, + "Content-Type": "application/json", + }, + ) + + for resp in response: + return resp + + raise ValueError("No response received from plugin daemon for unsubscribe") + + def refresh( + self, + tenant_id: str, + user_id: str, + provider: str, + subscription: Subscription, + credentials: dict[str, Any], + ) -> TriggerSubscriptionResponse: + """ + Refresh a trigger subscription. + """ + trigger_provider_id = GenericProviderID(provider) + + response = self._request_with_plugin_daemon_response_stream( + "POST", + f"plugin/{tenant_id}/dispatch/trigger/refresh", + TriggerSubscriptionResponse, + data={ + "user_id": user_id, + "data": { + "provider": trigger_provider_id.provider_name, + "subscription": subscription.model_dump(), + "credentials": credentials, + }, + }, + headers={ + "X-Plugin-ID": trigger_provider_id.plugin_id, + "Content-Type": "application/json", + }, + ) + + for resp in response: + return resp + + raise ValueError("No response received from plugin daemon for refresh") diff --git a/api/core/plugin/utils/http_parser.py b/api/core/plugin/utils/http_parser.py new file mode 100644 index 0000000000..b4d59a58a2 --- /dev/null +++ b/api/core/plugin/utils/http_parser.py @@ -0,0 +1,182 @@ +from io import BytesIO +from typing import Any + +from flask import Request, Response +from werkzeug.datastructures import Headers + + +def serialize_request(request: Request) -> bytes: + """ + Convert a Request object to raw HTTP data. + + Args: + request: The Request object to convert. + + Returns: + The raw HTTP data as bytes. + """ + # Start with the request line + method = request.method + path = request.full_path + # Remove trailing ? if there's no query string + path = path.removesuffix("?") + protocol = request.headers.get("HTTP_VERSION", "HTTP/1.1") + raw_data = f"{method} {path} {protocol}\r\n".encode() + + # Add headers + for header_name, header_value in request.headers.items(): + raw_data += f"{header_name}: {header_value}\r\n".encode() + + # Add empty line to separate headers from body + raw_data += b"\r\n" + + # Add body if exists + body = request.get_data(as_text=False) + if body: + raw_data += body + + return raw_data + + +def deserialize_request(raw_data: bytes) -> Request: + """ + Convert raw HTTP data to a Request object. + + Args: + raw_data: The raw HTTP data as bytes. + + Returns: + A Flask Request object. + """ + lines = raw_data.split(b"\r\n") + + # Parse request line + request_line = lines[0].decode("utf-8") + parts = request_line.split(" ", 2) # Split into max 3 parts + if len(parts) < 2: + raise ValueError(f"Invalid request line: {request_line}") + + method = parts[0] + path = parts[1] + protocol = parts[2] if len(parts) > 2 else "HTTP/1.1" + + # Parse headers + headers = Headers() + body_start = 0 + for i in range(1, len(lines)): + line = lines[i] + if line == b"": + body_start = i + 1 + break + if b":" in line: + header_line = line.decode("utf-8") + name, value = header_line.split(":", 1) + headers[name.strip()] = value.strip() + + # Extract body + body = b"" + if body_start > 0 and body_start < len(lines): + body = b"\r\n".join(lines[body_start:]) + + # Create environ for Request + environ = { + "REQUEST_METHOD": method, + "PATH_INFO": path.split("?")[0] if "?" in path else path, + "QUERY_STRING": path.split("?")[1] if "?" in path else "", + "SERVER_NAME": headers.get("Host", "localhost").split(":")[0], + "SERVER_PORT": headers.get("Host", "localhost:80").split(":")[1] if ":" in headers.get("Host", "") else "80", + "SERVER_PROTOCOL": protocol, + "wsgi.input": BytesIO(body), + "wsgi.url_scheme": "https" if headers.get("X-Forwarded-Proto") == "https" else "http", + "CONTENT_LENGTH": str(len(body)) if body else "0", + "CONTENT_TYPE": headers.get("Content-Type", ""), + } + + # Add headers to environ + for header_name, header_value in headers.items(): + env_name = f"HTTP_{header_name.upper().replace('-', '_')}" + if header_name.upper() not in ["CONTENT-TYPE", "CONTENT-LENGTH"]: + environ[env_name] = header_value + + return Request(environ) + + +def serialize_response(response: Response) -> bytes: + """ + Convert a Response object to raw HTTP data. + + Args: + response: The Response object to convert. + + Returns: + The raw HTTP data as bytes. + """ + # Start with the status line + protocol = "HTTP/1.1" + status_code = response.status_code + status_text = response.status.split(" ", 1)[1] if " " in response.status else "OK" + raw_data = f"{protocol} {status_code} {status_text}\r\n".encode() + + # Add headers + for header_name, header_value in response.headers.items(): + raw_data += f"{header_name}: {header_value}\r\n".encode() + + # Add empty line to separate headers from body + raw_data += b"\r\n" + + # Add body if exists + body = response.get_data(as_text=False) + if body: + raw_data += body + + return raw_data + + +def deserialize_response(raw_data: bytes) -> Response: + """ + Convert raw HTTP data to a Response object. + + Args: + raw_data: The raw HTTP data as bytes. + + Returns: + A Flask Response object. + """ + lines = raw_data.split(b"\r\n") + + # Parse status line + status_line = lines[0].decode("utf-8") + parts = status_line.split(" ", 2) + if len(parts) < 2: + raise ValueError(f"Invalid status line: {status_line}") + + protocol = parts[0] + status_code = int(parts[1]) + status_text = parts[2] if len(parts) > 2 else "OK" + + # Parse headers + headers: dict[str, Any] = {} + body_start = 0 + for i in range(1, len(lines)): + line = lines[i] + if line == b"": + body_start = i + 1 + break + if b":" in line: + header_line = line.decode("utf-8") + name, value = header_line.split(":", 1) + headers[name.strip()] = value.strip() + + # Extract body + body = b"" + if body_start > 0 and body_start < len(lines): + body = b"\r\n".join(lines[body_start:]) + + # Create Response object + response = Response( + response=body, + status=status_code, + headers=headers, + ) + + return response diff --git a/api/core/trigger/entities/api_entities.py b/api/core/trigger/entities/api_entities.py index 54f297b4b5..20687cfeb8 100644 --- a/api/core/trigger/entities/api_entities.py +++ b/api/core/trigger/entities/api_entities.py @@ -7,6 +7,7 @@ from core.entities.provider_entities import ProviderConfig from core.plugin.entities.plugin_daemon import CredentialType from core.trigger.entities.entities import ( OAuthSchema, + Subscription, SubscriptionSchema, TriggerDescription, TriggerEntity, @@ -39,5 +40,27 @@ class TriggerApiEntity(BaseModel): parameters: list[TriggerParameter] = Field(description="The parameters of the trigger") output_schema: Optional[Mapping[str, Any]] = Field(description="The output schema of the trigger") +class SubscriptionValidation(BaseModel): + id: str + name: str + tenant_id: str + user_id: str + provider_id: str + endpoint: str + parameters: dict + properties: dict + credentials: dict + credential_type: str + credential_expires_at: int + expires_at: int + + def to_subscription(self) -> Subscription: + return Subscription( + expires_at=self.expires_at, + endpoint=self.endpoint, + parameters=self.parameters, + properties=self.properties, + ) + __all__ = ["TriggerApiEntity", "TriggerProviderApiEntity", "TriggerProviderSubscriptionApiEntity"] diff --git a/api/core/trigger/entities/entities.py b/api/core/trigger/entities/entities.py index b0d8109f75..99163e8c6f 100644 --- a/api/core/trigger/entities/entities.py +++ b/api/core/trigger/entities/entities.py @@ -143,12 +143,19 @@ class Subscription(BaseModel): Contains all information needed to manage the subscription lifecycle. """ - expire_at: int = Field( + expires_at: int = Field( ..., description="The timestamp when the subscription will expire, this for refresh the subscription" ) - metadata: dict[str, Any] = Field( - ..., description="Metadata about the subscription in the external service, defined in subscription_schema" + endpoint: str = Field(..., description="The webhook endpoint URL allocated by Dify for receiving events") + + parameters: dict[str, Any] | None = Field( + default=None, + description="""The parameters of the subscription, this is the creation parameters. + Only available when creating a new subscription by credentials(auto subscription), not manual subscription""", + ) + properties: dict[str, Any] = Field( + ..., description="Subscription data containing all properties and provider-specific information" ) diff --git a/api/core/trigger/provider.py b/api/core/trigger/provider.py index f57bb7514d..240e65b390 100644 --- a/api/core/trigger/provider.py +++ b/api/core/trigger/provider.py @@ -3,12 +3,19 @@ Trigger Provider Controller for managing trigger providers """ import logging -import time -from typing import Optional +from typing import Any, Optional + +from flask import Request from core.entities.provider_entities import BasicProviderConfig from core.plugin.entities.plugin import TriggerProviderID from core.plugin.entities.plugin_daemon import CredentialType +from core.plugin.entities.request import ( + TriggerDispatchResponse, + TriggerInvokeResponse, + TriggerValidateProviderCredentialsResponse, +) +from core.plugin.impl.trigger import PluginTriggerManager from core.trigger.entities.api_entities import TriggerProviderApiEntity from core.trigger.entities.entities import ( ProviderConfig, @@ -90,18 +97,36 @@ class PluginTriggerProviderController: :return: List of subscription config schemas """ - return self.entity.subscription_schema + # Return the parameters schema from the subscription schema + if self.entity.subscription_schema and self.entity.subscription_schema.parameters_schema: + return self.entity.subscription_schema.parameters_schema + return [] - def validate_credentials(self, credentials: dict) -> None: + def validate_credentials(self, credentials: dict) -> TriggerValidateProviderCredentialsResponse: """ Validate credentials against schema :param credentials: Credentials to validate - :raises ValueError: If credentials are invalid + :return: Validation response """ + # First validate against schema for config in self.entity.credentials_schema: if config.required and config.name not in credentials: - raise ValueError(f"Missing required credential field: {config.name}") + return TriggerValidateProviderCredentialsResponse( + valid=False, + message=f"Missing required credential field: {config.name}", + error=f"Missing required credential field: {config.name}", + ) + + # Then validate with the plugin daemon + manager = PluginTriggerManager() + provider_id = self.get_provider_id() + return manager.validate_provider_credentials( + tenant_id=self.tenant_id, + user_id="system", # System validation + provider=str(provider_id), + credentials=credentials, + ) def get_supported_credential_types(self) -> list[CredentialType]: """ @@ -143,58 +168,127 @@ class PluginTriggerProviderController: """ return self.entity.oauth_schema.client_schema.copy() if self.entity.oauth_schema else [] - @property - def need_credentials(self) -> bool: - """Check if this provider needs credentials""" - return len(self.get_supported_credential_types()) > 0 + def dispatch(self,user_id: str, request: Request, subscription: Subscription) -> TriggerDispatchResponse: + """ + Dispatch a trigger through plugin runtime - def execute_trigger(self, trigger_name: str, parameters: dict, credentials: dict) -> dict: + :param user_id: User ID + :param request: Flask request object + :param subscription: Subscription + :return: Dispatch response with triggers and raw HTTP response + """ + manager = PluginTriggerManager() + provider_id = self.get_provider_id() + + response = manager.dispatch_event( + tenant_id=self.tenant_id, + user_id=user_id, + provider=str(provider_id), + subscription=subscription.model_dump(), + request=request, + ) + return response + + def invoke_trigger( + self, + user_id: str, + trigger_name: str, + parameters: dict, + credentials: dict, + credential_type: CredentialType, + request: Request, + ) -> TriggerInvokeResponse: """ Execute a trigger through plugin runtime + :param user_id: User ID :param trigger_name: Trigger name :param parameters: Trigger parameters :param credentials: Provider credentials - :return: Execution result + :param credential_type: Credential type + :param request: Request + :return: Trigger execution result """ - logger.info("Executing trigger %s for plugin %s", trigger_name, self.plugin_id) - return { - "success": True, - "trigger": trigger_name, - "plugin": self.plugin_id, - "result": "Trigger executed successfully", - } + manager = PluginTriggerManager() + provider_id = self.get_provider_id() - def subscribe_trigger(self, trigger_name: str, subscription_params: dict, credentials: dict) -> Subscription: + return manager.invoke_trigger( + tenant_id=self.tenant_id, + user_id=user_id, + provider=str(provider_id), + trigger=trigger_name, + credentials=credentials, + credential_type=credential_type, + request=request, + parameters=parameters, + ) + + def subscribe_trigger(self, user_id: str, endpoint: str, parameters: dict, credentials: dict) -> Subscription: """ Subscribe to a trigger through plugin runtime - :param trigger_name: Trigger name + :param user_id: User ID + :param endpoint: Subscription endpoint :param subscription_params: Subscription parameters :param credentials: Provider credentials :return: Subscription result """ - logger.info("Subscribing to trigger %s for plugin %s", trigger_name, self.plugin_id) - return Subscription( - expire_at=int(time.time()) + 86400, # 24 hours from now - metadata={ - "subscription_id": f"{self.plugin_id}_{trigger_name}_{time.time()}", - "webhook_url": f"/triggers/webhook/{self.plugin_id}/{trigger_name}", - **subscription_params, - }, + manager = PluginTriggerManager() + provider_id = self.get_provider_id() + + response = manager.subscribe( + tenant_id=self.tenant_id, + user_id=user_id, + provider=str(provider_id), + credentials=credentials, + endpoint=endpoint, + parameters=parameters, ) - def unsubscribe_trigger(self, trigger_name: str, subscription_metadata: dict, credentials: dict) -> Unsubscription: + return Subscription.model_validate(response.subscription) + + def unsubscribe_trigger(self, user_id: str, subscription: Subscription, credentials: dict) -> Unsubscription: """ Unsubscribe from a trigger through plugin runtime - :param trigger_name: Trigger name - :param subscription_metadata: Subscription metadata + :param user_id: User ID + :param subscription: Subscription metadata :param credentials: Provider credentials :return: Unsubscription result """ - logger.info("Unsubscribing from trigger %s for plugin %s", trigger_name, self.plugin_id) - return Unsubscription(success=True, message=f"Successfully unsubscribed from trigger {trigger_name}") + manager = PluginTriggerManager() + provider_id = self.get_provider_id() + + response = manager.unsubscribe( + tenant_id=self.tenant_id, + user_id=user_id, + provider=str(provider_id), + subscription=subscription, + credentials=credentials, + ) + + return Unsubscription.model_validate(response.subscription) + + def refresh_trigger(self, subscription: Subscription, credentials: dict) -> Subscription: + """ + Refresh a trigger subscription through plugin runtime + + :param subscription: Subscription metadata + :param credentials: Provider credentials + :return: Refreshed subscription result + """ + manager = PluginTriggerManager() + provider_id = self.get_provider_id() + + response = manager.refresh( + tenant_id=self.tenant_id, + user_id="system", # System refresh + provider=str(provider_id), + subscription=subscription, + credentials=credentials, + ) + + return Subscription.model_validate(response.subscription) __all__ = ["PluginTriggerProviderController"] diff --git a/api/core/trigger/trigger_manager.py b/api/core/trigger/trigger_manager.py index 58ae472d1b..b39557a555 100644 --- a/api/core/trigger/trigger_manager.py +++ b/api/core/trigger/trigger_manager.py @@ -5,7 +5,10 @@ Trigger Manager for loading and managing trigger providers and triggers import logging from typing import Optional +from flask import Request + from core.plugin.entities.plugin import TriggerProviderID +from core.plugin.entities.request import TriggerInvokeResponse from core.plugin.impl.trigger import PluginTriggerManager from core.trigger.entities.entities import ( ProviderConfig, @@ -14,6 +17,7 @@ from core.trigger.entities.entities import ( Unsubscription, ) from core.trigger.provider import PluginTriggerProviderController +from core.plugin.entities.plugin_daemon import CredentialType logger = logging.getLogger(__name__) @@ -123,75 +127,90 @@ class TriggerManager: :return: Tuple of (is_valid, error_message) """ try: - cls.get_trigger_provider(tenant_id, provider_id).validate_credentials(credentials) - return True, "" + provider = cls.get_trigger_provider(tenant_id, provider_id) + validation_result = provider.validate_credentials(credentials) + return validation_result.valid, validation_result.message if not validation_result.valid else "" except Exception as e: return False, str(e) @classmethod - def execute_trigger( - cls, tenant_id: str, provider_id: TriggerProviderID, trigger_name: str, parameters: dict, credentials: dict - ) -> dict: + def invoke_trigger( + cls, + tenant_id: str, + user_id: str, + provider_id: TriggerProviderID, + trigger_name: str, + parameters: dict, + credentials: dict, + credential_type: CredentialType, + request: Request, + ) -> TriggerInvokeResponse: """ Execute a trigger :param tenant_id: Tenant ID + :param user_id: User ID :param provider_id: Provider ID :param trigger_name: Trigger name :param parameters: Trigger parameters :param credentials: Provider credentials + :param credential_type: Credential type + :param request: Request :return: Trigger execution result """ - trigger = cls.get_trigger_provider(tenant_id, provider_id).get_trigger(trigger_name) + provider = cls.get_trigger_provider(tenant_id, provider_id) + trigger = provider.get_trigger(trigger_name) if not trigger: raise ValueError(f"Trigger {trigger_name} not found in provider {provider_id}") - return cls.get_trigger_provider(tenant_id, provider_id).execute_trigger(trigger_name, parameters, credentials) + return provider.invoke_trigger(user_id, trigger_name, parameters, credentials, credential_type, request) @classmethod def subscribe_trigger( cls, tenant_id: str, + user_id: str, provider_id: TriggerProviderID, - trigger_name: str, - subscription_params: dict, + endpoint: str, + parameters: dict, credentials: dict, ) -> Subscription: """ Subscribe to a trigger (e.g., register webhook) :param tenant_id: Tenant ID + :param user_id: User ID :param provider_id: Provider ID - :param trigger_name: Trigger name - :param subscription_params: Subscription parameters + :param endpoint: Subscription endpoint + :param parameters: Subscription parameters :param credentials: Provider credentials :return: Subscription result """ - return cls.get_trigger_provider(tenant_id, provider_id).subscribe_trigger( - trigger_name, subscription_params, credentials + provider = cls.get_trigger_provider(tenant_id, provider_id) + return provider.subscribe_trigger( + user_id=user_id, endpoint=endpoint, parameters=parameters, credentials=credentials ) @classmethod def unsubscribe_trigger( cls, tenant_id: str, + user_id: str, provider_id: TriggerProviderID, - trigger_name: str, - subscription_metadata: dict, + subscription: Subscription, credentials: dict, ) -> Unsubscription: """ Unsubscribe from a trigger :param tenant_id: Tenant ID + :param user_id: User ID :param provider_id: Provider ID - :param trigger_name: Trigger name - :param subscription_metadata: Subscription metadata from subscribe operation + :param subscription: Subscription metadata from subscribe operation :param credentials: Provider credentials :return: Unsubscription result """ - return cls.get_trigger_provider(tenant_id, provider_id).unsubscribe_trigger( - trigger_name, subscription_metadata, credentials - ) + provider = cls.get_trigger_provider(tenant_id, provider_id) + return provider.unsubscribe_trigger(user_id=user_id, subscription=subscription, credentials=credentials) @classmethod def get_provider_subscription_schema(cls, tenant_id: str, provider_id: TriggerProviderID) -> list[ProviderConfig]: @@ -204,6 +223,27 @@ class TriggerManager: """ return cls.get_trigger_provider(tenant_id, provider_id).get_subscription_schema() + @classmethod + def refresh_trigger( + cls, + tenant_id: str, + provider_id: TriggerProviderID, + trigger_name: str, + subscription: Subscription, + credentials: dict, + ) -> Subscription: + """ + Refresh a trigger subscription + + :param tenant_id: Tenant ID + :param provider_id: Provider ID + :param trigger_name: Trigger name + :param subscription: Subscription metadata from subscribe operation + :param credentials: Provider credentials + :return: Refreshed subscription result + """ + return cls.get_trigger_provider(tenant_id, provider_id).refresh_trigger(trigger_name, subscription, credentials) + # Export __all__ = ["TriggerManager"] diff --git a/api/fields/workflow_trigger_fields.py b/api/fields/workflow_trigger_fields.py index c6d254320f..702d20b3ce 100644 --- a/api/fields/workflow_trigger_fields.py +++ b/api/fields/workflow_trigger_fields.py @@ -23,4 +23,4 @@ webhook_trigger_fields = { "node_id": fields.String, "triggered_by": fields.String, "created_at": fields.DateTime(dt_format="iso8601"), -} \ No newline at end of file +} diff --git a/api/models/trigger.py b/api/models/trigger.py index 25c08f2ea3..f65bd5b470 100644 --- a/api/models/trigger.py +++ b/api/models/trigger.py @@ -8,7 +8,8 @@ from sqlalchemy import DateTime, Index, Integer, String, UniqueConstraint, func from sqlalchemy.orm import Mapped, mapped_column from core.plugin.entities.plugin_daemon import CredentialType -from core.trigger.entities.api_entities import TriggerProviderSubscriptionApiEntity +from core.trigger.entities.api_entities import SubscriptionValidation, TriggerProviderSubscriptionApiEntity +from core.trigger.entities.entities import Subscription from models.base import Base from models.types import StringUUID @@ -24,6 +25,7 @@ class TriggerSubscription(Base): sa.PrimaryKeyConstraint("id", name="trigger_subscription_pkey"), Index("idx_trigger_subscriptions_tenant_provider", "tenant_id", "provider_id"), UniqueConstraint("tenant_id", "provider_id", "name", name="unique_trigger_subscription"), + UniqueConstraint("endpoint", name="unique_trigger_subscription_endpoint"), ) id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) @@ -33,8 +35,9 @@ class TriggerSubscription(Base): provider_id: Mapped[str] = mapped_column( String(255), nullable=False, comment="Provider identifier (e.g., plugin_id/provider_name)" ) + endpoint: Mapped[str] = mapped_column(String(255), nullable=False, comment="Subscription endpoint") parameters: Mapped[dict] = mapped_column(sa.JSON, nullable=False, comment="Subscription parameters JSON") - configuration: Mapped[dict] = mapped_column(sa.JSON, nullable=False, comment="Subscription configuration JSON") + properties: Mapped[dict] = mapped_column(sa.JSON, nullable=False, comment="Subscription properties JSON") credentials: Mapped[dict] = mapped_column(sa.JSON, nullable=False, comment="Subscription credentials JSON") credential_type: Mapped[str] = mapped_column(String(50), nullable=False, comment="oauth or api_key") @@ -57,6 +60,14 @@ class TriggerSubscription(Base): # Check if token expires in next 3 minutes return (self.credential_expires_at - 180) < int(time.time()) + def to_entity(self) -> Subscription: + return Subscription( + expires_at=self.expires_at, + endpoint=self.endpoint, + parameters=self.parameters, + properties=self.properties, + ) + def to_api_entity(self) -> TriggerProviderSubscriptionApiEntity: return TriggerProviderSubscriptionApiEntity( id=self.id, @@ -66,7 +77,6 @@ class TriggerSubscription(Base): credentials=self.credentials, ) - # system level trigger oauth client params class TriggerOAuthSystemClient(Base): __tablename__ = "trigger_oauth_system_clients" diff --git a/api/services/trigger/trigger_provider_service.py b/api/services/trigger/trigger_provider_service.py index 7b4cd5852d..895a19097b 100644 --- a/api/services/trigger/trigger_provider_service.py +++ b/api/services/trigger/trigger_provider_service.py @@ -4,6 +4,7 @@ import re from collections.abc import Mapping from typing import Any, Optional +from flask import Request, Response from sqlalchemy import desc from sqlalchemy.orm import Session @@ -15,7 +16,11 @@ from core.plugin.entities.plugin import TriggerProviderID from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params -from core.trigger.entities.api_entities import TriggerProviderApiEntity, TriggerProviderSubscriptionApiEntity +from core.trigger.entities.api_entities import ( + SubscriptionValidation, + TriggerProviderApiEntity, + TriggerProviderSubscriptionApiEntity, +) from core.trigger.trigger_manager import TriggerManager from core.trigger.utils.encryption import ( create_trigger_provider_encrypter_for_subscription, @@ -32,7 +37,7 @@ logger = logging.getLogger(__name__) class TriggerProviderService: """Service for managing trigger providers and credentials""" - __MAX_TRIGGER_PROVIDER_COUNT__ = 100 + __MAX_TRIGGER_PROVIDER_COUNT__ = 10 @classmethod def list_trigger_providers(cls, tenant_id: str) -> list[TriggerProviderApiEntity]: @@ -553,3 +558,24 @@ class TriggerProviderService: except Exception as e: logger.warning("Error generating provider name") return f"{credential_type.get_name()} 1" + + @classmethod + def get_subscription_by_endpoint(cls, endpoint_id: str) -> TriggerSubscription | None: + """ + Get a trigger subscription by the endpoint ID. + """ + with Session(db.engine, autoflush=False) as session: + subscription = session.query(TriggerSubscription).filter_by(endpoint=endpoint_id).first() + return subscription + + @classmethod + def get_subscription_validation(cls, endpoint_id: str) -> SubscriptionValidation | None: + """ + Get a trigger subscription by the endpoint ID. + """ + cache_key = f"trigger:subscription:validation:endpoint:{endpoint_id}" + subscription_cache = redis_client.get(cache_key) + if subscription_cache: + return SubscriptionValidation.model_validate(json.loads(subscription_cache)) + + return None \ No newline at end of file diff --git a/api/services/trigger/trigger_subscription_validation_service.py b/api/services/trigger/trigger_subscription_validation_service.py new file mode 100644 index 0000000000..0a2b8ce1e3 --- /dev/null +++ b/api/services/trigger/trigger_subscription_validation_service.py @@ -0,0 +1,48 @@ +import logging + +from flask import Request, Response + +from core.plugin.entities.plugin import TriggerProviderID +from core.trigger.trigger_manager import TriggerManager +from services.trigger.trigger_provider_service import TriggerProviderService + +logger = logging.getLogger(__name__) + + +class TriggerSubscriptionValidationService: + __VALIDATION_REQUEST_CACHE_COUNT__ = 10 + __VALIDATION_REQUEST_CACHE_EXPIRE_MS__ = 5 * 60 * 1000 + + @classmethod + def append_validation_request_log(cls, endpoint_id: str, request: Request, response: Response) -> None: + """ + Append the validation request log to Redis. + """ + + + @classmethod + def process_validating_endpoint(cls, endpoint_id: str, request: Request) -> Response | None: + """ + Process a temporary endpoint request. + + :param endpoint_id: The endpoint identifier + :param request: The Flask request object + :return: The Flask response object + """ + # check if validation endpoint exists + subscription_validation = TriggerProviderService.get_subscription_validation(endpoint_id) + if not subscription_validation: + return None + + # response to validation endpoint + controller = TriggerManager.get_trigger_provider( + subscription_validation.tenant_id, TriggerProviderID(subscription_validation.provider_id) + ) + response = controller.dispatch( + user_id=subscription_validation.user_id, + request=request, + subscription=subscription_validation.to_subscription(), + ) + # append the request log + cls.append_validation_request_log(endpoint_id, request, response.response) + return response.response diff --git a/api/services/trigger_service.py b/api/services/trigger_service.py index fa56e19773..249e9fe33b 100644 --- a/api/services/trigger_service.py +++ b/api/services/trigger_service.py @@ -1,23 +1,192 @@ +import json import logging +import time +import uuid +from typing import Any, Optional from flask import Request, Response +from core.plugin.entities.plugin import TriggerProviderID +from core.trigger.trigger_manager import TriggerManager +from extensions.ext_redis import redis_client +from services.trigger.trigger_provider_service import TriggerProviderService + logger = logging.getLogger(__name__) class TriggerService: - __MAX_REQUEST_LOG_COUNT__ = 10 + __TEMPORARY_ENDPOINT_EXPIRE_MS__ = 5 * 60 * 1000 + __ENDPOINT_REQUEST_CACHE_COUNT__ = 10 + __ENDPOINT_REQUEST_CACHE_EXPIRE_MS__ = 5 * 60 * 1000 + # Lua script for atomic write with time & count based cleanup + __LUA_SCRIPT__ = """ + -- KEYS[1] = zset key + -- ARGV[1] = max_count (maximum number of entries to keep) + -- ARGV[2] = min_ts_ms (minimum timestamp to keep = now_ms - ttl_ms) + -- ARGV[3] = now_ms (current timestamp in milliseconds) + -- ARGV[4] = member (log entry JSON) + + local key = KEYS[1] + local maxCount = tonumber(ARGV[1]) + local minTs = tonumber(ARGV[2]) + local nowMs = tonumber(ARGV[3]) + local member = ARGV[4] + + -- 1) Add new entry with timestamp as score + redis.call('ZADD', key, nowMs, member) + + -- 2) Remove entries older than minTs (time-based cleanup) + redis.call('ZREMRANGEBYSCORE', key, '-inf', minTs) + + -- 3) Remove oldest entries if count exceeds maxCount (count-based cleanup) + local n = redis.call('ZCARD', key) + if n > maxCount then + redis.call('ZREMRANGEBYRANK', key, 0, n - maxCount - 1) -- 0 is oldest + end + + return n + """ @classmethod - def process_webhook(cls, webhook_id: str, request: Request) -> Response: - """Extract and process data from incoming webhook request.""" - # TODO redis slidingwindow log, save the recent request log in redis, rollover the log when the window is full + def process_endpoint(cls, endpoint_id: str, request: Request) -> Response | None: + """Extract and process data from incoming endpoint request.""" + subscription = TriggerProviderService.get_subscription_by_endpoint(endpoint_id) + if not subscription: + return None + + provider_id = TriggerProviderID(subscription.provider_id) + controller = TriggerManager.get_trigger_provider(subscription.tenant_id, provider_id) + if not controller: + return None + + dispatch_response = controller.dispatch( + user_id=subscription.user_id, request=request, subscription=subscription.to_entity() + ) - # TODO find the trigger subscription + # TODO invoke triggers + # dispatch_response.triggers - # TODO fetch the trigger controller + return dispatch_response.response - # TODO dispatch by the trigger controller + @classmethod + def log_endpoint_request(cls, endpoint_id: str, request: Request) -> int: + """ + Log the endpoint request to Redis using ZSET for rolling log with time & count based retention. - # TODO using the dispatch result(events) to invoke the trigger events - raise NotImplementedError("Not implemented") + Args: + endpoint_id: The endpoint identifier + request: The Flask request object + + Returns: + The current number of logged requests for this endpoint + """ + try: + # Prepare timestamp + now_ms = int(time.time() * 1000) + min_ts = now_ms - cls.__ENDPOINT_REQUEST_CACHE_EXPIRE_MS__ + + # Extract request data + request_data = { + "id": str(uuid.uuid4()), + "timestamp": now_ms, + "method": request.method, + "path": request.path, + "headers": dict(request.headers), + "query_params": request.args.to_dict(flat=False) if request.args else {}, + "body": None, + "remote_addr": request.remote_addr, + } + + # Try to get request body if it exists + if request.is_json: + try: + request_data["body"] = request.get_json(force=True) + except Exception: + request_data["body"] = request.get_data(as_text=True) + elif request.data: + request_data["body"] = request.get_data(as_text=True) + + # Serialize to JSON + member = json.dumps(request_data, separators=(",", ":")) + + # Execute Lua script atomically + key = f"trigger:endpoint_requests:{endpoint_id}" + count = redis_client.eval( + cls.__LUA_SCRIPT__, + 1, # number of keys + key, # KEYS[1] + str(cls.__ENDPOINT_REQUEST_CACHE_COUNT__), # ARGV[1] - max count + str(min_ts), # ARGV[2] - minimum timestamp + str(now_ms), # ARGV[3] - current timestamp + member, # ARGV[4] - log entry + ) + + logger.debug("Logged request for endpoint %s, current count: %s", endpoint_id, count) + return count + + except Exception as e: + logger.exception("Failed to log endpoint request for %s", endpoint_id, exc_info=e) + # Don't fail the main request processing if logging fails + return 0 + + @classmethod + def get_recent_endpoint_requests( + cls, endpoint_id: str, limit: int = 100, start_time_ms: Optional[int] = None, end_time_ms: Optional[int] = None + ) -> list[dict[str, Any]]: + """ + Retrieve recent logged requests for an endpoint. + + Args: + endpoint_id: The endpoint identifier + limit: Maximum number of entries to return + start_time_ms: Start timestamp in milliseconds (optional) + end_time_ms: End timestamp in milliseconds (optional, defaults to now) + + Returns: + List of request log entries, newest first + """ + try: + key = f"trigger:endpoint_requests:{endpoint_id}" + + # Set time bounds + if end_time_ms is None: + end_time_ms = int(time.time() * 1000) + if start_time_ms is None: + start_time_ms = end_time_ms - cls.__ENDPOINT_REQUEST_CACHE_EXPIRE_MS__ + + # Get entries in reverse order (newest first) + entries = redis_client.zrevrangebyscore(key, max=end_time_ms, min=start_time_ms, start=0, num=limit) + + # Parse JSON entries + requests = [] + for entry in entries: + try: + requests.append(json.loads(entry)) + except json.JSONDecodeError: + logger.warning("Failed to parse log entry: %s", entry) + + return requests + + except Exception as e: + logger.exception("Failed to retrieve endpoint requests for %s", endpoint_id, exc_info=e) + return [] + + @classmethod + def clear_endpoint_requests(cls, endpoint_id: str) -> bool: + """ + Clear all logged requests for an endpoint. + + Args: + endpoint_id: The endpoint identifier + + Returns: + True if successful, False otherwise + """ + try: + key = f"trigger:endpoint_requests:{endpoint_id}" + redis_client.delete(key) + logger.info("Cleared request logs for endpoint %s", endpoint_id) + return True + except Exception as e: + logger.exception("Failed to clear endpoint requests for %s", endpoint_id, exc_info=e) + return False diff --git a/api/tests/unit_tests/core/plugin/utils/test_http_parser.py b/api/tests/unit_tests/core/plugin/utils/test_http_parser.py new file mode 100644 index 0000000000..934331e074 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/utils/test_http_parser.py @@ -0,0 +1,655 @@ +import pytest +from flask import Request, Response + +from core.plugin.utils.http_parser import ( + deserialize_request, + deserialize_response, + serialize_request, + serialize_response, +) + + +class TestSerializeRequest: + def test_serialize_simple_get_request(self): + # Create a simple GET request + environ = { + "REQUEST_METHOD": "GET", + "PATH_INFO": "/api/test", + "QUERY_STRING": "", + "SERVER_NAME": "localhost", + "SERVER_PORT": "8000", + "wsgi.input": None, + "wsgi.url_scheme": "http", + } + request = Request(environ) + + raw_data = serialize_request(request) + + assert raw_data.startswith(b"GET /api/test HTTP/1.1\r\n") + assert b"\r\n\r\n" in raw_data # Empty line between headers and body + + def test_serialize_request_with_query_params(self): + # Create a GET request with query parameters + environ = { + "REQUEST_METHOD": "GET", + "PATH_INFO": "/api/search", + "QUERY_STRING": "q=test&limit=10", + "SERVER_NAME": "localhost", + "SERVER_PORT": "8000", + "wsgi.input": None, + "wsgi.url_scheme": "http", + } + request = Request(environ) + + raw_data = serialize_request(request) + + assert raw_data.startswith(b"GET /api/search?q=test&limit=10 HTTP/1.1\r\n") + + def test_serialize_post_request_with_body(self): + # Create a POST request with body + from io import BytesIO + + body = b'{"name": "test", "value": 123}' + environ = { + "REQUEST_METHOD": "POST", + "PATH_INFO": "/api/data", + "QUERY_STRING": "", + "SERVER_NAME": "localhost", + "SERVER_PORT": "8000", + "wsgi.input": BytesIO(body), + "wsgi.url_scheme": "http", + "CONTENT_LENGTH": str(len(body)), + "CONTENT_TYPE": "application/json", + "HTTP_CONTENT_TYPE": "application/json", + } + request = Request(environ) + + raw_data = serialize_request(request) + + assert b"POST /api/data HTTP/1.1\r\n" in raw_data + assert b"Content-Type: application/json" in raw_data + assert raw_data.endswith(body) + + def test_serialize_request_with_custom_headers(self): + # Create a request with custom headers + environ = { + "REQUEST_METHOD": "GET", + "PATH_INFO": "/api/test", + "QUERY_STRING": "", + "SERVER_NAME": "localhost", + "SERVER_PORT": "8000", + "wsgi.input": None, + "wsgi.url_scheme": "http", + "HTTP_AUTHORIZATION": "Bearer token123", + "HTTP_X_CUSTOM_HEADER": "custom-value", + } + request = Request(environ) + + raw_data = serialize_request(request) + + assert b"Authorization: Bearer token123" in raw_data + assert b"X-Custom-Header: custom-value" in raw_data + + +class TestDeserializeRequest: + def test_deserialize_simple_get_request(self): + raw_data = b"GET /api/test HTTP/1.1\r\nHost: localhost:8000\r\n\r\n" + + request = deserialize_request(raw_data) + + assert request.method == "GET" + assert request.path == "/api/test" + assert request.headers.get("Host") == "localhost:8000" + + def test_deserialize_request_with_query_params(self): + raw_data = b"GET /api/search?q=test&limit=10 HTTP/1.1\r\nHost: example.com\r\n\r\n" + + request = deserialize_request(raw_data) + + assert request.method == "GET" + assert request.path == "/api/search" + assert request.query_string == b"q=test&limit=10" + assert request.args.get("q") == "test" + assert request.args.get("limit") == "10" + + def test_deserialize_post_request_with_body(self): + body = b'{"name": "test", "value": 123}' + raw_data = ( + b"POST /api/data HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Content-Type: application/json\r\n" + b"Content-Length: " + str(len(body)).encode() + b"\r\n" + b"\r\n" + body + ) + + request = deserialize_request(raw_data) + + assert request.method == "POST" + assert request.path == "/api/data" + assert request.content_type == "application/json" + assert request.get_data() == body + + def test_deserialize_request_with_custom_headers(self): + raw_data = ( + b"GET /api/protected HTTP/1.1\r\n" + b"Host: api.example.com\r\n" + b"Authorization: Bearer token123\r\n" + b"X-Custom-Header: custom-value\r\n" + b"User-Agent: TestClient/1.0\r\n" + b"\r\n" + ) + + request = deserialize_request(raw_data) + + assert request.method == "GET" + assert request.headers.get("Authorization") == "Bearer token123" + assert request.headers.get("X-Custom-Header") == "custom-value" + assert request.headers.get("User-Agent") == "TestClient/1.0" + + def test_deserialize_request_with_multiline_body(self): + body = b"line1\r\nline2\r\nline3" + raw_data = b"PUT /api/text HTTP/1.1\r\nHost: localhost\r\nContent-Type: text/plain\r\n\r\n" + body + + request = deserialize_request(raw_data) + + assert request.method == "PUT" + assert request.get_data() == body + + def test_deserialize_invalid_request_line(self): + raw_data = b"INVALID\r\n\r\n" # Only one part, should fail + + with pytest.raises(ValueError, match="Invalid request line"): + deserialize_request(raw_data) + + def test_roundtrip_request(self): + # Test that serialize -> deserialize produces equivalent request + from io import BytesIO + + body = b"test body content" + environ = { + "REQUEST_METHOD": "POST", + "PATH_INFO": "/api/echo", + "QUERY_STRING": "format=json", + "SERVER_NAME": "localhost", + "SERVER_PORT": "8080", + "wsgi.input": BytesIO(body), + "wsgi.url_scheme": "http", + "CONTENT_LENGTH": str(len(body)), + "CONTENT_TYPE": "text/plain", + "HTTP_CONTENT_TYPE": "text/plain", + "HTTP_X_REQUEST_ID": "req-123", + } + original_request = Request(environ) + + # Serialize and deserialize + raw_data = serialize_request(original_request) + restored_request = deserialize_request(raw_data) + + # Verify key properties are preserved + assert restored_request.method == original_request.method + assert restored_request.path == original_request.path + assert restored_request.query_string == original_request.query_string + assert restored_request.get_data() == body + assert restored_request.headers.get("X-Request-Id") == "req-123" + + +class TestSerializeResponse: + def test_serialize_simple_response(self): + response = Response("Hello, World!", status=200) + + raw_data = serialize_response(response) + + assert raw_data.startswith(b"HTTP/1.1 200 OK\r\n") + assert b"\r\n\r\n" in raw_data + assert raw_data.endswith(b"Hello, World!") + + def test_serialize_response_with_headers(self): + response = Response( + '{"status": "success"}', + status=201, + headers={ + "Content-Type": "application/json", + "X-Request-Id": "req-456", + }, + ) + + raw_data = serialize_response(response) + + assert b"HTTP/1.1 201 CREATED\r\n" in raw_data + assert b"Content-Type: application/json" in raw_data + assert b"X-Request-Id: req-456" in raw_data + assert raw_data.endswith(b'{"status": "success"}') + + def test_serialize_error_response(self): + response = Response( + "Not Found", + status=404, + headers={"Content-Type": "text/plain"}, + ) + + raw_data = serialize_response(response) + + assert b"HTTP/1.1 404 NOT FOUND\r\n" in raw_data + assert b"Content-Type: text/plain" in raw_data + assert raw_data.endswith(b"Not Found") + + def test_serialize_response_without_body(self): + response = Response(status=204) # No Content + + raw_data = serialize_response(response) + + assert b"HTTP/1.1 204 NO CONTENT\r\n" in raw_data + assert raw_data.endswith(b"\r\n\r\n") # Should end with empty line + + def test_serialize_response_with_binary_body(self): + binary_data = b"\x00\x01\x02\x03\x04\x05" + response = Response( + binary_data, + status=200, + headers={"Content-Type": "application/octet-stream"}, + ) + + raw_data = serialize_response(response) + + assert b"HTTP/1.1 200 OK\r\n" in raw_data + assert b"Content-Type: application/octet-stream" in raw_data + assert raw_data.endswith(binary_data) + + +class TestDeserializeResponse: + def test_deserialize_simple_response(self): + raw_data = b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\nHello, World!" + + response = deserialize_response(raw_data) + + assert response.status_code == 200 + assert response.get_data() == b"Hello, World!" + assert response.headers.get("Content-Type") == "text/plain" + + def test_deserialize_response_with_json(self): + body = b'{"result": "success", "data": [1, 2, 3]}' + raw_data = ( + b"HTTP/1.1 201 Created\r\n" + b"Content-Type: application/json\r\n" + b"Content-Length: " + str(len(body)).encode() + b"\r\n" + b"X-Custom-Header: test-value\r\n" + b"\r\n" + body + ) + + response = deserialize_response(raw_data) + + assert response.status_code == 201 + assert response.get_data() == body + assert response.headers.get("Content-Type") == "application/json" + assert response.headers.get("X-Custom-Header") == "test-value" + + def test_deserialize_error_response(self): + raw_data = b"HTTP/1.1 404 Not Found\r\nContent-Type: text/html\r\n\r\nPage not found" + + response = deserialize_response(raw_data) + + assert response.status_code == 404 + assert response.get_data() == b"Page not found" + + def test_deserialize_response_without_body(self): + raw_data = b"HTTP/1.1 204 No Content\r\n\r\n" + + response = deserialize_response(raw_data) + + assert response.status_code == 204 + assert response.get_data() == b"" + + def test_deserialize_response_with_multiline_body(self): + body = b"Line 1\r\nLine 2\r\nLine 3" + raw_data = b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\n" + body + + response = deserialize_response(raw_data) + + assert response.status_code == 200 + assert response.get_data() == body + + def test_deserialize_response_minimal_status_line(self): + # Test with minimal status line (no status text) + raw_data = b"HTTP/1.1 200\r\n\r\nOK" + + response = deserialize_response(raw_data) + + assert response.status_code == 200 + assert response.get_data() == b"OK" + + def test_deserialize_invalid_status_line(self): + raw_data = b"INVALID\r\n\r\n" + + with pytest.raises(ValueError, match="Invalid status line"): + deserialize_response(raw_data) + + def test_roundtrip_response(self): + # Test that serialize -> deserialize produces equivalent response + original_response = Response( + '{"message": "test"}', + status=200, + headers={ + "Content-Type": "application/json", + "X-Request-Id": "abc-123", + "Cache-Control": "no-cache", + }, + ) + + # Serialize and deserialize + raw_data = serialize_response(original_response) + restored_response = deserialize_response(raw_data) + + # Verify key properties are preserved + assert restored_response.status_code == original_response.status_code + assert restored_response.get_data() == original_response.get_data() + assert restored_response.headers.get("Content-Type") == "application/json" + assert restored_response.headers.get("X-Request-Id") == "abc-123" + assert restored_response.headers.get("Cache-Control") == "no-cache" + + +class TestEdgeCases: + def test_request_with_empty_headers(self): + raw_data = b"GET / HTTP/1.1\r\n\r\n" + + request = deserialize_request(raw_data) + + assert request.method == "GET" + assert request.path == "/" + + def test_response_with_empty_headers(self): + raw_data = b"HTTP/1.1 200 OK\r\n\r\nSuccess" + + response = deserialize_response(raw_data) + + assert response.status_code == 200 + assert response.get_data() == b"Success" + + def test_request_with_special_characters_in_path(self): + raw_data = b"GET /api/test%20path?key=%26value HTTP/1.1\r\n\r\n" + + request = deserialize_request(raw_data) + + assert request.method == "GET" + assert "/api/test%20path" in request.full_path + + def test_response_with_binary_content(self): + binary_body = bytes(range(256)) # All possible byte values + raw_data = b"HTTP/1.1 200 OK\r\nContent-Type: application/octet-stream\r\n\r\n" + binary_body + + response = deserialize_response(raw_data) + + assert response.status_code == 200 + assert response.get_data() == binary_body + + +class TestFileUploads: + def test_serialize_request_with_text_file_upload(self): + # Test multipart/form-data request with text file + from io import BytesIO + + boundary = "----WebKitFormBoundary7MA4YWxkTrZu0gW" + text_content = "Hello, this is a test file content!\nWith multiple lines." + body = ( + f"------{boundary}\r\n" + f'Content-Disposition: form-data; name="file"; filename="test.txt"\r\n' + f"Content-Type: text/plain\r\n" + f"\r\n" + f"{text_content}\r\n" + f"------{boundary}\r\n" + f'Content-Disposition: form-data; name="description"\r\n' + f"\r\n" + f"Test file upload\r\n" + f"------{boundary}--\r\n" + ).encode() + + environ = { + "REQUEST_METHOD": "POST", + "PATH_INFO": "/api/upload", + "QUERY_STRING": "", + "SERVER_NAME": "localhost", + "SERVER_PORT": "8000", + "wsgi.input": BytesIO(body), + "wsgi.url_scheme": "http", + "CONTENT_LENGTH": str(len(body)), + "CONTENT_TYPE": f"multipart/form-data; boundary={boundary}", + "HTTP_CONTENT_TYPE": f"multipart/form-data; boundary={boundary}", + } + request = Request(environ) + + raw_data = serialize_request(request) + + assert b"POST /api/upload HTTP/1.1\r\n" in raw_data + assert f"Content-Type: multipart/form-data; boundary={boundary}".encode() in raw_data + assert b"Content-Disposition: form-data; name=\"file\"; filename=\"test.txt\"" in raw_data + assert text_content.encode() in raw_data + + def test_deserialize_request_with_text_file_upload(self): + # Test deserializing multipart/form-data request with text file + boundary = "----WebKitFormBoundary7MA4YWxkTrZu0gW" + text_content = "Sample text file content\nLine 2\nLine 3" + body = ( + f"------{boundary}\r\n" + f'Content-Disposition: form-data; name="document"; filename="document.txt"\r\n' + f"Content-Type: text/plain\r\n" + f"\r\n" + f"{text_content}\r\n" + f"------{boundary}\r\n" + f'Content-Disposition: form-data; name="title"\r\n' + f"\r\n" + f"My Document\r\n" + f"------{boundary}--\r\n" + ).encode() + + raw_data = ( + b"POST /api/documents HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Content-Type: multipart/form-data; boundary=" + boundary.encode() + b"\r\n" + b"Content-Length: " + str(len(body)).encode() + b"\r\n" + b"\r\n" + body + ) + + request = deserialize_request(raw_data) + + assert request.method == "POST" + assert request.path == "/api/documents" + assert "multipart/form-data" in request.content_type + # The body should contain the multipart data + request_body = request.get_data() + assert b"document.txt" in request_body + assert text_content.encode() in request_body + + def test_serialize_request_with_binary_file_upload(self): + # Test multipart/form-data request with binary file (e.g., image) + from io import BytesIO + + boundary = "----BoundaryString123" + # Simulate a small PNG file header + binary_content = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x10\x00\x00\x00\x10" + + # Build multipart body + body_parts = [] + body_parts.append(f"------{boundary}".encode()) + body_parts.append(b'Content-Disposition: form-data; name="image"; filename="test.png"') + body_parts.append(b"Content-Type: image/png") + body_parts.append(b"") + body_parts.append(binary_content) + body_parts.append(f"------{boundary}".encode()) + body_parts.append(b'Content-Disposition: form-data; name="caption"') + body_parts.append(b"") + body_parts.append(b"Test image") + body_parts.append(f"------{boundary}--".encode()) + + body = b"\r\n".join(body_parts) + + environ = { + "REQUEST_METHOD": "POST", + "PATH_INFO": "/api/images", + "QUERY_STRING": "", + "SERVER_NAME": "localhost", + "SERVER_PORT": "8000", + "wsgi.input": BytesIO(body), + "wsgi.url_scheme": "http", + "CONTENT_LENGTH": str(len(body)), + "CONTENT_TYPE": f"multipart/form-data; boundary={boundary}", + "HTTP_CONTENT_TYPE": f"multipart/form-data; boundary={boundary}", + } + request = Request(environ) + + raw_data = serialize_request(request) + + assert b"POST /api/images HTTP/1.1\r\n" in raw_data + assert f"Content-Type: multipart/form-data; boundary={boundary}".encode() in raw_data + assert b"filename=\"test.png\"" in raw_data + assert b"Content-Type: image/png" in raw_data + assert binary_content in raw_data + + def test_deserialize_request_with_binary_file_upload(self): + # Test deserializing multipart/form-data request with binary file + boundary = "----BoundaryABC123" + # Simulate a small JPEG file header + binary_content = b"\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00" + + body_parts = [] + body_parts.append(f"------{boundary}".encode()) + body_parts.append(b'Content-Disposition: form-data; name="photo"; filename="photo.jpg"') + body_parts.append(b"Content-Type: image/jpeg") + body_parts.append(b"") + body_parts.append(binary_content) + body_parts.append(f"------{boundary}".encode()) + body_parts.append(b'Content-Disposition: form-data; name="album"') + body_parts.append(b"") + body_parts.append(b"Vacation 2024") + body_parts.append(f"------{boundary}--".encode()) + + body = b"\r\n".join(body_parts) + + raw_data = ( + b"POST /api/photos HTTP/1.1\r\n" + b"Host: api.example.com\r\n" + b"Content-Type: multipart/form-data; boundary=" + boundary.encode() + b"\r\n" + b"Content-Length: " + str(len(body)).encode() + b"\r\n" + b"Accept: application/json\r\n" + b"\r\n" + body + ) + + request = deserialize_request(raw_data) + + assert request.method == "POST" + assert request.path == "/api/photos" + assert "multipart/form-data" in request.content_type + assert request.headers.get("Accept") == "application/json" + + # Verify the binary content is preserved + request_body = request.get_data() + assert b"photo.jpg" in request_body + assert b"image/jpeg" in request_body + assert binary_content in request_body + assert b"Vacation 2024" in request_body + + def test_serialize_request_with_multiple_files(self): + # Test request with multiple file uploads + from io import BytesIO + + boundary = "----MultiFilesBoundary" + text_file = b"Text file contents" + binary_file = b"\x00\x01\x02\x03\x04\x05" + + body_parts = [] + # First file (text) + body_parts.append(f"------{boundary}".encode()) + body_parts.append(b'Content-Disposition: form-data; name="files"; filename="doc.txt"') + body_parts.append(b"Content-Type: text/plain") + body_parts.append(b"") + body_parts.append(text_file) + # Second file (binary) + body_parts.append(f"------{boundary}".encode()) + body_parts.append(b'Content-Disposition: form-data; name="files"; filename="data.bin"') + body_parts.append(b"Content-Type: application/octet-stream") + body_parts.append(b"") + body_parts.append(binary_file) + # Additional form field + body_parts.append(f"------{boundary}".encode()) + body_parts.append(b'Content-Disposition: form-data; name="folder"') + body_parts.append(b"") + body_parts.append(b"uploads/2024") + body_parts.append(f"------{boundary}--".encode()) + + body = b"\r\n".join(body_parts) + + environ = { + "REQUEST_METHOD": "POST", + "PATH_INFO": "/api/batch-upload", + "QUERY_STRING": "", + "SERVER_NAME": "localhost", + "SERVER_PORT": "8000", + "wsgi.input": BytesIO(body), + "wsgi.url_scheme": "https", + "CONTENT_LENGTH": str(len(body)), + "CONTENT_TYPE": f"multipart/form-data; boundary={boundary}", + "HTTP_CONTENT_TYPE": f"multipart/form-data; boundary={boundary}", + "HTTP_X_FORWARDED_PROTO": "https", + } + request = Request(environ) + + raw_data = serialize_request(request) + + assert b"POST /api/batch-upload HTTP/1.1\r\n" in raw_data + assert b"doc.txt" in raw_data + assert b"data.bin" in raw_data + assert text_file in raw_data + assert binary_file in raw_data + assert b"uploads/2024" in raw_data + + def test_roundtrip_file_upload_request(self): + # Test that file upload request survives serialize -> deserialize + from io import BytesIO + + boundary = "----RoundTripBoundary" + file_content = b"This is my file content with special chars: \xf0\x9f\x98\x80" + + body_parts = [] + body_parts.append(f"------{boundary}".encode()) + body_parts.append(b'Content-Disposition: form-data; name="upload"; filename="emoji.txt"') + body_parts.append(b"Content-Type: text/plain; charset=utf-8") + body_parts.append(b"") + body_parts.append(file_content) + body_parts.append(f"------{boundary}".encode()) + body_parts.append(b'Content-Disposition: form-data; name="metadata"') + body_parts.append(b"") + body_parts.append(b'{"encoding": "utf-8", "size": 42}') + body_parts.append(f"------{boundary}--".encode()) + + body = b"\r\n".join(body_parts) + + environ = { + "REQUEST_METHOD": "PUT", + "PATH_INFO": "/api/files/123", + "QUERY_STRING": "version=2", + "SERVER_NAME": "storage.example.com", + "SERVER_PORT": "443", + "wsgi.input": BytesIO(body), + "wsgi.url_scheme": "https", + "CONTENT_LENGTH": str(len(body)), + "CONTENT_TYPE": f"multipart/form-data; boundary={boundary}", + "HTTP_CONTENT_TYPE": f"multipart/form-data; boundary={boundary}", + "HTTP_AUTHORIZATION": "Bearer token123", + "HTTP_X_FORWARDED_PROTO": "https", + } + original_request = Request(environ) + + # Serialize and deserialize + raw_data = serialize_request(original_request) + restored_request = deserialize_request(raw_data) + + # Verify the request is preserved + assert restored_request.method == "PUT" + assert restored_request.path == "/api/files/123" + assert restored_request.query_string == b"version=2" + assert "multipart/form-data" in restored_request.content_type + assert boundary in restored_request.content_type + + # Verify file content is preserved + restored_body = restored_request.get_data() + assert b"emoji.txt" in restored_body + assert file_content in restored_body + assert b'{"encoding": "utf-8", "size": 42}' in restored_body