feat(trigger): enhance trigger subscription management and processing

- Refactor trigger provider classes to improve naming consistency and clarity
- Introduce new methods for managing trigger subscriptions, including validation and dispatching
- Update API endpoints to reflect changes in subscription handling
- Implement logging and request management for endpoint interactions
- Enhance data models to support subscription attributes and lifecycle management

Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Harry 2025-09-01 12:08:48 +08:00
parent 6acc77d86d
commit 2f08306695
16 changed files with 1630 additions and 101 deletions

View File

@ -19,10 +19,6 @@ from models.workflow import AppTrigger, AppTriggerStatus, WorkflowWebhookTrigger
logger = logging.getLogger(__name__)
class WebhookTriggerApi(Resource):
"""Webhook Trigger API"""
@ -87,7 +83,7 @@ class WebhookTriggerApi(Resource):
base_url = dify_config.SERVICE_API_URL
webhook_trigger.webhook_url = f"{base_url}/triggers/webhook/{webhook_trigger.webhook_id}"
webhook_trigger.webhook_debug_url = f"{base_url}/triggers/webhook-debug/{webhook_trigger.webhook_id}"
return webhook_trigger
@setup_required
@ -231,7 +227,7 @@ class AppTriggerEnableApi(Resource):
trigger.icon = url_prefix + trigger.provider_name + "/icon"
else:
trigger.icon = ""
return trigger

View File

@ -31,7 +31,7 @@ class TriggerProviderListApi(Resource):
return jsonable_encoder(TriggerProviderService.list_trigger_providers(user.current_tenant_id))
class TriggerProviderSubscriptionListApi(Resource):
class TriggerSubscriptionListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@ -54,7 +54,7 @@ class TriggerProviderSubscriptionListApi(Resource):
raise
class TriggerProviderSubscriptionsAddApi(Resource):
class TriggerSubscriptionsAddApi(Resource):
@setup_required
@login_required
@account_initialization_required
@ -99,7 +99,7 @@ class TriggerProviderSubscriptionsAddApi(Resource):
raise
class TriggerProviderSubscriptionsDeleteApi(Resource):
class TriggerSubscriptionsDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
@ -125,7 +125,7 @@ class TriggerProviderSubscriptionsDeleteApi(Resource):
raise
class TriggerProviderOAuthAuthorizeApi(Resource):
class TriggerOAuthAuthorizeApi(Resource):
@setup_required
@login_required
@account_initialization_required
@ -189,7 +189,7 @@ class TriggerProviderOAuthAuthorizeApi(Resource):
raise
class TriggerProviderOAuthCallbackApi(Resource):
class TriggerOAuthCallbackApi(Resource):
@setup_required
def get(self, provider):
"""Handle OAuth callback for trigger provider"""
@ -252,7 +252,7 @@ class TriggerProviderOAuthCallbackApi(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
class TriggerProviderOAuthRefreshTokenApi(Resource):
class TriggerOAuthRefreshTokenApi(Resource):
@setup_required
@login_required
@account_initialization_required
@ -278,7 +278,7 @@ class TriggerProviderOAuthRefreshTokenApi(Resource):
raise
class TriggerProviderOAuthClientManageApi(Resource):
class TriggerOAuthClientManageApi(Resource):
@setup_required
@login_required
@account_initialization_required
@ -381,25 +381,25 @@ class TriggerProviderOAuthClientManageApi(Resource):
# Trigger provider endpoints
api.add_resource(TriggerProviderListApi, "/workspaces/current/trigger-providers")
api.add_resource(
TriggerProviderSubscriptionListApi, "/workspaces/current/trigger-provider/subscriptions/<path:provider>/list"
TriggerSubscriptionListApi, "/workspaces/current/trigger-provider/subscriptions/<path:provider>/list"
)
api.add_resource(
TriggerProviderSubscriptionsAddApi, "/workspaces/current/trigger-provider/subscriptions/<path:provider>/add"
TriggerSubscriptionsAddApi, "/workspaces/current/trigger-provider/subscriptions/<path:provider>/add"
)
api.add_resource(
TriggerProviderSubscriptionsDeleteApi,
TriggerSubscriptionsDeleteApi,
"/workspaces/current/trigger-provider/subscriptions/<path:subscription_id>/delete",
)
# OAuth
api.add_resource(
TriggerProviderOAuthAuthorizeApi, "/workspaces/current/trigger-provider/<path:provider>/oauth/authorize"
TriggerOAuthAuthorizeApi, "/workspaces/current/trigger-provider/<path:provider>/oauth/authorize"
)
api.add_resource(TriggerProviderOAuthCallbackApi, "/oauth/plugin/<path:provider>/trigger/callback")
api.add_resource(TriggerOAuthCallbackApi, "/oauth/plugin/<path:provider>/trigger/callback")
api.add_resource(
TriggerProviderOAuthRefreshTokenApi,
TriggerOAuthRefreshTokenApi,
"/workspaces/current/trigger-provider/subscriptions/<path:subscription_id>/oauth/refresh",
)
api.add_resource(
TriggerProviderOAuthClientManageApi, "/workspaces/current/trigger-provider/<path:provider>/oauth/client"
TriggerOAuthClientManageApi, "/workspaces/current/trigger-provider/<path:provider>/oauth/client"
)

View File

@ -1,21 +1,45 @@
import logging
import re
from flask import jsonify, request
from werkzeug.exceptions import NotFound
from controllers.trigger import bp
from services.trigger.trigger_provider_service import TriggerProviderService
from services.trigger.trigger_subscription_validation_service import TriggerSubscriptionValidationService
from services.trigger_service import TriggerService
logger = logging.getLogger(__name__)
UUID_PATTERN = r"^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$"
UUID_MATCHER = re.compile(UUID_PATTERN)
@bp.route("/trigger/webhook/<string:endpoint_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
def trigger_webhook(endpoint_id: str):
@bp.route(
"/trigger/endpoint/<string:endpoint_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"]
)
@bp.route(
"/trigger/endpoint-debug/<string:endpoint_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"]
)
def trigger_endpoint(endpoint_id: str):
"""
Handle webhook trigger calls.
Handle endpoint trigger calls.
"""
# endpoint_id must be UUID
if not UUID_MATCHER.match(endpoint_id):
raise NotFound("Invalid endpoint ID")
handling_chain = [
TriggerService.process_endpoint,
TriggerSubscriptionValidationService.process_validating_endpoint,
]
try:
return TriggerService.process_webhook(endpoint_id, request)
for handler in handling_chain:
response = handler(endpoint_id, request)
if response:
break
if not response:
raise NotFound("Endpoint not found")
return response
except ValueError as e:
raise NotFound(str(e))
except Exception as e:

View File

@ -1,5 +1,6 @@
from typing import Any, Literal, Optional
from flask import Response
from pydantic import BaseModel, ConfigDict, Field, field_validator
from core.entities.provider_entities import BasicProviderConfig
@ -237,3 +238,24 @@ class RequestFetchAppInfo(BaseModel):
"""
app_id: str
class TriggerInvokeResponse(BaseModel):
event: dict[str, Any]
class PluginTriggerDispatchResponse(BaseModel):
triggers: list[str]
raw_http_response: str
class TriggerSubscriptionResponse(BaseModel):
subscription: dict[str, Any]
class TriggerValidateProviderCredentialsResponse(BaseModel):
valid: bool
message: str
error: str
class TriggerDispatchResponse:
triggers: list[str]
response: Response

View File

@ -1,14 +1,26 @@
import binascii
from typing import Any
from core.plugin.entities.plugin import TriggerProviderID
from core.plugin.entities.plugin_daemon import PluginTriggerProviderEntity
from flask import Request
from core.plugin.entities.plugin import GenericProviderID, TriggerProviderID
from core.plugin.entities.plugin_daemon import CredentialType, PluginTriggerProviderEntity
from core.plugin.entities.request import (
PluginTriggerDispatchResponse,
TriggerDispatchResponse,
TriggerInvokeResponse,
TriggerSubscriptionResponse,
TriggerValidateProviderCredentialsResponse,
)
from core.plugin.impl.base import BasePluginClient
from core.plugin.utils.http_parser import deserialize_response, serialize_request
from core.trigger.entities.entities import Subscription
class PluginTriggerManager(BasePluginClient):
def fetch_trigger_providers(self, tenant_id: str) -> list[PluginTriggerProviderEntity]:
"""
Fetch tool providers for the given tenant.
Fetch trigger providers for the given tenant.
"""
def transformer(json_response: dict[str, Any]) -> dict:
@ -31,7 +43,7 @@ class PluginTriggerManager(BasePluginClient):
for provider in response:
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
# override the provider name for each tool to plugin_id/provider_name
# override the provider name for each trigger to plugin_id/provider_name
for trigger in provider.declaration.triggers:
trigger.identity.provider = provider.declaration.identity.name
@ -39,7 +51,7 @@ class PluginTriggerManager(BasePluginClient):
def fetch_trigger_provider(self, tenant_id: str, provider_id: TriggerProviderID) -> PluginTriggerProviderEntity:
"""
Fetch tool provider for the given tenant and plugin.
Fetch trigger provider for the given tenant and plugin.
"""
def transformer(json_response: dict[str, Any]) -> dict:
@ -65,3 +77,224 @@ class PluginTriggerManager(BasePluginClient):
trigger.identity.provider = response.declaration.identity.name
return response
def invoke_trigger(
self,
tenant_id: str,
user_id: str,
provider: str,
trigger: str,
credentials: dict[str, Any],
credential_type: CredentialType,
request: Request,
parameters: dict[str, Any],
) -> TriggerInvokeResponse:
"""
Invoke a trigger with the given parameters.
"""
trigger_provider_id = GenericProviderID(provider)
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/trigger/invoke",
TriggerInvokeResponse,
data={
"user_id": user_id,
"data": {
"provider": trigger_provider_id.provider_name,
"trigger": trigger,
"credentials": credentials,
"credential_type": credential_type,
"raw_http_request": binascii.hexlify(serialize_request(request)).decode(),
"parameters": parameters,
},
},
headers={
"X-Plugin-ID": trigger_provider_id.plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return TriggerInvokeResponse(event=resp.event)
raise ValueError("No response received from plugin daemon for invoke trigger")
def validate_provider_credentials(
self, tenant_id: str, user_id: str, provider: str, credentials: dict[str, Any]
) -> TriggerValidateProviderCredentialsResponse:
"""
Validate the credentials of the trigger provider.
"""
trigger_provider_id = GenericProviderID(provider)
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/trigger/validate_credentials",
TriggerValidateProviderCredentialsResponse,
data={
"user_id": user_id,
"data": {
"provider": trigger_provider_id.provider_name,
"credentials": credentials,
},
},
headers={
"X-Plugin-ID": trigger_provider_id.plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp
return TriggerValidateProviderCredentialsResponse(valid=False, message="No response", error="No response")
def dispatch_event(
self,
tenant_id: str,
user_id: str,
provider: str,
subscription: dict[str, Any],
request: Request,
) -> TriggerDispatchResponse:
"""
Dispatch an event to triggers.
"""
trigger_provider_id = GenericProviderID(provider)
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/trigger/dispatch_event",
PluginTriggerDispatchResponse,
data={
"user_id": user_id,
"data": {
"provider": trigger_provider_id.provider_name,
"subscription": subscription,
"raw_http_request": binascii.hexlify(serialize_request(request)).decode(),
},
},
headers={
"X-Plugin-ID": trigger_provider_id.plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return TriggerDispatchResponse(
triggers=resp.triggers,
response=deserialize_response(binascii.unhexlify(resp.raw_http_response.encode())),
)
raise ValueError("No response received from plugin daemon for dispatch event")
def subscribe(
self,
tenant_id: str,
user_id: str,
provider: str,
credentials: dict[str, Any],
endpoint: str,
parameters: dict[str, Any],
) -> TriggerSubscriptionResponse:
"""
Subscribe to a trigger.
"""
trigger_provider_id = GenericProviderID(provider)
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/trigger/subscribe",
TriggerSubscriptionResponse,
data={
"user_id": user_id,
"data": {
"provider": trigger_provider_id.provider_name,
"credentials": credentials,
"endpoint": endpoint,
"parameters": parameters,
},
},
headers={
"X-Plugin-ID": trigger_provider_id.plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp
raise ValueError("No response received from plugin daemon for subscribe")
def unsubscribe(
self,
tenant_id: str,
user_id: str,
provider: str,
subscription: Subscription,
credentials: dict[str, Any],
) -> TriggerSubscriptionResponse:
"""
Unsubscribe from a trigger.
"""
trigger_provider_id = GenericProviderID(provider)
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/trigger/unsubscribe",
TriggerSubscriptionResponse,
data={
"user_id": user_id,
"data": {
"provider": trigger_provider_id.provider_name,
"subscription": subscription.model_dump(),
"credentials": credentials,
},
},
headers={
"X-Plugin-ID": trigger_provider_id.plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp
raise ValueError("No response received from plugin daemon for unsubscribe")
def refresh(
self,
tenant_id: str,
user_id: str,
provider: str,
subscription: Subscription,
credentials: dict[str, Any],
) -> TriggerSubscriptionResponse:
"""
Refresh a trigger subscription.
"""
trigger_provider_id = GenericProviderID(provider)
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/trigger/refresh",
TriggerSubscriptionResponse,
data={
"user_id": user_id,
"data": {
"provider": trigger_provider_id.provider_name,
"subscription": subscription.model_dump(),
"credentials": credentials,
},
},
headers={
"X-Plugin-ID": trigger_provider_id.plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp
raise ValueError("No response received from plugin daemon for refresh")

View File

@ -0,0 +1,182 @@
from io import BytesIO
from typing import Any
from flask import Request, Response
from werkzeug.datastructures import Headers
def serialize_request(request: Request) -> bytes:
"""
Convert a Request object to raw HTTP data.
Args:
request: The Request object to convert.
Returns:
The raw HTTP data as bytes.
"""
# Start with the request line
method = request.method
path = request.full_path
# Remove trailing ? if there's no query string
path = path.removesuffix("?")
protocol = request.headers.get("HTTP_VERSION", "HTTP/1.1")
raw_data = f"{method} {path} {protocol}\r\n".encode()
# Add headers
for header_name, header_value in request.headers.items():
raw_data += f"{header_name}: {header_value}\r\n".encode()
# Add empty line to separate headers from body
raw_data += b"\r\n"
# Add body if exists
body = request.get_data(as_text=False)
if body:
raw_data += body
return raw_data
def deserialize_request(raw_data: bytes) -> Request:
"""
Convert raw HTTP data to a Request object.
Args:
raw_data: The raw HTTP data as bytes.
Returns:
A Flask Request object.
"""
lines = raw_data.split(b"\r\n")
# Parse request line
request_line = lines[0].decode("utf-8")
parts = request_line.split(" ", 2) # Split into max 3 parts
if len(parts) < 2:
raise ValueError(f"Invalid request line: {request_line}")
method = parts[0]
path = parts[1]
protocol = parts[2] if len(parts) > 2 else "HTTP/1.1"
# Parse headers
headers = Headers()
body_start = 0
for i in range(1, len(lines)):
line = lines[i]
if line == b"":
body_start = i + 1
break
if b":" in line:
header_line = line.decode("utf-8")
name, value = header_line.split(":", 1)
headers[name.strip()] = value.strip()
# Extract body
body = b""
if body_start > 0 and body_start < len(lines):
body = b"\r\n".join(lines[body_start:])
# Create environ for Request
environ = {
"REQUEST_METHOD": method,
"PATH_INFO": path.split("?")[0] if "?" in path else path,
"QUERY_STRING": path.split("?")[1] if "?" in path else "",
"SERVER_NAME": headers.get("Host", "localhost").split(":")[0],
"SERVER_PORT": headers.get("Host", "localhost:80").split(":")[1] if ":" in headers.get("Host", "") else "80",
"SERVER_PROTOCOL": protocol,
"wsgi.input": BytesIO(body),
"wsgi.url_scheme": "https" if headers.get("X-Forwarded-Proto") == "https" else "http",
"CONTENT_LENGTH": str(len(body)) if body else "0",
"CONTENT_TYPE": headers.get("Content-Type", ""),
}
# Add headers to environ
for header_name, header_value in headers.items():
env_name = f"HTTP_{header_name.upper().replace('-', '_')}"
if header_name.upper() not in ["CONTENT-TYPE", "CONTENT-LENGTH"]:
environ[env_name] = header_value
return Request(environ)
def serialize_response(response: Response) -> bytes:
"""
Convert a Response object to raw HTTP data.
Args:
response: The Response object to convert.
Returns:
The raw HTTP data as bytes.
"""
# Start with the status line
protocol = "HTTP/1.1"
status_code = response.status_code
status_text = response.status.split(" ", 1)[1] if " " in response.status else "OK"
raw_data = f"{protocol} {status_code} {status_text}\r\n".encode()
# Add headers
for header_name, header_value in response.headers.items():
raw_data += f"{header_name}: {header_value}\r\n".encode()
# Add empty line to separate headers from body
raw_data += b"\r\n"
# Add body if exists
body = response.get_data(as_text=False)
if body:
raw_data += body
return raw_data
def deserialize_response(raw_data: bytes) -> Response:
"""
Convert raw HTTP data to a Response object.
Args:
raw_data: The raw HTTP data as bytes.
Returns:
A Flask Response object.
"""
lines = raw_data.split(b"\r\n")
# Parse status line
status_line = lines[0].decode("utf-8")
parts = status_line.split(" ", 2)
if len(parts) < 2:
raise ValueError(f"Invalid status line: {status_line}")
protocol = parts[0]
status_code = int(parts[1])
status_text = parts[2] if len(parts) > 2 else "OK"
# Parse headers
headers: dict[str, Any] = {}
body_start = 0
for i in range(1, len(lines)):
line = lines[i]
if line == b"":
body_start = i + 1
break
if b":" in line:
header_line = line.decode("utf-8")
name, value = header_line.split(":", 1)
headers[name.strip()] = value.strip()
# Extract body
body = b""
if body_start > 0 and body_start < len(lines):
body = b"\r\n".join(lines[body_start:])
# Create Response object
response = Response(
response=body,
status=status_code,
headers=headers,
)
return response

View File

@ -7,6 +7,7 @@ from core.entities.provider_entities import ProviderConfig
from core.plugin.entities.plugin_daemon import CredentialType
from core.trigger.entities.entities import (
OAuthSchema,
Subscription,
SubscriptionSchema,
TriggerDescription,
TriggerEntity,
@ -39,5 +40,27 @@ class TriggerApiEntity(BaseModel):
parameters: list[TriggerParameter] = Field(description="The parameters of the trigger")
output_schema: Optional[Mapping[str, Any]] = Field(description="The output schema of the trigger")
class SubscriptionValidation(BaseModel):
id: str
name: str
tenant_id: str
user_id: str
provider_id: str
endpoint: str
parameters: dict
properties: dict
credentials: dict
credential_type: str
credential_expires_at: int
expires_at: int
def to_subscription(self) -> Subscription:
return Subscription(
expires_at=self.expires_at,
endpoint=self.endpoint,
parameters=self.parameters,
properties=self.properties,
)
__all__ = ["TriggerApiEntity", "TriggerProviderApiEntity", "TriggerProviderSubscriptionApiEntity"]

View File

@ -143,12 +143,19 @@ class Subscription(BaseModel):
Contains all information needed to manage the subscription lifecycle.
"""
expire_at: int = Field(
expires_at: int = Field(
..., description="The timestamp when the subscription will expire, this for refresh the subscription"
)
metadata: dict[str, Any] = Field(
..., description="Metadata about the subscription in the external service, defined in subscription_schema"
endpoint: str = Field(..., description="The webhook endpoint URL allocated by Dify for receiving events")
parameters: dict[str, Any] | None = Field(
default=None,
description="""The parameters of the subscription, this is the creation parameters.
Only available when creating a new subscription by credentials(auto subscription), not manual subscription""",
)
properties: dict[str, Any] = Field(
..., description="Subscription data containing all properties and provider-specific information"
)

View File

@ -3,12 +3,19 @@ Trigger Provider Controller for managing trigger providers
"""
import logging
import time
from typing import Optional
from typing import Any, Optional
from flask import Request
from core.entities.provider_entities import BasicProviderConfig
from core.plugin.entities.plugin import TriggerProviderID
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.entities.request import (
TriggerDispatchResponse,
TriggerInvokeResponse,
TriggerValidateProviderCredentialsResponse,
)
from core.plugin.impl.trigger import PluginTriggerManager
from core.trigger.entities.api_entities import TriggerProviderApiEntity
from core.trigger.entities.entities import (
ProviderConfig,
@ -90,18 +97,36 @@ class PluginTriggerProviderController:
:return: List of subscription config schemas
"""
return self.entity.subscription_schema
# Return the parameters schema from the subscription schema
if self.entity.subscription_schema and self.entity.subscription_schema.parameters_schema:
return self.entity.subscription_schema.parameters_schema
return []
def validate_credentials(self, credentials: dict) -> None:
def validate_credentials(self, credentials: dict) -> TriggerValidateProviderCredentialsResponse:
"""
Validate credentials against schema
:param credentials: Credentials to validate
:raises ValueError: If credentials are invalid
:return: Validation response
"""
# First validate against schema
for config in self.entity.credentials_schema:
if config.required and config.name not in credentials:
raise ValueError(f"Missing required credential field: {config.name}")
return TriggerValidateProviderCredentialsResponse(
valid=False,
message=f"Missing required credential field: {config.name}",
error=f"Missing required credential field: {config.name}",
)
# Then validate with the plugin daemon
manager = PluginTriggerManager()
provider_id = self.get_provider_id()
return manager.validate_provider_credentials(
tenant_id=self.tenant_id,
user_id="system", # System validation
provider=str(provider_id),
credentials=credentials,
)
def get_supported_credential_types(self) -> list[CredentialType]:
"""
@ -143,58 +168,127 @@ class PluginTriggerProviderController:
"""
return self.entity.oauth_schema.client_schema.copy() if self.entity.oauth_schema else []
@property
def need_credentials(self) -> bool:
"""Check if this provider needs credentials"""
return len(self.get_supported_credential_types()) > 0
def dispatch(self,user_id: str, request: Request, subscription: Subscription) -> TriggerDispatchResponse:
"""
Dispatch a trigger through plugin runtime
def execute_trigger(self, trigger_name: str, parameters: dict, credentials: dict) -> dict:
:param user_id: User ID
:param request: Flask request object
:param subscription: Subscription
:return: Dispatch response with triggers and raw HTTP response
"""
manager = PluginTriggerManager()
provider_id = self.get_provider_id()
response = manager.dispatch_event(
tenant_id=self.tenant_id,
user_id=user_id,
provider=str(provider_id),
subscription=subscription.model_dump(),
request=request,
)
return response
def invoke_trigger(
self,
user_id: str,
trigger_name: str,
parameters: dict,
credentials: dict,
credential_type: CredentialType,
request: Request,
) -> TriggerInvokeResponse:
"""
Execute a trigger through plugin runtime
:param user_id: User ID
:param trigger_name: Trigger name
:param parameters: Trigger parameters
:param credentials: Provider credentials
:return: Execution result
:param credential_type: Credential type
:param request: Request
:return: Trigger execution result
"""
logger.info("Executing trigger %s for plugin %s", trigger_name, self.plugin_id)
return {
"success": True,
"trigger": trigger_name,
"plugin": self.plugin_id,
"result": "Trigger executed successfully",
}
manager = PluginTriggerManager()
provider_id = self.get_provider_id()
def subscribe_trigger(self, trigger_name: str, subscription_params: dict, credentials: dict) -> Subscription:
return manager.invoke_trigger(
tenant_id=self.tenant_id,
user_id=user_id,
provider=str(provider_id),
trigger=trigger_name,
credentials=credentials,
credential_type=credential_type,
request=request,
parameters=parameters,
)
def subscribe_trigger(self, user_id: str, endpoint: str, parameters: dict, credentials: dict) -> Subscription:
"""
Subscribe to a trigger through plugin runtime
:param trigger_name: Trigger name
:param user_id: User ID
:param endpoint: Subscription endpoint
:param subscription_params: Subscription parameters
:param credentials: Provider credentials
:return: Subscription result
"""
logger.info("Subscribing to trigger %s for plugin %s", trigger_name, self.plugin_id)
return Subscription(
expire_at=int(time.time()) + 86400, # 24 hours from now
metadata={
"subscription_id": f"{self.plugin_id}_{trigger_name}_{time.time()}",
"webhook_url": f"/triggers/webhook/{self.plugin_id}/{trigger_name}",
**subscription_params,
},
manager = PluginTriggerManager()
provider_id = self.get_provider_id()
response = manager.subscribe(
tenant_id=self.tenant_id,
user_id=user_id,
provider=str(provider_id),
credentials=credentials,
endpoint=endpoint,
parameters=parameters,
)
def unsubscribe_trigger(self, trigger_name: str, subscription_metadata: dict, credentials: dict) -> Unsubscription:
return Subscription.model_validate(response.subscription)
def unsubscribe_trigger(self, user_id: str, subscription: Subscription, credentials: dict) -> Unsubscription:
"""
Unsubscribe from a trigger through plugin runtime
:param trigger_name: Trigger name
:param subscription_metadata: Subscription metadata
:param user_id: User ID
:param subscription: Subscription metadata
:param credentials: Provider credentials
:return: Unsubscription result
"""
logger.info("Unsubscribing from trigger %s for plugin %s", trigger_name, self.plugin_id)
return Unsubscription(success=True, message=f"Successfully unsubscribed from trigger {trigger_name}")
manager = PluginTriggerManager()
provider_id = self.get_provider_id()
response = manager.unsubscribe(
tenant_id=self.tenant_id,
user_id=user_id,
provider=str(provider_id),
subscription=subscription,
credentials=credentials,
)
return Unsubscription.model_validate(response.subscription)
def refresh_trigger(self, subscription: Subscription, credentials: dict) -> Subscription:
"""
Refresh a trigger subscription through plugin runtime
:param subscription: Subscription metadata
:param credentials: Provider credentials
:return: Refreshed subscription result
"""
manager = PluginTriggerManager()
provider_id = self.get_provider_id()
response = manager.refresh(
tenant_id=self.tenant_id,
user_id="system", # System refresh
provider=str(provider_id),
subscription=subscription,
credentials=credentials,
)
return Subscription.model_validate(response.subscription)
__all__ = ["PluginTriggerProviderController"]

View File

@ -5,7 +5,10 @@ Trigger Manager for loading and managing trigger providers and triggers
import logging
from typing import Optional
from flask import Request
from core.plugin.entities.plugin import TriggerProviderID
from core.plugin.entities.request import TriggerInvokeResponse
from core.plugin.impl.trigger import PluginTriggerManager
from core.trigger.entities.entities import (
ProviderConfig,
@ -14,6 +17,7 @@ from core.trigger.entities.entities import (
Unsubscription,
)
from core.trigger.provider import PluginTriggerProviderController
from core.plugin.entities.plugin_daemon import CredentialType
logger = logging.getLogger(__name__)
@ -123,75 +127,90 @@ class TriggerManager:
:return: Tuple of (is_valid, error_message)
"""
try:
cls.get_trigger_provider(tenant_id, provider_id).validate_credentials(credentials)
return True, ""
provider = cls.get_trigger_provider(tenant_id, provider_id)
validation_result = provider.validate_credentials(credentials)
return validation_result.valid, validation_result.message if not validation_result.valid else ""
except Exception as e:
return False, str(e)
@classmethod
def execute_trigger(
cls, tenant_id: str, provider_id: TriggerProviderID, trigger_name: str, parameters: dict, credentials: dict
) -> dict:
def invoke_trigger(
cls,
tenant_id: str,
user_id: str,
provider_id: TriggerProviderID,
trigger_name: str,
parameters: dict,
credentials: dict,
credential_type: CredentialType,
request: Request,
) -> TriggerInvokeResponse:
"""
Execute a trigger
:param tenant_id: Tenant ID
:param user_id: User ID
:param provider_id: Provider ID
:param trigger_name: Trigger name
:param parameters: Trigger parameters
:param credentials: Provider credentials
:param credential_type: Credential type
:param request: Request
:return: Trigger execution result
"""
trigger = cls.get_trigger_provider(tenant_id, provider_id).get_trigger(trigger_name)
provider = cls.get_trigger_provider(tenant_id, provider_id)
trigger = provider.get_trigger(trigger_name)
if not trigger:
raise ValueError(f"Trigger {trigger_name} not found in provider {provider_id}")
return cls.get_trigger_provider(tenant_id, provider_id).execute_trigger(trigger_name, parameters, credentials)
return provider.invoke_trigger(user_id, trigger_name, parameters, credentials, credential_type, request)
@classmethod
def subscribe_trigger(
cls,
tenant_id: str,
user_id: str,
provider_id: TriggerProviderID,
trigger_name: str,
subscription_params: dict,
endpoint: str,
parameters: dict,
credentials: dict,
) -> Subscription:
"""
Subscribe to a trigger (e.g., register webhook)
:param tenant_id: Tenant ID
:param user_id: User ID
:param provider_id: Provider ID
:param trigger_name: Trigger name
:param subscription_params: Subscription parameters
:param endpoint: Subscription endpoint
:param parameters: Subscription parameters
:param credentials: Provider credentials
:return: Subscription result
"""
return cls.get_trigger_provider(tenant_id, provider_id).subscribe_trigger(
trigger_name, subscription_params, credentials
provider = cls.get_trigger_provider(tenant_id, provider_id)
return provider.subscribe_trigger(
user_id=user_id, endpoint=endpoint, parameters=parameters, credentials=credentials
)
@classmethod
def unsubscribe_trigger(
cls,
tenant_id: str,
user_id: str,
provider_id: TriggerProviderID,
trigger_name: str,
subscription_metadata: dict,
subscription: Subscription,
credentials: dict,
) -> Unsubscription:
"""
Unsubscribe from a trigger
:param tenant_id: Tenant ID
:param user_id: User ID
:param provider_id: Provider ID
:param trigger_name: Trigger name
:param subscription_metadata: Subscription metadata from subscribe operation
:param subscription: Subscription metadata from subscribe operation
:param credentials: Provider credentials
:return: Unsubscription result
"""
return cls.get_trigger_provider(tenant_id, provider_id).unsubscribe_trigger(
trigger_name, subscription_metadata, credentials
)
provider = cls.get_trigger_provider(tenant_id, provider_id)
return provider.unsubscribe_trigger(user_id=user_id, subscription=subscription, credentials=credentials)
@classmethod
def get_provider_subscription_schema(cls, tenant_id: str, provider_id: TriggerProviderID) -> list[ProviderConfig]:
@ -204,6 +223,27 @@ class TriggerManager:
"""
return cls.get_trigger_provider(tenant_id, provider_id).get_subscription_schema()
@classmethod
def refresh_trigger(
cls,
tenant_id: str,
provider_id: TriggerProviderID,
trigger_name: str,
subscription: Subscription,
credentials: dict,
) -> Subscription:
"""
Refresh a trigger subscription
:param tenant_id: Tenant ID
:param provider_id: Provider ID
:param trigger_name: Trigger name
:param subscription: Subscription metadata from subscribe operation
:param credentials: Provider credentials
:return: Refreshed subscription result
"""
return cls.get_trigger_provider(tenant_id, provider_id).refresh_trigger(trigger_name, subscription, credentials)
# Export
__all__ = ["TriggerManager"]

View File

@ -23,4 +23,4 @@ webhook_trigger_fields = {
"node_id": fields.String,
"triggered_by": fields.String,
"created_at": fields.DateTime(dt_format="iso8601"),
}
}

View File

@ -8,7 +8,8 @@ from sqlalchemy import DateTime, Index, Integer, String, UniqueConstraint, func
from sqlalchemy.orm import Mapped, mapped_column
from core.plugin.entities.plugin_daemon import CredentialType
from core.trigger.entities.api_entities import TriggerProviderSubscriptionApiEntity
from core.trigger.entities.api_entities import SubscriptionValidation, TriggerProviderSubscriptionApiEntity
from core.trigger.entities.entities import Subscription
from models.base import Base
from models.types import StringUUID
@ -24,6 +25,7 @@ class TriggerSubscription(Base):
sa.PrimaryKeyConstraint("id", name="trigger_subscription_pkey"),
Index("idx_trigger_subscriptions_tenant_provider", "tenant_id", "provider_id"),
UniqueConstraint("tenant_id", "provider_id", "name", name="unique_trigger_subscription"),
UniqueConstraint("endpoint", name="unique_trigger_subscription_endpoint"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
@ -33,8 +35,9 @@ class TriggerSubscription(Base):
provider_id: Mapped[str] = mapped_column(
String(255), nullable=False, comment="Provider identifier (e.g., plugin_id/provider_name)"
)
endpoint: Mapped[str] = mapped_column(String(255), nullable=False, comment="Subscription endpoint")
parameters: Mapped[dict] = mapped_column(sa.JSON, nullable=False, comment="Subscription parameters JSON")
configuration: Mapped[dict] = mapped_column(sa.JSON, nullable=False, comment="Subscription configuration JSON")
properties: Mapped[dict] = mapped_column(sa.JSON, nullable=False, comment="Subscription properties JSON")
credentials: Mapped[dict] = mapped_column(sa.JSON, nullable=False, comment="Subscription credentials JSON")
credential_type: Mapped[str] = mapped_column(String(50), nullable=False, comment="oauth or api_key")
@ -57,6 +60,14 @@ class TriggerSubscription(Base):
# Check if token expires in next 3 minutes
return (self.credential_expires_at - 180) < int(time.time())
def to_entity(self) -> Subscription:
return Subscription(
expires_at=self.expires_at,
endpoint=self.endpoint,
parameters=self.parameters,
properties=self.properties,
)
def to_api_entity(self) -> TriggerProviderSubscriptionApiEntity:
return TriggerProviderSubscriptionApiEntity(
id=self.id,
@ -66,7 +77,6 @@ class TriggerSubscription(Base):
credentials=self.credentials,
)
# system level trigger oauth client params
class TriggerOAuthSystemClient(Base):
__tablename__ = "trigger_oauth_system_clients"

View File

@ -4,6 +4,7 @@ import re
from collections.abc import Mapping
from typing import Any, Optional
from flask import Request, Response
from sqlalchemy import desc
from sqlalchemy.orm import Session
@ -15,7 +16,11 @@ from core.plugin.entities.plugin import TriggerProviderID
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.oauth import OAuthHandler
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
from core.trigger.entities.api_entities import TriggerProviderApiEntity, TriggerProviderSubscriptionApiEntity
from core.trigger.entities.api_entities import (
SubscriptionValidation,
TriggerProviderApiEntity,
TriggerProviderSubscriptionApiEntity,
)
from core.trigger.trigger_manager import TriggerManager
from core.trigger.utils.encryption import (
create_trigger_provider_encrypter_for_subscription,
@ -32,7 +37,7 @@ logger = logging.getLogger(__name__)
class TriggerProviderService:
"""Service for managing trigger providers and credentials"""
__MAX_TRIGGER_PROVIDER_COUNT__ = 100
__MAX_TRIGGER_PROVIDER_COUNT__ = 10
@classmethod
def list_trigger_providers(cls, tenant_id: str) -> list[TriggerProviderApiEntity]:
@ -553,3 +558,24 @@ class TriggerProviderService:
except Exception as e:
logger.warning("Error generating provider name")
return f"{credential_type.get_name()} 1"
@classmethod
def get_subscription_by_endpoint(cls, endpoint_id: str) -> TriggerSubscription | None:
"""
Get a trigger subscription by the endpoint ID.
"""
with Session(db.engine, autoflush=False) as session:
subscription = session.query(TriggerSubscription).filter_by(endpoint=endpoint_id).first()
return subscription
@classmethod
def get_subscription_validation(cls, endpoint_id: str) -> SubscriptionValidation | None:
"""
Get a trigger subscription by the endpoint ID.
"""
cache_key = f"trigger:subscription:validation:endpoint:{endpoint_id}"
subscription_cache = redis_client.get(cache_key)
if subscription_cache:
return SubscriptionValidation.model_validate(json.loads(subscription_cache))
return None

View File

@ -0,0 +1,48 @@
import logging
from flask import Request, Response
from core.plugin.entities.plugin import TriggerProviderID
from core.trigger.trigger_manager import TriggerManager
from services.trigger.trigger_provider_service import TriggerProviderService
logger = logging.getLogger(__name__)
class TriggerSubscriptionValidationService:
__VALIDATION_REQUEST_CACHE_COUNT__ = 10
__VALIDATION_REQUEST_CACHE_EXPIRE_MS__ = 5 * 60 * 1000
@classmethod
def append_validation_request_log(cls, endpoint_id: str, request: Request, response: Response) -> None:
"""
Append the validation request log to Redis.
"""
@classmethod
def process_validating_endpoint(cls, endpoint_id: str, request: Request) -> Response | None:
"""
Process a temporary endpoint request.
:param endpoint_id: The endpoint identifier
:param request: The Flask request object
:return: The Flask response object
"""
# check if validation endpoint exists
subscription_validation = TriggerProviderService.get_subscription_validation(endpoint_id)
if not subscription_validation:
return None
# response to validation endpoint
controller = TriggerManager.get_trigger_provider(
subscription_validation.tenant_id, TriggerProviderID(subscription_validation.provider_id)
)
response = controller.dispatch(
user_id=subscription_validation.user_id,
request=request,
subscription=subscription_validation.to_subscription(),
)
# append the request log
cls.append_validation_request_log(endpoint_id, request, response.response)
return response.response

View File

@ -1,23 +1,192 @@
import json
import logging
import time
import uuid
from typing import Any, Optional
from flask import Request, Response
from core.plugin.entities.plugin import TriggerProviderID
from core.trigger.trigger_manager import TriggerManager
from extensions.ext_redis import redis_client
from services.trigger.trigger_provider_service import TriggerProviderService
logger = logging.getLogger(__name__)
class TriggerService:
__MAX_REQUEST_LOG_COUNT__ = 10
__TEMPORARY_ENDPOINT_EXPIRE_MS__ = 5 * 60 * 1000
__ENDPOINT_REQUEST_CACHE_COUNT__ = 10
__ENDPOINT_REQUEST_CACHE_EXPIRE_MS__ = 5 * 60 * 1000
# Lua script for atomic write with time & count based cleanup
__LUA_SCRIPT__ = """
-- KEYS[1] = zset key
-- ARGV[1] = max_count (maximum number of entries to keep)
-- ARGV[2] = min_ts_ms (minimum timestamp to keep = now_ms - ttl_ms)
-- ARGV[3] = now_ms (current timestamp in milliseconds)
-- ARGV[4] = member (log entry JSON)
local key = KEYS[1]
local maxCount = tonumber(ARGV[1])
local minTs = tonumber(ARGV[2])
local nowMs = tonumber(ARGV[3])
local member = ARGV[4]
-- 1) Add new entry with timestamp as score
redis.call('ZADD', key, nowMs, member)
-- 2) Remove entries older than minTs (time-based cleanup)
redis.call('ZREMRANGEBYSCORE', key, '-inf', minTs)
-- 3) Remove oldest entries if count exceeds maxCount (count-based cleanup)
local n = redis.call('ZCARD', key)
if n > maxCount then
redis.call('ZREMRANGEBYRANK', key, 0, n - maxCount - 1) -- 0 is oldest
end
return n
"""
@classmethod
def process_webhook(cls, webhook_id: str, request: Request) -> Response:
"""Extract and process data from incoming webhook request."""
# TODO redis slidingwindow log, save the recent request log in redis, rollover the log when the window is full
def process_endpoint(cls, endpoint_id: str, request: Request) -> Response | None:
"""Extract and process data from incoming endpoint request."""
subscription = TriggerProviderService.get_subscription_by_endpoint(endpoint_id)
if not subscription:
return None
provider_id = TriggerProviderID(subscription.provider_id)
controller = TriggerManager.get_trigger_provider(subscription.tenant_id, provider_id)
if not controller:
return None
dispatch_response = controller.dispatch(
user_id=subscription.user_id, request=request, subscription=subscription.to_entity()
)
# TODO find the trigger subscription
# TODO invoke triggers
# dispatch_response.triggers
# TODO fetch the trigger controller
return dispatch_response.response
# TODO dispatch by the trigger controller
@classmethod
def log_endpoint_request(cls, endpoint_id: str, request: Request) -> int:
"""
Log the endpoint request to Redis using ZSET for rolling log with time & count based retention.
# TODO using the dispatch result(events) to invoke the trigger events
raise NotImplementedError("Not implemented")
Args:
endpoint_id: The endpoint identifier
request: The Flask request object
Returns:
The current number of logged requests for this endpoint
"""
try:
# Prepare timestamp
now_ms = int(time.time() * 1000)
min_ts = now_ms - cls.__ENDPOINT_REQUEST_CACHE_EXPIRE_MS__
# Extract request data
request_data = {
"id": str(uuid.uuid4()),
"timestamp": now_ms,
"method": request.method,
"path": request.path,
"headers": dict(request.headers),
"query_params": request.args.to_dict(flat=False) if request.args else {},
"body": None,
"remote_addr": request.remote_addr,
}
# Try to get request body if it exists
if request.is_json:
try:
request_data["body"] = request.get_json(force=True)
except Exception:
request_data["body"] = request.get_data(as_text=True)
elif request.data:
request_data["body"] = request.get_data(as_text=True)
# Serialize to JSON
member = json.dumps(request_data, separators=(",", ":"))
# Execute Lua script atomically
key = f"trigger:endpoint_requests:{endpoint_id}"
count = redis_client.eval(
cls.__LUA_SCRIPT__,
1, # number of keys
key, # KEYS[1]
str(cls.__ENDPOINT_REQUEST_CACHE_COUNT__), # ARGV[1] - max count
str(min_ts), # ARGV[2] - minimum timestamp
str(now_ms), # ARGV[3] - current timestamp
member, # ARGV[4] - log entry
)
logger.debug("Logged request for endpoint %s, current count: %s", endpoint_id, count)
return count
except Exception as e:
logger.exception("Failed to log endpoint request for %s", endpoint_id, exc_info=e)
# Don't fail the main request processing if logging fails
return 0
@classmethod
def get_recent_endpoint_requests(
cls, endpoint_id: str, limit: int = 100, start_time_ms: Optional[int] = None, end_time_ms: Optional[int] = None
) -> list[dict[str, Any]]:
"""
Retrieve recent logged requests for an endpoint.
Args:
endpoint_id: The endpoint identifier
limit: Maximum number of entries to return
start_time_ms: Start timestamp in milliseconds (optional)
end_time_ms: End timestamp in milliseconds (optional, defaults to now)
Returns:
List of request log entries, newest first
"""
try:
key = f"trigger:endpoint_requests:{endpoint_id}"
# Set time bounds
if end_time_ms is None:
end_time_ms = int(time.time() * 1000)
if start_time_ms is None:
start_time_ms = end_time_ms - cls.__ENDPOINT_REQUEST_CACHE_EXPIRE_MS__
# Get entries in reverse order (newest first)
entries = redis_client.zrevrangebyscore(key, max=end_time_ms, min=start_time_ms, start=0, num=limit)
# Parse JSON entries
requests = []
for entry in entries:
try:
requests.append(json.loads(entry))
except json.JSONDecodeError:
logger.warning("Failed to parse log entry: %s", entry)
return requests
except Exception as e:
logger.exception("Failed to retrieve endpoint requests for %s", endpoint_id, exc_info=e)
return []
@classmethod
def clear_endpoint_requests(cls, endpoint_id: str) -> bool:
"""
Clear all logged requests for an endpoint.
Args:
endpoint_id: The endpoint identifier
Returns:
True if successful, False otherwise
"""
try:
key = f"trigger:endpoint_requests:{endpoint_id}"
redis_client.delete(key)
logger.info("Cleared request logs for endpoint %s", endpoint_id)
return True
except Exception as e:
logger.exception("Failed to clear endpoint requests for %s", endpoint_id, exc_info=e)
return False

View File

@ -0,0 +1,655 @@
import pytest
from flask import Request, Response
from core.plugin.utils.http_parser import (
deserialize_request,
deserialize_response,
serialize_request,
serialize_response,
)
class TestSerializeRequest:
def test_serialize_simple_get_request(self):
# Create a simple GET request
environ = {
"REQUEST_METHOD": "GET",
"PATH_INFO": "/api/test",
"QUERY_STRING": "",
"SERVER_NAME": "localhost",
"SERVER_PORT": "8000",
"wsgi.input": None,
"wsgi.url_scheme": "http",
}
request = Request(environ)
raw_data = serialize_request(request)
assert raw_data.startswith(b"GET /api/test HTTP/1.1\r\n")
assert b"\r\n\r\n" in raw_data # Empty line between headers and body
def test_serialize_request_with_query_params(self):
# Create a GET request with query parameters
environ = {
"REQUEST_METHOD": "GET",
"PATH_INFO": "/api/search",
"QUERY_STRING": "q=test&limit=10",
"SERVER_NAME": "localhost",
"SERVER_PORT": "8000",
"wsgi.input": None,
"wsgi.url_scheme": "http",
}
request = Request(environ)
raw_data = serialize_request(request)
assert raw_data.startswith(b"GET /api/search?q=test&limit=10 HTTP/1.1\r\n")
def test_serialize_post_request_with_body(self):
# Create a POST request with body
from io import BytesIO
body = b'{"name": "test", "value": 123}'
environ = {
"REQUEST_METHOD": "POST",
"PATH_INFO": "/api/data",
"QUERY_STRING": "",
"SERVER_NAME": "localhost",
"SERVER_PORT": "8000",
"wsgi.input": BytesIO(body),
"wsgi.url_scheme": "http",
"CONTENT_LENGTH": str(len(body)),
"CONTENT_TYPE": "application/json",
"HTTP_CONTENT_TYPE": "application/json",
}
request = Request(environ)
raw_data = serialize_request(request)
assert b"POST /api/data HTTP/1.1\r\n" in raw_data
assert b"Content-Type: application/json" in raw_data
assert raw_data.endswith(body)
def test_serialize_request_with_custom_headers(self):
# Create a request with custom headers
environ = {
"REQUEST_METHOD": "GET",
"PATH_INFO": "/api/test",
"QUERY_STRING": "",
"SERVER_NAME": "localhost",
"SERVER_PORT": "8000",
"wsgi.input": None,
"wsgi.url_scheme": "http",
"HTTP_AUTHORIZATION": "Bearer token123",
"HTTP_X_CUSTOM_HEADER": "custom-value",
}
request = Request(environ)
raw_data = serialize_request(request)
assert b"Authorization: Bearer token123" in raw_data
assert b"X-Custom-Header: custom-value" in raw_data
class TestDeserializeRequest:
def test_deserialize_simple_get_request(self):
raw_data = b"GET /api/test HTTP/1.1\r\nHost: localhost:8000\r\n\r\n"
request = deserialize_request(raw_data)
assert request.method == "GET"
assert request.path == "/api/test"
assert request.headers.get("Host") == "localhost:8000"
def test_deserialize_request_with_query_params(self):
raw_data = b"GET /api/search?q=test&limit=10 HTTP/1.1\r\nHost: example.com\r\n\r\n"
request = deserialize_request(raw_data)
assert request.method == "GET"
assert request.path == "/api/search"
assert request.query_string == b"q=test&limit=10"
assert request.args.get("q") == "test"
assert request.args.get("limit") == "10"
def test_deserialize_post_request_with_body(self):
body = b'{"name": "test", "value": 123}'
raw_data = (
b"POST /api/data HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Content-Type: application/json\r\n"
b"Content-Length: " + str(len(body)).encode() + b"\r\n"
b"\r\n" + body
)
request = deserialize_request(raw_data)
assert request.method == "POST"
assert request.path == "/api/data"
assert request.content_type == "application/json"
assert request.get_data() == body
def test_deserialize_request_with_custom_headers(self):
raw_data = (
b"GET /api/protected HTTP/1.1\r\n"
b"Host: api.example.com\r\n"
b"Authorization: Bearer token123\r\n"
b"X-Custom-Header: custom-value\r\n"
b"User-Agent: TestClient/1.0\r\n"
b"\r\n"
)
request = deserialize_request(raw_data)
assert request.method == "GET"
assert request.headers.get("Authorization") == "Bearer token123"
assert request.headers.get("X-Custom-Header") == "custom-value"
assert request.headers.get("User-Agent") == "TestClient/1.0"
def test_deserialize_request_with_multiline_body(self):
body = b"line1\r\nline2\r\nline3"
raw_data = b"PUT /api/text HTTP/1.1\r\nHost: localhost\r\nContent-Type: text/plain\r\n\r\n" + body
request = deserialize_request(raw_data)
assert request.method == "PUT"
assert request.get_data() == body
def test_deserialize_invalid_request_line(self):
raw_data = b"INVALID\r\n\r\n" # Only one part, should fail
with pytest.raises(ValueError, match="Invalid request line"):
deserialize_request(raw_data)
def test_roundtrip_request(self):
# Test that serialize -> deserialize produces equivalent request
from io import BytesIO
body = b"test body content"
environ = {
"REQUEST_METHOD": "POST",
"PATH_INFO": "/api/echo",
"QUERY_STRING": "format=json",
"SERVER_NAME": "localhost",
"SERVER_PORT": "8080",
"wsgi.input": BytesIO(body),
"wsgi.url_scheme": "http",
"CONTENT_LENGTH": str(len(body)),
"CONTENT_TYPE": "text/plain",
"HTTP_CONTENT_TYPE": "text/plain",
"HTTP_X_REQUEST_ID": "req-123",
}
original_request = Request(environ)
# Serialize and deserialize
raw_data = serialize_request(original_request)
restored_request = deserialize_request(raw_data)
# Verify key properties are preserved
assert restored_request.method == original_request.method
assert restored_request.path == original_request.path
assert restored_request.query_string == original_request.query_string
assert restored_request.get_data() == body
assert restored_request.headers.get("X-Request-Id") == "req-123"
class TestSerializeResponse:
def test_serialize_simple_response(self):
response = Response("Hello, World!", status=200)
raw_data = serialize_response(response)
assert raw_data.startswith(b"HTTP/1.1 200 OK\r\n")
assert b"\r\n\r\n" in raw_data
assert raw_data.endswith(b"Hello, World!")
def test_serialize_response_with_headers(self):
response = Response(
'{"status": "success"}',
status=201,
headers={
"Content-Type": "application/json",
"X-Request-Id": "req-456",
},
)
raw_data = serialize_response(response)
assert b"HTTP/1.1 201 CREATED\r\n" in raw_data
assert b"Content-Type: application/json" in raw_data
assert b"X-Request-Id: req-456" in raw_data
assert raw_data.endswith(b'{"status": "success"}')
def test_serialize_error_response(self):
response = Response(
"Not Found",
status=404,
headers={"Content-Type": "text/plain"},
)
raw_data = serialize_response(response)
assert b"HTTP/1.1 404 NOT FOUND\r\n" in raw_data
assert b"Content-Type: text/plain" in raw_data
assert raw_data.endswith(b"Not Found")
def test_serialize_response_without_body(self):
response = Response(status=204) # No Content
raw_data = serialize_response(response)
assert b"HTTP/1.1 204 NO CONTENT\r\n" in raw_data
assert raw_data.endswith(b"\r\n\r\n") # Should end with empty line
def test_serialize_response_with_binary_body(self):
binary_data = b"\x00\x01\x02\x03\x04\x05"
response = Response(
binary_data,
status=200,
headers={"Content-Type": "application/octet-stream"},
)
raw_data = serialize_response(response)
assert b"HTTP/1.1 200 OK\r\n" in raw_data
assert b"Content-Type: application/octet-stream" in raw_data
assert raw_data.endswith(binary_data)
class TestDeserializeResponse:
def test_deserialize_simple_response(self):
raw_data = b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\nHello, World!"
response = deserialize_response(raw_data)
assert response.status_code == 200
assert response.get_data() == b"Hello, World!"
assert response.headers.get("Content-Type") == "text/plain"
def test_deserialize_response_with_json(self):
body = b'{"result": "success", "data": [1, 2, 3]}'
raw_data = (
b"HTTP/1.1 201 Created\r\n"
b"Content-Type: application/json\r\n"
b"Content-Length: " + str(len(body)).encode() + b"\r\n"
b"X-Custom-Header: test-value\r\n"
b"\r\n" + body
)
response = deserialize_response(raw_data)
assert response.status_code == 201
assert response.get_data() == body
assert response.headers.get("Content-Type") == "application/json"
assert response.headers.get("X-Custom-Header") == "test-value"
def test_deserialize_error_response(self):
raw_data = b"HTTP/1.1 404 Not Found\r\nContent-Type: text/html\r\n\r\n<html><body>Page not found</body></html>"
response = deserialize_response(raw_data)
assert response.status_code == 404
assert response.get_data() == b"<html><body>Page not found</body></html>"
def test_deserialize_response_without_body(self):
raw_data = b"HTTP/1.1 204 No Content\r\n\r\n"
response = deserialize_response(raw_data)
assert response.status_code == 204
assert response.get_data() == b""
def test_deserialize_response_with_multiline_body(self):
body = b"Line 1\r\nLine 2\r\nLine 3"
raw_data = b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\n" + body
response = deserialize_response(raw_data)
assert response.status_code == 200
assert response.get_data() == body
def test_deserialize_response_minimal_status_line(self):
# Test with minimal status line (no status text)
raw_data = b"HTTP/1.1 200\r\n\r\nOK"
response = deserialize_response(raw_data)
assert response.status_code == 200
assert response.get_data() == b"OK"
def test_deserialize_invalid_status_line(self):
raw_data = b"INVALID\r\n\r\n"
with pytest.raises(ValueError, match="Invalid status line"):
deserialize_response(raw_data)
def test_roundtrip_response(self):
# Test that serialize -> deserialize produces equivalent response
original_response = Response(
'{"message": "test"}',
status=200,
headers={
"Content-Type": "application/json",
"X-Request-Id": "abc-123",
"Cache-Control": "no-cache",
},
)
# Serialize and deserialize
raw_data = serialize_response(original_response)
restored_response = deserialize_response(raw_data)
# Verify key properties are preserved
assert restored_response.status_code == original_response.status_code
assert restored_response.get_data() == original_response.get_data()
assert restored_response.headers.get("Content-Type") == "application/json"
assert restored_response.headers.get("X-Request-Id") == "abc-123"
assert restored_response.headers.get("Cache-Control") == "no-cache"
class TestEdgeCases:
def test_request_with_empty_headers(self):
raw_data = b"GET / HTTP/1.1\r\n\r\n"
request = deserialize_request(raw_data)
assert request.method == "GET"
assert request.path == "/"
def test_response_with_empty_headers(self):
raw_data = b"HTTP/1.1 200 OK\r\n\r\nSuccess"
response = deserialize_response(raw_data)
assert response.status_code == 200
assert response.get_data() == b"Success"
def test_request_with_special_characters_in_path(self):
raw_data = b"GET /api/test%20path?key=%26value HTTP/1.1\r\n\r\n"
request = deserialize_request(raw_data)
assert request.method == "GET"
assert "/api/test%20path" in request.full_path
def test_response_with_binary_content(self):
binary_body = bytes(range(256)) # All possible byte values
raw_data = b"HTTP/1.1 200 OK\r\nContent-Type: application/octet-stream\r\n\r\n" + binary_body
response = deserialize_response(raw_data)
assert response.status_code == 200
assert response.get_data() == binary_body
class TestFileUploads:
def test_serialize_request_with_text_file_upload(self):
# Test multipart/form-data request with text file
from io import BytesIO
boundary = "----WebKitFormBoundary7MA4YWxkTrZu0gW"
text_content = "Hello, this is a test file content!\nWith multiple lines."
body = (
f"------{boundary}\r\n"
f'Content-Disposition: form-data; name="file"; filename="test.txt"\r\n'
f"Content-Type: text/plain\r\n"
f"\r\n"
f"{text_content}\r\n"
f"------{boundary}\r\n"
f'Content-Disposition: form-data; name="description"\r\n'
f"\r\n"
f"Test file upload\r\n"
f"------{boundary}--\r\n"
).encode()
environ = {
"REQUEST_METHOD": "POST",
"PATH_INFO": "/api/upload",
"QUERY_STRING": "",
"SERVER_NAME": "localhost",
"SERVER_PORT": "8000",
"wsgi.input": BytesIO(body),
"wsgi.url_scheme": "http",
"CONTENT_LENGTH": str(len(body)),
"CONTENT_TYPE": f"multipart/form-data; boundary={boundary}",
"HTTP_CONTENT_TYPE": f"multipart/form-data; boundary={boundary}",
}
request = Request(environ)
raw_data = serialize_request(request)
assert b"POST /api/upload HTTP/1.1\r\n" in raw_data
assert f"Content-Type: multipart/form-data; boundary={boundary}".encode() in raw_data
assert b"Content-Disposition: form-data; name=\"file\"; filename=\"test.txt\"" in raw_data
assert text_content.encode() in raw_data
def test_deserialize_request_with_text_file_upload(self):
# Test deserializing multipart/form-data request with text file
boundary = "----WebKitFormBoundary7MA4YWxkTrZu0gW"
text_content = "Sample text file content\nLine 2\nLine 3"
body = (
f"------{boundary}\r\n"
f'Content-Disposition: form-data; name="document"; filename="document.txt"\r\n'
f"Content-Type: text/plain\r\n"
f"\r\n"
f"{text_content}\r\n"
f"------{boundary}\r\n"
f'Content-Disposition: form-data; name="title"\r\n'
f"\r\n"
f"My Document\r\n"
f"------{boundary}--\r\n"
).encode()
raw_data = (
b"POST /api/documents HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Content-Type: multipart/form-data; boundary=" + boundary.encode() + b"\r\n"
b"Content-Length: " + str(len(body)).encode() + b"\r\n"
b"\r\n" + body
)
request = deserialize_request(raw_data)
assert request.method == "POST"
assert request.path == "/api/documents"
assert "multipart/form-data" in request.content_type
# The body should contain the multipart data
request_body = request.get_data()
assert b"document.txt" in request_body
assert text_content.encode() in request_body
def test_serialize_request_with_binary_file_upload(self):
# Test multipart/form-data request with binary file (e.g., image)
from io import BytesIO
boundary = "----BoundaryString123"
# Simulate a small PNG file header
binary_content = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x10\x00\x00\x00\x10"
# Build multipart body
body_parts = []
body_parts.append(f"------{boundary}".encode())
body_parts.append(b'Content-Disposition: form-data; name="image"; filename="test.png"')
body_parts.append(b"Content-Type: image/png")
body_parts.append(b"")
body_parts.append(binary_content)
body_parts.append(f"------{boundary}".encode())
body_parts.append(b'Content-Disposition: form-data; name="caption"')
body_parts.append(b"")
body_parts.append(b"Test image")
body_parts.append(f"------{boundary}--".encode())
body = b"\r\n".join(body_parts)
environ = {
"REQUEST_METHOD": "POST",
"PATH_INFO": "/api/images",
"QUERY_STRING": "",
"SERVER_NAME": "localhost",
"SERVER_PORT": "8000",
"wsgi.input": BytesIO(body),
"wsgi.url_scheme": "http",
"CONTENT_LENGTH": str(len(body)),
"CONTENT_TYPE": f"multipart/form-data; boundary={boundary}",
"HTTP_CONTENT_TYPE": f"multipart/form-data; boundary={boundary}",
}
request = Request(environ)
raw_data = serialize_request(request)
assert b"POST /api/images HTTP/1.1\r\n" in raw_data
assert f"Content-Type: multipart/form-data; boundary={boundary}".encode() in raw_data
assert b"filename=\"test.png\"" in raw_data
assert b"Content-Type: image/png" in raw_data
assert binary_content in raw_data
def test_deserialize_request_with_binary_file_upload(self):
# Test deserializing multipart/form-data request with binary file
boundary = "----BoundaryABC123"
# Simulate a small JPEG file header
binary_content = b"\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00"
body_parts = []
body_parts.append(f"------{boundary}".encode())
body_parts.append(b'Content-Disposition: form-data; name="photo"; filename="photo.jpg"')
body_parts.append(b"Content-Type: image/jpeg")
body_parts.append(b"")
body_parts.append(binary_content)
body_parts.append(f"------{boundary}".encode())
body_parts.append(b'Content-Disposition: form-data; name="album"')
body_parts.append(b"")
body_parts.append(b"Vacation 2024")
body_parts.append(f"------{boundary}--".encode())
body = b"\r\n".join(body_parts)
raw_data = (
b"POST /api/photos HTTP/1.1\r\n"
b"Host: api.example.com\r\n"
b"Content-Type: multipart/form-data; boundary=" + boundary.encode() + b"\r\n"
b"Content-Length: " + str(len(body)).encode() + b"\r\n"
b"Accept: application/json\r\n"
b"\r\n" + body
)
request = deserialize_request(raw_data)
assert request.method == "POST"
assert request.path == "/api/photos"
assert "multipart/form-data" in request.content_type
assert request.headers.get("Accept") == "application/json"
# Verify the binary content is preserved
request_body = request.get_data()
assert b"photo.jpg" in request_body
assert b"image/jpeg" in request_body
assert binary_content in request_body
assert b"Vacation 2024" in request_body
def test_serialize_request_with_multiple_files(self):
# Test request with multiple file uploads
from io import BytesIO
boundary = "----MultiFilesBoundary"
text_file = b"Text file contents"
binary_file = b"\x00\x01\x02\x03\x04\x05"
body_parts = []
# First file (text)
body_parts.append(f"------{boundary}".encode())
body_parts.append(b'Content-Disposition: form-data; name="files"; filename="doc.txt"')
body_parts.append(b"Content-Type: text/plain")
body_parts.append(b"")
body_parts.append(text_file)
# Second file (binary)
body_parts.append(f"------{boundary}".encode())
body_parts.append(b'Content-Disposition: form-data; name="files"; filename="data.bin"')
body_parts.append(b"Content-Type: application/octet-stream")
body_parts.append(b"")
body_parts.append(binary_file)
# Additional form field
body_parts.append(f"------{boundary}".encode())
body_parts.append(b'Content-Disposition: form-data; name="folder"')
body_parts.append(b"")
body_parts.append(b"uploads/2024")
body_parts.append(f"------{boundary}--".encode())
body = b"\r\n".join(body_parts)
environ = {
"REQUEST_METHOD": "POST",
"PATH_INFO": "/api/batch-upload",
"QUERY_STRING": "",
"SERVER_NAME": "localhost",
"SERVER_PORT": "8000",
"wsgi.input": BytesIO(body),
"wsgi.url_scheme": "https",
"CONTENT_LENGTH": str(len(body)),
"CONTENT_TYPE": f"multipart/form-data; boundary={boundary}",
"HTTP_CONTENT_TYPE": f"multipart/form-data; boundary={boundary}",
"HTTP_X_FORWARDED_PROTO": "https",
}
request = Request(environ)
raw_data = serialize_request(request)
assert b"POST /api/batch-upload HTTP/1.1\r\n" in raw_data
assert b"doc.txt" in raw_data
assert b"data.bin" in raw_data
assert text_file in raw_data
assert binary_file in raw_data
assert b"uploads/2024" in raw_data
def test_roundtrip_file_upload_request(self):
# Test that file upload request survives serialize -> deserialize
from io import BytesIO
boundary = "----RoundTripBoundary"
file_content = b"This is my file content with special chars: \xf0\x9f\x98\x80"
body_parts = []
body_parts.append(f"------{boundary}".encode())
body_parts.append(b'Content-Disposition: form-data; name="upload"; filename="emoji.txt"')
body_parts.append(b"Content-Type: text/plain; charset=utf-8")
body_parts.append(b"")
body_parts.append(file_content)
body_parts.append(f"------{boundary}".encode())
body_parts.append(b'Content-Disposition: form-data; name="metadata"')
body_parts.append(b"")
body_parts.append(b'{"encoding": "utf-8", "size": 42}')
body_parts.append(f"------{boundary}--".encode())
body = b"\r\n".join(body_parts)
environ = {
"REQUEST_METHOD": "PUT",
"PATH_INFO": "/api/files/123",
"QUERY_STRING": "version=2",
"SERVER_NAME": "storage.example.com",
"SERVER_PORT": "443",
"wsgi.input": BytesIO(body),
"wsgi.url_scheme": "https",
"CONTENT_LENGTH": str(len(body)),
"CONTENT_TYPE": f"multipart/form-data; boundary={boundary}",
"HTTP_CONTENT_TYPE": f"multipart/form-data; boundary={boundary}",
"HTTP_AUTHORIZATION": "Bearer token123",
"HTTP_X_FORWARDED_PROTO": "https",
}
original_request = Request(environ)
# Serialize and deserialize
raw_data = serialize_request(original_request)
restored_request = deserialize_request(raw_data)
# Verify the request is preserved
assert restored_request.method == "PUT"
assert restored_request.path == "/api/files/123"
assert restored_request.query_string == b"version=2"
assert "multipart/form-data" in restored_request.content_type
assert boundary in restored_request.content_type
# Verify file content is preserved
restored_body = restored_request.get_data()
assert b"emoji.txt" in restored_body
assert file_content in restored_body
assert b'{"encoding": "utf-8", "size": 42}' in restored_body